upgini 1.2.87.dev2__py3-none-any.whl → 1.2.87.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/metrics.py +96 -34
- upgini/utils/datetime_utils.py +86 -78
- upgini/utils/sklearn_ext.py +112 -8
- {upgini-1.2.87.dev2.dist-info → upgini-1.2.87.dev3.dist-info}/METADATA +1 -1
- {upgini-1.2.87.dev2.dist-info → upgini-1.2.87.dev3.dist-info}/RECORD +8 -8
- {upgini-1.2.87.dev2.dist-info → upgini-1.2.87.dev3.dist-info}/WHEEL +0 -0
- {upgini-1.2.87.dev2.dist-info → upgini-1.2.87.dev3.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.87.
|
1
|
+
__version__ = "1.2.87.dev3"
|
upgini/metrics.py
CHANGED
@@ -6,13 +6,23 @@ import re
|
|
6
6
|
from collections import defaultdict
|
7
7
|
from copy import deepcopy
|
8
8
|
from dataclasses import dataclass
|
9
|
-
from typing import
|
9
|
+
from typing import (
|
10
|
+
Any,
|
11
|
+
Callable,
|
12
|
+
Dict,
|
13
|
+
List,
|
14
|
+
Literal,
|
15
|
+
Optional,
|
16
|
+
Protocol,
|
17
|
+
Tuple,
|
18
|
+
Union,
|
19
|
+
runtime_checkable,
|
20
|
+
)
|
10
21
|
|
11
22
|
import lightgbm as lgb
|
12
23
|
import numpy as np
|
13
24
|
import pandas as pd
|
14
25
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
15
|
-
from category_encoders.cat_boost import CatBoostEncoder
|
16
26
|
from lightgbm import LGBMClassifier, LGBMRegressor
|
17
27
|
from numpy import log1p
|
18
28
|
from pandas.api.types import is_float_dtype, is_integer_dtype, is_numeric_dtype
|
@@ -32,10 +42,7 @@ except ImportError:
|
|
32
42
|
available_scorers = SCORERS
|
33
43
|
from sklearn.metrics import mean_squared_error
|
34
44
|
from sklearn.metrics._regression import _check_reg_targets, check_consistent_length
|
35
|
-
from sklearn.model_selection import
|
36
|
-
BaseCrossValidator,
|
37
|
-
TimeSeriesSplit,
|
38
|
-
)
|
45
|
+
from sklearn.model_selection import BaseCrossValidator, TimeSeriesSplit
|
39
46
|
|
40
47
|
from upgini.errors import ValidationError
|
41
48
|
from upgini.metadata import ModelTaskType
|
@@ -57,6 +64,16 @@ CATBOOST_REGRESSION_PARAMS = {
|
|
57
64
|
"allow_writing_files": False,
|
58
65
|
}
|
59
66
|
|
67
|
+
CATBOOST_TS_PARAMS = {
|
68
|
+
"learning_rate": 0.05,
|
69
|
+
"early_stopping_rounds": 20,
|
70
|
+
"use_best_model": True,
|
71
|
+
"one_hot_max_size": 100,
|
72
|
+
"verbose": False,
|
73
|
+
"random_state": 42,
|
74
|
+
"allow_writing_files": False,
|
75
|
+
}
|
76
|
+
|
60
77
|
CATBOOST_BINARY_PARAMS = {
|
61
78
|
"iterations": 250,
|
62
79
|
"learning_rate": 0.05,
|
@@ -311,6 +328,7 @@ class EstimatorWrapper:
|
|
311
328
|
self.target_type = target_type
|
312
329
|
self.add_params = add_params
|
313
330
|
self.cv_estimators = None
|
331
|
+
self.cv_cat_encoders: Optional[List[Optional[HasTransform]]] = None
|
314
332
|
self.groups = groups
|
315
333
|
self.text_features = text_features
|
316
334
|
self.logger = logger or logging.getLogger()
|
@@ -437,7 +455,9 @@ class EstimatorWrapper:
|
|
437
455
|
|
438
456
|
return x, y, {}
|
439
457
|
|
440
|
-
def calculate_shap(
|
458
|
+
def calculate_shap(
|
459
|
+
self, x: pd.DataFrame, y: pd.Series, estimator, cat_encoder: Optional[HasTransform]
|
460
|
+
) -> Optional[Dict[str, float]]:
|
441
461
|
return None
|
442
462
|
|
443
463
|
def cross_val_predict(
|
@@ -468,9 +488,11 @@ class EstimatorWrapper:
|
|
468
488
|
fit_params=fit_params,
|
469
489
|
return_estimator=True,
|
470
490
|
error_score="raise",
|
491
|
+
random_state=DEFAULT_RANDOM_STATE,
|
471
492
|
)
|
472
493
|
metrics_by_fold = cv_results["test_score"]
|
473
494
|
self.cv_estimators = cv_results["estimator"]
|
495
|
+
self.cv_cat_encoders = cv_results["cat_encoder"]
|
474
496
|
|
475
497
|
self.check_fold_metrics(metrics_by_fold)
|
476
498
|
|
@@ -478,14 +500,14 @@ class EstimatorWrapper:
|
|
478
500
|
|
479
501
|
splits = self.cv.split(x, y, groups)
|
480
502
|
|
481
|
-
for estimator, split in zip(self.cv_estimators, splits):
|
503
|
+
for estimator, cat_encoder, split in zip(self.cv_estimators, self.cv_cat_encoders, splits):
|
482
504
|
_, validation_idx = split
|
483
505
|
cv_x = x.iloc[validation_idx]
|
484
506
|
if isinstance(y, pd.Series):
|
485
507
|
cv_y = y.iloc[validation_idx]
|
486
508
|
else:
|
487
509
|
cv_y = y[validation_idx]
|
488
|
-
shaps = self.calculate_shap(cv_x, cv_y, estimator)
|
510
|
+
shaps = self.calculate_shap(cv_x, cv_y, estimator, cat_encoder)
|
489
511
|
if shaps is not None:
|
490
512
|
for feature, shap_value in shaps.items():
|
491
513
|
shap_values_all_folds[feature].append(shap_value)
|
@@ -525,8 +547,19 @@ class EstimatorWrapper:
|
|
525
547
|
metric, metric_std = roc_auc_score(y, x[baseline_score_column]), None
|
526
548
|
else:
|
527
549
|
metrics = []
|
528
|
-
for est in self.cv_estimators:
|
529
|
-
|
550
|
+
for est, cat_encoder in zip(self.cv_estimators, self.cv_cat_encoders):
|
551
|
+
x_copy = x.copy()
|
552
|
+
if cat_encoder is not None:
|
553
|
+
if hasattr(cat_encoder, "feature_names_in_"):
|
554
|
+
encoded = cat_encoder.transform(x_copy[cat_encoder.feature_names_in_])
|
555
|
+
else:
|
556
|
+
encoded = cat_encoder.transform(x[self.cat_features])
|
557
|
+
if isinstance(self.cv, TimeSeriesSplit) or isinstance(self.cv, BlockedTimeSeriesSplit):
|
558
|
+
encoded = encoded.astype(int)
|
559
|
+
else:
|
560
|
+
encoded = encoded.astype("category")
|
561
|
+
x_copy[self.cat_features] = encoded
|
562
|
+
metrics.append(self.scorer(est, x_copy, y))
|
530
563
|
|
531
564
|
metric, metric_std = self._calculate_metric_from_folds(metrics)
|
532
565
|
return _CrossValResults(metric=metric, metric_std=metric_std, shap_values=None)
|
@@ -549,7 +582,7 @@ class EstimatorWrapper:
|
|
549
582
|
text_features: Optional[List[str]] = None,
|
550
583
|
add_params: Optional[Dict[str, Any]] = None,
|
551
584
|
groups: Optional[List[str]] = None,
|
552
|
-
has_time:
|
585
|
+
has_time: bool = False,
|
553
586
|
) -> EstimatorWrapper:
|
554
587
|
scorer, metric_name, multiplier = define_scorer(target_type, scoring)
|
555
588
|
kwargs = {
|
@@ -576,7 +609,10 @@ class EstimatorWrapper:
|
|
576
609
|
params = _get_add_params(params, add_params)
|
577
610
|
estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
|
578
611
|
elif target_type == ModelTaskType.REGRESSION:
|
579
|
-
|
612
|
+
if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
|
613
|
+
params = _get_add_params(params, CATBOOST_TS_PARAMS)
|
614
|
+
else:
|
615
|
+
params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
|
580
616
|
params = _get_add_params(params, add_params)
|
581
617
|
estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
|
582
618
|
else:
|
@@ -767,15 +803,24 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
767
803
|
else:
|
768
804
|
raise e
|
769
805
|
|
770
|
-
def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator) -> Optional[Dict[str, float]]:
|
806
|
+
def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator, cat_encoder) -> Optional[Dict[str, float]]:
|
771
807
|
try:
|
772
808
|
from catboost import Pool
|
773
809
|
|
810
|
+
if cat_encoder is not None:
|
811
|
+
if isinstance(self.cv, TimeSeriesSplit) or isinstance(self.cv, BlockedTimeSeriesSplit):
|
812
|
+
encoded = cat_encoder.transform(x[self.cat_features]).astype(int)
|
813
|
+
cat_features = None
|
814
|
+
else:
|
815
|
+
encoded = cat_encoder.transform(x[self.cat_features])
|
816
|
+
cat_features = encoded.columns.to_list()
|
817
|
+
x[self.cat_features] = encoded
|
818
|
+
|
774
819
|
# Create Pool for fold data, if need (for example, when categorical features are present)
|
775
820
|
fold_pool = Pool(
|
776
821
|
x,
|
777
822
|
y,
|
778
|
-
cat_features=
|
823
|
+
cat_features=cat_features,
|
779
824
|
text_features=self.text_features,
|
780
825
|
embedding_features=self.grouped_embedding_features,
|
781
826
|
)
|
@@ -832,7 +877,6 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
832
877
|
text_features=text_features,
|
833
878
|
logger=logger,
|
834
879
|
)
|
835
|
-
self.cat_encoder = None
|
836
880
|
self.n_classes = None
|
837
881
|
|
838
882
|
def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, np.ndarray, dict]:
|
@@ -844,10 +888,10 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
844
888
|
params["eval_metric"] = "auc"
|
845
889
|
params["callbacks"] = [lgb.early_stopping(stopping_rounds=LIGHTGBM_EARLY_STOPPING_ROUNDS, verbose=False)]
|
846
890
|
if self.cat_features:
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
891
|
+
for c in self.cat_features:
|
892
|
+
if x[c].dtype != "category":
|
893
|
+
x[c] = x[c].astype("category")
|
894
|
+
|
851
895
|
for c in x.columns:
|
852
896
|
if x[c].dtype not in ["category", "int64", "float64", "bool"]:
|
853
897
|
self.logger.warning(f"Feature {c} is not numeric and will be dropped")
|
@@ -857,15 +901,26 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
857
901
|
|
858
902
|
def _prepare_to_calculate(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
|
859
903
|
x, y_numpy, params = super()._prepare_to_calculate(x, y)
|
860
|
-
if self.cat_features
|
861
|
-
|
862
|
-
|
904
|
+
if self.cat_features:
|
905
|
+
for c in self.cat_features:
|
906
|
+
if x[c].dtype != "category":
|
907
|
+
x[c] = x[c].astype("category")
|
863
908
|
return x, y_numpy, params
|
864
909
|
|
865
|
-
def calculate_shap(
|
910
|
+
def calculate_shap(
|
911
|
+
self, x: pd.DataFrame, y: pd.Series, estimator, cat_encoder: Optional[HasTransform]
|
912
|
+
) -> Optional[Dict[str, float]]:
|
866
913
|
try:
|
914
|
+
x_copy = x.copy()
|
915
|
+
if cat_encoder is not None:
|
916
|
+
if isinstance(self.cv, TimeSeriesSplit) or isinstance(self.cv, BlockedTimeSeriesSplit):
|
917
|
+
encoded = cat_encoder.transform(x_copy[self.cat_features]).astype(int)
|
918
|
+
else:
|
919
|
+
encoded = cat_encoder.transform(x_copy[self.cat_features]).astype("category")
|
920
|
+
x_copy[self.cat_features] = encoded
|
921
|
+
|
867
922
|
shap_matrix = estimator.predict(
|
868
|
-
|
923
|
+
x_copy,
|
869
924
|
predict_disable_shape_check=True,
|
870
925
|
raw_score=True,
|
871
926
|
pred_leaf=False,
|
@@ -924,10 +979,10 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
924
979
|
num_features = [col for col in x.columns if col not in self.cat_features]
|
925
980
|
x[num_features] = x[num_features].fillna(-999)
|
926
981
|
if self.cat_features:
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
982
|
+
for c in self.cat_features:
|
983
|
+
if x[c].dtype != "category":
|
984
|
+
x[c] = x[c].astype("category")
|
985
|
+
params["cat_features"] = self.cat_features
|
931
986
|
for c in x.columns:
|
932
987
|
if x[c].dtype not in ["category", "int64", "float64", "bool"]:
|
933
988
|
self.logger.warning(f"Feature {c} is not numeric and will be dropped")
|
@@ -938,15 +993,22 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
938
993
|
def _prepare_to_calculate(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
|
939
994
|
x, y_numpy, params = super()._prepare_to_calculate(x, y)
|
940
995
|
if self.cat_features is not None:
|
996
|
+
for c in self.cat_features:
|
997
|
+
if x[c].dtype != "category":
|
998
|
+
x[c] = x[c].astype("category")
|
941
999
|
num_features = [col for col in x.columns if col not in self.cat_features]
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
).astype("category")
|
1000
|
+
else:
|
1001
|
+
num_features = x.columns
|
1002
|
+
x[num_features] = x[num_features].fillna(-999)
|
1003
|
+
|
947
1004
|
return x, y_numpy, params
|
948
1005
|
|
949
1006
|
|
1007
|
+
@runtime_checkable
|
1008
|
+
class HasTransform(Protocol):
|
1009
|
+
def transform(self, X: pd.DataFrame, y: Optional[Union[pd.Series, np.ndarray]] = None) -> pd.DataFrame: ...
|
1010
|
+
|
1011
|
+
|
950
1012
|
def validate_scoring_argument(scoring: Union[Callable, str, None]):
|
951
1013
|
if scoring is None:
|
952
1014
|
return
|
upgini/utils/datetime_utils.py
CHANGED
@@ -251,99 +251,107 @@ def is_time_series(df: pd.DataFrame, date_col: str) -> bool:
|
|
251
251
|
|
252
252
|
|
253
253
|
def is_blocked_time_series(df: pd.DataFrame, date_col: str, search_keys: List[str]) -> bool:
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
254
|
+
try:
|
255
|
+
df = df.copy()
|
256
|
+
seconds = "datetime_seconds"
|
257
|
+
if isinstance(df[date_col].dtype, pd.PeriodDtype):
|
258
|
+
df[date_col] = df[date_col].dt.to_timestamp()
|
259
|
+
elif is_numeric_dtype(df[date_col]):
|
260
|
+
df[date_col] = pd.to_datetime(df[date_col], unit="ms")
|
261
|
+
else:
|
262
|
+
df[date_col] = pd.to_datetime(df[date_col])
|
263
|
+
df[date_col] = df[date_col].dt.tz_localize(None)
|
264
|
+
df[seconds] = (df[date_col] - df[date_col].dt.floor("D")).dt.seconds
|
265
|
+
|
266
|
+
seconds_without_na = df[seconds].dropna()
|
267
|
+
columns_to_drop = [c for c in search_keys if c != date_col] + [seconds]
|
268
|
+
df.drop(columns=columns_to_drop, inplace=True)
|
269
|
+
# Date, not datetime
|
270
|
+
if (seconds_without_na != 0).any() and seconds_without_na.nunique() > 1:
|
271
|
+
return False
|
269
272
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
273
|
+
nunique_dates = df[date_col].nunique()
|
274
|
+
# Unique dates count more than 270
|
275
|
+
if nunique_dates < 270:
|
276
|
+
return False
|
274
277
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
278
|
+
min_date = df[date_col].min()
|
279
|
+
max_date = df[date_col].max()
|
280
|
+
days_delta = (max_date - min_date).days + 1
|
281
|
+
# Missing dates less than 30% (unique dates count and days delta between earliest and latest dates)
|
282
|
+
if nunique_dates / days_delta < 0.3:
|
283
|
+
return False
|
281
284
|
|
282
|
-
|
285
|
+
accumulated_changing_columns = set()
|
283
286
|
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
+
def check_differences(group: pd.DataFrame):
|
288
|
+
changing_columns = group.columns[group.nunique(dropna=False) > 1].to_list()
|
289
|
+
accumulated_changing_columns.update(changing_columns)
|
287
290
|
|
288
|
-
|
289
|
-
|
291
|
+
def is_multiple_rows(group: pd.DataFrame) -> bool:
|
292
|
+
return group.shape[0] > 1
|
290
293
|
|
291
|
-
|
292
|
-
|
294
|
+
grouped = df.groupby(date_col)[[c for c in df.columns if c != date_col]]
|
295
|
+
dates_with_multiple_rows = grouped.apply(is_multiple_rows).sum()
|
293
296
|
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
+
# share of dates with more than one record is more than 99%
|
298
|
+
if dates_with_multiple_rows / nunique_dates < 0.99:
|
299
|
+
return False
|
297
300
|
|
298
|
-
|
299
|
-
|
301
|
+
if df.shape[1] <= 3:
|
302
|
+
return True
|
300
303
|
|
301
|
-
|
302
|
-
|
304
|
+
grouped.apply(check_differences)
|
305
|
+
return len(accumulated_changing_columns) <= 2
|
306
|
+
except Exception:
|
307
|
+
return False
|
303
308
|
|
304
309
|
|
305
310
|
def is_dates_distribution_valid(
|
306
311
|
df: pd.DataFrame,
|
307
312
|
search_keys: Dict[str, SearchKey],
|
308
313
|
) -> bool:
|
309
|
-
|
314
|
+
try:
|
315
|
+
maybe_date_col = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
|
310
316
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
317
|
+
if EVAL_SET_INDEX in df.columns:
|
318
|
+
X = df.query(f"{EVAL_SET_INDEX} == 0")
|
319
|
+
else:
|
320
|
+
X = df
|
315
321
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
+
if maybe_date_col is None:
|
323
|
+
for col in X.columns:
|
324
|
+
if col in search_keys:
|
325
|
+
continue
|
326
|
+
try:
|
327
|
+
if isinstance(X[col].dtype, pd.PeriodDtype):
|
328
|
+
pass
|
329
|
+
elif pd.__version__ >= "2.0.0":
|
330
|
+
# Format mixed to avoid massive warnings
|
331
|
+
pd.to_datetime(X[col], format="mixed")
|
332
|
+
else:
|
333
|
+
pd.to_datetime(X[col])
|
334
|
+
maybe_date_col = col
|
335
|
+
break
|
336
|
+
except Exception:
|
322
337
|
pass
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
date_counts = dates.value_counts().sort_index()
|
344
|
-
|
345
|
-
date_counts_1 = date_counts[: round(len(date_counts) / 2)]
|
346
|
-
date_counts_2 = date_counts[round(len(date_counts) / 2) :]
|
347
|
-
ratio = date_counts_2.mean() / date_counts_1.mean()
|
348
|
-
|
349
|
-
return ratio >= 0.8 and ratio <= 1.2
|
338
|
+
|
339
|
+
if maybe_date_col is None:
|
340
|
+
return
|
341
|
+
|
342
|
+
if isinstance(X[maybe_date_col].dtype, pd.PeriodDtype):
|
343
|
+
dates = X[maybe_date_col].dt.to_timestamp().dt.date
|
344
|
+
elif pd.__version__ >= "2.0.0":
|
345
|
+
dates = pd.to_datetime(X[maybe_date_col], format="mixed").dt.date
|
346
|
+
else:
|
347
|
+
dates = pd.to_datetime(X[maybe_date_col]).dt.date
|
348
|
+
|
349
|
+
date_counts = dates.value_counts().sort_index()
|
350
|
+
|
351
|
+
date_counts_1 = date_counts[: round(len(date_counts) / 2)]
|
352
|
+
date_counts_2 = date_counts[round(len(date_counts) / 2) :]
|
353
|
+
ratio = date_counts_2.mean() / date_counts_1.mean()
|
354
|
+
|
355
|
+
return ratio >= 0.8 and ratio <= 1.2
|
356
|
+
except Exception:
|
357
|
+
return False
|
upgini/utils/sklearn_ext.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import inspect
|
2
3
|
import numbers
|
3
4
|
import time
|
4
5
|
import warnings
|
@@ -9,6 +10,7 @@ from traceback import format_exc
|
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
import scipy.sparse as sp
|
13
|
+
from category_encoders import CatBoostEncoder
|
12
14
|
from joblib import Parallel, logger
|
13
15
|
from scipy.sparse import issparse
|
14
16
|
from sklearn import config_context, get_config
|
@@ -16,10 +18,13 @@ from sklearn.base import clone, is_classifier
|
|
16
18
|
from sklearn.exceptions import FitFailedWarning, NotFittedError
|
17
19
|
from sklearn.metrics import check_scoring
|
18
20
|
from sklearn.metrics._scorer import _MultimetricScorer
|
19
|
-
from sklearn.model_selection import StratifiedKFold, check_cv
|
21
|
+
from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit, check_cv
|
22
|
+
from sklearn.preprocessing import OrdinalEncoder
|
20
23
|
from sklearn.utils.fixes import np_version, parse_version
|
21
24
|
from sklearn.utils.validation import indexable
|
22
25
|
|
26
|
+
from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
27
|
+
|
23
28
|
# from sklearn.model_selection import cross_validate as original_cross_validate
|
24
29
|
|
25
30
|
_DEFAULT_TAGS = {
|
@@ -59,6 +64,7 @@ def cross_validate(
|
|
59
64
|
return_train_score=False,
|
60
65
|
return_estimator=False,
|
61
66
|
error_score=np.nan,
|
67
|
+
random_state=None,
|
62
68
|
):
|
63
69
|
"""Evaluate metric(s) by cross-validation and also record fit/score times.
|
64
70
|
|
@@ -279,6 +285,8 @@ def cross_validate(
|
|
279
285
|
return_times=True,
|
280
286
|
return_estimator=return_estimator,
|
281
287
|
error_score=error_score,
|
288
|
+
is_timeseries=isinstance(cv, TimeSeriesSplit) or isinstance(cv, BlockedTimeSeriesSplit),
|
289
|
+
random_state=random_state,
|
282
290
|
)
|
283
291
|
for train, test in cv.split(x, y, groups)
|
284
292
|
)
|
@@ -296,6 +304,7 @@ def cross_validate(
|
|
296
304
|
ret = {}
|
297
305
|
ret["fit_time"] = results["fit_time"]
|
298
306
|
ret["score_time"] = results["score_time"]
|
307
|
+
ret["cat_encoder"] = results["cat_encoder"]
|
299
308
|
|
300
309
|
if return_estimator:
|
301
310
|
ret["estimator"] = results["estimator"]
|
@@ -320,16 +329,16 @@ def cross_validate(
|
|
320
329
|
else:
|
321
330
|
shuffle = False
|
322
331
|
if hasattr(cv, "random_state") and shuffle:
|
323
|
-
|
332
|
+
cv_random_state = cv.random_state
|
324
333
|
else:
|
325
|
-
|
334
|
+
cv_random_state = None
|
326
335
|
return cross_validate(
|
327
336
|
estimator,
|
328
337
|
x,
|
329
338
|
y,
|
330
339
|
groups=groups,
|
331
340
|
scoring=scoring,
|
332
|
-
cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=
|
341
|
+
cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=cv_random_state),
|
333
342
|
n_jobs=n_jobs,
|
334
343
|
verbose=verbose,
|
335
344
|
fit_params=fit_params,
|
@@ -337,21 +346,46 @@ def cross_validate(
|
|
337
346
|
return_train_score=return_train_score,
|
338
347
|
return_estimator=return_estimator,
|
339
348
|
error_score=error_score,
|
349
|
+
random_state=random_state,
|
340
350
|
)
|
341
351
|
raise e
|
342
352
|
|
343
353
|
|
344
|
-
def
|
354
|
+
def _is_catboost_estimator(estimator):
|
345
355
|
try:
|
346
356
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
357
|
+
|
347
358
|
return isinstance(estimator, (CatBoostClassifier, CatBoostRegressor))
|
348
359
|
except ImportError:
|
349
360
|
return False
|
350
361
|
|
351
362
|
|
352
|
-
def
|
363
|
+
def _supports_cat_features(estimator) -> bool:
|
364
|
+
"""Check if estimator's fit method accepts cat_features parameter.
|
365
|
+
|
366
|
+
Parameters
|
367
|
+
----------
|
368
|
+
estimator : estimator object
|
369
|
+
The estimator to check.
|
370
|
+
|
371
|
+
Returns
|
372
|
+
-------
|
373
|
+
bool
|
374
|
+
True if estimator's fit method accepts cat_features parameter, False otherwise.
|
375
|
+
"""
|
376
|
+
try:
|
377
|
+
# Get the signature of the fit method
|
378
|
+
fit_params = inspect.signature(estimator.fit).parameters
|
379
|
+
# Check if cat_features is in the parameters
|
380
|
+
return "cat_features" in fit_params
|
381
|
+
except (AttributeError, ValueError):
|
382
|
+
return False
|
383
|
+
|
384
|
+
|
385
|
+
def _is_lightgbm_estimator(estimator):
|
353
386
|
try:
|
354
387
|
from lightgbm import LGBMClassifier, LGBMRegressor
|
388
|
+
|
355
389
|
return isinstance(estimator, (LGBMClassifier, LGBMRegressor))
|
356
390
|
except ImportError:
|
357
391
|
return False
|
@@ -375,6 +409,8 @@ def _fit_and_score(
|
|
375
409
|
split_progress=None,
|
376
410
|
candidate_progress=None,
|
377
411
|
error_score=np.nan,
|
412
|
+
is_timeseries=False,
|
413
|
+
random_state=None,
|
378
414
|
):
|
379
415
|
"""Fit estimator and compute scores for a given dataset split.
|
380
416
|
|
@@ -509,13 +545,24 @@ def _fit_and_score(
|
|
509
545
|
|
510
546
|
result = {}
|
511
547
|
try:
|
548
|
+
if "cat_features" in fit_params and fit_params["cat_features"]:
|
549
|
+
X_train, y_train, X_test, y_test, cat_features, cat_encoder = _encode_cat_features(
|
550
|
+
X_train, y_train, X_test, y_test, fit_params["cat_features"], estimator, is_timeseries, random_state
|
551
|
+
)
|
552
|
+
if cat_features and _supports_cat_features(estimator):
|
553
|
+
fit_params["cat_features"] = cat_features
|
554
|
+
else:
|
555
|
+
del fit_params["cat_features"]
|
556
|
+
else:
|
557
|
+
cat_encoder = None
|
558
|
+
result["cat_encoder"] = cat_encoder
|
512
559
|
if y_train is None:
|
513
560
|
estimator.fit(X_train, **fit_params)
|
514
561
|
else:
|
515
|
-
if
|
562
|
+
if _is_catboost_estimator(estimator):
|
516
563
|
fit_params = fit_params.copy()
|
517
564
|
fit_params["eval_set"] = [(X_test, y_test)]
|
518
|
-
elif
|
565
|
+
elif _is_lightgbm_estimator(estimator):
|
519
566
|
fit_params = fit_params.copy()
|
520
567
|
fit_params["eval_set"] = [(X_test, y_test)]
|
521
568
|
estimator.fit(X_train, y_train, **fit_params)
|
@@ -1245,3 +1292,60 @@ def _num_samples(x):
|
|
1245
1292
|
return len(x)
|
1246
1293
|
except TypeError as type_error:
|
1247
1294
|
raise TypeError(message) from type_error
|
1295
|
+
|
1296
|
+
|
1297
|
+
def _encode_cat_features(X_train, y_train, X_test, y_test, cat_features, estimator, is_timeseries, random_state):
|
1298
|
+
if _is_catboost_estimator(estimator):
|
1299
|
+
if is_timeseries:
|
1300
|
+
# Fit encoder on training fold
|
1301
|
+
encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
|
1302
|
+
encoder.fit(X_train[cat_features], y_train)
|
1303
|
+
|
1304
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features]).astype(int)
|
1305
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
|
1306
|
+
|
1307
|
+
# Don't use as categorical features, so CatBoost will not encode them
|
1308
|
+
return X_train, y_train, X_test, y_test, [], encoder
|
1309
|
+
else:
|
1310
|
+
return X_train, y_train, X_test, y_test, cat_features, None
|
1311
|
+
else:
|
1312
|
+
if is_timeseries:
|
1313
|
+
# Fit encoder on training fold
|
1314
|
+
encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
|
1315
|
+
encoder.fit(X_train[cat_features], y_train)
|
1316
|
+
|
1317
|
+
# Progressive encoding on train (using y)
|
1318
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype(int)
|
1319
|
+
|
1320
|
+
# Static encoding on validation (no y)
|
1321
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
|
1322
|
+
|
1323
|
+
return X_train, y_train, X_test, y_test, [], encoder
|
1324
|
+
else:
|
1325
|
+
# Shuffle train data
|
1326
|
+
X_train_shuffled, y_train_shuffled = _shuffle_pair(
|
1327
|
+
X_train[cat_features].astype("object"), y_train, random_state
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
# Fit encoder on training fold
|
1331
|
+
encoder = CatBoostEncoder(random_state=random_state, cols=cat_features)
|
1332
|
+
encoder.fit(X_train_shuffled, y_train_shuffled)
|
1333
|
+
|
1334
|
+
# Progressive encoding on train (using y)
|
1335
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype("category")
|
1336
|
+
|
1337
|
+
# Static encoding on validation (no y)
|
1338
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype("category")
|
1339
|
+
|
1340
|
+
return X_train, y_train, X_test, y_test, cat_features, encoder
|
1341
|
+
|
1342
|
+
|
1343
|
+
def _shuffle_pair(X, y, random_state):
|
1344
|
+
# If X doesn't have reseted index there could be a problem
|
1345
|
+
# shuffled_idx = np.random.RandomState(random_state).permutation(len(X))
|
1346
|
+
# return X.iloc[shuffled_idx], pd.Series(y).iloc[shuffled_idx]
|
1347
|
+
|
1348
|
+
Xy = X.copy()
|
1349
|
+
Xy["target"] = y
|
1350
|
+
Xy_shuffled = Xy.sample(frac=1, random_state=random_state)
|
1351
|
+
return Xy_shuffled.drop(columns="target"), Xy_shuffled["target"]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
upgini/__about__.py,sha256
|
1
|
+
upgini/__about__.py,sha256=-MoNpjvEXC0uIle8xxIgQduzBZJlNzuW-1rPMTm_xc8,28
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=fRtqSkXNONLnPe6cCL967GMt349FTIpXzy_u8LUKncw,35354
|
@@ -6,7 +6,7 @@ upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
|
6
6
|
upgini/features_enricher.py,sha256=n8KBoBgJApLiRv4wXeSgfS-PfbB1D5aDOJfFnL0q6v8,214487
|
7
7
|
upgini/http.py,sha256=6Qcepv0tDC72mBBJxYHnA2xqw6QwFaKrXN8o4vju8Es,44372
|
8
8
|
upgini/metadata.py,sha256=zt_9k0iQbWXuiRZcel4ORNPdQKt6Ou69ucZD_E1Q46o,12341
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=CR_MKBcq1RlNMXeqc9S374JzHgunMl-mEmlTnZAm_VI,45236
|
10
10
|
upgini/search_task.py,sha256=Q5HjBpLIB3OCxAD1zNv5yQ3ZNJx696WCK_-H35_y7Rs,17912
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -51,7 +51,7 @@ upgini/utils/blocked_time_series.py,sha256=Uqr3vp4YqNclj2-PzEYqVy763GSXHn86sbpIl
|
|
51
51
|
upgini/utils/country_utils.py,sha256=lY-eXWwFVegdVENFttbvLcgGDjFO17Sex8hd2PyJaRk,6937
|
52
52
|
upgini/utils/custom_loss_utils.py,sha256=kieNZYBYZm5ZGBltF1F_jOSF4ea6C29rYuCyiDcqVNY,3857
|
53
53
|
upgini/utils/cv_utils.py,sha256=w6FQb9nO8BWDx88EF83NpjPLarK4eR4ia0Wg0kLBJC4,3525
|
54
|
-
upgini/utils/datetime_utils.py,sha256=
|
54
|
+
upgini/utils/datetime_utils.py,sha256=UL1ernnawW0LV9mPDpCIc6sFy0HUhFscWVNwfH4V7rI,14366
|
55
55
|
upgini/utils/deduplicate_utils.py,sha256=jm9ARZ0fbJFF3aJqj-xm_T6lNh-WErM0H0h6B_L1xQc,8948
|
56
56
|
upgini/utils/display_utils.py,sha256=hAeWEcJtPDg8fAVcMNrNB-azFD2WJp1nvbPAhR7SeP4,12071
|
57
57
|
upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
|
@@ -64,13 +64,13 @@ upgini/utils/mstats.py,sha256=u3gQVUtDRbyrOQK6V1UJ2Rx1QbkSNYGjXa6m3Z_dPVs,6286
|
|
64
64
|
upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,10432
|
65
65
|
upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
|
66
66
|
upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
|
67
|
-
upgini/utils/sklearn_ext.py,sha256=
|
67
|
+
upgini/utils/sklearn_ext.py,sha256=Mdxz0tc-9zT4QyNccA3B86fY4l0MnLDr94POVdYeCT4,49332
|
68
68
|
upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
|
69
69
|
upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,16832
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.87.
|
74
|
-
upgini-1.2.87.
|
75
|
-
upgini-1.2.87.
|
76
|
-
upgini-1.2.87.
|
73
|
+
upgini-1.2.87.dev3.dist-info/METADATA,sha256=Pm-acVK8TpDLvPsO0qluwSjmu0cb3FHmtXmqMj--2Ag,49167
|
74
|
+
upgini-1.2.87.dev3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.87.dev3.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.87.dev3.dist-info/RECORD,,
|
File without changes
|
File without changes
|