upgini 1.2.80__py3-none-any.whl → 1.2.81a3832.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/features_enricher.py +64 -36
- upgini/http.py +21 -21
- upgini/mdc/__init__.py +1 -1
- upgini/metrics.py +141 -128
- upgini/utils/target_utils.py +9 -6
- {upgini-1.2.80.dist-info → upgini-1.2.81a3832.dev2.dist-info}/METADATA +3 -1
- {upgini-1.2.80.dist-info → upgini-1.2.81a3832.dev2.dist-info}/RECORD +10 -10
- {upgini-1.2.80.dist-info → upgini-1.2.81a3832.dev2.dist-info}/WHEEL +0 -0
- {upgini-1.2.80.dist-info → upgini-1.2.81a3832.dev2.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.
|
1
|
+
__version__ = "1.2.81a3832.dev2"
|
upgini/features_enricher.py
CHANGED
@@ -63,7 +63,7 @@ from upgini.metadata import (
|
|
63
63
|
RuntimeParameters,
|
64
64
|
SearchKey,
|
65
65
|
)
|
66
|
-
from upgini.metrics import EstimatorWrapper, validate_scoring_argument
|
66
|
+
from upgini.metrics import EstimatorWrapper, define_scorer, validate_scoring_argument
|
67
67
|
from upgini.normalizer.normalize_utils import Normalizer
|
68
68
|
from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
|
69
69
|
from upgini.search_task import SearchTask
|
@@ -310,6 +310,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
310
310
|
self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
|
311
311
|
file_metadata = self._search_task.get_file_metadata(trace_id)
|
312
312
|
x_columns = [c.originalName or c.name for c in file_metadata.columns]
|
313
|
+
self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
|
313
314
|
df = pd.DataFrame(columns=x_columns)
|
314
315
|
self.__prepare_feature_importances(trace_id, df, silent=True)
|
315
316
|
# TODO validate search_keys with search_keys from file_metadata
|
@@ -476,7 +477,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
476
477
|
self.__validate_search_keys(self.search_keys)
|
477
478
|
|
478
479
|
# Validate client estimator params
|
479
|
-
self.
|
480
|
+
self._get_and_validate_client_cat_features(estimator, X, self.search_keys)
|
480
481
|
|
481
482
|
try:
|
482
483
|
self.X = X
|
@@ -957,9 +958,17 @@ class FeaturesEnricher(TransformerMixin):
|
|
957
958
|
self.__display_support_link(msg)
|
958
959
|
return None
|
959
960
|
|
960
|
-
|
961
|
+
cat_features_from_backend = self.__get_categorical_features()
|
962
|
+
client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
|
961
963
|
estimator, validated_X, self.search_keys
|
962
964
|
)
|
965
|
+
for cat_feature in cat_features_from_backend:
|
966
|
+
original_cat_feature = self.fit_columns_renaming.get(cat_feature)
|
967
|
+
if original_cat_feature in self.search_keys:
|
968
|
+
if self.search_keys[original_cat_feature] in [SearchKey.COUNTRY, SearchKey.POSTAL_CODE]:
|
969
|
+
search_keys_for_metrics.append(original_cat_feature)
|
970
|
+
else:
|
971
|
+
self.logger.warning(self.bundle.get("cat_feature_search_key").format(original_cat_feature))
|
963
972
|
search_keys_for_metrics.extend([c for c in self.id_columns or [] if c not in search_keys_for_metrics])
|
964
973
|
self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
|
965
974
|
|
@@ -976,7 +985,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
976
985
|
search_keys_for_metrics=search_keys_for_metrics,
|
977
986
|
progress_bar=progress_bar,
|
978
987
|
progress_callback=progress_callback,
|
979
|
-
|
988
|
+
client_cat_features=client_cat_features,
|
980
989
|
)
|
981
990
|
if prepared_data is None:
|
982
991
|
return None
|
@@ -994,11 +1003,19 @@ class FeaturesEnricher(TransformerMixin):
|
|
994
1003
|
) = prepared_data
|
995
1004
|
|
996
1005
|
# rename cat_features
|
997
|
-
if
|
1006
|
+
if client_cat_features:
|
998
1007
|
for new_c, old_c in columns_renaming.items():
|
999
|
-
if old_c in
|
1000
|
-
|
1001
|
-
|
1008
|
+
if old_c in client_cat_features:
|
1009
|
+
client_cat_features.remove(old_c)
|
1010
|
+
client_cat_features.append(new_c)
|
1011
|
+
for cat_feature in client_cat_features:
|
1012
|
+
if cat_feature not in fitting_X.columns:
|
1013
|
+
self.logger.error(
|
1014
|
+
f"Client cat_feature `{cat_feature}` not found in"
|
1015
|
+
f" x columns: {fitting_X.columns.to_list()}"
|
1016
|
+
)
|
1017
|
+
else:
|
1018
|
+
client_cat_features = []
|
1002
1019
|
|
1003
1020
|
gc.collect()
|
1004
1021
|
|
@@ -1019,20 +1036,16 @@ class FeaturesEnricher(TransformerMixin):
|
|
1019
1036
|
|
1020
1037
|
has_date = self._get_date_column(search_keys) is not None
|
1021
1038
|
model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
|
1039
|
+
cat_features = list(set(client_cat_features + cat_features_from_backend))
|
1040
|
+
baseline_cat_features = [f for f in cat_features if f in fitting_X.columns]
|
1041
|
+
enriched_cat_features = [f for f in cat_features if f in fitting_enriched_X.columns]
|
1042
|
+
if len(enriched_cat_features) < len(cat_features):
|
1043
|
+
missing_cat_features = [f for f in cat_features if f not in fitting_enriched_X.columns]
|
1044
|
+
self.logger.warning(
|
1045
|
+
f"Some cat_features were not found in enriched_X: {missing_cat_features}"
|
1046
|
+
)
|
1022
1047
|
|
1023
|
-
|
1024
|
-
estimator,
|
1025
|
-
self.logger,
|
1026
|
-
model_task_type,
|
1027
|
-
_cv,
|
1028
|
-
fitting_enriched_X,
|
1029
|
-
scoring,
|
1030
|
-
groups=groups,
|
1031
|
-
text_features=text_features,
|
1032
|
-
has_date=has_date,
|
1033
|
-
)
|
1034
|
-
metric = wrapper.metric_name
|
1035
|
-
multiplier = wrapper.multiplier
|
1048
|
+
_, metric, multiplier = define_scorer(model_task_type, scoring)
|
1036
1049
|
|
1037
1050
|
# 1 If client features are presented - fit and predict with KFold estimator
|
1038
1051
|
# on etalon features and calculate baseline metric
|
@@ -1050,9 +1063,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
1050
1063
|
self.logger,
|
1051
1064
|
model_task_type,
|
1052
1065
|
_cv,
|
1053
|
-
|
1054
|
-
|
1055
|
-
cat_features,
|
1066
|
+
scoring=scoring,
|
1067
|
+
cat_features=baseline_cat_features,
|
1056
1068
|
add_params=custom_loss_add_params,
|
1057
1069
|
groups=groups,
|
1058
1070
|
text_features=text_features,
|
@@ -1085,9 +1097,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
1085
1097
|
self.logger,
|
1086
1098
|
model_task_type,
|
1087
1099
|
_cv,
|
1088
|
-
|
1089
|
-
|
1090
|
-
cat_features,
|
1100
|
+
scoring=scoring,
|
1101
|
+
cat_features=enriched_cat_features,
|
1091
1102
|
add_params=custom_loss_add_params,
|
1092
1103
|
groups=groups,
|
1093
1104
|
text_features=text_features,
|
@@ -1420,7 +1431,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1420
1431
|
|
1421
1432
|
return _cv, groups
|
1422
1433
|
|
1423
|
-
def
|
1434
|
+
def _get_and_validate_client_cat_features(
|
1424
1435
|
self, estimator: Optional[Any], X: pd.DataFrame, search_keys: Dict[str, SearchKey]
|
1425
1436
|
) -> Tuple[Optional[List[str]], List[str]]:
|
1426
1437
|
cat_features = None
|
@@ -1428,12 +1439,20 @@ class FeaturesEnricher(TransformerMixin):
|
|
1428
1439
|
if (
|
1429
1440
|
estimator is not None
|
1430
1441
|
and hasattr(estimator, "get_param")
|
1442
|
+
and hasattr(estimator, "_init_params")
|
1431
1443
|
and estimator.get_param("cat_features") is not None
|
1432
1444
|
):
|
1433
|
-
|
1434
|
-
if
|
1435
|
-
|
1436
|
-
|
1445
|
+
estimator_cat_features = estimator.get_param("cat_features")
|
1446
|
+
if all([isinstance(c, int) for c in estimator_cat_features]):
|
1447
|
+
cat_features = [X.columns[idx] for idx in estimator_cat_features]
|
1448
|
+
elif all([isinstance(c, str) for c in estimator_cat_features]):
|
1449
|
+
cat_features = estimator_cat_features
|
1450
|
+
else:
|
1451
|
+
print(f"WARNING: Unsupported type of cat_features in CatBoost estimator: {estimator_cat_features}")
|
1452
|
+
|
1453
|
+
del estimator._init_params["cat_features"]
|
1454
|
+
|
1455
|
+
if cat_features:
|
1437
1456
|
self.logger.info(f"Collected categorical features {cat_features} from user estimator")
|
1438
1457
|
for cat_feature in cat_features:
|
1439
1458
|
if cat_feature in search_keys:
|
@@ -1457,7 +1476,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1457
1476
|
search_keys_for_metrics: Optional[List[str]] = None,
|
1458
1477
|
progress_bar: Optional[ProgressBar] = None,
|
1459
1478
|
progress_callback: Optional[Callable[[SearchProgress], Any]] = None,
|
1460
|
-
|
1479
|
+
client_cat_features: Optional[List[str]] = None,
|
1461
1480
|
):
|
1462
1481
|
is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
|
1463
1482
|
is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
|
@@ -1531,7 +1550,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1531
1550
|
|
1532
1551
|
# Detect and drop high cardinality columns in train
|
1533
1552
|
columns_with_high_cardinality = FeaturesValidator.find_high_cardinality(fitting_X)
|
1534
|
-
non_excluding_columns = (self.generate_features or []) + (
|
1553
|
+
non_excluding_columns = (self.generate_features or []) + (client_cat_features or [])
|
1535
1554
|
columns_with_high_cardinality = [c for c in columns_with_high_cardinality if c not in non_excluding_columns]
|
1536
1555
|
if len(columns_with_high_cardinality) > 0:
|
1537
1556
|
self.logger.warning(
|
@@ -2069,10 +2088,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
2069
2088
|
search_keys: Dict,
|
2070
2089
|
columns_renaming: Dict[str, str],
|
2071
2090
|
):
|
2091
|
+
# X_sampled - with hash-suffixes
|
2092
|
+
reversed_renaming = {v: k for k, v in columns_renaming.items()}
|
2072
2093
|
search_keys = {
|
2073
|
-
|
2094
|
+
reversed_renaming.get(k, k): v
|
2074
2095
|
for k, v in search_keys.items()
|
2075
|
-
if
|
2096
|
+
if reversed_renaming.get(k, k) in X_sampled.columns.to_list()
|
2076
2097
|
}
|
2077
2098
|
return FeaturesEnricher._SampledDataForMetrics(
|
2078
2099
|
X_sampled=X_sampled,
|
@@ -3855,6 +3876,13 @@ if response.status_code == 200:
|
|
3855
3876
|
|
3856
3877
|
return importances
|
3857
3878
|
|
3879
|
+
def __get_categorical_features(self) -> List[str]:
|
3880
|
+
features_meta = self._search_task.get_all_features_metadata_v2()
|
3881
|
+
if features_meta is None:
|
3882
|
+
raise Exception(self.bundle.get("missing_features_meta"))
|
3883
|
+
|
3884
|
+
return [f.name for f in features_meta if f.type == "categorical" and f.shap_value > 0.0]
|
3885
|
+
|
3858
3886
|
def __prepare_feature_importances(
|
3859
3887
|
self, trace_id: str, df: pd.DataFrame, updated_shaps: Optional[Dict[str, float]] = None, silent=False
|
3860
3888
|
):
|
upgini/http.py
CHANGED
@@ -20,7 +20,7 @@ import jwt
|
|
20
20
|
# import pandas as pd
|
21
21
|
import requests
|
22
22
|
from pydantic import BaseModel
|
23
|
-
from pythonjsonlogger import jsonlogger
|
23
|
+
from pythonjsonlogger import json as jsonlogger
|
24
24
|
from requests.exceptions import RequestException
|
25
25
|
|
26
26
|
from upgini.__about__ import __version__
|
@@ -459,19 +459,19 @@ class _RestClient:
|
|
459
459
|
content = file.read()
|
460
460
|
md5_hash.update(content)
|
461
461
|
digest = md5_hash.hexdigest()
|
462
|
-
metadata_with_md5 = metadata.
|
462
|
+
metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
|
463
463
|
|
464
464
|
# digest_sha256 = hashlib.sha256(
|
465
465
|
# pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
|
466
466
|
# ).hexdigest()
|
467
467
|
digest_sha256 = self.compute_file_digest(file_path)
|
468
|
-
metadata_with_md5 = metadata_with_md5.
|
468
|
+
metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
|
469
469
|
|
470
470
|
with open(file_path, "rb") as file:
|
471
471
|
files = {
|
472
472
|
"metadata": (
|
473
473
|
"metadata.json",
|
474
|
-
metadata_with_md5.
|
474
|
+
metadata_with_md5.model_dump_json(exclude_none=True).encode(),
|
475
475
|
"application/json",
|
476
476
|
),
|
477
477
|
"tracking": (
|
@@ -481,7 +481,7 @@ class _RestClient:
|
|
481
481
|
),
|
482
482
|
"metrics": (
|
483
483
|
"metrics.json",
|
484
|
-
metrics.
|
484
|
+
metrics.model_dump_json(exclude_none=True).encode(),
|
485
485
|
"application/json",
|
486
486
|
),
|
487
487
|
"file": (metadata_with_md5.name, file, "application/octet-stream"),
|
@@ -489,7 +489,7 @@ class _RestClient:
|
|
489
489
|
if search_customization is not None:
|
490
490
|
files["customization"] = (
|
491
491
|
"customization.json",
|
492
|
-
search_customization.
|
492
|
+
search_customization.model_dump_json(exclude_none=True).encode(),
|
493
493
|
"application/json",
|
494
494
|
)
|
495
495
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
@@ -504,7 +504,7 @@ class _RestClient:
|
|
504
504
|
def check_uploaded_file_v2(self, trace_id: str, file_upload_id: str, metadata: FileMetadata) -> bool:
|
505
505
|
api_path = self.CHECK_UPLOADED_FILE_URL_FMT_V2.format(file_upload_id)
|
506
506
|
response = self._with_unauth_retry(
|
507
|
-
lambda: self._send_post_req(api_path, trace_id, metadata.
|
507
|
+
lambda: self._send_post_req(api_path, trace_id, metadata.model_dump_json(exclude_none=True))
|
508
508
|
)
|
509
509
|
return bool(response)
|
510
510
|
|
@@ -518,11 +518,11 @@ class _RestClient:
|
|
518
518
|
) -> SearchTaskResponse:
|
519
519
|
api_path = self.INITIAL_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id)
|
520
520
|
files = {
|
521
|
-
"metadata": ("metadata.json", metadata.
|
522
|
-
"metrics": ("metrics.json", metrics.
|
521
|
+
"metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
|
522
|
+
"metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
|
523
523
|
}
|
524
524
|
if search_customization is not None:
|
525
|
-
files["customization"] = search_customization.
|
525
|
+
files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
|
526
526
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
527
527
|
response = self._with_unauth_retry(
|
528
528
|
lambda: self._send_post_file_req_v2(
|
@@ -548,19 +548,19 @@ class _RestClient:
|
|
548
548
|
content = file.read()
|
549
549
|
md5_hash.update(content)
|
550
550
|
digest = md5_hash.hexdigest()
|
551
|
-
metadata_with_md5 = metadata.
|
551
|
+
metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
|
552
552
|
|
553
553
|
# digest_sha256 = hashlib.sha256(
|
554
554
|
# pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
|
555
555
|
# ).hexdigest()
|
556
556
|
digest_sha256 = self.compute_file_digest(file_path)
|
557
|
-
metadata_with_md5 = metadata_with_md5.
|
557
|
+
metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
|
558
558
|
|
559
559
|
with open(file_path, "rb") as file:
|
560
560
|
files = {
|
561
561
|
"metadata": (
|
562
562
|
"metadata.json",
|
563
|
-
metadata_with_md5.
|
563
|
+
metadata_with_md5.model_dump_json(exclude_none=True).encode(),
|
564
564
|
"application/json",
|
565
565
|
),
|
566
566
|
"tracking": (
|
@@ -570,7 +570,7 @@ class _RestClient:
|
|
570
570
|
),
|
571
571
|
"metrics": (
|
572
572
|
"metrics.json",
|
573
|
-
metrics.
|
573
|
+
metrics.model_dump_json(exclude_none=True).encode(),
|
574
574
|
"application/json",
|
575
575
|
),
|
576
576
|
"file": (metadata_with_md5.name, file, "application/octet-stream"),
|
@@ -578,7 +578,7 @@ class _RestClient:
|
|
578
578
|
if search_customization is not None:
|
579
579
|
files["customization"] = (
|
580
580
|
"customization.json",
|
581
|
-
search_customization.
|
581
|
+
search_customization.model_dump_json(exclude_none=True).encode(),
|
582
582
|
"application/json",
|
583
583
|
)
|
584
584
|
|
@@ -602,11 +602,11 @@ class _RestClient:
|
|
602
602
|
) -> SearchTaskResponse:
|
603
603
|
api_path = self.VALIDATION_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id, initial_search_task_id)
|
604
604
|
files = {
|
605
|
-
"metadata": ("metadata.json", metadata.
|
606
|
-
"metrics": ("metrics.json", metrics.
|
605
|
+
"metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
|
606
|
+
"metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
|
607
607
|
}
|
608
608
|
if search_customization is not None:
|
609
|
-
files["customization"] = search_customization.
|
609
|
+
files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
|
610
610
|
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
|
611
611
|
response = self._with_unauth_retry(
|
612
612
|
lambda: self._send_post_file_req_v2(
|
@@ -670,7 +670,7 @@ class _RestClient:
|
|
670
670
|
"file": (metadata.name, file, "application/octet-stream"),
|
671
671
|
"metadata": (
|
672
672
|
"metadata.json",
|
673
|
-
metadata.
|
673
|
+
metadata.model_dump_json(exclude_none=True).encode(),
|
674
674
|
"application/json",
|
675
675
|
),
|
676
676
|
}
|
@@ -682,12 +682,12 @@ class _RestClient:
|
|
682
682
|
def get_search_file_metadata(self, search_task_id: str, trace_id: str) -> FileMetadata:
|
683
683
|
api_path = self.SEARCH_FILE_METADATA_URI_FMT_V2.format(search_task_id)
|
684
684
|
response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
|
685
|
-
return FileMetadata.
|
685
|
+
return FileMetadata.model_validate(response)
|
686
686
|
|
687
687
|
def get_provider_search_metadata_v3(self, provider_search_task_id: str, trace_id: str) -> ProviderTaskMetadataV2:
|
688
688
|
api_path = self.SEARCH_TASK_METADATA_FMT_V3.format(provider_search_task_id)
|
689
689
|
response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
|
690
|
-
return ProviderTaskMetadataV2.
|
690
|
+
return ProviderTaskMetadataV2.model_validate(response)
|
691
691
|
|
692
692
|
def get_current_transform_usage(self, trace_id) -> TransformUsage:
|
693
693
|
track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
|
upgini/mdc/__init__.py
CHANGED
upgini/metrics.py
CHANGED
@@ -11,15 +11,16 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
11
11
|
import lightgbm as lgb
|
12
12
|
import numpy as np
|
13
13
|
import pandas as pd
|
14
|
+
from catboost import CatBoostClassifier, CatBoostRegressor
|
15
|
+
from category_encoders.cat_boost import CatBoostEncoder
|
14
16
|
from lightgbm import LGBMClassifier, LGBMRegressor
|
15
17
|
from numpy import log1p
|
16
18
|
from pandas.api.types import is_numeric_dtype
|
17
19
|
from sklearn.metrics import check_scoring, get_scorer, make_scorer, roc_auc_score
|
18
|
-
from sklearn.preprocessing import OrdinalEncoder
|
19
20
|
|
21
|
+
# from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
20
22
|
from upgini.utils.features_validator import FeaturesValidator
|
21
23
|
from upgini.utils.sklearn_ext import cross_validate
|
22
|
-
from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
23
24
|
|
24
25
|
try:
|
25
26
|
from sklearn.metrics import get_scorer_names
|
@@ -31,12 +32,12 @@ except ImportError:
|
|
31
32
|
available_scorers = SCORERS
|
32
33
|
from sklearn.metrics import mean_squared_error
|
33
34
|
from sklearn.metrics._regression import _check_reg_targets, check_consistent_length
|
34
|
-
from sklearn.model_selection import BaseCrossValidator, TimeSeriesSplit
|
35
|
+
from sklearn.model_selection import BaseCrossValidator # , TimeSeriesSplit
|
35
36
|
|
36
37
|
from upgini.errors import ValidationError
|
37
38
|
from upgini.metadata import ModelTaskType
|
38
39
|
from upgini.resource_bundle import bundle
|
39
|
-
from upgini.utils.target_utils import
|
40
|
+
from upgini.utils.target_utils import prepare_target
|
40
41
|
|
41
42
|
DEFAULT_RANDOM_STATE = 42
|
42
43
|
|
@@ -287,6 +288,7 @@ class EstimatorWrapper:
|
|
287
288
|
self,
|
288
289
|
estimator,
|
289
290
|
scorer: Callable,
|
291
|
+
cat_features: Optional[List[str]],
|
290
292
|
metric_name: str,
|
291
293
|
multiplier: int,
|
292
294
|
cv: BaseCrossValidator,
|
@@ -298,9 +300,8 @@ class EstimatorWrapper:
|
|
298
300
|
):
|
299
301
|
self.estimator = estimator
|
300
302
|
self.scorer = scorer
|
301
|
-
self.
|
302
|
-
|
303
|
-
)
|
303
|
+
self.cat_features = cat_features
|
304
|
+
self.metric_name = metric_name
|
304
305
|
self.multiplier = multiplier
|
305
306
|
self.cv = cv
|
306
307
|
self.target_type = target_type
|
@@ -328,10 +329,14 @@ class EstimatorWrapper:
|
|
328
329
|
) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
|
329
330
|
self.logger.info(f"Before preparing data columns: {x.columns.to_list()}")
|
330
331
|
for c in x.columns:
|
331
|
-
if
|
332
|
-
|
333
|
-
|
334
|
-
|
332
|
+
if c not in self.cat_features:
|
333
|
+
if is_numeric_dtype(x[c]):
|
334
|
+
x[c] = x[c].astype(float)
|
335
|
+
elif not x[c].dtype == "category":
|
336
|
+
x[c] = x[c].astype(str)
|
337
|
+
else:
|
338
|
+
if x[c].dtype == "category" and x[c].cat.categories.dtype == np.int64:
|
339
|
+
x[c] = x[c].astype(np.int64)
|
335
340
|
|
336
341
|
if not isinstance(y, pd.Series):
|
337
342
|
raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
|
@@ -345,6 +350,8 @@ class EstimatorWrapper:
|
|
345
350
|
else:
|
346
351
|
x, y = self._remove_empty_target_rows(x, y)
|
347
352
|
|
353
|
+
y = prepare_target(y, self.target_type)
|
354
|
+
|
348
355
|
self.logger.info(f"After preparing data columns: {x.columns.to_list()}")
|
349
356
|
return x, y, groups
|
350
357
|
|
@@ -409,7 +416,6 @@ class EstimatorWrapper:
|
|
409
416
|
shaps = self.calculate_shap(cv_x, cv_y, estimator)
|
410
417
|
if shaps is not None:
|
411
418
|
for feature, shap_value in shaps.items():
|
412
|
-
# shap_values_all_folds[feature] = shap_values_all_folds.get(feature, []) + shap_value.tolist()
|
413
419
|
shap_values_all_folds[feature].append(shap_value)
|
414
420
|
|
415
421
|
if shap_values_all_folds:
|
@@ -465,7 +471,7 @@ class EstimatorWrapper:
|
|
465
471
|
logger: logging.Logger,
|
466
472
|
target_type: ModelTaskType,
|
467
473
|
cv: BaseCrossValidator,
|
468
|
-
|
474
|
+
*,
|
469
475
|
scoring: Union[Callable, str, None] = None,
|
470
476
|
cat_features: Optional[List[str]] = None,
|
471
477
|
text_features: Optional[List[str]] = None,
|
@@ -473,9 +479,10 @@ class EstimatorWrapper:
|
|
473
479
|
groups: Optional[List[str]] = None,
|
474
480
|
has_date: Optional[bool] = None,
|
475
481
|
) -> EstimatorWrapper:
|
476
|
-
scorer, metric_name, multiplier =
|
482
|
+
scorer, metric_name, multiplier = define_scorer(target_type, scoring)
|
477
483
|
kwargs = {
|
478
484
|
"scorer": scorer,
|
485
|
+
"cat_features": cat_features,
|
479
486
|
"metric_name": metric_name,
|
480
487
|
"multiplier": multiplier,
|
481
488
|
"cv": cv,
|
@@ -485,20 +492,29 @@ class EstimatorWrapper:
|
|
485
492
|
"logger": logger,
|
486
493
|
}
|
487
494
|
if estimator is None:
|
488
|
-
params = {"
|
495
|
+
params = {"has_time": has_date}
|
489
496
|
if target_type == ModelTaskType.MULTICLASS:
|
490
|
-
params = _get_add_params(params,
|
497
|
+
params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
|
491
498
|
params = _get_add_params(params, add_params)
|
492
|
-
estimator =
|
499
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
500
|
+
# params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
|
501
|
+
# params = _get_add_params(params, add_params)
|
502
|
+
# estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
493
503
|
elif target_type == ModelTaskType.BINARY:
|
494
|
-
params = _get_add_params(params,
|
504
|
+
params = _get_add_params(params, CATBOOST_BINARY_PARAMS)
|
495
505
|
params = _get_add_params(params, add_params)
|
496
|
-
estimator =
|
506
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
507
|
+
# params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
|
508
|
+
# params = _get_add_params(params, add_params)
|
509
|
+
# estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
497
510
|
elif target_type == ModelTaskType.REGRESSION:
|
498
|
-
|
499
|
-
params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
|
511
|
+
params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
|
500
512
|
params = _get_add_params(params, add_params)
|
501
|
-
estimator =
|
513
|
+
estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
|
514
|
+
# if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
|
515
|
+
# params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
|
516
|
+
# params = _get_add_params(params, add_params)
|
517
|
+
# estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
|
502
518
|
else:
|
503
519
|
raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
|
504
520
|
else:
|
@@ -509,18 +525,11 @@ class EstimatorWrapper:
|
|
509
525
|
kwargs["estimator"] = estimator_copy
|
510
526
|
if is_catboost_estimator(estimator):
|
511
527
|
if cat_features is not None:
|
512
|
-
for cat_feature in cat_features:
|
513
|
-
if cat_feature not in x.columns:
|
514
|
-
logger.error(
|
515
|
-
f"Client cat_feature `{cat_feature}` not found in x columns: {x.columns.to_list()}"
|
516
|
-
)
|
517
528
|
estimator_copy.set_params(cat_features=cat_features, has_time=has_date)
|
518
529
|
estimator = CatBoostWrapper(**kwargs)
|
519
530
|
else:
|
520
531
|
if isinstance(estimator, (LGBMClassifier, LGBMRegressor)):
|
521
532
|
estimator = LightGBMWrapper(**kwargs)
|
522
|
-
elif is_catboost_estimator(estimator):
|
523
|
-
estimator = CatBoostWrapper(**kwargs)
|
524
533
|
else:
|
525
534
|
logger.warning(
|
526
535
|
f"Unexpected estimator is used for metrics: {estimator}. "
|
@@ -536,6 +545,7 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
536
545
|
self,
|
537
546
|
estimator,
|
538
547
|
scorer: Callable,
|
548
|
+
cat_features: Optional[List[str]],
|
539
549
|
metric_name: str,
|
540
550
|
multiplier: int,
|
541
551
|
cv: BaseCrossValidator,
|
@@ -547,6 +557,7 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
547
557
|
super(CatBoostWrapper, self).__init__(
|
548
558
|
estimator,
|
549
559
|
scorer,
|
560
|
+
cat_features,
|
550
561
|
metric_name,
|
551
562
|
multiplier,
|
552
563
|
cv,
|
@@ -555,10 +566,10 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
555
566
|
text_features=text_features,
|
556
567
|
logger=logger,
|
557
568
|
)
|
558
|
-
self.cat_features = None
|
559
569
|
self.emb_features = None
|
560
570
|
self.grouped_embedding_features = None
|
561
|
-
self.
|
571
|
+
self.drop_cat_features = []
|
572
|
+
self.features_to_encode = []
|
562
573
|
|
563
574
|
def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
|
564
575
|
x, y, groups, params = super()._prepare_to_fit(x, y)
|
@@ -595,37 +606,16 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
595
606
|
self.logger.warning(f"Text features are not supported by this Catboost version {catboost.__version__}")
|
596
607
|
|
597
608
|
# Find rest categorical features
|
598
|
-
self.cat_features
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
self.exclude_features.append(name)
|
609
|
-
self.cat_features = unique_cat_features
|
610
|
-
if (
|
611
|
-
hasattr(self.estimator, "get_param")
|
612
|
-
and hasattr(self.estimator, "_init_params")
|
613
|
-
and self.estimator.get_param("cat_features") is not None
|
614
|
-
):
|
615
|
-
estimator_cat_features = self.estimator.get_param("cat_features")
|
616
|
-
if all([isinstance(c, int) for c in estimator_cat_features]):
|
617
|
-
cat_features_idx = {x.columns.get_loc(c) for c in self.cat_features}
|
618
|
-
cat_features_idx.update(estimator_cat_features)
|
619
|
-
self.cat_features = [x.columns[idx] for idx in cat_features_idx]
|
620
|
-
elif all([isinstance(c, str) for c in estimator_cat_features]):
|
621
|
-
self.cat_features = list(set(self.cat_features + estimator_cat_features))
|
622
|
-
else:
|
623
|
-
print(f"WARNING: Unsupported type of cat_features in CatBoost estimator: {estimator_cat_features}")
|
624
|
-
|
625
|
-
del self.estimator._init_params["cat_features"]
|
626
|
-
|
627
|
-
self.logger.info(f"Selected categorical features: {self.cat_features}")
|
628
|
-
params["cat_features"] = self.cat_features
|
609
|
+
self.cat_features, self.features_to_encode, self.exclude_features = _get_cat_features(
|
610
|
+
self.logger, x, self.cat_features, self.text_features, self.grouped_embedding_features
|
611
|
+
)
|
612
|
+
if self.features_to_encode:
|
613
|
+
for c in self.features_to_encode:
|
614
|
+
if is_numeric_dtype(x[c]):
|
615
|
+
x[c] = x[c].fillna(np.nan)
|
616
|
+
else:
|
617
|
+
x[c] = x[c].fillna("NA")
|
618
|
+
params["cat_features"] = self.features_to_encode
|
629
619
|
|
630
620
|
return x, y, groups, params
|
631
621
|
|
@@ -654,9 +644,14 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
654
644
|
if self.grouped_embedding_features:
|
655
645
|
x, emb_columns = self.group_embeddings(x)
|
656
646
|
params["embedding_features"] = emb_columns
|
657
|
-
|
658
|
-
|
659
|
-
|
647
|
+
|
648
|
+
if self.features_to_encode:
|
649
|
+
for c in self.features_to_encode:
|
650
|
+
if is_numeric_dtype(x[c]):
|
651
|
+
x[c] = x[c].fillna(np.nan)
|
652
|
+
else:
|
653
|
+
x[c] = x[c].fillna("NA")
|
654
|
+
params["cat_features"] = self.features_to_encode
|
660
655
|
|
661
656
|
return x, y, params
|
662
657
|
|
@@ -700,23 +695,29 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
700
695
|
embedding_features=self.grouped_embedding_features,
|
701
696
|
)
|
702
697
|
|
703
|
-
|
704
|
-
shap_values_fold = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
|
698
|
+
shap_values = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
|
705
699
|
|
706
|
-
# Remove last columns (base value) and flatten
|
707
700
|
if self.target_type == ModelTaskType.MULTICLASS:
|
708
|
-
|
709
|
-
|
701
|
+
# For multiclass, shap_values has shape (n_samples, n_classes, n_features + 1)
|
702
|
+
# Last column is bias term
|
703
|
+
shap_values = shap_values[:, :, :-1] # Remove bias term
|
704
|
+
# Average SHAP values across classes
|
705
|
+
shap_values = np.mean(np.abs(shap_values), axis=1)
|
710
706
|
else:
|
711
|
-
|
712
|
-
|
707
|
+
# For binary/regression, shap_values has shape (n_samples, n_features + 1)
|
708
|
+
# Last column is bias term
|
709
|
+
shap_values = shap_values[:, :-1] # Remove bias term
|
710
|
+
# Take absolute values
|
711
|
+
shap_values = np.abs(shap_values)
|
713
712
|
|
714
|
-
|
713
|
+
feature_importance = {}
|
714
|
+
for i, col in enumerate(x.columns):
|
715
|
+
feature_importance[col] = np.mean(np.abs(shap_values[:, i]))
|
715
716
|
|
716
|
-
return
|
717
|
+
return feature_importance
|
717
718
|
|
718
|
-
except Exception:
|
719
|
-
self.logger.exception("Failed to recalculate new SHAP values")
|
719
|
+
except Exception as e:
|
720
|
+
self.logger.exception(f"Failed to recalculate new SHAP values: {str(e)}")
|
720
721
|
return None
|
721
722
|
|
722
723
|
|
@@ -725,6 +726,7 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
725
726
|
self,
|
726
727
|
estimator,
|
727
728
|
scorer: Callable,
|
729
|
+
cat_features: Optional[List[str]],
|
728
730
|
metric_name: str,
|
729
731
|
multiplier: int,
|
730
732
|
cv: BaseCrossValidator,
|
@@ -736,6 +738,7 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
736
738
|
super(LightGBMWrapper, self).__init__(
|
737
739
|
estimator,
|
738
740
|
scorer,
|
741
|
+
cat_features,
|
739
742
|
metric_name,
|
740
743
|
multiplier,
|
741
744
|
cv,
|
@@ -744,9 +747,10 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
744
747
|
text_features=text_features,
|
745
748
|
logger=logger,
|
746
749
|
)
|
747
|
-
self.cat_features = None
|
748
750
|
self.cat_encoder = None
|
749
751
|
self.n_classes = None
|
752
|
+
self.exclude_features = []
|
753
|
+
self.features_to_encode = []
|
750
754
|
|
751
755
|
def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, np.ndarray, dict]:
|
752
756
|
x, y_numpy, groups, params = super()._prepare_to_fit(x, y)
|
@@ -756,30 +760,25 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
756
760
|
if self.target_type == ModelTaskType.BINARY:
|
757
761
|
params["eval_metric"] = "auc"
|
758
762
|
params["callbacks"] = [lgb.early_stopping(stopping_rounds=LIGHTGBM_EARLY_STOPPING_ROUNDS, verbose=False)]
|
759
|
-
self.cat_features = _get_cat_features(
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
x[self.cat_features] = encoded
|
763
|
+
self.cat_features, self.features_to_encode, self.exclude_features = _get_cat_features(
|
764
|
+
self.logger, x, self.cat_features
|
765
|
+
)
|
766
|
+
if self.features_to_encode:
|
767
|
+
encoder = CatBoostEncoder(random_state=DEFAULT_RANDOM_STATE, return_df=True)
|
768
|
+
encoded = encoder.fit_transform(x[self.features_to_encode].astype("object"), y_numpy).astype("category")
|
769
|
+
x[self.features_to_encode] = encoded
|
767
770
|
self.cat_encoder = encoder
|
768
|
-
if not is_numeric_dtype(y_numpy):
|
769
|
-
y_numpy = correct_string_target(y_numpy)
|
770
771
|
|
771
772
|
return x, y_numpy, groups, params
|
772
773
|
|
773
774
|
def _prepare_to_calculate(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
|
775
|
+
if self.exclude_features:
|
776
|
+
x = x.drop(columns=self.exclude_features)
|
774
777
|
x, y_numpy, params = super()._prepare_to_calculate(x, y)
|
775
|
-
if self.
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
self.cat_encoder.transform(x[self.cat_features]), columns=self.cat_features, dtype="category"
|
780
|
-
)
|
781
|
-
if not is_numeric_dtype(y):
|
782
|
-
y_numpy = correct_string_target(y_numpy)
|
778
|
+
if self.features_to_encode is not None and self.cat_encoder is not None:
|
779
|
+
x[self.features_to_encode] = self.cat_encoder.transform(x[self.features_to_encode].astype("object")).astype(
|
780
|
+
"category"
|
781
|
+
)
|
783
782
|
return x, y_numpy, params
|
784
783
|
|
785
784
|
def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator) -> Optional[Dict[str, float]]:
|
@@ -805,20 +804,6 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
805
804
|
for i, col in enumerate(x.columns):
|
806
805
|
feature_importance[col] = np.mean(np.abs(shap_matrix[:, i]))
|
807
806
|
|
808
|
-
# # exclude last column (base value)
|
809
|
-
# shap_values_only = shap_values[:, :-1]
|
810
|
-
# mean_abs_shap = np.mean(np.abs(shap_values_only), axis=0)
|
811
|
-
|
812
|
-
# # For classification, shap_values is returned as a list for each class
|
813
|
-
# # Take values for the positive class
|
814
|
-
# if isinstance(shap_values, list):
|
815
|
-
# shap_values = shap_values[1]
|
816
|
-
|
817
|
-
# # Calculate mean absolute SHAP value for each feature
|
818
|
-
# feature_importance = {}
|
819
|
-
# for i, col in enumerate(x.columns):
|
820
|
-
# feature_importance[col] = np.mean(np.abs(shap_values[:, i]))
|
821
|
-
|
822
807
|
return feature_importance
|
823
808
|
|
824
809
|
except Exception as e:
|
@@ -831,6 +816,7 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
831
816
|
self,
|
832
817
|
estimator,
|
833
818
|
scorer: Callable,
|
819
|
+
cat_features: Optional[List[str]],
|
834
820
|
metric_name: str,
|
835
821
|
multiplier: int,
|
836
822
|
cv: BaseCrossValidator,
|
@@ -842,6 +828,7 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
842
828
|
super(OtherEstimatorWrapper, self).__init__(
|
843
829
|
estimator,
|
844
830
|
scorer,
|
831
|
+
cat_features,
|
845
832
|
metric_name,
|
846
833
|
multiplier,
|
847
834
|
cv,
|
@@ -850,32 +837,32 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
850
837
|
text_features=text_features,
|
851
838
|
logger=logger,
|
852
839
|
)
|
853
|
-
self.cat_features = None
|
854
840
|
|
855
841
|
def _prepare_to_fit(self, x: pd.DataFrame, y: np.ndarray) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
|
856
|
-
x,
|
857
|
-
self.cat_features = _get_cat_features(
|
842
|
+
x, y_numpy, groups, params = super()._prepare_to_fit(x, y)
|
843
|
+
self.cat_features, self.features_to_encode, self.exclude_features = _get_cat_features(
|
844
|
+
self.logger, x, self.cat_features
|
845
|
+
)
|
858
846
|
num_features = [col for col in x.columns if col not in self.cat_features]
|
859
847
|
x[num_features] = x[num_features].fillna(-999)
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
x[
|
864
|
-
|
865
|
-
|
866
|
-
return x, y, groups, params
|
848
|
+
if self.cat_features:
|
849
|
+
encoder = CatBoostEncoder(random_state=DEFAULT_RANDOM_STATE, return_df=True)
|
850
|
+
encoded = encoder.fit_transform(x[self.cat_features].astype("object"), y_numpy).astype("category")
|
851
|
+
x[self.cat_features] = encoded
|
852
|
+
self.cat_encoder = encoder
|
853
|
+
return x, y_numpy, groups, params
|
867
854
|
|
868
855
|
def _prepare_to_calculate(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
|
856
|
+
if self.exclude_features:
|
857
|
+
x = x.drop(columns=self.exclude_features)
|
869
858
|
x, y, params = super()._prepare_to_calculate(x, y)
|
870
859
|
if self.cat_features is not None:
|
871
860
|
num_features = [col for col in x.columns if col not in self.cat_features]
|
872
861
|
x[num_features] = x[num_features].fillna(-999)
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
if not is_numeric_dtype(y):
|
878
|
-
y = correct_string_target(y)
|
862
|
+
if self.features_to_encode and self.cat_encoder is not None:
|
863
|
+
x[self.features_to_encode] = self.cat_encoder.transform(
|
864
|
+
x[self.features_to_encode].astype("object")
|
865
|
+
).astype("category")
|
879
866
|
return x, y, params
|
880
867
|
|
881
868
|
|
@@ -938,7 +925,7 @@ def _get_scorer_by_name(scoring: str) -> Tuple[Callable, str, int]:
|
|
938
925
|
return scoring, metric_name, multiplier
|
939
926
|
|
940
927
|
|
941
|
-
def
|
928
|
+
def define_scorer(target_type: ModelTaskType, scoring: Union[Callable, str, None]) -> Tuple[Callable, str, int]:
|
942
929
|
if scoring is None:
|
943
930
|
if target_type == ModelTaskType.BINARY:
|
944
931
|
scoring = "roc_auc"
|
@@ -957,16 +944,42 @@ def _get_scorer(target_type: ModelTaskType, scoring: Union[Callable, str, None])
|
|
957
944
|
else:
|
958
945
|
metric_name = str(scoring)
|
959
946
|
|
947
|
+
metric_name = "GINI" if metric_name.upper() == "ROC_AUC" and target_type == ModelTaskType.BINARY else metric_name
|
948
|
+
|
960
949
|
return scoring, metric_name, multiplier
|
961
950
|
|
962
951
|
|
963
952
|
def _get_cat_features(
|
964
|
-
|
953
|
+
logger: logging.Logger,
|
954
|
+
x: pd.DataFrame,
|
955
|
+
cat_features: Optional[List[str]],
|
956
|
+
text_features: Optional[List[str]] = None,
|
957
|
+
emb_features: Optional[List[str]] = None,
|
965
958
|
) -> List[str]:
|
959
|
+
cat_features = cat_features or []
|
966
960
|
text_features = text_features or []
|
967
961
|
emb_features = emb_features or []
|
968
962
|
exclude_features = text_features + emb_features
|
969
|
-
|
963
|
+
cat_features = [c for c in cat_features if c not in exclude_features]
|
964
|
+
unique_cat_features = []
|
965
|
+
drop_cat_features = []
|
966
|
+
for name in cat_features:
|
967
|
+
# Remove constant categorical features
|
968
|
+
if x[name].nunique() > 1:
|
969
|
+
unique_cat_features.append(name)
|
970
|
+
else:
|
971
|
+
logger.info(f"Drop column {name} on preparing data for fit")
|
972
|
+
x = x.drop(columns=name)
|
973
|
+
drop_cat_features.append(name)
|
974
|
+
cat_features = unique_cat_features
|
975
|
+
|
976
|
+
logger.info(f"Selected categorical features: {cat_features}")
|
977
|
+
|
978
|
+
features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype()]).columns))
|
979
|
+
|
980
|
+
logger.info(f"Features to encode: {features_to_encode}")
|
981
|
+
|
982
|
+
return cat_features, features_to_encode, drop_cat_features
|
970
983
|
|
971
984
|
|
972
985
|
def _get_add_params(input_params, add_params):
|
upgini/utils/target_utils.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
|
-
from pandas.api.types import
|
6
|
+
from pandas.api.types import is_bool_dtype, is_datetime64_any_dtype, is_numeric_dtype
|
7
7
|
|
8
8
|
from upgini.errors import ValidationError
|
9
9
|
from upgini.metadata import SYSTEM_RECORD_ID, CVType, ModelTaskType
|
@@ -14,11 +14,14 @@ from upgini.utils.ts_utils import get_most_frequent_time_unit, trunc_datetime
|
|
14
14
|
TS_MIN_DIFFERENT_IDS_RATIO = 0.2
|
15
15
|
|
16
16
|
|
17
|
-
def
|
18
|
-
if
|
19
|
-
|
20
|
-
|
21
|
-
|
17
|
+
def prepare_target(y: Union[pd.Series, np.ndarray], target_type: ModelTaskType) -> Union[pd.Series, np.ndarray]:
|
18
|
+
if target_type != ModelTaskType.REGRESSION or (not is_numeric_dtype(y) and not is_datetime64_any_dtype(y)):
|
19
|
+
if isinstance(y, pd.Series):
|
20
|
+
y = y.astype(str).astype("category").cat.codes
|
21
|
+
elif isinstance(y, np.ndarray):
|
22
|
+
y = pd.Series(y).astype(str).astype("category").cat.codes.values
|
23
|
+
|
24
|
+
return y
|
22
25
|
|
23
26
|
|
24
27
|
def define_task(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: upgini
|
3
|
-
Version: 1.2.
|
3
|
+
Version: 1.2.81a3832.dev2
|
4
4
|
Summary: Intelligent data search & enrichment for Machine Learning
|
5
5
|
Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
|
6
6
|
Project-URL: Homepage, https://upgini.com/
|
@@ -22,6 +22,8 @@ Classifier: Programming Language :: Python :: 3.11
|
|
22
22
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
23
23
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
24
24
|
Requires-Python: <3.12,>=3.10
|
25
|
+
Requires-Dist: catboost>=1.2.8
|
26
|
+
Requires-Dist: category-encoders>=2.8.1
|
25
27
|
Requires-Dist: fastparquet>=0.8.1
|
26
28
|
Requires-Dist: ipywidgets>=8.1.0
|
27
29
|
Requires-Dist: jarowinkler>=2.0.0
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256=
|
1
|
+
upgini/__about__.py,sha256=7ytM9g8DI6H-u5aMwPu2Qxa34E_K8afMwp4RaWapTSw,33
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=aspri7ZAgwkNNUiIgQ1GRXvw8XQii3F4RfNXSrF4wrw,35365
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
7
|
-
upgini/http.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=WiSVfmlHI9oKJQbyf46FH0yY80hBJ6hheFpugw0f_vE,210583
|
7
|
+
upgini/http.py,sha256=AfaJ3c8z_tK2hZFEehNybDKE0mp1tYcyAP_l0_p8bLQ,43933
|
8
8
|
upgini/metadata.py,sha256=Yd6iW2f7Wz6vUkg5uvR4xylN16ANnCKVKqAsAkap7p8,12354
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=KxtcjiClNDNlMWpoCbAvVPveC59Nz7z2lA4b-hQozRE,39608
|
10
10
|
upgini/search_task.py,sha256=RcvAE785yksWTsTNWuZFVNlk32jHElMoEna1T_C5N8Q,17823
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -32,7 +32,7 @@ upgini/autofe/timeseries/trend.py,sha256=K1_iw2ko_LIUU8YCUgrvN3n0MkHtsi7-63-8x9e
|
|
32
32
|
upgini/autofe/timeseries/volatility.py,sha256=9shUmIKjpWTHVYjj80YBsk0XheBJ9uBuLv5NW9Mchnk,7953
|
33
33
|
upgini/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
34
34
|
upgini/data_source/data_source_publisher.py,sha256=4S9qwlAklD8vg9tUU_c1pHE2_glUHAh15-wr5hMwKFw,22879
|
35
|
-
upgini/mdc/__init__.py,sha256=
|
35
|
+
upgini/mdc/__init__.py,sha256=iHJlXQg6xRM1-ZOUtaPSJqw5SpQDszvxp4LyqviNLIQ,1027
|
36
36
|
upgini/mdc/context.py,sha256=3u1B-jXt7tXEvNcV3qmR9SDCseudnY7KYsLclBdwVLk,1405
|
37
37
|
upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
38
|
upgini/normalizer/normalize_utils.py,sha256=Ft2MwSgVoBilXAORAOYAuwPD79GOLfwn4qQE3IUFzzg,7218
|
@@ -66,11 +66,11 @@ upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml
|
|
66
66
|
upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
|
67
67
|
upgini/utils/sklearn_ext.py,sha256=HpaNQaKJisgNE7IZ71n7uswxTj7kbPglU2G3s1sORAc,45042
|
68
68
|
upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
|
69
|
-
upgini/utils/target_utils.py,sha256=
|
69
|
+
upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,16832
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.
|
74
|
-
upgini-1.2.
|
75
|
-
upgini-1.2.
|
76
|
-
upgini-1.2.
|
73
|
+
upgini-1.2.81a3832.dev2.dist-info/METADATA,sha256=Kdxh014FUNln4eeF-RflHu3c_pfvPXpsoXfvb6SBneE,49172
|
74
|
+
upgini-1.2.81a3832.dev2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.81a3832.dev2.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.81a3832.dev2.dist-info/RECORD,,
|
File without changes
|
File without changes
|