upgini 1.2.86a2__py3-none-any.whl → 1.2.87__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import inspect
2
3
  import numbers
3
4
  import time
4
5
  import warnings
@@ -9,6 +10,7 @@ from traceback import format_exc
9
10
 
10
11
  import numpy as np
11
12
  import scipy.sparse as sp
13
+ from category_encoders import CatBoostEncoder
12
14
  from joblib import Parallel, logger
13
15
  from scipy.sparse import issparse
14
16
  from sklearn import config_context, get_config
@@ -16,10 +18,13 @@ from sklearn.base import clone, is_classifier
16
18
  from sklearn.exceptions import FitFailedWarning, NotFittedError
17
19
  from sklearn.metrics import check_scoring
18
20
  from sklearn.metrics._scorer import _MultimetricScorer
19
- from sklearn.model_selection import StratifiedKFold, check_cv
21
+ from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit, check_cv
22
+ from sklearn.preprocessing import OrdinalEncoder
20
23
  from sklearn.utils.fixes import np_version, parse_version
21
24
  from sklearn.utils.validation import indexable
22
25
 
26
+ from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
27
+
23
28
  # from sklearn.model_selection import cross_validate as original_cross_validate
24
29
 
25
30
  _DEFAULT_TAGS = {
@@ -59,6 +64,7 @@ def cross_validate(
59
64
  return_train_score=False,
60
65
  return_estimator=False,
61
66
  error_score=np.nan,
67
+ random_state=None,
62
68
  ):
63
69
  """Evaluate metric(s) by cross-validation and also record fit/score times.
64
70
 
@@ -279,6 +285,8 @@ def cross_validate(
279
285
  return_times=True,
280
286
  return_estimator=return_estimator,
281
287
  error_score=error_score,
288
+ is_timeseries=isinstance(cv, TimeSeriesSplit) or isinstance(cv, BlockedTimeSeriesSplit),
289
+ random_state=random_state,
282
290
  )
283
291
  for train, test in cv.split(x, y, groups)
284
292
  )
@@ -296,6 +304,7 @@ def cross_validate(
296
304
  ret = {}
297
305
  ret["fit_time"] = results["fit_time"]
298
306
  ret["score_time"] = results["score_time"]
307
+ ret["cat_encoder"] = results["cat_encoder"]
299
308
 
300
309
  if return_estimator:
301
310
  ret["estimator"] = results["estimator"]
@@ -320,16 +329,16 @@ def cross_validate(
320
329
  else:
321
330
  shuffle = False
322
331
  if hasattr(cv, "random_state") and shuffle:
323
- random_state = cv.random_state
332
+ cv_random_state = cv.random_state
324
333
  else:
325
- random_state = None
334
+ cv_random_state = None
326
335
  return cross_validate(
327
336
  estimator,
328
337
  x,
329
338
  y,
330
339
  groups=groups,
331
340
  scoring=scoring,
332
- cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=random_state),
341
+ cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=cv_random_state),
333
342
  n_jobs=n_jobs,
334
343
  verbose=verbose,
335
344
  fit_params=fit_params,
@@ -337,21 +346,46 @@ def cross_validate(
337
346
  return_train_score=return_train_score,
338
347
  return_estimator=return_estimator,
339
348
  error_score=error_score,
349
+ random_state=random_state,
340
350
  )
341
351
  raise e
342
352
 
343
353
 
344
- def is_catboost_estimator(estimator):
354
+ def _is_catboost_estimator(estimator):
345
355
  try:
346
356
  from catboost import CatBoostClassifier, CatBoostRegressor
357
+
347
358
  return isinstance(estimator, (CatBoostClassifier, CatBoostRegressor))
348
359
  except ImportError:
349
360
  return False
350
361
 
351
362
 
352
- def is_lightgbm_estimator(estimator):
363
+ def _supports_cat_features(estimator) -> bool:
364
+ """Check if estimator's fit method accepts cat_features parameter.
365
+
366
+ Parameters
367
+ ----------
368
+ estimator : estimator object
369
+ The estimator to check.
370
+
371
+ Returns
372
+ -------
373
+ bool
374
+ True if estimator's fit method accepts cat_features parameter, False otherwise.
375
+ """
376
+ try:
377
+ # Get the signature of the fit method
378
+ fit_params = inspect.signature(estimator.fit).parameters
379
+ # Check if cat_features is in the parameters
380
+ return "cat_features" in fit_params
381
+ except (AttributeError, ValueError):
382
+ return False
383
+
384
+
385
+ def _is_lightgbm_estimator(estimator):
353
386
  try:
354
387
  from lightgbm import LGBMClassifier, LGBMRegressor
388
+
355
389
  return isinstance(estimator, (LGBMClassifier, LGBMRegressor))
356
390
  except ImportError:
357
391
  return False
@@ -375,6 +409,8 @@ def _fit_and_score(
375
409
  split_progress=None,
376
410
  candidate_progress=None,
377
411
  error_score=np.nan,
412
+ is_timeseries=False,
413
+ random_state=None,
378
414
  ):
379
415
  """Fit estimator and compute scores for a given dataset split.
380
416
 
@@ -509,13 +545,24 @@ def _fit_and_score(
509
545
 
510
546
  result = {}
511
547
  try:
548
+ if "cat_features" in fit_params and fit_params["cat_features"]:
549
+ X_train, y_train, X_test, y_test, cat_features, cat_encoder = _encode_cat_features(
550
+ X_train, y_train, X_test, y_test, fit_params["cat_features"], estimator, is_timeseries, random_state
551
+ )
552
+ if cat_features and _supports_cat_features(estimator):
553
+ fit_params["cat_features"] = cat_features
554
+ else:
555
+ del fit_params["cat_features"]
556
+ else:
557
+ cat_encoder = None
558
+ result["cat_encoder"] = cat_encoder
512
559
  if y_train is None:
513
560
  estimator.fit(X_train, **fit_params)
514
561
  else:
515
- if is_catboost_estimator(estimator):
562
+ if _is_catboost_estimator(estimator):
516
563
  fit_params = fit_params.copy()
517
564
  fit_params["eval_set"] = [(X_test, y_test)]
518
- elif is_lightgbm_estimator(estimator):
565
+ elif _is_lightgbm_estimator(estimator):
519
566
  fit_params = fit_params.copy()
520
567
  fit_params["eval_set"] = [(X_test, y_test)]
521
568
  estimator.fit(X_train, y_train, **fit_params)
@@ -1245,3 +1292,60 @@ def _num_samples(x):
1245
1292
  return len(x)
1246
1293
  except TypeError as type_error:
1247
1294
  raise TypeError(message) from type_error
1295
+
1296
+
1297
+ def _encode_cat_features(X_train, y_train, X_test, y_test, cat_features, estimator, is_timeseries, random_state):
1298
+ if _is_catboost_estimator(estimator):
1299
+ if is_timeseries:
1300
+ # Fit encoder on training fold
1301
+ encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
1302
+ encoder.fit(X_train[cat_features], y_train)
1303
+
1304
+ X_train[cat_features] = encoder.transform(X_train[cat_features]).astype(int)
1305
+ X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
1306
+
1307
+ # Don't use as categorical features, so CatBoost will not encode them
1308
+ return X_train, y_train, X_test, y_test, [], encoder
1309
+ else:
1310
+ return X_train, y_train, X_test, y_test, cat_features, None
1311
+ else:
1312
+ if is_timeseries:
1313
+ # Fit encoder on training fold
1314
+ encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
1315
+ encoder.fit(X_train[cat_features], y_train)
1316
+
1317
+ # Progressive encoding on train (using y)
1318
+ X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype(int)
1319
+
1320
+ # Static encoding on validation (no y)
1321
+ X_test[cat_features] = encoder.transform(X_test[cat_features]).astype(int)
1322
+
1323
+ return X_train, y_train, X_test, y_test, [], encoder
1324
+ else:
1325
+ # Shuffle train data
1326
+ X_train_shuffled, y_train_shuffled = _shuffle_pair(
1327
+ X_train[cat_features], y_train, random_state
1328
+ )
1329
+
1330
+ # Fit encoder on training fold
1331
+ encoder = CatBoostEncoder(random_state=random_state, cols=cat_features)
1332
+ encoder.fit(X_train_shuffled, y_train_shuffled)
1333
+
1334
+ # Progressive encoding on train (using y)
1335
+ X_train[cat_features] = encoder.transform(X_train[cat_features], y_train).astype("category")
1336
+
1337
+ # Static encoding on validation (no y)
1338
+ X_test[cat_features] = encoder.transform(X_test[cat_features]).astype("category")
1339
+
1340
+ return X_train, y_train, X_test, y_test, cat_features, encoder
1341
+
1342
+
1343
+ def _shuffle_pair(X, y, random_state):
1344
+ # If X doesn't have reseted index there could be a problem
1345
+ # shuffled_idx = np.random.RandomState(random_state).permutation(len(X))
1346
+ # return X.iloc[shuffled_idx], pd.Series(y).iloc[shuffled_idx]
1347
+
1348
+ Xy = X.copy()
1349
+ Xy["target"] = y
1350
+ Xy_shuffled = Xy.sample(frac=1, random_state=random_state)
1351
+ return Xy_shuffled.drop(columns="target"), Xy_shuffled["target"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.86a2
3
+ Version: 1.2.87
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -1,12 +1,12 @@
1
- upgini/__about__.py,sha256=yGLa0SZe61T_OjwHem32zlqsP2f3eCCrsj4uwsanjlA,25
1
+ upgini/__about__.py,sha256=2c1xmkbQfshecLuTpCtHd1FsSA6LAdrFr8uGLjxlkKs,23
2
2
  upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
3
3
  upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
4
4
  upgini/dataset.py,sha256=fRtqSkXNONLnPe6cCL967GMt349FTIpXzy_u8LUKncw,35354
5
5
  upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
6
- upgini/features_enricher.py,sha256=G0qbRPdlWe9p6cwYF3khP99-0kgAO8N0A2sfQxSLgmM,213446
6
+ upgini/features_enricher.py,sha256=eFnJVb8jM1INlT-imfjafhWtOfx9EJv2HSvlfyGy0_U,216188
7
7
  upgini/http.py,sha256=6Qcepv0tDC72mBBJxYHnA2xqw6QwFaKrXN8o4vju8Es,44372
8
8
  upgini/metadata.py,sha256=zt_9k0iQbWXuiRZcel4ORNPdQKt6Ou69ucZD_E1Q46o,12341
9
- upgini/metrics.py,sha256=3cip0_L6-OFew74KsRwzxJDU6UFq05h2v7IsyHLcMRc,43164
9
+ upgini/metrics.py,sha256=zIOaiyfQLedU9Fk4877drnlWh-KiImSkZpPeiq6Xr1E,45295
10
10
  upgini/search_task.py,sha256=Q5HjBpLIB3OCxAD1zNv5yQ3ZNJx696WCK_-H35_y7Rs,17912
11
11
  upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
12
12
  upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
@@ -31,14 +31,14 @@ upgini/autofe/timeseries/roll.py,sha256=zADKXU-eYWQnQ5R3am1yEal8uU6Tm0jLAixwPb_a
31
31
  upgini/autofe/timeseries/trend.py,sha256=K1_iw2ko_LIUU8YCUgrvN3n0MkHtsi7-63-8x9er1k4,2129
32
32
  upgini/autofe/timeseries/volatility.py,sha256=SvZfhM_ZAWCNpTf87WjSnZsnlblARgruDlu4By4Zvhc,8078
33
33
  upgini/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
- upgini/data_source/data_source_publisher.py,sha256=4S9qwlAklD8vg9tUU_c1pHE2_glUHAh15-wr5hMwKFw,22879
34
+ upgini/data_source/data_source_publisher.py,sha256=GJd12WDqFBjLJDYQ4nG4SgOqDXS1duI8zIg_YKycjPI,24285
35
35
  upgini/mdc/__init__.py,sha256=iHJlXQg6xRM1-ZOUtaPSJqw5SpQDszvxp4LyqviNLIQ,1027
36
36
  upgini/mdc/context.py,sha256=3u1B-jXt7tXEvNcV3qmR9SDCseudnY7KYsLclBdwVLk,1405
37
37
  upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  upgini/normalizer/normalize_utils.py,sha256=g2TcDXZeJp9kAFO2sTqZ4CAsN4J1qHNgoJHZ8gtzUWo,7376
39
39
  upgini/resource_bundle/__init__.py,sha256=S5F2G47pnJd2LDpmFsjDqEwiKkP8Hm-hcseDbMka6Ko,8345
40
40
  upgini/resource_bundle/exceptions.py,sha256=5fRvx0_vWdE1-7HcSgF0tckB4A9AKyf5RiinZkInTsI,621
41
- upgini/resource_bundle/strings.properties,sha256=U_ewTI-qPww4X3WcFG3qDf_jv2vo6RrlCehVDjqtzEI,27991
41
+ upgini/resource_bundle/strings.properties,sha256=xpHD-3mW1U6Nca0QghC6FSrQLDci9pInuMpOBPPiB8M,28212
42
42
  upgini/resource_bundle/strings_widget.properties,sha256=gOdqvZWntP2LCza_tyVk1_yRYcG4c04K9sQOAVhF_gw,1577
43
43
  upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  upgini/sampler/base.py,sha256=7GpjYqjOp58vYcJLiX__1R5wjUlyQbxvHJ2klFnup_M,6389
@@ -51,8 +51,8 @@ upgini/utils/blocked_time_series.py,sha256=Uqr3vp4YqNclj2-PzEYqVy763GSXHn86sbpIl
51
51
  upgini/utils/country_utils.py,sha256=lY-eXWwFVegdVENFttbvLcgGDjFO17Sex8hd2PyJaRk,6937
52
52
  upgini/utils/custom_loss_utils.py,sha256=kieNZYBYZm5ZGBltF1F_jOSF4ea6C29rYuCyiDcqVNY,3857
53
53
  upgini/utils/cv_utils.py,sha256=w6FQb9nO8BWDx88EF83NpjPLarK4eR4ia0Wg0kLBJC4,3525
54
- upgini/utils/datetime_utils.py,sha256=_jq-kn_dGNFfs-DGXcWCGzy9bkplfAjrZ8SsmN28zXc,13535
55
- upgini/utils/deduplicate_utils.py,sha256=AcMLoObMjhOTQ_fMS1LWy0GKp6WXnZ-FNux_8V3nbZU,8914
54
+ upgini/utils/datetime_utils.py,sha256=UL1ernnawW0LV9mPDpCIc6sFy0HUhFscWVNwfH4V7rI,14366
55
+ upgini/utils/deduplicate_utils.py,sha256=EpBVCov42-FJIAPfa4jY_ZRct3N2MFaC7i-oJNZ_MGI,8954
56
56
  upgini/utils/display_utils.py,sha256=hAeWEcJtPDg8fAVcMNrNB-azFD2WJp1nvbPAhR7SeP4,12071
57
57
  upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
58
58
  upgini/utils/fallback_progress_bar.py,sha256=PDaKb8dYpVZaWMroNcOHsTc3pSjgi9mOm0--cOFTwJ0,1074
@@ -64,13 +64,13 @@ upgini/utils/mstats.py,sha256=u3gQVUtDRbyrOQK6V1UJ2Rx1QbkSNYGjXa6m3Z_dPVs,6286
64
64
  upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,10432
65
65
  upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
66
66
  upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
67
- upgini/utils/sklearn_ext.py,sha256=HpaNQaKJisgNE7IZ71n7uswxTj7kbPglU2G3s1sORAc,45042
67
+ upgini/utils/sklearn_ext.py,sha256=jLJWAKkqQinV15Z4y1ZnsN3c-fKFwXTsprs00COnyVU,49315
68
68
  upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
69
69
  upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,16832
70
70
  upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
71
71
  upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
72
72
  upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
73
- upgini-1.2.86a2.dist-info/METADATA,sha256=xU87Vnwtiae10PnJMUIC5KiOMP_TUEZ8BeafznKJxCg,49164
74
- upgini-1.2.86a2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
- upgini-1.2.86a2.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
- upgini-1.2.86a2.dist-info/RECORD,,
73
+ upgini-1.2.87.dist-info/METADATA,sha256=7RwdKFD1Q_DPR057nF27EPBCwNWtQl8SLOX0dc3n0do,49162
74
+ upgini-1.2.87.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
+ upgini-1.2.87.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
+ upgini-1.2.87.dist-info/RECORD,,