upgini 1.2.81a3832.dev7__py3-none-any.whl → 1.2.81a3832.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/features_enricher.py +5 -5
- upgini/metrics.py +48 -30
- {upgini-1.2.81a3832.dev7.dist-info → upgini-1.2.81a3832.dev9.dist-info}/METADATA +1 -1
- {upgini-1.2.81a3832.dev7.dist-info → upgini-1.2.81a3832.dev9.dist-info}/RECORD +7 -7
- {upgini-1.2.81a3832.dev7.dist-info → upgini-1.2.81a3832.dev9.dist-info}/WHEEL +0 -0
- {upgini-1.2.81a3832.dev7.dist-info → upgini-1.2.81a3832.dev9.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.81a3832.
|
1
|
+
__version__ = "1.2.81a3832.dev9"
|
upgini/features_enricher.py
CHANGED
@@ -1768,10 +1768,10 @@ class FeaturesEnricher(TransformerMixin):
|
|
1768
1768
|
df = generator.generate(df)
|
1769
1769
|
generated_features.extend(generator.generated_features)
|
1770
1770
|
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
columns_renaming = {c: c for c in df.columns}
|
1771
|
+
normalizer = Normalizer(self.bundle, self.logger)
|
1772
|
+
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
1773
|
+
columns_renaming = normalizer.columns_renaming
|
1774
|
+
# columns_renaming = {c: c for c in df.columns}
|
1775
1775
|
|
1776
1776
|
df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
|
1777
1777
|
|
@@ -3881,7 +3881,7 @@ if response.status_code == 200:
|
|
3881
3881
|
if features_meta is None:
|
3882
3882
|
raise Exception(self.bundle.get("missing_features_meta"))
|
3883
3883
|
|
3884
|
-
return [f.name for f in features_meta if f.type == "categorical"
|
3884
|
+
return [f.name for f in features_meta if f.type == "categorical"]
|
3885
3885
|
|
3886
3886
|
def __prepare_feature_importances(
|
3887
3887
|
self, trace_id: str, df: pd.DataFrame, updated_shaps: Optional[Dict[str, float]] = None, silent=False
|
upgini/metrics.py
CHANGED
@@ -6,7 +6,7 @@ import re
|
|
6
6
|
from collections import defaultdict
|
7
7
|
from copy import deepcopy
|
8
8
|
from dataclasses import dataclass
|
9
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
9
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
10
10
|
|
11
11
|
import lightgbm as lgb
|
12
12
|
import numpy as np
|
@@ -18,6 +18,7 @@ from numpy import log1p
|
|
18
18
|
from pandas.api.types import is_numeric_dtype
|
19
19
|
from sklearn.metrics import check_scoring, get_scorer, make_scorer, roc_auc_score
|
20
20
|
|
21
|
+
from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
21
22
|
from upgini.utils.features_validator import FeaturesValidator
|
22
23
|
from upgini.utils.sklearn_ext import cross_validate
|
23
24
|
|
@@ -31,7 +32,7 @@ except ImportError:
|
|
31
32
|
available_scorers = SCORERS
|
32
33
|
from sklearn.metrics import mean_squared_error
|
33
34
|
from sklearn.metrics._regression import _check_reg_targets, check_consistent_length
|
34
|
-
from sklearn.model_selection import BaseCrossValidator # , TimeSeriesSplit
|
35
|
+
from sklearn.model_selection import BaseCrossValidator, TimeSeriesSplit # , TimeSeriesSplit
|
35
36
|
|
36
37
|
from upgini.errors import ValidationError
|
37
38
|
from upgini.metadata import ModelTaskType
|
@@ -250,6 +251,8 @@ class _CrossValResults:
|
|
250
251
|
|
251
252
|
|
252
253
|
class EstimatorWrapper:
|
254
|
+
default_estimator: Literal["catboost", "lightgbm"] = "catboost"
|
255
|
+
|
253
256
|
def __init__(
|
254
257
|
self,
|
255
258
|
estimator,
|
@@ -303,6 +306,8 @@ class EstimatorWrapper:
|
|
303
306
|
else:
|
304
307
|
if x[c].dtype == "category" and x[c].cat.categories.dtype == np.int64:
|
305
308
|
x[c] = x[c].astype(np.int64)
|
309
|
+
elif not is_numeric_dtype(x[c]):
|
310
|
+
x[c] = x[c].astype(str).astype("category")
|
306
311
|
|
307
312
|
if not isinstance(y, pd.Series):
|
308
313
|
raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
|
@@ -352,6 +357,7 @@ class EstimatorWrapper:
|
|
352
357
|
self.logger.info("Calculate baseline GINI on passed baseline_score_column and target")
|
353
358
|
metric = roc_auc_score(y, x[baseline_score_column])
|
354
359
|
else:
|
360
|
+
self.logger.info(f"Cross validate with estimeator: {self.estimator}")
|
355
361
|
cv_results = cross_validate(
|
356
362
|
estimator=self.estimator,
|
357
363
|
x=x,
|
@@ -458,31 +464,43 @@ class EstimatorWrapper:
|
|
458
464
|
"logger": logger,
|
459
465
|
}
|
460
466
|
if estimator is None:
|
461
|
-
|
462
|
-
|
463
|
-
params =
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
467
|
+
if EstimatorWrapper.default_estimator == "catboost":
|
468
|
+
logger.info("Using CatBoost as default estimator")
|
469
|
+
params = {"has_time": has_date}
|
470
|
+
if target_type == ModelTaskType.MULTICLASS:
|
471
|
+
params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
|
472
|
+
params = _get_add_params(params, add_params)
|
473
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
474
|
+
elif target_type == ModelTaskType.BINARY:
|
475
|
+
params = _get_add_params(params, CATBOOST_BINARY_PARAMS)
|
476
|
+
params = _get_add_params(params, add_params)
|
477
|
+
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
478
|
+
elif target_type == ModelTaskType.REGRESSION:
|
479
|
+
params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
|
480
|
+
params = _get_add_params(params, add_params)
|
481
|
+
estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
|
482
|
+
else:
|
483
|
+
raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
|
484
|
+
elif EstimatorWrapper.default_estimator == "lightgbm":
|
485
|
+
logger.info("Using LightGBM as default estimator")
|
486
|
+
params = {"random_state": DEFAULT_RANDOM_STATE, "verbose": -1}
|
487
|
+
if target_type == ModelTaskType.MULTICLASS:
|
488
|
+
params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
|
489
|
+
params = _get_add_params(params, add_params)
|
490
|
+
estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
491
|
+
elif target_type == ModelTaskType.BINARY:
|
492
|
+
params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
|
493
|
+
params = _get_add_params(params, add_params)
|
494
|
+
estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
|
495
|
+
elif target_type == ModelTaskType.REGRESSION:
|
496
|
+
if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
|
497
|
+
params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
|
498
|
+
params = _get_add_params(params, add_params)
|
499
|
+
estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
|
500
|
+
else:
|
501
|
+
raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
|
484
502
|
else:
|
485
|
-
raise Exception(
|
503
|
+
raise Exception("Unsupported default_estimator. Available: catboost, lightgbm")
|
486
504
|
else:
|
487
505
|
if hasattr(estimator, "copy"):
|
488
506
|
estimator_copy = estimator.copy()
|
@@ -490,8 +508,8 @@ class EstimatorWrapper:
|
|
490
508
|
estimator_copy = deepcopy(estimator)
|
491
509
|
kwargs["estimator"] = estimator_copy
|
492
510
|
if is_catboost_estimator(estimator):
|
493
|
-
if
|
494
|
-
estimator_copy.set_params(
|
511
|
+
if has_date is not None:
|
512
|
+
estimator_copy.set_params(has_time=has_date)
|
495
513
|
estimator = CatBoostWrapper(**kwargs)
|
496
514
|
else:
|
497
515
|
if isinstance(estimator, (LGBMClassifier, LGBMRegressor)):
|
@@ -941,8 +959,8 @@ def _get_cat_features(
|
|
941
959
|
|
942
960
|
logger.info(f"Selected categorical features: {cat_features}")
|
943
961
|
|
944
|
-
|
945
|
-
features_to_encode = [f for f in cat_features if f
|
962
|
+
features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype()]).columns))
|
963
|
+
features_to_encode = [f for f in cat_features if f in features_to_encode]
|
946
964
|
|
947
965
|
logger.info(f"Features to encode: {features_to_encode}")
|
948
966
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: upgini
|
3
|
-
Version: 1.2.81a3832.
|
3
|
+
Version: 1.2.81a3832.dev9
|
4
4
|
Summary: Intelligent data search & enrichment for Machine Learning
|
5
5
|
Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
|
6
6
|
Project-URL: Homepage, https://upgini.com/
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256=
|
1
|
+
upgini/__about__.py,sha256=wEcwloV3XNyxWA40HLqEb4PIXttvc8pREucBfzAKW0c,33
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=aspri7ZAgwkNNUiIgQ1GRXvw8XQii3F4RfNXSrF4wrw,35365
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=ZSSukaq4_mngCkJyQe-XCssXbH8nOD7ByWfSHi9nypc,210847
|
7
7
|
upgini/http.py,sha256=AfaJ3c8z_tK2hZFEehNybDKE0mp1tYcyAP_l0_p8bLQ,43933
|
8
8
|
upgini/metadata.py,sha256=Yd6iW2f7Wz6vUkg5uvR4xylN16ANnCKVKqAsAkap7p8,12354
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=4ehQO8VEebKLiCuBq2LRqC2QbPIqswoe7b1pnR_-zQA,39985
|
10
10
|
upgini/search_task.py,sha256=RcvAE785yksWTsTNWuZFVNlk32jHElMoEna1T_C5N8Q,17823
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -70,7 +70,7 @@ upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,
|
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.81a3832.
|
74
|
-
upgini-1.2.81a3832.
|
75
|
-
upgini-1.2.81a3832.
|
76
|
-
upgini-1.2.81a3832.
|
73
|
+
upgini-1.2.81a3832.dev9.dist-info/METADATA,sha256=6jP4TJl2tN98P8wuWIBARzrPtZVRT48uPukgTvZOvlA,49172
|
74
|
+
upgini-1.2.81a3832.dev9.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.81a3832.dev9.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.81a3832.dev9.dist-info/RECORD,,
|
File without changes
|
File without changes
|