upgini 1.2.68a3832.dev8__py3-none-any.whl → 1.2.68a3832.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/__about__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.2.68a3832.dev8"
1
+ __version__ = "1.2.68a3832.dev9"
upgini/metrics.py CHANGED
@@ -491,30 +491,15 @@ class EstimatorWrapper:
491
491
  }
492
492
  if estimator is None:
493
493
  params = {}
494
- # emb_pattern = r"(.+)_emb\d+"
495
- # emb_features = [c for c in x.columns if re.match(emb_pattern, c) and is_numeric_dtype(x[c])]
496
- # max_bin_by_feature_type = {
497
- # feature: 63 if feature in emb_features else 255 for feature in x.columns
498
- # }
499
- # params["max_bin_by_feature_type"] = max_bin_by_feature_type
500
494
  if target_type == ModelTaskType.MULTICLASS:
501
- # params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
502
- # params = _get_add_params(params, add_params)
503
- # estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
504
495
  params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
505
496
  params = _get_add_params(params, add_params)
506
497
  estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
507
498
  elif target_type == ModelTaskType.BINARY:
508
- # params = _get_add_params(params, CATBOOST_BINARY_PARAMS)
509
- # params = _get_add_params(params, add_params)
510
- # estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
511
499
  params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
512
500
  params = _get_add_params(params, add_params)
513
501
  estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
514
502
  elif target_type == ModelTaskType.REGRESSION:
515
- # params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
516
- # params = _get_add_params(params, add_params)
517
- # estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
518
503
  params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
519
504
  params = _get_add_params(params, add_params)
520
505
  estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
@@ -527,18 +512,19 @@ class EstimatorWrapper:
527
512
  estimator_copy = deepcopy(estimator)
528
513
  kwargs["estimator"] = estimator_copy
529
514
  if is_catboost_estimator(estimator):
530
- params["has_time"] = has_date
531
515
  if cat_features is not None:
532
516
  for cat_feature in cat_features:
533
517
  if cat_feature not in x.columns:
534
518
  logger.error(
535
519
  f"Client cat_feature `{cat_feature}` not found in x columns: {x.columns.to_list()}"
536
520
  )
537
- estimator_copy.set_params(cat_features=cat_features)
521
+ estimator_copy.set_params(cat_features=cat_features, has_time=has_date)
538
522
  estimator = CatBoostWrapper(**kwargs)
539
523
  else:
540
524
  if isinstance(estimator, (LGBMClassifier, LGBMRegressor)):
541
525
  estimator = LightGBMWrapper(**kwargs)
526
+ elif is_catboost_estimator(estimator):
527
+ estimator = CatBoostWrapper(**kwargs)
542
528
  else:
543
529
  logger.warning(
544
530
  f"Unexpected estimator is used for metrics: {estimator}. "
@@ -765,14 +751,12 @@ class LightGBMWrapper(EstimatorWrapper):
765
751
  self.cat_features = None
766
752
 
767
753
  def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, np.ndarray, dict]:
768
- x, y, groups, params = super()._prepare_to_fit(x, y)
769
- if self.target_type == ModelTaskType.MULTICLASS:
770
- params["num_class"] = y.nunique()
754
+ x, y_numpy, groups, params = super()._prepare_to_fit(x, y)
771
755
  self.cat_features = _get_cat_features(x)
772
756
  x = fill_na_cat_features(x, self.cat_features)
773
757
  for feature in self.cat_features:
774
758
  x[feature] = x[feature].astype("category").cat.codes
775
- if not is_numeric_dtype(y):
759
+ if not is_numeric_dtype(y_numpy):
776
760
  y = correct_string_target(y)
777
761
 
778
762
  return x, y, groups, params
@@ -90,7 +90,8 @@ class FeatureInfo:
90
90
  def _get_feature_sample(feature_meta: FeaturesMetadataV2, data: Optional[pd.DataFrame]) -> str:
91
91
  if data is not None and len(data) > 0 and feature_meta.name in data.columns:
92
92
  if len(data) > 3:
93
- feature_sample = np.random.choice(data[feature_meta.name].dropna().unique(), 3).tolist()
93
+ rand = np.random.RandomState(42)
94
+ feature_sample = rand.choice(data[feature_meta.name].dropna().unique(), 3).tolist()
94
95
  else:
95
96
  feature_sample = data[feature_meta.name].dropna().unique().tolist()
96
97
  if len(feature_sample) > 0 and isinstance(feature_sample[0], float):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.68a3832.dev8
3
+ Version: 1.2.68a3832.dev9
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -1,4 +1,4 @@
1
- upgini/__about__.py,sha256=KMZpRXK_ksEVGZxYvE4jNHZgG-Ce5Wv3Crjnd_eiTNE,33
1
+ upgini/__about__.py,sha256=PUD99JYjzGuY1b4bkoVfexbGcQ412xzz1bitG33oTYM,33
2
2
  upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
3
3
  upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
4
4
  upgini/dataset.py,sha256=1rb6BzyuiQFGVCTDmKL2wox3UFRNjtNaIJOwQnZ801A,34956
@@ -7,7 +7,7 @@ upgini/features_enricher.py,sha256=GXXx14jwf3F26_KrfJ6O40Vcu1hRx5iBjUB_jxy3Xvg,2
7
7
  upgini/http.py,sha256=ud0Cp7h0jNeHuuZGpU_1dAAEiabGoJjGxc1X5oeBQr4,43496
8
8
  upgini/lazy_import.py,sha256=74gQ8JuA48BGRLxAo7lNHNKY2D2emMxrUxKGdxVGhuY,1012
9
9
  upgini/metadata.py,sha256=Jh6YTaS00m_nbaOY_owvlSyn9zgkErkqu8iTr9ZjKI8,12279
10
- upgini/metrics.py,sha256=1YFj2tmnOYLL4-ZXNZJDYclZADX0w6556DlN6TOlZ44,38686
10
+ upgini/metrics.py,sha256=5g2N-z7IFH52tkgeOQjOq2-lsn-s1S4PtEWlwoS14jg,37629
11
11
  upgini/search_task.py,sha256=qxUxAD-bed-FpZYmTB_4orW7YJsW_O6a1TcgnZIRFr4,17307
12
12
  upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
13
13
  upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
@@ -56,7 +56,7 @@ upgini/utils/deduplicate_utils.py,sha256=SMZx9IKIhWI5HqXepfKiQb3uDJrogQZtG6jcWuM
56
56
  upgini/utils/display_utils.py,sha256=DsBjJ8jEYAh8BPgfAbzq5imoGFV6IACP20PQ78BQCX0,11964
57
57
  upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
58
58
  upgini/utils/fallback_progress_bar.py,sha256=PDaKb8dYpVZaWMroNcOHsTc3pSjgi9mOm0--cOFTwJ0,1074
59
- upgini/utils/feature_info.py,sha256=m1tQcT3hTChPAiXzpk0WQcEqElj8KgeCifEJFa7-gss,7247
59
+ upgini/utils/feature_info.py,sha256=Q9HN6A-fvfVD-irFWrmOqqZG9RsUSvh5MTY_k0xu-tE,7287
60
60
  upgini/utils/features_validator.py,sha256=lEfmk4DoxZ4ooOE1HC0ZXtUb_lFKRFHIrnFULZ4_rL8,3746
61
61
  upgini/utils/format.py,sha256=Yv5cvvSs2bOLUzzNu96Pu33VMDNbabio92QepUj41jU,243
62
62
  upgini/utils/ip_utils.py,sha256=TSQ_qDsLlVnm09X1HacpabEf_HNqSWpxBF4Sdc2xs08,6580
@@ -70,7 +70,7 @@ upgini/utils/target_utils.py,sha256=b1GzO8_gMcwXSZ2v98CY50MJJBzKbWHId_BJGybXfkM,
70
70
  upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
71
71
  upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
72
72
  upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
73
- upgini-1.2.68a3832.dev8.dist-info/METADATA,sha256=LQi_ixiFjU2qIyoVie4__YDTl_2Tzp6bGlZHFLT5PP4,49149
74
- upgini-1.2.68a3832.dev8.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
- upgini-1.2.68a3832.dev8.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
- upgini-1.2.68a3832.dev8.dist-info/RECORD,,
73
+ upgini-1.2.68a3832.dev9.dist-info/METADATA,sha256=zgbDmansWs-GTVIg9WJrS86ZRfIZ4Psx-2Ao7JVFIQs,49149
74
+ upgini-1.2.68a3832.dev9.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
+ upgini-1.2.68a3832.dev9.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
+ upgini-1.2.68a3832.dev9.dist-info/RECORD,,