upgini 1.2.81a3832.dev1__py3-none-any.whl → 1.2.81a3832.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/__about__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.2.81a3832.dev1"
1
+ __version__ = "1.2.81a3832.dev2"
@@ -310,6 +310,7 @@ class FeaturesEnricher(TransformerMixin):
310
310
  self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
311
311
  file_metadata = self._search_task.get_file_metadata(trace_id)
312
312
  x_columns = [c.originalName or c.name for c in file_metadata.columns]
313
+ self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
313
314
  df = pd.DataFrame(columns=x_columns)
314
315
  self.__prepare_feature_importances(trace_id, df, silent=True)
315
316
  # TODO validate search_keys with search_keys from file_metadata
@@ -476,7 +477,7 @@ class FeaturesEnricher(TransformerMixin):
476
477
  self.__validate_search_keys(self.search_keys)
477
478
 
478
479
  # Validate client estimator params
479
- self._get_client_cat_features(estimator, X, self.search_keys)
480
+ self._get_and_validate_client_cat_features(estimator, X, self.search_keys)
480
481
 
481
482
  try:
482
483
  self.X = X
@@ -957,9 +958,17 @@ class FeaturesEnricher(TransformerMixin):
957
958
  self.__display_support_link(msg)
958
959
  return None
959
960
 
960
- client_cat_features, search_keys_for_metrics = self._get_client_cat_features(
961
+ cat_features_from_backend = self.__get_categorical_features()
962
+ client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
961
963
  estimator, validated_X, self.search_keys
962
964
  )
965
+ for cat_feature in cat_features_from_backend:
966
+ original_cat_feature = self.fit_columns_renaming.get(cat_feature)
967
+ if original_cat_feature in self.search_keys:
968
+ if self.search_keys[original_cat_feature] in [SearchKey.COUNTRY, SearchKey.POSTAL_CODE]:
969
+ search_keys_for_metrics.append(original_cat_feature)
970
+ else:
971
+ self.logger.warning(self.bundle.get("cat_feature_search_key").format(original_cat_feature))
963
972
  search_keys_for_metrics.extend([c for c in self.id_columns or [] if c not in search_keys_for_metrics])
964
973
  self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
965
974
 
@@ -976,7 +985,7 @@ class FeaturesEnricher(TransformerMixin):
976
985
  search_keys_for_metrics=search_keys_for_metrics,
977
986
  progress_bar=progress_bar,
978
987
  progress_callback=progress_callback,
979
- cat_features=client_cat_features,
988
+ client_cat_features=client_cat_features,
980
989
  )
981
990
  if prepared_data is None:
982
991
  return None
@@ -1027,7 +1036,6 @@ class FeaturesEnricher(TransformerMixin):
1027
1036
 
1028
1037
  has_date = self._get_date_column(search_keys) is not None
1029
1038
  model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
1030
- cat_features_from_backend = self.__get_categorical_features()
1031
1039
  cat_features = list(set(client_cat_features + cat_features_from_backend))
1032
1040
  baseline_cat_features = [f for f in cat_features if f in fitting_X.columns]
1033
1041
  enriched_cat_features = [f for f in cat_features if f in fitting_enriched_X.columns]
@@ -1423,7 +1431,7 @@ class FeaturesEnricher(TransformerMixin):
1423
1431
 
1424
1432
  return _cv, groups
1425
1433
 
1426
- def _get_client_cat_features(
1434
+ def _get_and_validate_client_cat_features(
1427
1435
  self, estimator: Optional[Any], X: pd.DataFrame, search_keys: Dict[str, SearchKey]
1428
1436
  ) -> Tuple[Optional[List[str]], List[str]]:
1429
1437
  cat_features = None
@@ -1468,7 +1476,7 @@ class FeaturesEnricher(TransformerMixin):
1468
1476
  search_keys_for_metrics: Optional[List[str]] = None,
1469
1477
  progress_bar: Optional[ProgressBar] = None,
1470
1478
  progress_callback: Optional[Callable[[SearchProgress], Any]] = None,
1471
- cat_features: Optional[List[str]] = None,
1479
+ client_cat_features: Optional[List[str]] = None,
1472
1480
  ):
1473
1481
  is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
1474
1482
  is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
@@ -1542,7 +1550,7 @@ class FeaturesEnricher(TransformerMixin):
1542
1550
 
1543
1551
  # Detect and drop high cardinality columns in train
1544
1552
  columns_with_high_cardinality = FeaturesValidator.find_high_cardinality(fitting_X)
1545
- non_excluding_columns = (self.generate_features or []) + (cat_features or [])
1553
+ non_excluding_columns = (self.generate_features or []) + (client_cat_features or [])
1546
1554
  columns_with_high_cardinality = [c for c in columns_with_high_cardinality if c not in non_excluding_columns]
1547
1555
  if len(columns_with_high_cardinality) > 0:
1548
1556
  self.logger.warning(
@@ -2080,10 +2088,12 @@ class FeaturesEnricher(TransformerMixin):
2080
2088
  search_keys: Dict,
2081
2089
  columns_renaming: Dict[str, str],
2082
2090
  ):
2091
+ # X_sampled - with hash-suffixes
2092
+ reversed_renaming = {v: k for k, v in columns_renaming.items()}
2083
2093
  search_keys = {
2084
- columns_renaming.get(k, k): v
2094
+ reversed_renaming.get(k, k): v
2085
2095
  for k, v in search_keys.items()
2086
- if columns_renaming.get(k, k) in X_sampled.columns.to_list()
2096
+ if reversed_renaming.get(k, k) in X_sampled.columns.to_list()
2087
2097
  }
2088
2098
  return FeaturesEnricher._SampledDataForMetrics(
2089
2099
  X_sampled=X_sampled,
@@ -3871,7 +3881,7 @@ if response.status_code == 200:
3871
3881
  if features_meta is None:
3872
3882
  raise Exception(self.bundle.get("missing_features_meta"))
3873
3883
 
3874
- return [f.name for f in features_meta if f.type == "categorical"]
3884
+ return [f.name for f in features_meta if f.type == "categorical" and f.shap_value > 0.0]
3875
3885
 
3876
3886
  def __prepare_feature_importances(
3877
3887
  self, trace_id: str, df: pd.DataFrame, updated_shaps: Optional[Dict[str, float]] = None, silent=False
upgini/http.py CHANGED
@@ -20,7 +20,7 @@ import jwt
20
20
  # import pandas as pd
21
21
  import requests
22
22
  from pydantic import BaseModel
23
- from pythonjsonlogger import jsonlogger
23
+ from pythonjsonlogger import json as jsonlogger
24
24
  from requests.exceptions import RequestException
25
25
 
26
26
  from upgini.__about__ import __version__
@@ -459,19 +459,19 @@ class _RestClient:
459
459
  content = file.read()
460
460
  md5_hash.update(content)
461
461
  digest = md5_hash.hexdigest()
462
- metadata_with_md5 = metadata.copy(update={"checksumMD5": digest})
462
+ metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
463
463
 
464
464
  # digest_sha256 = hashlib.sha256(
465
465
  # pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
466
466
  # ).hexdigest()
467
467
  digest_sha256 = self.compute_file_digest(file_path)
468
- metadata_with_md5 = metadata_with_md5.copy(update={"digest": digest_sha256})
468
+ metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
469
469
 
470
470
  with open(file_path, "rb") as file:
471
471
  files = {
472
472
  "metadata": (
473
473
  "metadata.json",
474
- metadata_with_md5.json(exclude_none=True).encode(),
474
+ metadata_with_md5.model_dump_json(exclude_none=True).encode(),
475
475
  "application/json",
476
476
  ),
477
477
  "tracking": (
@@ -481,7 +481,7 @@ class _RestClient:
481
481
  ),
482
482
  "metrics": (
483
483
  "metrics.json",
484
- metrics.json(exclude_none=True).encode(),
484
+ metrics.model_dump_json(exclude_none=True).encode(),
485
485
  "application/json",
486
486
  ),
487
487
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
@@ -489,7 +489,7 @@ class _RestClient:
489
489
  if search_customization is not None:
490
490
  files["customization"] = (
491
491
  "customization.json",
492
- search_customization.json(exclude_none=True).encode(),
492
+ search_customization.model_dump_json(exclude_none=True).encode(),
493
493
  "application/json",
494
494
  )
495
495
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
@@ -504,7 +504,7 @@ class _RestClient:
504
504
  def check_uploaded_file_v2(self, trace_id: str, file_upload_id: str, metadata: FileMetadata) -> bool:
505
505
  api_path = self.CHECK_UPLOADED_FILE_URL_FMT_V2.format(file_upload_id)
506
506
  response = self._with_unauth_retry(
507
- lambda: self._send_post_req(api_path, trace_id, metadata.json(exclude_none=True))
507
+ lambda: self._send_post_req(api_path, trace_id, metadata.model_dump_json(exclude_none=True))
508
508
  )
509
509
  return bool(response)
510
510
 
@@ -518,11 +518,11 @@ class _RestClient:
518
518
  ) -> SearchTaskResponse:
519
519
  api_path = self.INITIAL_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id)
520
520
  files = {
521
- "metadata": ("metadata.json", metadata.json(exclude_none=True).encode(), "application/json"),
522
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
521
+ "metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
522
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
523
523
  }
524
524
  if search_customization is not None:
525
- files["customization"] = search_customization.json(exclude_none=True).encode()
525
+ files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
526
526
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
527
527
  response = self._with_unauth_retry(
528
528
  lambda: self._send_post_file_req_v2(
@@ -548,19 +548,19 @@ class _RestClient:
548
548
  content = file.read()
549
549
  md5_hash.update(content)
550
550
  digest = md5_hash.hexdigest()
551
- metadata_with_md5 = metadata.copy(update={"checksumMD5": digest})
551
+ metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
552
552
 
553
553
  # digest_sha256 = hashlib.sha256(
554
554
  # pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
555
555
  # ).hexdigest()
556
556
  digest_sha256 = self.compute_file_digest(file_path)
557
- metadata_with_md5 = metadata_with_md5.copy(update={"digest": digest_sha256})
557
+ metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
558
558
 
559
559
  with open(file_path, "rb") as file:
560
560
  files = {
561
561
  "metadata": (
562
562
  "metadata.json",
563
- metadata_with_md5.json(exclude_none=True).encode(),
563
+ metadata_with_md5.model_dump_json(exclude_none=True).encode(),
564
564
  "application/json",
565
565
  ),
566
566
  "tracking": (
@@ -570,7 +570,7 @@ class _RestClient:
570
570
  ),
571
571
  "metrics": (
572
572
  "metrics.json",
573
- metrics.json(exclude_none=True).encode(),
573
+ metrics.model_dump_json(exclude_none=True).encode(),
574
574
  "application/json",
575
575
  ),
576
576
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
@@ -578,7 +578,7 @@ class _RestClient:
578
578
  if search_customization is not None:
579
579
  files["customization"] = (
580
580
  "customization.json",
581
- search_customization.json(exclude_none=True).encode(),
581
+ search_customization.model_dump_json(exclude_none=True).encode(),
582
582
  "application/json",
583
583
  )
584
584
 
@@ -602,11 +602,11 @@ class _RestClient:
602
602
  ) -> SearchTaskResponse:
603
603
  api_path = self.VALIDATION_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id, initial_search_task_id)
604
604
  files = {
605
- "metadata": ("metadata.json", metadata.json(exclude_none=True).encode(), "application/json"),
606
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
605
+ "metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
606
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
607
607
  }
608
608
  if search_customization is not None:
609
- files["customization"] = search_customization.json(exclude_none=True).encode()
609
+ files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
610
610
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
611
611
  response = self._with_unauth_retry(
612
612
  lambda: self._send_post_file_req_v2(
@@ -670,7 +670,7 @@ class _RestClient:
670
670
  "file": (metadata.name, file, "application/octet-stream"),
671
671
  "metadata": (
672
672
  "metadata.json",
673
- metadata.json(exclude_none=True).encode(),
673
+ metadata.model_dump_json(exclude_none=True).encode(),
674
674
  "application/json",
675
675
  ),
676
676
  }
@@ -682,12 +682,12 @@ class _RestClient:
682
682
  def get_search_file_metadata(self, search_task_id: str, trace_id: str) -> FileMetadata:
683
683
  api_path = self.SEARCH_FILE_METADATA_URI_FMT_V2.format(search_task_id)
684
684
  response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
685
- return FileMetadata.parse_obj(response)
685
+ return FileMetadata.model_validate(response)
686
686
 
687
687
  def get_provider_search_metadata_v3(self, provider_search_task_id: str, trace_id: str) -> ProviderTaskMetadataV2:
688
688
  api_path = self.SEARCH_TASK_METADATA_FMT_V3.format(provider_search_task_id)
689
689
  response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
690
- return ProviderTaskMetadataV2.parse_obj(response)
690
+ return ProviderTaskMetadataV2.model_validate(response)
691
691
 
692
692
  def get_current_transform_usage(self, trace_id) -> TransformUsage:
693
693
  track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
upgini/mdc/__init__.py CHANGED
@@ -5,7 +5,7 @@
5
5
 
6
6
  import logging
7
7
 
8
- from pythonjsonlogger import jsonlogger
8
+ from pythonjsonlogger import json as jsonlogger
9
9
 
10
10
  from upgini.mdc.context import get_mdc_fields, new_log_context
11
11
 
upgini/metrics.py CHANGED
@@ -11,13 +11,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
11
  import lightgbm as lgb
12
12
  import numpy as np
13
13
  import pandas as pd
14
+ from catboost import CatBoostClassifier, CatBoostRegressor
14
15
  from category_encoders.cat_boost import CatBoostEncoder
15
16
  from lightgbm import LGBMClassifier, LGBMRegressor
16
17
  from numpy import log1p
17
18
  from pandas.api.types import is_numeric_dtype
18
19
  from sklearn.metrics import check_scoring, get_scorer, make_scorer, roc_auc_score
19
20
 
20
- from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
21
+ # from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
21
22
  from upgini.utils.features_validator import FeaturesValidator
22
23
  from upgini.utils.sklearn_ext import cross_validate
23
24
 
@@ -31,7 +32,7 @@ except ImportError:
31
32
  available_scorers = SCORERS
32
33
  from sklearn.metrics import mean_squared_error
33
34
  from sklearn.metrics._regression import _check_reg_targets, check_consistent_length
34
- from sklearn.model_selection import BaseCrossValidator, TimeSeriesSplit
35
+ from sklearn.model_selection import BaseCrossValidator # , TimeSeriesSplit
35
36
 
36
37
  from upgini.errors import ValidationError
37
38
  from upgini.metadata import ModelTaskType
@@ -328,10 +329,14 @@ class EstimatorWrapper:
328
329
  ) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
329
330
  self.logger.info(f"Before preparing data columns: {x.columns.to_list()}")
330
331
  for c in x.columns:
331
- if is_numeric_dtype(x[c]):
332
- x[c] = x[c].astype(float)
333
- elif not x[c].dtype == "category":
334
- x[c] = x[c].astype(str)
332
+ if c not in self.cat_features:
333
+ if is_numeric_dtype(x[c]):
334
+ x[c] = x[c].astype(float)
335
+ elif not x[c].dtype == "category":
336
+ x[c] = x[c].astype(str)
337
+ else:
338
+ if x[c].dtype == "category" and x[c].cat.categories.dtype == np.int64:
339
+ x[c] = x[c].astype(np.int64)
335
340
 
336
341
  if not isinstance(y, pd.Series):
337
342
  raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
@@ -411,7 +416,6 @@ class EstimatorWrapper:
411
416
  shaps = self.calculate_shap(cv_x, cv_y, estimator)
412
417
  if shaps is not None:
413
418
  for feature, shap_value in shaps.items():
414
- # shap_values_all_folds[feature] = shap_values_all_folds.get(feature, []) + shap_value.tolist()
415
419
  shap_values_all_folds[feature].append(shap_value)
416
420
 
417
421
  if shap_values_all_folds:
@@ -488,20 +492,29 @@ class EstimatorWrapper:
488
492
  "logger": logger,
489
493
  }
490
494
  if estimator is None:
491
- params = {"random_state": DEFAULT_RANDOM_STATE, "verbose": -1}
495
+ params = {"has_time": has_date}
492
496
  if target_type == ModelTaskType.MULTICLASS:
493
- params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
497
+ params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
494
498
  params = _get_add_params(params, add_params)
495
- estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
499
+ estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
500
+ # params = _get_add_params(params, LIGHTGBM_MULTICLASS_PARAMS)
501
+ # params = _get_add_params(params, add_params)
502
+ # estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
496
503
  elif target_type == ModelTaskType.BINARY:
497
- params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
504
+ params = _get_add_params(params, CATBOOST_BINARY_PARAMS)
498
505
  params = _get_add_params(params, add_params)
499
- estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
506
+ estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
507
+ # params = _get_add_params(params, LIGHTGBM_BINARY_PARAMS)
508
+ # params = _get_add_params(params, add_params)
509
+ # estimator = LightGBMWrapper(LGBMClassifier(**params), **kwargs)
500
510
  elif target_type == ModelTaskType.REGRESSION:
501
- if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
502
- params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
511
+ params = _get_add_params(params, CATBOOST_REGRESSION_PARAMS)
503
512
  params = _get_add_params(params, add_params)
504
- estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
513
+ estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
514
+ # if not isinstance(cv, TimeSeriesSplit) and not isinstance(cv, BlockedTimeSeriesSplit):
515
+ # params = _get_add_params(params, LIGHTGBM_REGRESSION_PARAMS)
516
+ # params = _get_add_params(params, add_params)
517
+ # estimator = LightGBMWrapper(LGBMRegressor(**params), **kwargs)
505
518
  else:
506
519
  raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
507
520
  else:
@@ -517,8 +530,6 @@ class EstimatorWrapper:
517
530
  else:
518
531
  if isinstance(estimator, (LGBMClassifier, LGBMRegressor)):
519
532
  estimator = LightGBMWrapper(**kwargs)
520
- elif is_catboost_estimator(estimator):
521
- estimator = CatBoostWrapper(**kwargs)
522
533
  else:
523
534
  logger.warning(
524
535
  f"Unexpected estimator is used for metrics: {estimator}. "
@@ -558,6 +569,7 @@ class CatBoostWrapper(EstimatorWrapper):
558
569
  self.emb_features = None
559
570
  self.grouped_embedding_features = None
560
571
  self.drop_cat_features = []
572
+ self.features_to_encode = []
561
573
 
562
574
  def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
563
575
  x, y, groups, params = super()._prepare_to_fit(x, y)
@@ -597,7 +609,13 @@ class CatBoostWrapper(EstimatorWrapper):
597
609
  self.cat_features, self.features_to_encode, self.exclude_features = _get_cat_features(
598
610
  self.logger, x, self.cat_features, self.text_features, self.grouped_embedding_features
599
611
  )
600
- params["cat_features"] = self.cat_features
612
+ if self.features_to_encode:
613
+ for c in self.features_to_encode:
614
+ if is_numeric_dtype(x[c]):
615
+ x[c] = x[c].fillna(np.nan)
616
+ else:
617
+ x[c] = x[c].fillna("NA")
618
+ params["cat_features"] = self.features_to_encode
601
619
 
602
620
  return x, y, groups, params
603
621
 
@@ -626,8 +644,14 @@ class CatBoostWrapper(EstimatorWrapper):
626
644
  if self.grouped_embedding_features:
627
645
  x, emb_columns = self.group_embeddings(x)
628
646
  params["embedding_features"] = emb_columns
629
- if self.cat_features:
630
- params["cat_features"] = self.cat_features
647
+
648
+ if self.features_to_encode:
649
+ for c in self.features_to_encode:
650
+ if is_numeric_dtype(x[c]):
651
+ x[c] = x[c].fillna(np.nan)
652
+ else:
653
+ x[c] = x[c].fillna("NA")
654
+ params["cat_features"] = self.features_to_encode
631
655
 
632
656
  return x, y, params
633
657
 
@@ -671,23 +695,29 @@ class CatBoostWrapper(EstimatorWrapper):
671
695
  embedding_features=self.grouped_embedding_features,
672
696
  )
673
697
 
674
- # Get SHAP values of current estimator
675
- shap_values_fold = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
698
+ shap_values = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
676
699
 
677
- # Remove last columns (base value) and flatten
678
700
  if self.target_type == ModelTaskType.MULTICLASS:
679
- all_shaps = shap_values_fold[:, :, :-1]
680
- all_shaps = [all_shaps[:, :, k].flatten() for k in range(all_shaps.shape[2])]
701
+ # For multiclass, shap_values has shape (n_samples, n_classes, n_features + 1)
702
+ # Last column is bias term
703
+ shap_values = shap_values[:, :, :-1] # Remove bias term
704
+ # Average SHAP values across classes
705
+ shap_values = np.mean(np.abs(shap_values), axis=1)
681
706
  else:
682
- all_shaps = shap_values_fold[:, :-1]
683
- all_shaps = [all_shaps[:, k].flatten() for k in range(all_shaps.shape[1])]
707
+ # For binary/regression, shap_values has shape (n_samples, n_features + 1)
708
+ # Last column is bias term
709
+ shap_values = shap_values[:, :-1] # Remove bias term
710
+ # Take absolute values
711
+ shap_values = np.abs(shap_values)
684
712
 
685
- all_shaps = np.abs(all_shaps)
713
+ feature_importance = {}
714
+ for i, col in enumerate(x.columns):
715
+ feature_importance[col] = np.mean(np.abs(shap_values[:, i]))
686
716
 
687
- return dict(zip(estimator.feature_names_, all_shaps))
717
+ return feature_importance
688
718
 
689
- except Exception:
690
- self.logger.exception("Failed to recalculate new SHAP values")
719
+ except Exception as e:
720
+ self.logger.exception(f"Failed to recalculate new SHAP values: {str(e)}")
691
721
  return None
692
722
 
693
723
 
@@ -830,9 +860,9 @@ class OtherEstimatorWrapper(EstimatorWrapper):
830
860
  num_features = [col for col in x.columns if col not in self.cat_features]
831
861
  x[num_features] = x[num_features].fillna(-999)
832
862
  if self.features_to_encode and self.cat_encoder is not None:
833
- x[self.features_to_encode] = self.cat_encoder.transform(x[self.features_to_encode].astype("object")).astype(
834
- "category"
835
- )
863
+ x[self.features_to_encode] = self.cat_encoder.transform(
864
+ x[self.features_to_encode].astype("object")
865
+ ).astype("category")
836
866
  return x, y, params
837
867
 
838
868
 
@@ -945,7 +975,7 @@ def _get_cat_features(
945
975
 
946
976
  logger.info(f"Selected categorical features: {cat_features}")
947
977
 
948
- features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype]).columns))
978
+ features_to_encode = list(set(x.select_dtypes(exclude=[np.number, np.datetime64, pd.CategoricalDtype()]).columns))
949
979
 
950
980
  logger.info(f"Features to encode: {features_to_encode}")
951
981
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.81a3832.dev1
3
+ Version: 1.2.81a3832.dev2
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -22,6 +22,7 @@ Classifier: Programming Language :: Python :: 3.11
22
22
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
23
  Classifier: Topic :: Scientific/Engineering :: Information Analysis
24
24
  Requires-Python: <3.12,>=3.10
25
+ Requires-Dist: catboost>=1.2.8
25
26
  Requires-Dist: category-encoders>=2.8.1
26
27
  Requires-Dist: fastparquet>=0.8.1
27
28
  Requires-Dist: ipywidgets>=8.1.0
@@ -1,12 +1,12 @@
1
- upgini/__about__.py,sha256=-WSXUS5Ith33qArTnDO4LmrI0wUaXbJ8bIzoMZvAsWU,33
1
+ upgini/__about__.py,sha256=7ytM9g8DI6H-u5aMwPu2Qxa34E_K8afMwp4RaWapTSw,33
2
2
  upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
3
3
  upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
4
4
  upgini/dataset.py,sha256=aspri7ZAgwkNNUiIgQ1GRXvw8XQii3F4RfNXSrF4wrw,35365
5
5
  upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
6
- upgini/features_enricher.py,sha256=qtrQJwF2QbKdQ8Tqk5RQj3aAqOzDgygD6nIHrco3AzE,209728
7
- upgini/http.py,sha256=UH7nswcZ221un3O_VW9limCBO5oRsyg1eKUHiVslRPs,43737
6
+ upgini/features_enricher.py,sha256=WiSVfmlHI9oKJQbyf46FH0yY80hBJ6hheFpugw0f_vE,210583
7
+ upgini/http.py,sha256=AfaJ3c8z_tK2hZFEehNybDKE0mp1tYcyAP_l0_p8bLQ,43933
8
8
  upgini/metadata.py,sha256=Yd6iW2f7Wz6vUkg5uvR4xylN16ANnCKVKqAsAkap7p8,12354
9
- upgini/metrics.py,sha256=95sK1Kr3dYxqQcdkkoNFDe9OZY7OhgLjYwe3bhMQd38,38087
9
+ upgini/metrics.py,sha256=KxtcjiClNDNlMWpoCbAvVPveC59Nz7z2lA4b-hQozRE,39608
10
10
  upgini/search_task.py,sha256=RcvAE785yksWTsTNWuZFVNlk32jHElMoEna1T_C5N8Q,17823
11
11
  upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
12
12
  upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
@@ -32,7 +32,7 @@ upgini/autofe/timeseries/trend.py,sha256=K1_iw2ko_LIUU8YCUgrvN3n0MkHtsi7-63-8x9e
32
32
  upgini/autofe/timeseries/volatility.py,sha256=9shUmIKjpWTHVYjj80YBsk0XheBJ9uBuLv5NW9Mchnk,7953
33
33
  upgini/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  upgini/data_source/data_source_publisher.py,sha256=4S9qwlAklD8vg9tUU_c1pHE2_glUHAh15-wr5hMwKFw,22879
35
- upgini/mdc/__init__.py,sha256=aM08nIWFc2gWdWUa3_IuEnNND0cQPkBGnYpRMnfFN8k,1019
35
+ upgini/mdc/__init__.py,sha256=iHJlXQg6xRM1-ZOUtaPSJqw5SpQDszvxp4LyqviNLIQ,1027
36
36
  upgini/mdc/context.py,sha256=3u1B-jXt7tXEvNcV3qmR9SDCseudnY7KYsLclBdwVLk,1405
37
37
  upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  upgini/normalizer/normalize_utils.py,sha256=Ft2MwSgVoBilXAORAOYAuwPD79GOLfwn4qQE3IUFzzg,7218
@@ -70,7 +70,7 @@ upgini/utils/target_utils.py,sha256=LRN840dzx78-wg7ftdxAkp2c1eu8-JDvkACiRThm4HE,
70
70
  upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
71
71
  upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
72
72
  upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
73
- upgini-1.2.81a3832.dev1.dist-info/METADATA,sha256=ShIRi8EeeujsKBJ0byR2XWJ6DKFka2vrViq9d5VwjzU,49141
74
- upgini-1.2.81a3832.dev1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
- upgini-1.2.81a3832.dev1.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
- upgini-1.2.81a3832.dev1.dist-info/RECORD,,
73
+ upgini-1.2.81a3832.dev2.dist-info/METADATA,sha256=Kdxh014FUNln4eeF-RflHu3c_pfvPXpsoXfvb6SBneE,49172
74
+ upgini-1.2.81a3832.dev2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
75
+ upgini-1.2.81a3832.dev2.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
76
+ upgini-1.2.81a3832.dev2.dist-info/RECORD,,