upgini 1.2.86a2__py3-none-any.whl → 1.2.87__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/data_source/data_source_publisher.py +21 -0
- upgini/features_enricher.py +91 -41
- upgini/metrics.py +103 -41
- upgini/resource_bundle/strings.properties +3 -1
- upgini/utils/datetime_utils.py +130 -118
- upgini/utils/deduplicate_utils.py +4 -4
- upgini/utils/sklearn_ext.py +112 -8
- {upgini-1.2.86a2.dist-info → upgini-1.2.87.dist-info}/METADATA +1 -1
- {upgini-1.2.86a2.dist-info → upgini-1.2.87.dist-info}/RECORD +12 -12
- {upgini-1.2.86a2.dist-info → upgini-1.2.87.dist-info}/WHEEL +0 -0
- {upgini-1.2.86a2.dist-info → upgini-1.2.87.dist-info}/licenses/LICENSE +0 -0
upgini/utils/sklearn_ext.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import inspect
|
2
3
|
import numbers
|
3
4
|
import time
|
4
5
|
import warnings
|
@@ -9,6 +10,7 @@ from traceback import format_exc
|
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
import scipy.sparse as sp
|
13
|
+
from category_encoders import CatBoostEncoder
|
12
14
|
from joblib import Parallel, logger
|
13
15
|
from scipy.sparse import issparse
|
14
16
|
from sklearn import config_context, get_config
|
@@ -16,10 +18,13 @@ from sklearn.base import clone, is_classifier
|
|
16
18
|
from sklearn.exceptions import FitFailedWarning, NotFittedError
|
17
19
|
from sklearn.metrics import check_scoring
|
18
20
|
from sklearn.metrics._scorer import _MultimetricScorer
|
19
|
-
from sklearn.model_selection import StratifiedKFold, check_cv
|
21
|
+
from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit, check_cv
|
22
|
+
from sklearn.preprocessing import OrdinalEncoder
|
20
23
|
from sklearn.utils.fixes import np_version, parse_version
|
21
24
|
from sklearn.utils.validation import indexable
|
22
25
|
|
26
|
+
from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
|
27
|
+
|
23
28
|
# from sklearn.model_selection import cross_validate as original_cross_validate
|
24
29
|
|
25
30
|
_DEFAULT_TAGS = {
|
@@ -59,6 +64,7 @@ def cross_validate(
|
|
59
64
|
return_train_score=False,
|
60
65
|
return_estimator=False,
|
61
66
|
error_score=np.nan,
|
67
|
+
random_state=None,
|
62
68
|
):
|
63
69
|
"""Evaluate metric(s) by cross-validation and also record fit/score times.
|
64
70
|
|
@@ -279,6 +285,8 @@ def cross_validate(
|
|
279
285
|
return_times=True,
|
280
286
|
return_estimator=return_estimator,
|
281
287
|
error_score=error_score,
|
288
|
+
is_timeseries=isinstance(cv, TimeSeriesSplit) or isinstance(cv, BlockedTimeSeriesSplit),
|
289
|
+
random_state=random_state,
|
282
290
|
)
|
283
291
|
for train, test in cv.split(x, y, groups)
|
284
292
|
)
|
@@ -296,6 +304,7 @@ def cross_validate(
|
|
296
304
|
ret = {}
|
297
305
|
ret["fit_time"] = results["fit_time"]
|
298
306
|
ret["score_time"] = results["score_time"]
|
307
|
+
ret["cat_encoder"] = results["cat_encoder"]
|
299
308
|
|
300
309
|
if return_estimator:
|
301
310
|
ret["estimator"] = results["estimator"]
|
@@ -320,16 +329,16 @@ def cross_validate(
|
|
320
329
|
else:
|
321
330
|
shuffle = False
|
322
331
|
if hasattr(cv, "random_state") and shuffle:
|
323
|
-
|
332
|
+
cv_random_state = cv.random_state
|
324
333
|
else:
|
325
|
-
|
334
|
+
cv_random_state = None
|
326
335
|
return cross_validate(
|
327
336
|
estimator,
|
328
337
|
x,
|
329
338
|
y,
|
330
339
|
groups=groups,
|
331
340
|
scoring=scoring,
|
332
|
-
cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=
|
341
|
+
cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=cv_random_state),
|
333
342
|
n_jobs=n_jobs,
|
334
343
|
verbose=verbose,
|
335
344
|
fit_params=fit_params,
|
@@ -337,21 +346,46 @@ def cross_validate(
|
|
337
346
|
return_train_score=return_train_score,
|
338
347
|
return_estimator=return_estimator,
|
339
348
|
error_score=error_score,
|
349
|
+
random_state=random_state,
|
340
350
|
)
|
341
351
|
raise e
|
342
352
|
|
343
353
|
|
344
|
-
def
|
354
|
+
def _is_catboost_estimator(estimator):
|
345
355
|
try:
|
346
356
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
357
|
+
|
347
358
|
return isinstance(estimator, (CatBoostClassifier, CatBoostRegressor))
|
348
359
|
except ImportError:
|
349
360
|
return False
|
350
361
|
|
351
362
|
|
352
|
-
def
|
363
|
+
def _supports_cat_features(estimator) -> bool:
|
364
|
+
"""Check if estimator's fit method accepts cat_features parameter.
|
365
|
+
|
366
|
+
Parameters
|
367
|
+
----------
|
368
|
+
estimator : estimator object
|
369
|
+
The estimator to check.
|
370
|
+
|
371
|
+
Returns
|
372
|
+
-------
|
373
|
+
bool
|
374
|
+
True if estimator's fit method accepts cat_features parameter, False otherwise.
|
375
|
+
"""
|
376
|
+
try:
|
377
|
+
# Get the signature of the fit method
|
378
|
+
fit_params = inspect.signature(estimator.fit).parameters
|
379
|
+
# Check if cat_features is in the parameters
|
380
|
+
return "cat_features" in fit_params
|
381
|
+
except (AttributeError, ValueError):
|
382
|
+
return False
|
383
|
+
|
384
|
+
|
385
|
+
def _is_lightgbm_estimator(estimator):
|
353
386
|
try:
|
354
387
|
from lightgbm import LGBMClassifier, LGBMRegressor
|
388
|
+
|
355
389
|
return isinstance(estimator, (LGBMClassifier, LGBMRegressor))
|
356
390
|
except ImportError:
|
357
391
|
return False
|
@@ -375,6 +409,8 @@ def _fit_and_score(
|
|
375
409
|
split_progress=None,
|
376
410
|
candidate_progress=None,
|
377
411
|
error_score=np.nan,
|
412
|
+
is_timeseries=False,
|
413
|
+
random_state=None,
|
378
414
|
):
|
379
415
|
"""Fit estimator and compute scores for a given dataset split.
|
380
416
|
|
@@ -509,13 +545,24 @@ def _fit_and_score(
|
|
509
545
|
|
510
546
|
result = {}
|
511
547
|
try:
|
548
|
+
if "cat_features" in fit_params and fit_params["cat_features"]:
|
549
|
+
X_train, y_train, X_test, y_test, cat_features, cat_encoder = _encode_cat_features(
|
550
|
+
X_train, y_train, X_test, y_test, fit_params["cat_features"], estimator, is_timeseries, random_state
|
551
|
+
)
|
552
|
+
if cat_features and _supports_cat_features(estimator):
|
553
|
+
fit_params["cat_features"] = cat_features
|
554
|
+
else:
|
555
|
+
del fit_params["cat_features"]
|
556
|
+
else:
|
557
|
+
cat_encoder = None
|
558
|
+
result["cat_encoder"] = cat_encoder
|
512
559
|
if y_train is None:
|
513
560
|
estimator.fit(X_train, **fit_params)
|
514
561
|
else:
|
515
|
-
if
|
562
|
+
if _is_catboost_estimator(estimator):
|
516
563
|
fit_params = fit_params.copy()
|
517
564
|
fit_params["eval_set"] = [(X_test, y_test)]
|
518
|
-
elif
|
565
|
+
elif _is_lightgbm_estimator(estimator):
|
519
566
|
fit_params = fit_params.copy()
|
520
567
|
fit_params["eval_set"] = [(X_test, y_test)]
|
521
568
|
estimator.fit(X_train, y_train, **fit_params)
|
@@ -1245,3 +1292,60 @@ def _num_samples(x):
|
|
1245
1292
|
return len(x)
|
1246
1293
|
except TypeError as type_error:
|
1247
1294
|
raise TypeError(message) from type_error
|
1295
|
+
|
1296
|
+
|
1297
|
+
def _encode_cat_features(X_train, y_train, X_test, y_test, cat_features, estimator, is_timeseries, random_state):
|
1298
|
+
if _is_catboost_estimator(estimator):
|
1299
|
+
if is_timeseries:
|
1300
|
+
# Fit encoder on training fold
|
1301
|
+
encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
|
1302
|
+
encoder.fit(X_train[cat_features], y_train)
|
1303
|
+
|
1304
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features]).astype(int)
|
1305
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
|
1306
|
+
|
1307
|
+
# Don't use as categorical features, so CatBoost will not encode them
|
1308
|
+
return X_train, y_train, X_test, y_test, [], encoder
|
1309
|
+
else:
|
1310
|
+
return X_train, y_train, X_test, y_test, cat_features, None
|
1311
|
+
else:
|
1312
|
+
if is_timeseries:
|
1313
|
+
# Fit encoder on training fold
|
1314
|
+
encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
|
1315
|
+
encoder.fit(X_train[cat_features], y_train)
|
1316
|
+
|
1317
|
+
# Progressive encoding on train (using y)
|
1318
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype(int)
|
1319
|
+
|
1320
|
+
# Static encoding on validation (no y)
|
1321
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
|
1322
|
+
|
1323
|
+
return X_train, y_train, X_test, y_test, [], encoder
|
1324
|
+
else:
|
1325
|
+
# Shuffle train data
|
1326
|
+
X_train_shuffled, y_train_shuffled = _shuffle_pair(
|
1327
|
+
X_train[cat_features], y_train, random_state
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
# Fit encoder on training fold
|
1331
|
+
encoder = CatBoostEncoder(random_state=random_state, cols=cat_features)
|
1332
|
+
encoder.fit(X_train_shuffled, y_train_shuffled)
|
1333
|
+
|
1334
|
+
# Progressive encoding on train (using y)
|
1335
|
+
X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype("category")
|
1336
|
+
|
1337
|
+
# Static encoding on validation (no y)
|
1338
|
+
X_test[cat_features] = encoder.transform(X_test[cat_features]).astype("category")
|
1339
|
+
|
1340
|
+
return X_train, y_train, X_test, y_test, cat_features, encoder
|
1341
|
+
|
1342
|
+
|
1343
|
+
def _shuffle_pair(X, y, random_state):
|
1344
|
+
# If X doesn't have reseted index there could be a problem
|
1345
|
+
# shuffled_idx = np.random.RandomState(random_state).permutation(len(X))
|
1346
|
+
# return X.iloc[shuffled_idx], pd.Series(y).iloc[shuffled_idx]
|
1347
|
+
|
1348
|
+
Xy = X.copy()
|
1349
|
+
Xy["target"] = y
|
1350
|
+
Xy_shuffled = Xy.sample(frac=1, random_state=random_state)
|
1351
|
+
return Xy_shuffled.drop(columns="target"), Xy_shuffled["target"]
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256=
|
1
|
+
upgini/__about__.py,sha256=2c1xmkbQfshecLuTpCtHd1FsSA6LAdrFr8uGLjxlkKs,23
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=fRtqSkXNONLnPe6cCL967GMt349FTIpXzy_u8LUKncw,35354
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=eFnJVb8jM1INlT-imfjafhWtOfx9EJv2HSvlfyGy0_U,216188
|
7
7
|
upgini/http.py,sha256=6Qcepv0tDC72mBBJxYHnA2xqw6QwFaKrXN8o4vju8Es,44372
|
8
8
|
upgini/metadata.py,sha256=zt_9k0iQbWXuiRZcel4ORNPdQKt6Ou69ucZD_E1Q46o,12341
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=zIOaiyfQLedU9Fk4877drnlWh-KiImSkZpPeiq6Xr1E,45295
|
10
10
|
upgini/search_task.py,sha256=Q5HjBpLIB3OCxAD1zNv5yQ3ZNJx696WCK_-H35_y7Rs,17912
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -31,14 +31,14 @@ upgini/autofe/timeseries/roll.py,sha256=zADKXU-eYWQnQ5R3am1yEal8uU6Tm0jLAixwPb_a
|
|
31
31
|
upgini/autofe/timeseries/trend.py,sha256=K1_iw2ko_LIUU8YCUgrvN3n0MkHtsi7-63-8x9er1k4,2129
|
32
32
|
upgini/autofe/timeseries/volatility.py,sha256=SvZfhM_ZAWCNpTf87WjSnZsnlblARgruDlu4By4Zvhc,8078
|
33
33
|
upgini/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
34
|
-
upgini/data_source/data_source_publisher.py,sha256=
|
34
|
+
upgini/data_source/data_source_publisher.py,sha256=GJd12WDqFBjLJDYQ4nG4SgOqDXS1duI8zIg_YKycjPI,24285
|
35
35
|
upgini/mdc/__init__.py,sha256=iHJlXQg6xRM1-ZOUtaPSJqw5SpQDszvxp4LyqviNLIQ,1027
|
36
36
|
upgini/mdc/context.py,sha256=3u1B-jXt7tXEvNcV3qmR9SDCseudnY7KYsLclBdwVLk,1405
|
37
37
|
upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
38
|
upgini/normalizer/normalize_utils.py,sha256=g2TcDXZeJp9kAFO2sTqZ4CAsN4J1qHNgoJHZ8gtzUWo,7376
|
39
39
|
upgini/resource_bundle/__init__.py,sha256=S5F2G47pnJd2LDpmFsjDqEwiKkP8Hm-hcseDbMka6Ko,8345
|
40
40
|
upgini/resource_bundle/exceptions.py,sha256=5fRvx0_vWdE1-7HcSgF0tckB4A9AKyf5RiinZkInTsI,621
|
41
|
-
upgini/resource_bundle/strings.properties,sha256=
|
41
|
+
upgini/resource_bundle/strings.properties,sha256=xpHD-3mW1U6Nca0QghC6FSrQLDci9pInuMpOBPPiB8M,28212
|
42
42
|
upgini/resource_bundle/strings_widget.properties,sha256=gOdqvZWntP2LCza_tyVk1_yRYcG4c04K9sQOAVhF_gw,1577
|
43
43
|
upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
upgini/sampler/base.py,sha256=7GpjYqjOp58vYcJLiX__1R5wjUlyQbxvHJ2klFnup_M,6389
|
@@ -51,8 +51,8 @@ upgini/utils/blocked_time_series.py,sha256=Uqr3vp4YqNclj2-PzEYqVy763GSXHn86sbpIl
|
|
51
51
|
upgini/utils/country_utils.py,sha256=lY-eXWwFVegdVENFttbvLcgGDjFO17Sex8hd2PyJaRk,6937
|
52
52
|
upgini/utils/custom_loss_utils.py,sha256=kieNZYBYZm5ZGBltF1F_jOSF4ea6C29rYuCyiDcqVNY,3857
|
53
53
|
upgini/utils/cv_utils.py,sha256=w6FQb9nO8BWDx88EF83NpjPLarK4eR4ia0Wg0kLBJC4,3525
|
54
|
-
upgini/utils/datetime_utils.py,sha256=
|
55
|
-
upgini/utils/deduplicate_utils.py,sha256=
|
54
|
+
upgini/utils/datetime_utils.py,sha256=UL1ernnawW0LV9mPDpCIc6sFy0HUhFscWVNwfH4V7rI,14366
|
55
|
+
upgini/utils/deduplicate_utils.py,sha256=EpBVCov42-FJIAPfa4jY_ZRct3N2MFaC7i-oJNZ_MGI,8954
|
56
56
|
upgini/utils/display_utils.py,sha256=hAeWEcJtPDg8fAVcMNrNB-azFD2WJp1nvbPAhR7SeP4,12071
|
57
57
|
upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
|
58
58
|
upgini/utils/fallback_progress_bar.py,sha256=PDaKb8dYpVZaWMroNcOHsTc3pSjgi9mOm0--cOFTwJ0,1074
|
@@ -64,13 +64,13 @@ upgini/utils/mstats.py,sha256=u3gQVUtDRbyrOQK6V1UJ2Rx1QbkSNYGjXa6m3Z_dPVs,6286
|
|
64
64
|
upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,10432
|
65
65
|
upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
|
66
66
|
upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
|
67
|
-
upgini/utils/sklearn_ext.py,sha256=
|
67
|
+
upgini/utils/sklearn_ext.py,sha256=jLJWAKkqQinV15Z4y1ZnsN3c-fKFwXTsprs00COnyVU,49315
|
68
68
|
upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
|
69
69
|
upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,16832
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.
|
74
|
-
upgini-1.2.
|
75
|
-
upgini-1.2.
|
76
|
-
upgini-1.2.
|
73
|
+
upgini-1.2.87.dist-info/METADATA,sha256=7RwdKFD1Q_DPR057nF27EPBCwNWtQl8SLOX0dc3n0do,49162
|
74
|
+
upgini-1.2.87.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.87.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.87.dist-info/RECORD,,
|
File without changes
|
File without changes
|