upgini 1.2.81a3832.dev1__py3-none-any.whl → 1.2.81a3832.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/features_enricher.py +20 -10
- upgini/http.py +21 -21
- upgini/mdc/__init__.py +1 -1
- upgini/metrics.py +68 -38
- {upgini-1.2.81a3832.dev1.dist-info → upgini-1.2.81a3832.dev3.dist-info}/METADATA +2 -1
- {upgini-1.2.81a3832.dev1.dist-info → upgini-1.2.81a3832.dev3.dist-info}/RECORD +9 -9
- {upgini-1.2.81a3832.dev1.dist-info → upgini-1.2.81a3832.dev3.dist-info}/WHEEL +0 -0
- {upgini-1.2.81a3832.dev1.dist-info → upgini-1.2.81a3832.dev3.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.81a3832.
|
1
|
+
__version__ = "1.2.81a3832.dev3"
|
upgini/features_enricher.py
CHANGED
@@ -310,6 +310,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
310
310
|
self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
|
311
311
|
file_metadata = self._search_task.get_file_metadata(trace_id)
|
312
312
|
x_columns = [c.originalName or c.name for c in file_metadata.columns]
|
313
|
+
self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
|
313
314
|
df = pd.DataFrame(columns=x_columns)
|
314
315
|
self.__prepare_feature_importances(trace_id, df, silent=True)
|
315
316
|
# TODO validate search_keys with search_keys from file_metadata
|
@@ -476,7 +477,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
476
477
|
self.__validate_search_keys(self.search_keys)
|
477
478
|
|
478
479
|
# Validate client estimator params
|
479
|
-
self.
|
480
|
+
self._get_and_validate_client_cat_features(estimator, X, self.search_keys)
|
480
481
|
|
481
482
|
try:
|
482
483
|
self.X = X
|
@@ -957,9 +958,17 @@ class FeaturesEnricher(TransformerMixin):
|
|
957
958
|
self.__display_support_link(msg)
|
958
959
|
return None
|
959
960
|
|
960
|
-
|
961
|
+
cat_features_from_backend = self.__get_categorical_features()
|
962
|
+
client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
|
961
963
|
estimator, validated_X, self.search_keys
|
962
964
|
)
|
965
|
+
for cat_feature in cat_features_from_backend:
|
966
|
+
original_cat_feature = self.fit_columns_renaming.get(cat_feature)
|
967
|
+
if original_cat_feature in self.search_keys:
|
968
|
+
if self.search_keys[original_cat_feature] in [SearchKey.COUNTRY, SearchKey.POSTAL_CODE]:
|
969
|
+
search_keys_for_metrics.append(original_cat_feature)
|
970
|
+
else:
|
971
|
+
self.logger.warning(self.bundle.get("cat_feature_search_key").format(original_cat_feature))
|
963
972
|
search_keys_for_metrics.extend([c for c in self.id_columns or [] if c not in search_keys_for_metrics])
|
964
973
|
self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
|
965
974
|
|
@@ -976,7 +985,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
976
985
|
search_keys_for_metrics=search_keys_for_metrics,
|
977
986
|
progress_bar=progress_bar,
|
978
987
|
progress_callback=progress_callback,
|
979
|
-
|
988
|
+
client_cat_features=client_cat_features,
|
980
989
|
)
|
981
990
|
if prepared_data is None:
|
982
991
|
return None
|
@@ -1027,7 +1036,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
1027
1036
|
|
1028
1037
|
has_date = self._get_date_column(search_keys) is not None
|
1029
1038
|
model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
|
1030
|
-
cat_features_from_backend = self.__get_categorical_features()
|
1031
1039
|
cat_features = list(set(client_cat_features + cat_features_from_backend))
|
1032
1040
|
baseline_cat_features = [f for f in cat_features if f in fitting_X.columns]
|
1033
1041
|
enriched_cat_features = [f for f in cat_features if f in fitting_enriched_X.columns]
|
@@ -1423,7 +1431,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1423
1431
|
|
1424
1432
|
return _cv, groups
|
1425
1433
|
|
1426
|
-
def
|
1434
|
+
def _get_and_validate_client_cat_features(
|
1427
1435
|
self, estimator: Optional[Any], X: pd.DataFrame, search_keys: Dict[str, SearchKey]
|
1428
1436
|
) -> Tuple[Optional[List[str]], List[str]]:
|
1429
1437
|
cat_features = None
|
@@ -1468,7 +1476,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1468
1476
|
search_keys_for_metrics: Optional[List[str]] = None,
|
1469
1477
|
progress_bar: Optional[ProgressBar] = None,
|
1470
1478
|
progress_callback: Optional[Callable[[SearchProgress], Any]] = None,
|
1471
|
-
|
1479
|
+
client_cat_features: Optional[List[str]] = None,
|
1472
1480
|
):
|
1473
1481
|
is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
|
1474
1482
|
is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
|
@@ -1542,7 +1550,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1542
1550
|
|
1543
1551
|
# Detect and drop high cardinality columns in train
|
1544
1552
|
columns_with_high_cardinality = FeaturesValidator.find_high_cardinality(fitting_X)
|
1545
|
-
non_excluding_columns = (self.generate_features or []) + (
|
1553
|
+
non_excluding_columns = (self.generate_features or []) + (client_cat_features or [])
|
1546
1554
|
columns_with_high_cardinality = [c for c in columns_with_high_cardinality if c not in non_excluding_columns]
|
1547
1555
|
if len(columns_with_high_cardinality) > 0:
|
1548
1556
|
self.logger.warning(
|
@@ -2080,10 +2088,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
2080
2088
|
search_keys: Dict,
|
2081
2089
|
columns_renaming: Dict[str, str],
|
2082
2090
|
):
|
2091
|
+
# X_sampled - with hash-suffixes
|
2092
|
+
reversed_renaming = {v: k for k, v in columns_renaming.items()}
|
2083
2093
|
search_keys = {
|
2084
|
-
|
2094
|
+
reversed_renaming.get(k, k): v
|
2085
2095
|
for k, v in search_keys.items()
|
2086
|
-
if
|
2096
|
+
if reversed_renaming.get(k, k) in X_sampled.columns.to_list()
|
2087
2097
|
}
|
2088
2098
|
return FeaturesEnricher._SampledDataForMetrics(
|
2089
2099
|
X_sampled=X_sampled,
|
@@ -3871,7 +3881,7 @@ if response.status_code == 200:
|
|
3871
3881
|
if features_meta is None:
|
3872
3882
|
raise Exception(self.bundle.get("missing_features_meta"))
|
3873
3883
|
|
3874
|
-
return [f.name for f in features_meta if f.type == "categorical"]
|
3884
|
+
return [f.name for f in features_meta if f.type == "categorical" and f.shap_value > 0.0]
|
3875
3885
|
|
3876
3886
|
def __prepare_feature_importances(
|
3877
3887
|
self, trace_id: str, df: pd.DataFrame, updated_shaps: Optional[Dict[str, float]] = None, silent=False
|
upgini/http.py
CHANGED
@@ -20,7 +20,7 @@ import jwt
|
|
20
20
|
# import pandas as pd
|
21
21
|
import requests
|
22
22
|
from pydantic import BaseModel
|
23
|
-
from pythonjsonlogger import jsonlogger
|
23
|
+
from pythonjsonlogger import json as jsonlogger
|
24
24
|
from requests.exceptions import RequestException
|
25
25
|
|
26
26
|
from upgini.__about__ import __version__
|
@@ -459,19 +459,19 @@ class _RestClient:
|
|
459
459
|
content = file.read()
|
460
460
|
md5_hash.update(content)
|
461
461
|
digest = md5_hash.hexdigest()
|
462
|
-
metadata_with_md5 = metadata.
|
462
|
+
metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
|
463
463
|
|
464
464
|
# digest_sha256 = hashlib.sha256(
|
465
465
|
# pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
|
466
466
|
# ).hexdigest()
|
467
467
|
digest_sha256 = self.compute_file_digest(file_path)
|
468
|
-
metadata_with_md5 = metadata_with_md5.
|
468
|
+
metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
|
469
469
|
|
470
470
|
with open(file_path, "rb") as file:
|
471
471
|
files = {
|
472
472
|
"metadata": (
|
473
473
|
"metadata.json",
|
474
|
-
metadata_with_md5.
|
474
|
+
metadata_with_md5.model_dump_json(exclude_none=True).encode(),
|
475
475
|
"application/json",
|
476
476
|
),
|
477
477
|
"tracking": (
|
@@ -481,7 +481,7 @@ class _RestClient:
|
|
481
481
|
),
|
482
482
|
"metrics": (
|
483
483
|
"metrics.json",
|
484
|
-
metrics.
|
484
|
+
metrics.model_dump_json(exclude_none=True).encode(),
|
485
485
|
"application/json",
|
486
486
|
),
|
487
487
|
"file": (metadata_with_md5.name, file, "application/octet-stream"),
|
@@ -489,7 +489,7 @@ class _RestClient:
|
|
489
489
|
if search_customization is not None:
|
490
490
|
files["customization"] = (
|
491
491
|
"customization.json",
|
492
|
-
search_customization.
|
492
|
+
search_customization.model_dump_json(exclude_none=True).encode(),
|
493
493
|
"application/json",
|
494
494
|
)
|
495
495
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
@@ -504,7 +504,7 @@ class _RestClient:
|
|
504
504
|
def check_uploaded_file_v2(self, trace_id: str, file_upload_id: str, metadata: FileMetadata) -> bool:
|
505
505
|
api_path = self.CHECK_UPLOADED_FILE_URL_FMT_V2.format(file_upload_id)
|
506
506
|
response = self._with_unauth_retry(
|
507
|
-
lambda: self._send_post_req(api_path, trace_id, metadata.
|
507
|
+
lambda: self._send_post_req(api_path, trace_id, metadata.model_dump_json(exclude_none=True))
|
508
508
|
)
|
509
509
|
return bool(response)
|
510
510
|
|
@@ -518,11 +518,11 @@ class _RestClient:
|
|
518
518
|
) -> SearchTaskResponse:
|
519
519
|
api_path = self.INITIAL_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id)
|
520
520
|
files = {
|
521
|
-
"metadata": ("metadata.json", metadata.
|
522
|
-
"metrics": ("metrics.json", metrics.
|
521
|
+
"metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
|
522
|
+
"metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
|
523
523
|
}
|
524
524
|
if search_customization is not None:
|
525
|
-
files["customization"] = search_customization.
|
525
|
+
files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
|
526
526
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
527
527
|
response = self._with_unauth_retry(
|
528
528
|
lambda: self._send_post_file_req_v2(
|
@@ -548,19 +548,19 @@ class _RestClient:
|
|
548
548
|
content = file.read()
|
549
549
|
md5_hash.update(content)
|
550
550
|
digest = md5_hash.hexdigest()
|
551
|
-
metadata_with_md5 = metadata.
|
551
|
+
metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
|
552
552
|
|
553
553
|
# digest_sha256 = hashlib.sha256(
|
554
554
|
# pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
|
555
555
|
# ).hexdigest()
|
556
556
|
digest_sha256 = self.compute_file_digest(file_path)
|
557
|
-
metadata_with_md5 = metadata_with_md5.
|
557
|
+
metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
|
558
558
|
|
559
559
|
with open(file_path, "rb") as file:
|
560
560
|
files = {
|
561
561
|
"metadata": (
|
562
562
|
"metadata.json",
|
563
|
-
metadata_with_md5.
|
563
|
+
metadata_with_md5.model_dump_json(exclude_none=True).encode(),
|
564
564
|
"application/json",
|
565
565
|
),
|
566
566
|
"tracking": (
|
@@ -570,7 +570,7 @@ class _RestClient:
|
|
570
570
|
),
|
571
571
|
"metrics": (
|
572
572
|
"metrics.json",
|
573
|
-
metrics.
|
573
|
+
metrics.model_dump_json(exclude_none=True).encode(),
|
574
574
|
"application/json",
|
575
575
|
),
|
576
576
|
"file": (metadata_with_md5.name, file, "application/octet-stream"),
|
@@ -578,7 +578,7 @@ class _RestClient:
|
|
578
578
|
if search_customization is not None:
|
579
579
|
files["customization"] = (
|
580
580
|
"customization.json",
|
581
|
-
search_customization.
|
581
|
+
search_customization.model_dump_json(exclude_none=True).encode(),
|
582
582
|
"application/json",
|
583
583
|
)
|
584
584
|
|
@@ -602,11 +602,11 @@ class _RestClient:
|
|
602
602
|
) -> SearchTaskResponse:
|
603
603
|
api_path = self.VALIDATION_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id, initial_search_task_id)
|
604
604
|
files = {
|
605
|
-
"metadata": ("metadata.json", metadata.
|
606
|
-
"metrics": ("metrics.json", metrics.
|
605
|
+
"metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
|
606
|
+
"metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
|
607
607
|
}
|
608
608
|
if search_customization is not None:
|
609
|
-
files["customization"] = search_customization.
|
609
|
+
files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
|
610
610
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
611
611
|
response = self._with_unauth_retry(
|
612
612
|
lambda: self._send_post_file_req_v2(
|
@@ -670,7 +670,7 @@ class _RestClient:
|
|
670
670
|
"file": (metadata.name, file, "application/octet-stream"),
|
671
671
|
"metadata": (
|
672
672
|
"metadata.json",
|
673
|
-
metadata.
|
673
|
+
metadata.model_dump_json(exclude_none=True).encode(),
|
674
674
|
"application/json",
|
675
675
|
),
|
676
676
|
}
|
@@ -682,12 +682,12 @@ class _RestClient:
|
|
682
682
|
def get_search_file_metadata(self, search_task_id: str, trace_id: str) -> FileMetadata:
|
683
683
|
api_path = self.SEARCH_FILE_METADATA_URI_FMT_V2.format(search_task_id)
|
684
684
|
response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
|
685
|
-
return FileMetadata.
|
685
|
+
return FileMetadata.model_validate(response)
|
686
686
|
|
687
687
|
def get_provider_search_metadata_v3(self, provider_search_task_id: str, trace_id: str) -> ProviderTaskMetadataV2:
|
688
688
|
api_path = self.SEARCH_TASK_METADATA_FMT_V3.format(provider_search_task_id)
|
689
689
|
response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
|
690
|
-
return ProviderTaskMetadataV2.
|
690
|
+
return ProviderTaskMetadataV2.model_validate(response)
|
691
691
|
|
692
692
|
def get_current_transform_usage(self, trace_id) -> TransformUsage:
|
693
693
|
track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
|
upgini/mdc/__init__.py
CHANGED
upgini/metrics.py
CHANGED
@@ -11,13 +11,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
11
11
|
import lightgbm as lgb
|
12
12
|
import numpy as np
|
13
13
|
import pandas as pd
|
14
|
+
from catboost import CatBoostClassifier, CatBoostRegressor
|
14
15
|
from category_encoders.cat_boost import CatBoostEncoder
|
15
16
|
from lightgbm import LGBMClassifier, LGBMRegressor
|
16
17
|
from numpy import log1p
|
17
18
|
from pandas.api.types import is_numeric_dtype
|
18
19
|
from sklearn.metrics import check_scoring, get_scorer, make_scorer, roc_auc_score
|
19
20
|
|
20
|
-
from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
21
|
+
# from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
21
22
|
from upgini.utils.features_validator import FeaturesValidator
|
22
23
|
from upgini.utils.sklearn_ext import cross_validate
|
23
24
|
|
@@ -31,7 +32,7 @@ except ImportError:
|
|
31
32
|
available_scorers = SCORERS
|
32
33
|
from sklearn.metrics import mean_squared_error
|
33
34
|
from sklearn.metrics._regression import _check_reg_targets, check_consistent_length
|
34
|
-
from sklearn.model_selection import BaseCrossValidator, TimeSeriesSplit
|
35
|
+
from sklearn.model_selection import BaseCrossValidator # , TimeSeriesSplit
|
35
36
|
|
36
37
|
from upgini.errors import ValidationError
|
37
38
|
from upgini.metadata import ModelTaskType
|
@@ -328,10 +329,14 @@ class EstimatorWrapper:
|
|
328
329
|
) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
|
329
330
|
self.logger.info(f"Before preparing data columns: {x.columns.to_list()}")
|
330
331
|
for c in x.columns:
|
331
|
-
if
|
332
|
-
|
333
|
-
|
334
|
-
|
332
|
+
if c not in self.cat_features:
|
333
|
+
if is_numeric_dtype(x[c]):
|
334
|
+
x[c] = x[c].astype(float)
|
335
|
+
elif not x[c].dtype == "category":
|
336
|
+
x[c] = x[c].astype(str)
|
337
|
+
else:
|
338
|
+
if x[c].dtype == "category" and x[c].cat.categories.dtype == np.int64:
|
339
|
+
x[c] = x[c].astype(np.int64)
|
335
340
|
|
336
341
|
if not isinstance(y, pd.Series):
|
337
342
|
raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
|
@@ -411,7 +416,6 @@ class EstimatorWrapper:
|
|
411
416
|
shaps = self.calculate_shap(cv_x, cv_y, estimator)
|
412
417
|
if shaps is not None:
|
413
418
|
for feature, shap_value in shaps.items():
|
414
|
-
# shap_values_all_folds[feature] = shap_values_all_folds.get(feature, []) + shap_value.tolist()
|
415
419
|
shap_values_all_folds[feature].append(shap_value)
|
416
420
|
|
417
421
|
if shap_values_all_folds:
|
@@ -488,20 +492,29 @@ class EstimatorWrapper:
|
|
488
492
|
"logger": logger,
|
489
493
|
}
|
490
494
|
if estimator is None:
|
491
|
-
params = {"
|
495
|
+
params = {"has_time": has_date}
|
492
496
|
if target_type == ModelTaskType.MULTICLASS:
|
493
|
-
params = _get_add_params(params,
|
497
|
+
params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
|
494
498
|
params = _get_add_params(params, add_params)
|
495
|
-
estimator =
|
499
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
500
|
+
# params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
|
501
|
+
# params = _get_add_params(params, add_params)
|
502
|
+
# estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
496
503
|
elif target_type == ModelTaskType.BINARY:
|
497
|
-
params = _get_add_params(params,
|
504
|
+
params = _get_add_params(params, CATBOOST_BINARY_PARAMS)
|
498
505
|
params = _get_add_params(params, add_params)
|
499
|
-
estimator =
|
506
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
507
|
+
# params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
|
508
|
+
# params = _get_add_params(params, add_params)
|
509
|
+
# estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
500
510
|
elif target_type == ModelTaskType.REGRESSION:
|
501
|
-
|
502
|
-
params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
|
511
|
+
params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
|
503
512
|
params = _get_add_params(params, add_params)
|
504
|
-
estimator =
|
513
|
+
estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
|
514
|
+
# if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
|
515
|
+
# params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
|
516
|
+
# params = _get_add_params(params, add_params)
|
517
|
+
# estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
|
505
518
|
else:
|
506
519
|
raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
|
507
520
|
else:
|
@@ -517,8 +530,6 @@ class EstimatorWrapper:
|
|
517
530
|
else:
|
518
531
|
if isinstance(estimator, (LGBMClassifier, LGBMRegressor)):
|
519
532
|
estimator = LightGBMWrapper(**kwargs)
|
520
|
-
elif is_catboost_estimator(estimator):
|
521
|
-
estimator = CatBoostWrapper(**kwargs)
|
522
533
|
else:
|
523
534
|
logger.warning(
|
524
535
|
f"Unexpected estimator is used for metrics: {estimator}. "
|
@@ -558,6 +569,7 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
558
569
|
self.emb_features = None
|
559
570
|
self.grouped_embedding_features = None
|
560
571
|
self.drop_cat_features = []
|
572
|
+
self.features_to_encode = []
|
561
573
|
|
562
574
|
def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
|
563
575
|
x, y, groups, params = super()._prepare_to_fit(x, y)
|
@@ -597,7 +609,13 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
597
609
|
self.cat_features, self.features_to_encode, self.exclude_features = _get_cat_features(
|
598
610
|
self.logger, x, self.cat_features, self.text_features, self.grouped_embedding_features
|
599
611
|
)
|
600
|
-
|
612
|
+
if self.features_to_encode:
|
613
|
+
for c in self.features_to_encode:
|
614
|
+
if is_numeric_dtype(x[c]):
|
615
|
+
x[c] = x[c].fillna(np.nan)
|
616
|
+
else:
|
617
|
+
x[c] = x[c].fillna("NA")
|
618
|
+
params["cat_features"] = self.features_to_encode
|
601
619
|
|
602
620
|
return x, y, groups, params
|
603
621
|
|
@@ -626,8 +644,14 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
626
644
|
if self.grouped_embedding_features:
|
627
645
|
x, emb_columns = self.group_embeddings(x)
|
628
646
|
params["embedding_features"] = emb_columns
|
629
|
-
|
630
|
-
|
647
|
+
|
648
|
+
if self.features_to_encode:
|
649
|
+
for c in self.features_to_encode:
|
650
|
+
if is_numeric_dtype(x[c]):
|
651
|
+
x[c] = x[c].fillna(np.nan)
|
652
|
+
else:
|
653
|
+
x[c] = x[c].fillna("NA")
|
654
|
+
params["cat_features"] = self.features_to_encode
|
631
655
|
|
632
656
|
return x, y, params
|
633
657
|
|
@@ -671,23 +695,29 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
671
695
|
embedding_features=self.grouped_embedding_features,
|
672
696
|
)
|
673
697
|
|
674
|
-
|
675
|
-
shap_values_fold = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
|
698
|
+
shap_values = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
|
676
699
|
|
677
|
-
# Remove last columns (base value) and flatten
|
678
700
|
if self.target_type == ModelTaskType.MULTICLASS:
|
679
|
-
|
680
|
-
|
701
|
+
# For multiclass, shap_values has shape (n_samples, n_classes, n_features + 1)
|
702
|
+
# Last column is bias term
|
703
|
+
shap_values = shap_values[:, :, :-1] # Remove bias term
|
704
|
+
# Average SHAP values across classes
|
705
|
+
shap_values = np.mean(np.abs(shap_values), axis=1)
|
681
706
|
else:
|
682
|
-
|
683
|
-
|
707
|
+
# For binary/regression, shap_values has shape (n_samples, n_features + 1)
|
708
|
+
# Last column is bias term
|
709
|
+
shap_values = shap_values[:, :-1] # Remove bias term
|
710
|
+
# Take absolute values
|
711
|
+
shap_values = np.abs(shap_values)
|
684
712
|
|
685
|
-
|
713
|
+
feature_importance = {}
|
714
|
+
for i, col in enumerate(x.columns):
|
715
|
+
feature_importance[col] = np.mean(np.abs(shap_values[:, i]))
|
686
716
|
|
687
|
-
return
|
717
|
+
return feature_importance
|
688
718
|
|
689
|
-
except Exception:
|
690
|
-
self.logger.exception("Failed to recalculate new SHAP values")
|
719
|
+
except Exception as e:
|
720
|
+
self.logger.exception(f"Failed to recalculate new SHAP values: {str(e)}")
|
691
721
|
return None
|
692
722
|
|
693
723
|
|
@@ -830,9 +860,9 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
830
860
|
num_features = [col for col in x.columns if col not in self.cat_features]
|
831
861
|
x[num_features] = x[num_features].fillna(-999)
|
832
862
|
if self.features_to_encode and self.cat_encoder is not None:
|
833
|
-
x[self.features_to_encode] = self.cat_encoder.transform(
|
834
|
-
"
|
835
|
-
)
|
863
|
+
x[self.features_to_encode] = self.cat_encoder.transform(
|
864
|
+
x[self.features_to_encode].astype("object")
|
865
|
+
).astype("category")
|
836
866
|
return x, y, params
|
837
867
|
|
838
868
|
|
@@ -935,17 +965,17 @@ def _get_cat_features(
|
|
935
965
|
drop_cat_features = []
|
936
966
|
for name in cat_features:
|
937
967
|
# Remove constant categorical features
|
938
|
-
if x[name].nunique() > 1:
|
968
|
+
if x[name].nunique(dropna=False) > 1:
|
939
969
|
unique_cat_features.append(name)
|
940
970
|
else:
|
941
|
-
logger.
|
942
|
-
x
|
971
|
+
logger.warning(f"Drop column {name} on preparing data for fit")
|
972
|
+
x.drop(columns=name, inplace=True)
|
943
973
|
drop_cat_features.append(name)
|
944
974
|
cat_features = unique_cat_features
|
945
975
|
|
946
976
|
logger.info(f"Selected categorical features: {cat_features}")
|
947
977
|
|
948
|
-
features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype]).columns))
|
978
|
+
features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype()]).columns))
|
949
979
|
|
950
980
|
logger.info(f"Features to encode: {features_to_encode}")
|
951
981
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: upgini
|
3
|
-
Version: 1.2.81a3832.
|
3
|
+
Version: 1.2.81a3832.dev3
|
4
4
|
Summary: Intelligent data search & enrichment for Machine Learning
|
5
5
|
Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
|
6
6
|
Project-URL: Homepage, https://upgini.com/
|
@@ -22,6 +22,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
22
22
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
23
23
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
24
24
|
Requires-Python: <3.12,>=3.10
|
25
|
+
Requires-Dist: catboost>=1.0.3
|
25
26
|
Requires-Dist: category-encoders>=2.8.1
|
26
27
|
Requires-Dist: fastparquet>=0.8.1
|
27
28
|
Requires-Dist: ipywidgets>=8.1.0
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256
|
1
|
+
upgini/__about__.py,sha256=sQSOnYXU8JfHaCG4spEa8dwpUzrTX39X2sSVYCzITIk,33
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=aspri7ZAgwkNNUiIgQ1GRXvw8XQii3F4RfNXSrF4wrw,35365
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
7
|
-
upgini/http.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=WiSVfmlHI9oKJQbyf46FH0yY80hBJ6hheFpugw0f_vE,210583
|
7
|
+
upgini/http.py,sha256=AfaJ3c8z_tK2hZFEehNybDKE0mp1tYcyAP_l0_p8bLQ,43933
|
8
8
|
upgini/metadata.py,sha256=Yd6iW2f7Wz6vUkg5uvR4xylN16ANnCKVKqAsAkap7p8,12354
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=fhBhMM455C__1adECAk2H3K-zyO_WUnVqZV_AJ-rQBo,39633
|
10
10
|
upgini/search_task.py,sha256=RcvAE785yksWTsTNWuZFVNlk32jHElMoEna1T_C5N8Q,17823
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -32,7 +32,7 @@ upgini/autofe/timeseries/trend.py,sha256=K1_iw2ko_LIUU8YCUgrvN3n0MkHtsi7-63-8x9e
|
|
32
32
|
upgini/autofe/timeseries/volatility.py,sha256=9shUmIKjpWTHVYjj80YBsk0XheBJ9uBuLv5NW9Mchnk,7953
|
33
33
|
upgini/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
34
34
|
upgini/data_source/data_source_publisher.py,sha256=4S9qwlAklD8vg9tUU_c1pHE2_glUHAh15-wr5hMwKFw,22879
|
35
|
-
upgini/mdc/__init__.py,sha256=
|
35
|
+
upgini/mdc/__init__.py,sha256=iHJlXQg6xRM1-ZOUtaPSJqw5SpQDszvxp4LyqviNLIQ,1027
|
36
36
|
upgini/mdc/context.py,sha256=3u1B-jXt7tXEvNcV3qmR9SDCseudnY7KYsLclBdwVLk,1405
|
37
37
|
upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
38
|
upgini/normalizer/normalize_utils.py,sha256=Ft2MwSgVoBilXAORAOYAuwPD79GOLfwn4qQE3IUFzzg,7218
|
@@ -70,7 +70,7 @@ upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,
|
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.81a3832.
|
74
|
-
upgini-1.2.81a3832.
|
75
|
-
upgini-1.2.81a3832.
|
76
|
-
upgini-1.2.81a3832.
|
73
|
+
upgini-1.2.81a3832.dev3.dist-info/METADATA,sha256=rjTrlaR6RTthHUMnhRDn3QFCs9EhW6dDUHukgwnObxI,49172
|
74
|
+
upgini-1.2.81a3832.dev3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.81a3832.dev3.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.81a3832.dev3.dist-info/RECORD,,
|
File without changes
|
File without changes
|