upgini 1.2.14a3616.dev3__tar.gz → 1.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

Files changed (65) hide show
  1. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/PKG-INFO +2 -2
  2. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/README.md +1 -1
  3. upgini-1.2.16/src/upgini/__about__.py +1 -0
  4. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/dataset.py +6 -3
  5. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/features_enricher.py +111 -55
  6. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/metrics.py +66 -9
  7. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/normalizer/normalize_utils.py +22 -15
  8. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/resource_bundle/strings.properties +8 -1
  9. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/display_utils.py +8 -2
  10. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/target_utils.py +96 -46
  11. upgini-1.2.14a3616.dev3/src/upgini/__about__.py +0 -1
  12. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/.gitignore +0 -0
  13. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/LICENSE +0 -0
  14. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/pyproject.toml +0 -0
  15. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/__init__.py +0 -0
  16. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/ads.py +0 -0
  17. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/ads_management/__init__.py +0 -0
  18. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/ads_management/ads_manager.py +0 -0
  19. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/__init__.py +0 -0
  20. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/all_operands.py +0 -0
  21. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/binary.py +0 -0
  22. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/date.py +0 -0
  23. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/feature.py +0 -0
  24. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/groupby.py +0 -0
  25. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/operand.py +0 -0
  26. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/unary.py +0 -0
  27. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/autofe/vector.py +0 -0
  28. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/data_source/__init__.py +0 -0
  29. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/data_source/data_source_publisher.py +0 -0
  30. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/errors.py +0 -0
  31. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/http.py +0 -0
  32. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/lazy_import.py +0 -0
  33. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/mdc/__init__.py +0 -0
  34. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/mdc/context.py +0 -0
  35. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/metadata.py +0 -0
  36. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/normalizer/__init__.py +0 -0
  37. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/resource_bundle/__init__.py +0 -0
  38. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/resource_bundle/exceptions.py +0 -0
  39. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  40. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/sampler/__init__.py +0 -0
  41. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/sampler/base.py +0 -0
  42. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/sampler/random_under_sampler.py +0 -0
  43. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/sampler/utils.py +0 -0
  44. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/search_task.py +0 -0
  45. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/spinner.py +0 -0
  46. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/__init__.py +0 -0
  47. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/base_search_key_detector.py +0 -0
  48. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/blocked_time_series.py +0 -0
  49. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/country_utils.py +0 -0
  50. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/custom_loss_utils.py +0 -0
  51. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/cv_utils.py +0 -0
  52. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/datetime_utils.py +0 -0
  53. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/deduplicate_utils.py +0 -0
  54. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/email_utils.py +0 -0
  55. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/fallback_progress_bar.py +0 -0
  56. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/features_validator.py +0 -0
  57. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/format.py +0 -0
  58. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/ip_utils.py +0 -0
  59. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/phone_utils.py +0 -0
  60. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/postal_code_utils.py +0 -0
  61. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/progress_bar.py +0 -0
  62. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/sklearn_ext.py +0 -0
  63. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/track_info.py +0 -0
  64. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/utils/warning_counter.py +0 -0
  65. {upgini-1.2.14a3616.dev3 → upgini-1.2.16}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.14a3616.dev3
3
+ Version: 1.2.16
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -145,7 +145,7 @@ Description-Content-Type: text/markdown
145
145
 
146
146
  ## 💼 Tutorials
147
147
 
148
- ### [Search of relevant external features & Automated feature generation for Salary predicton task (use as a template)](https://github.com/upgini/upgini/blob/main/notebooks/Upgini_Features_search%26generation.ipynb)
148
+ ### [Search of relevant external features & Automated feature generation for Salary prediction task (use as a template)](https://github.com/upgini/upgini/blob/main/notebooks/Upgini_Features_search%26generation.ipynb)
149
149
 
150
150
  * The goal is to predict salary for data science job postning based on information about employer and job description.
151
151
  * Following this guide, you'll learn how to **search & auto generate new relevant features with Upgini library**
@@ -103,7 +103,7 @@
103
103
 
104
104
  ## 💼 Tutorials
105
105
 
106
- ### [Search of relevant external features & Automated feature generation for Salary predicton task (use as a template)](https://github.com/upgini/upgini/blob/main/notebooks/Upgini_Features_search%26generation.ipynb)
106
+ ### [Search of relevant external features & Automated feature generation for Salary prediction task (use as a template)](https://github.com/upgini/upgini/blob/main/notebooks/Upgini_Features_search%26generation.ipynb)
107
107
 
108
108
  * The goal is to predict salary for data science job postning based on information about employer and job description.
109
109
  * Following this guide, you'll learn how to **search & auto generate new relevant features with Upgini library**
@@ -0,0 +1 @@
1
+ __version__ = "1.2.16"
@@ -53,7 +53,8 @@ class Dataset: # (pd.DataFrame):
53
53
  FIT_SAMPLE_THRESHOLD = 200_000
54
54
  FIT_SAMPLE_WITH_EVAL_SET_ROWS = 200_000
55
55
  FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD = 200_000
56
- MIN_SAMPLE_THRESHOLD = 5_000
56
+ BINARY_MIN_SAMPLE_THRESHOLD = 5_000
57
+ MULTICLASS_MIN_SAMPLE_THRESHOLD = 25_000
57
58
  IMBALANCE_THESHOLD = 0.6
58
59
  BINARY_BOOTSTRAP_LOOPS = 5
59
60
  MULTICLASS_BOOTSTRAP_LOOPS = 2
@@ -225,7 +226,7 @@ class Dataset: # (pd.DataFrame):
225
226
  train_segment = self.data
226
227
 
227
228
  if self.task_type == ModelTaskType.MULTICLASS or (
228
- self.task_type == ModelTaskType.BINARY and len(train_segment) > self.MIN_SAMPLE_THRESHOLD
229
+ self.task_type == ModelTaskType.BINARY and len(train_segment) > self.BINARY_MIN_SAMPLE_THRESHOLD
229
230
  ):
230
231
  count = len(train_segment)
231
232
  target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
@@ -253,6 +254,7 @@ class Dataset: # (pd.DataFrame):
253
254
  min_class_percent = self.IMBALANCE_THESHOLD / target_classes_count
254
255
  min_class_threshold = min_class_percent * count
255
256
 
257
+ # If min class count less than 30% for binary or (60 / classes_count)% for multiclass
256
258
  if min_class_count < min_class_threshold:
257
259
  self.imbalanced = True
258
260
  self.data = balance_undersample(
@@ -260,7 +262,8 @@ class Dataset: # (pd.DataFrame):
260
262
  target_column=target_column,
261
263
  task_type=self.task_type,
262
264
  random_state=self.random_state,
263
- imbalance_threshold=self.IMBALANCE_THESHOLD,
265
+ binary_min_sample_threshold=self.BINARY_MIN_SAMPLE_THRESHOLD,
266
+ multiclass_min_sample_threshold=self.MULTICLASS_MIN_SAMPLE_THRESHOLD,
264
267
  binary_bootstrap_loops=self.BINARY_BOOTSTRAP_LOOPS,
265
268
  multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
266
269
  logger=self.logger,
@@ -336,6 +336,7 @@ class FeaturesEnricher(TransformerMixin):
336
336
  self.exclude_columns = exclude_columns
337
337
  self.baseline_score_column = baseline_score_column
338
338
  self.add_date_if_missing = add_date_if_missing
339
+ self.features_info_display_handle = None
339
340
 
340
341
  def _get_api_key(self):
341
342
  return self._api_key
@@ -871,6 +872,13 @@ class FeaturesEnricher(TransformerMixin):
871
872
  else None
872
873
  )
873
874
 
875
+ if self.X is None:
876
+ self.X = X
877
+ if self.y is None:
878
+ self.y = y
879
+ if self.eval_set is None:
880
+ self.eval_set = effective_eval_set
881
+
874
882
  try:
875
883
  self.__log_debug_information(
876
884
  validated_X,
@@ -938,14 +946,14 @@ class FeaturesEnricher(TransformerMixin):
938
946
 
939
947
  gc.collect()
940
948
 
949
+ if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
950
+ print(self.bundle.get("metrics_no_important_free_features"))
951
+ self.logger.warning("No client or free relevant ADS features found to calculate metrics")
952
+ self.warning_counter.increment()
953
+ return None
954
+
941
955
  print(self.bundle.get("metrics_start"))
942
956
  with Spinner():
943
- if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
944
- print(self.bundle.get("metrics_no_important_free_features"))
945
- self.logger.warning("No client or free relevant ADS features found to calculate metrics")
946
- self.warning_counter.increment()
947
- return None
948
-
949
957
  self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
950
958
 
951
959
  has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
@@ -989,7 +997,7 @@ class FeaturesEnricher(TransformerMixin):
989
997
  text_features=self.generate_features,
990
998
  has_date=has_date,
991
999
  )
992
- etalon_metric = baseline_estimator.cross_val_predict(
1000
+ etalon_metric, _ = baseline_estimator.cross_val_predict(
993
1001
  fitting_X, y_sorted, self.baseline_score_column
994
1002
  )
995
1003
  if etalon_metric is None:
@@ -1023,7 +1031,13 @@ class FeaturesEnricher(TransformerMixin):
1023
1031
  text_features=self.generate_features,
1024
1032
  has_date=has_date,
1025
1033
  )
1026
- enriched_metric = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
1034
+ enriched_metric, enriched_shaps = enriched_estimator.cross_val_predict(
1035
+ fitting_enriched_X, enriched_y_sorted
1036
+ )
1037
+
1038
+ if enriched_shaps is not None:
1039
+ self._update_shap_values(enriched_shaps)
1040
+
1027
1041
  if enriched_metric is None:
1028
1042
  self.logger.warning(
1029
1043
  f"Enriched {metric} on train combined features is None (maybe all features was removed)"
@@ -1156,13 +1170,6 @@ class FeaturesEnricher(TransformerMixin):
1156
1170
  elif uplift_col in metrics_df.columns and (metrics_df[uplift_col] < 0).any():
1157
1171
  self.logger.warning("Uplift is negative")
1158
1172
 
1159
- if self.X is None:
1160
- self.X = X
1161
- if self.y is None:
1162
- self.y = y
1163
- if self.eval_set is None:
1164
- self.eval_set = effective_eval_set
1165
-
1166
1173
  return metrics_df
1167
1174
  except Exception as e:
1168
1175
  error_message = "Failed to calculate metrics" + (
@@ -1187,6 +1194,48 @@ class FeaturesEnricher(TransformerMixin):
1187
1194
  finally:
1188
1195
  self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
1189
1196
 
1197
+ def _update_shap_values(self, new_shaps: Dict[str, float]):
1198
+ new_shaps = {
1199
+ feature: self._round_shap_value(shap)
1200
+ for feature, shap in new_shaps.items()
1201
+ if feature in self.feature_names_
1202
+ }
1203
+ features_importances = list(new_shaps.items())
1204
+ features_importances.sort(key=lambda m: (-m[1], m[0]))
1205
+ self.feature_names_, self.feature_importances_ = zip(*features_importances)
1206
+ self.feature_names_ = list(self.feature_names_)
1207
+ self.feature_importances_ = list(self.feature_importances_)
1208
+
1209
+ feature_name_header = self.bundle.get("features_info_name")
1210
+ shap_value_header = self.bundle.get("features_info_shap")
1211
+
1212
+ def update_shap(row):
1213
+ return new_shaps.get(row[feature_name_header], row[shap_value_header])
1214
+
1215
+ self.features_info[shap_value_header] = self.features_info.apply(update_shap, axis=1)
1216
+ self._internal_features_info[shap_value_header] = self._internal_features_info.apply(update_shap, axis=1)
1217
+ self._features_info_without_links[shap_value_header] = self._features_info_without_links.apply(
1218
+ update_shap, axis=1
1219
+ )
1220
+ self.logger.info(f"Recalculated SHAP values:\n{self._features_info_without_links}")
1221
+
1222
+ self.features_info.sort_values(by=shap_value_header, ascending=False, inplace=True)
1223
+ self._internal_features_info.sort_values(by=shap_value_header, ascending=False, inplace=True)
1224
+ self._features_info_without_links.sort_values(by=shap_value_header, ascending=False, inplace=True)
1225
+
1226
+ if self.features_info_display_handle:
1227
+ try:
1228
+ _ = get_ipython() # type: ignore
1229
+
1230
+ display_html_dataframe(
1231
+ self.features_info,
1232
+ self._features_info_without_links,
1233
+ self.bundle.get("relevant_features_header"),
1234
+ display_handle=self.features_info_display_handle,
1235
+ )
1236
+ except (ImportError, NameError):
1237
+ print(self._internal_features_info)
1238
+
1190
1239
  def _check_train_and_eval_target_distribution(self, y, eval_set_dict):
1191
1240
  uneven_distribution = False
1192
1241
  for eval_set in eval_set_dict.values():
@@ -1515,11 +1564,19 @@ class FeaturesEnricher(TransformerMixin):
1515
1564
  self.logger.info("No external features selected. So use only input datasets for metrics calculation")
1516
1565
  return self.__sample_only_input(validated_X, validated_y, eval_set, is_demo_dataset)
1517
1566
  # TODO save and check if dataset was deduplicated - use imbalance branch for such case
1518
- elif not self.imbalanced and not exclude_features_sources and is_input_same_as_fit:
1567
+ elif (
1568
+ not self.imbalanced
1569
+ and not exclude_features_sources
1570
+ and is_input_same_as_fit
1571
+ and self.df_with_original_index is not None
1572
+ ):
1519
1573
  self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
1520
1574
  return self.__sample_balanced(eval_set, trace_id, remove_outliers_calc_metrics)
1521
1575
  else:
1522
- self.logger.info("Dataset is imbalanced or exclude_features_sources or X was passed. Run transform")
1576
+ self.logger.info(
1577
+ "Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
1578
+ " Run transform"
1579
+ )
1523
1580
  print(self.bundle.get("prepare_data_for_metrics"))
1524
1581
  return self.__sample_imbalanced(
1525
1582
  validated_X,
@@ -1577,8 +1634,8 @@ class FeaturesEnricher(TransformerMixin):
1577
1634
  df = generator.generate(df)
1578
1635
  generated_features.extend(generator.generated_features)
1579
1636
 
1580
- normalizer = Normalizer(search_keys, generated_features, self.bundle, self.logger, self.warning_counter)
1581
- df = normalizer.normalize(df)
1637
+ normalizer = Normalizer(self.bundle, self.logger, self.warning_counter)
1638
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
1582
1639
  columns_renaming = normalizer.columns_renaming
1583
1640
 
1584
1641
  df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
@@ -2017,10 +2074,8 @@ class FeaturesEnricher(TransformerMixin):
2017
2074
  df = generator.generate(df)
2018
2075
  generated_features.extend(generator.generated_features)
2019
2076
 
2020
- normalizer = Normalizer(
2021
- search_keys, generated_features, self.bundle, self.logger, self.warning_counter, silent_mode
2022
- )
2023
- df = normalizer.normalize(df)
2077
+ normalizer = Normalizer(self.bundle, self.logger, self.warning_counter, silent_mode)
2078
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2024
2079
  columns_renaming = normalizer.columns_renaming
2025
2080
 
2026
2081
  # Don't pass all features in backend on transform
@@ -2449,16 +2504,13 @@ class FeaturesEnricher(TransformerMixin):
2449
2504
  if is_numeric_dtype(df[self.TARGET_NAME]) and has_date:
2450
2505
  self._validate_PSI(df.sort_values(by=maybe_date_column))
2451
2506
 
2452
- normalizer = Normalizer(
2453
- self.fit_search_keys, self.fit_generated_features, self.bundle, self.logger, self.warning_counter
2507
+ normalizer = Normalizer(self.bundle, self.logger, self.warning_counter)
2508
+ df, self.fit_search_keys, self.fit_generated_features = normalizer.normalize(
2509
+ df, self.fit_search_keys, self.fit_generated_features
2454
2510
  )
2455
- df = normalizer.normalize(df)
2456
- columns_renaming = normalizer.columns_renaming
2457
- self.fit_columns_renaming = columns_renaming
2511
+ self.fit_columns_renaming = normalizer.columns_renaming
2458
2512
 
2459
- self.__adjust_cv(
2460
- df, normalizer.search_keys, self.model_task_type
2461
- )
2513
+ self.__adjust_cv(df)
2462
2514
 
2463
2515
  df = remove_fintech_duplicates(
2464
2516
  df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
@@ -2472,7 +2524,7 @@ class FeaturesEnricher(TransformerMixin):
2472
2524
  self.df_with_original_index = df.copy()
2473
2525
  # TODO check maybe need to drop _time column from df_with_original_index
2474
2526
 
2475
- df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, columns_renaming)
2527
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, self.fit_columns_renaming)
2476
2528
 
2477
2529
  # Convert EMAIL to HEM after unnesting to do it only with one column
2478
2530
  email_column = self._get_email_column(self.fit_search_keys)
@@ -2482,7 +2534,7 @@ class FeaturesEnricher(TransformerMixin):
2482
2534
  email_column,
2483
2535
  hem_column,
2484
2536
  self.fit_search_keys,
2485
- columns_renaming,
2537
+ self.fit_columns_renaming,
2486
2538
  list(unnest_search_keys.keys()),
2487
2539
  self.logger,
2488
2540
  )
@@ -2493,7 +2545,7 @@ class FeaturesEnricher(TransformerMixin):
2493
2545
  converter = IpSearchKeyConverter(
2494
2546
  ip_column,
2495
2547
  self.fit_search_keys,
2496
- columns_renaming,
2548
+ self.fit_columns_renaming,
2497
2549
  list(unnest_search_keys.keys()),
2498
2550
  self.bundle,
2499
2551
  self.logger,
@@ -2524,7 +2576,7 @@ class FeaturesEnricher(TransformerMixin):
2524
2576
  features_columns = [c for c in df.columns if c not in non_feature_columns]
2525
2577
 
2526
2578
  features_to_drop = FeaturesValidator(self.logger).validate(
2527
- df, features_columns, self.generate_features, self.warning_counter, columns_renaming
2579
+ df, features_columns, self.generate_features, self.warning_counter, self.fit_columns_renaming
2528
2580
  )
2529
2581
  self.fit_dropped_features.update(features_to_drop)
2530
2582
  df = df.drop(columns=features_to_drop)
@@ -2565,7 +2617,7 @@ class FeaturesEnricher(TransformerMixin):
2565
2617
  rest_client=self.rest_client,
2566
2618
  logger=self.logger,
2567
2619
  )
2568
- dataset.columns_renaming = columns_renaming
2620
+ dataset.columns_renaming = self.fit_columns_renaming
2569
2621
 
2570
2622
  self.passed_features = [
2571
2623
  column for column, meaning_type in meaning_types.items() if meaning_type == FileColumnMeaningType.FEATURE
@@ -2712,22 +2764,22 @@ class FeaturesEnricher(TransformerMixin):
2712
2764
  if not self.warning_counter.has_warnings():
2713
2765
  self.__display_support_link(self.bundle.get("all_ok_community_invite"))
2714
2766
 
2715
- def __adjust_cv(self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], model_task_type: ModelTaskType):
2716
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2767
+ def __adjust_cv(self, df: pd.DataFrame):
2768
+ date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2717
2769
  # Check Multivariate time series
2718
2770
  if (
2719
2771
  self.cv is None
2720
2772
  and date_column
2721
- and model_task_type == ModelTaskType.REGRESSION
2722
- and len({SearchKey.PHONE, SearchKey.EMAIL, SearchKey.HEM}.intersection(search_keys.keys())) == 0
2723
- and is_blocked_time_series(df, date_column, list(search_keys.keys()) + [TARGET])
2773
+ and self.model_task_type == ModelTaskType.REGRESSION
2774
+ and len({SearchKey.PHONE, SearchKey.EMAIL, SearchKey.HEM}.intersection(self.fit_search_keys.keys())) == 0
2775
+ and is_blocked_time_series(df, date_column, list(self.fit_search_keys.keys()) + [TARGET])
2724
2776
  ):
2725
2777
  msg = self.bundle.get("multivariate_timeseries_detected")
2726
2778
  self.__override_cv(CVType.blocked_time_series, msg, print_warning=False)
2727
- elif self.cv is None and model_task_type != ModelTaskType.REGRESSION:
2779
+ elif self.cv is None and self.model_task_type != ModelTaskType.REGRESSION:
2728
2780
  msg = self.bundle.get("group_k_fold_in_classification")
2729
2781
  self.__override_cv(CVType.group_k_fold, msg, print_warning=self.cv is not None)
2730
- group_columns = self._get_group_columns(df, search_keys)
2782
+ group_columns = self._get_group_columns(df, self.fit_search_keys)
2731
2783
  self.runtime_parameters.properties["cv_params.group_columns"] = ",".join(group_columns)
2732
2784
  self.runtime_parameters.properties["cv_params.shuffle_kfold"] = "True"
2733
2785
 
@@ -3379,6 +3431,13 @@ class FeaturesEnricher(TransformerMixin):
3379
3431
 
3380
3432
  return result_train, result_eval_sets
3381
3433
 
3434
+ @staticmethod
3435
+ def _round_shap_value(shap: float) -> float:
3436
+ if shap > 0.0 and shap < 0.0001:
3437
+ return 0.0001
3438
+ else:
3439
+ return round(shap, 4)
3440
+
3382
3441
  def __prepare_feature_importances(self, trace_id: str, x_columns: List[str], silent=False):
3383
3442
  llm_source = "LLM with external data augmentation"
3384
3443
  if self._search_task is None:
@@ -3396,12 +3455,6 @@ class FeaturesEnricher(TransformerMixin):
3396
3455
  features_info_without_links = []
3397
3456
  internal_features_info = []
3398
3457
 
3399
- def round_shap_value(shap: float) -> float:
3400
- if shap > 0.0 and shap < 0.0001:
3401
- return 0.0001
3402
- else:
3403
- return round(shap, 4)
3404
-
3405
3458
  def list_or_single(lst: List[str], single: str):
3406
3459
  return lst or ([single] if single else [])
3407
3460
 
@@ -3434,7 +3487,7 @@ class FeaturesEnricher(TransformerMixin):
3434
3487
 
3435
3488
  feature_sample = []
3436
3489
  self.feature_names_.append(feature_meta.name)
3437
- self.feature_importances_.append(round_shap_value(feature_meta.shap_value))
3490
+ self.feature_importances_.append(self._round_shap_value(feature_meta.shap_value))
3438
3491
  if feature_meta.name in features_df.columns:
3439
3492
  feature_sample = np.random.choice(features_df[feature_meta.name].dropna().unique(), 3).tolist()
3440
3493
  if len(feature_sample) > 0 and isinstance(feature_sample[0], float):
@@ -3473,7 +3526,7 @@ class FeaturesEnricher(TransformerMixin):
3473
3526
  features_info.append(
3474
3527
  {
3475
3528
  self.bundle.get("features_info_name"): feature_name,
3476
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3529
+ self.bundle.get("features_info_shap"): self._round_shap_value(feature_meta.shap_value),
3477
3530
  self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3478
3531
  self.bundle.get("features_info_value_preview"): feature_sample,
3479
3532
  self.bundle.get("features_info_provider"): provider,
@@ -3484,7 +3537,7 @@ class FeaturesEnricher(TransformerMixin):
3484
3537
  features_info_without_links.append(
3485
3538
  {
3486
3539
  self.bundle.get("features_info_name"): internal_feature_name,
3487
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3540
+ self.bundle.get("features_info_shap"): self._round_shap_value(feature_meta.shap_value),
3488
3541
  self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3489
3542
  self.bundle.get("features_info_value_preview"): feature_sample,
3490
3543
  self.bundle.get("features_info_provider"): internal_provider,
@@ -3496,7 +3549,7 @@ class FeaturesEnricher(TransformerMixin):
3496
3549
  {
3497
3550
  self.bundle.get("features_info_name"): internal_feature_name,
3498
3551
  "feature_link": feature_meta.doc_link,
3499
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3552
+ self.bundle.get("features_info_shap"): self._round_shap_value(feature_meta.shap_value),
3500
3553
  self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3501
3554
  self.bundle.get("features_info_value_preview"): feature_sample,
3502
3555
  self.bundle.get("features_info_provider"): internal_provider,
@@ -3776,8 +3829,11 @@ class FeaturesEnricher(TransformerMixin):
3776
3829
  print(Format.GREEN + Format.BOLD + msg + Format.END)
3777
3830
  self.logger.info(msg)
3778
3831
  if len(self.feature_names_) > 0:
3779
- display_html_dataframe(
3780
- self.features_info, self._features_info_without_links, self.bundle.get("relevant_features_header")
3832
+ self.features_info_display_handle = display_html_dataframe(
3833
+ self.features_info,
3834
+ self._features_info_without_links,
3835
+ self.bundle.get("relevant_features_header"),
3836
+ display_id="features_info",
3781
3837
  )
3782
3838
 
3783
3839
  display_html_dataframe(
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import logging
5
5
  import re
6
+ from collections import defaultdict
6
7
  from copy import deepcopy
7
8
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
9
 
9
10
  import catboost
10
11
  import numpy as np
11
12
  import pandas as pd
12
- from catboost import CatBoostClassifier, CatBoostRegressor
13
+ from catboost import CatBoost, CatBoostClassifier, CatBoostRegressor, Pool
13
14
  from numpy import log1p
14
15
  from pandas.api.types import is_numeric_dtype
15
16
  from sklearn.metrics import check_scoring, get_scorer, make_scorer, roc_auc_score
@@ -288,9 +289,12 @@ class EstimatorWrapper:
288
289
  x, y, _ = self._prepare_data(x, y)
289
290
  return x, y, {}
290
291
 
292
+ def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator) -> Optional[Dict[str, float]]:
293
+ return None
294
+
291
295
  def cross_val_predict(
292
296
  self, x: pd.DataFrame, y: np.ndarray, baseline_score_column: Optional[Any] = None
293
- ) -> Optional[float]:
297
+ ) -> Tuple[Optional[float], Optional[Dict[str, float]]]:
294
298
  x, y, groups, fit_params = self._prepare_to_fit(x, y)
295
299
 
296
300
  if x.shape[1] == 0:
@@ -298,6 +302,7 @@ class EstimatorWrapper:
298
302
 
299
303
  scorer = check_scoring(self.estimator, scoring=self.scorer)
300
304
 
305
+ shap_values_all_folds = defaultdict(list)
301
306
  if baseline_score_column is not None and self.metric_name == "GINI":
302
307
  self.logger.info("Calculate baseline GINI on passed baseline_score_column and target")
303
308
  metric = roc_auc_score(y, x[baseline_score_column])
@@ -319,7 +324,29 @@ class EstimatorWrapper:
319
324
  self.check_fold_metrics(metrics_by_fold)
320
325
 
321
326
  metric = np.mean(metrics_by_fold) * self.multiplier
322
- return self.post_process_metric(metric)
327
+
328
+ splits = self.cv.split(x, y, groups)
329
+
330
+ for estimator, split in zip(self.cv_estimators, splits):
331
+ _, validation_idx = split
332
+ cv_x = x.iloc[validation_idx]
333
+ cv_y = y[validation_idx]
334
+ shaps = self.calculate_shap(cv_x, cv_y, estimator)
335
+ if shaps is not None:
336
+ for feature, shap_value in shaps.items():
337
+ # shap_values_all_folds[feature] = shap_values_all_folds.get(feature, []) + shap_value.tolist()
338
+ shap_values_all_folds[feature].extend(shap_value.tolist())
339
+
340
+ if shap_values_all_folds:
341
+ average_shap_values = {
342
+ feature: np.mean(np.array(shaps)) for feature, shaps in shap_values_all_folds.items() if len(shaps) > 0
343
+ }
344
+ if len(average_shap_values) == 0:
345
+ average_shap_values = None
346
+ else:
347
+ average_shap_values = None
348
+
349
+ return self.post_process_metric(metric), average_shap_values
323
350
 
324
351
  def check_fold_metrics(self, metrics_by_fold: List[float]):
325
352
  first_metric_sign = 1 if metrics_by_fold[0] >= 0 else -1
@@ -453,6 +480,7 @@ class CatBoostWrapper(EstimatorWrapper):
453
480
  )
454
481
  self.cat_features = None
455
482
  self.emb_features = None
483
+ self.grouped_embedding_features = None
456
484
  self.exclude_features = []
457
485
 
458
486
  def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
@@ -462,17 +490,16 @@ class CatBoostWrapper(EstimatorWrapper):
462
490
  if hasattr(CatBoostClassifier, "get_embedding_feature_indices"):
463
491
  emb_pattern = r"(.+)_emb\d+"
464
492
  self.emb_features = [c for c in x.columns if re.match(emb_pattern, c) and is_numeric_dtype(x[c])]
465
- embedding_features = []
466
493
  if len(self.emb_features) > 3: # There is no reason to reduce embeddings dimension with less than 4
467
494
  self.logger.info(
468
495
  "Embedding features count more than 3, so group them into one vector for CatBoost: "
469
496
  f"{self.emb_features}"
470
497
  )
471
- x, embedding_features = self.group_embeddings(x)
472
- params["embedding_features"] = embedding_features
498
+ x, self.grouped_embedding_features = self.group_embeddings(x)
499
+ params["embedding_features"] = self.grouped_embedding_features
473
500
  else:
474
501
  self.logger.info(f"Embedding features count less than 3, so use them separately: {self.emb_features}")
475
- self.emb_features = []
502
+ self.grouped_embedding_features = None
476
503
  else:
477
504
  self.logger.warning(f"Embedding features are not supported by Catboost version {catboost.__version__}")
478
505
 
@@ -488,7 +515,7 @@ class CatBoostWrapper(EstimatorWrapper):
488
515
  self.logger.warning(f"Text features are not supported by this Catboost version {catboost.__version__}")
489
516
 
490
517
  # Find rest categorical features
491
- self.cat_features = _get_cat_features(x, self.text_features, embedding_features)
518
+ self.cat_features = _get_cat_features(x, self.text_features, self.grouped_embedding_features)
492
519
  # x = fill_na_cat_features(x, self.cat_features)
493
520
  unique_cat_features = []
494
521
  for name in self.cat_features:
@@ -548,7 +575,7 @@ class CatBoostWrapper(EstimatorWrapper):
548
575
 
549
576
  def cross_val_predict(
550
577
  self, x: pd.DataFrame, y: np.ndarray, baseline_score_column: Optional[Any] = None
551
- ) -> Optional[float]:
578
+ ) -> Tuple[Optional[float], Optional[Dict[str, float]]]:
552
579
  try:
553
580
  return super().cross_val_predict(x, y, baseline_score_column)
554
581
  except Exception as e:
@@ -573,6 +600,36 @@ class CatBoostWrapper(EstimatorWrapper):
573
600
  else:
574
601
  raise e
575
602
 
603
+ def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator: CatBoost) -> Optional[Dict[str, float]]:
604
+ try:
605
+ # Create Pool for fold data, if need (for example, when categorical features are present)
606
+ fold_pool = Pool(
607
+ x,
608
+ y,
609
+ cat_features=self.cat_features,
610
+ text_features=self.text_features,
611
+ embedding_features=self.grouped_embedding_features,
612
+ )
613
+
614
+ # Get SHAP values of current estimator
615
+ shap_values_fold = estimator.get_feature_importance(data=fold_pool, type="ShapValues")
616
+
617
+ # Remove last columns (base value) and flatten
618
+ if self.target_type == ModelTaskType.MULTICLASS:
619
+ all_shaps = shap_values_fold[:, :, :-1]
620
+ all_shaps = [all_shaps[:, :, k].flatten() for k in range(all_shaps.shape[2])]
621
+ else:
622
+ all_shaps = shap_values_fold[:, :-1]
623
+ all_shaps = [all_shaps[:, k].flatten() for k in range(all_shaps.shape[1])]
624
+
625
+ all_shaps = np.abs(all_shaps)
626
+
627
+ return dict(zip(estimator.feature_names_, all_shaps))
628
+
629
+ except Exception:
630
+ self.logger.exception("Failed to recalculate new SHAP values")
631
+ return None
632
+
576
633
 
577
634
  class LightGBMWrapper(EstimatorWrapper):
578
635
  def __init__(
@@ -1,6 +1,6 @@
1
1
  import hashlib
2
2
  from logging import Logger, getLogger
3
- from typing import Dict, List
3
+ from typing import Dict, List, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -35,22 +35,25 @@ class Normalizer:
35
35
 
36
36
  def __init__(
37
37
  self,
38
- search_keys: Dict[str, SearchKey],
39
- generated_features: List[str],
40
38
  bundle: ResourceBundle = None,
41
39
  logger: Logger = None,
42
40
  warnings_counter: WarningCounter = None,
43
41
  silent_mode=False,
44
42
  ):
45
- self.search_keys = search_keys
46
- self.generated_features = generated_features
47
43
  self.bundle = bundle or get_custom_bundle()
48
44
  self.logger = logger or getLogger()
49
45
  self.warnings_counter = warnings_counter or WarningCounter()
50
46
  self.silent_mode = silent_mode
51
47
  self.columns_renaming = {}
48
+ self.search_keys = {}
49
+ self.generated_features = []
50
+
51
+ def normalize(
52
+ self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], generated_features: List[str]
53
+ ) -> Tuple[pd.DataFrame, Dict[str, SearchKey], List[str]]:
54
+ self.search_keys = search_keys.copy()
55
+ self.generated_features = generated_features.copy()
52
56
 
53
- def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
54
57
  df = df.copy()
55
58
  df = self._rename_columns(df)
56
59
 
@@ -68,21 +71,25 @@ class Normalizer:
68
71
 
69
72
  df = self.__convert_features_types(df)
70
73
 
71
- return df
74
+ return df, self.search_keys, self.generated_features
72
75
 
73
76
  def _rename_columns(self, df: pd.DataFrame):
74
77
  # logger.info("Replace restricted symbols in column names")
75
78
  new_columns = []
76
79
  dup_counter = 0
77
80
  for column in df.columns:
78
- if column in [
79
- TARGET,
80
- EVAL_SET_INDEX,
81
- SYSTEM_RECORD_ID,
82
- ENTITY_SYSTEM_RECORD_ID,
83
- SEARCH_KEY_UNNEST,
84
- DateTimeSearchKeyConverter.DATETIME_COL,
85
- ] + self.generated_features:
81
+ if (
82
+ column
83
+ in [
84
+ TARGET,
85
+ EVAL_SET_INDEX,
86
+ SYSTEM_RECORD_ID,
87
+ ENTITY_SYSTEM_RECORD_ID,
88
+ SEARCH_KEY_UNNEST,
89
+ DateTimeSearchKeyConverter.DATETIME_COL,
90
+ ]
91
+ + self.generated_features
92
+ ):
86
93
  self.columns_renaming[column] = column
87
94
  new_columns.append(column)
88
95
  continue
@@ -203,11 +203,18 @@ email_detected=Emails detected in column `{}`. It will be used as a search key\n
203
203
  email_detected_not_registered=Emails detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns
204
204
  phone_detected=Phone numbers detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns
205
205
  phone_detected_not_registered=\nWARNING: Phone numbers detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns
206
- target_type_detected=\nDetected task type: {}\n
206
+ target_type_detected=\nDetected task type: {}. Reason: {}\nYou can set task type manually with argument `model_task_type` of FeaturesEnricher constructor if task type detected incorrectly\n
207
+ binary_target_reason=only two unique label-values observed
208
+ non_numeric_multiclass_reason=non-numeric label values observed
209
+ few_unique_label_multiclass_reason=few unique label-values observed and can be considered as categorical
210
+ date_search_key_regression_reason=date search key is present, treating as regression
211
+ many_unique_label_regression_reason=many unique label-values or non-integer floating point values observed
212
+ limited_int_multiclass_reason=integer-like values with limited unique values observed
207
213
  # all_ok_community_invite=Chat with us in Slack community:
208
214
  all_ok_community_invite=❓ Support request
209
215
  too_small_for_metrics=Your train dataset or one of eval datasets contains less than 500 rows. For such dataset Upgini will not calculate accuracy metrics. Please increase the number of rows in the training dataset to calculate accuracy metrics
210
216
  imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} records in train dataset). \nDownsample classes with records more than {1}.
217
+ imbalanced_target=\nWARNING: Target is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
211
218
  loss_selection_info=Using loss `{}` for feature selection
212
219
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
213
220
 
@@ -9,6 +9,7 @@ from typing import Callable, List, Optional
9
9
 
10
10
  import pandas as pd
11
11
  from xhtml2pdf import pisa
12
+
12
13
  from upgini.__about__ import __version__
13
14
 
14
15
 
@@ -72,7 +73,9 @@ def make_table(df: pd.DataFrame, wrap_long_string=None) -> str:
72
73
  )
73
74
 
74
75
 
75
- def display_html_dataframe(df: pd.DataFrame, internal_df: pd.DataFrame, header: str):
76
+ def display_html_dataframe(
77
+ df: pd.DataFrame, internal_df: pd.DataFrame, header: str, display_id: Optional[str] = None, display_handle=None
78
+ ):
76
79
  if not ipython_available():
77
80
  print(header)
78
81
  print(internal_df)
@@ -133,7 +136,10 @@ def display_html_dataframe(df: pd.DataFrame, internal_df: pd.DataFrame, header:
133
136
  {table_html}
134
137
  </div>
135
138
  """
136
- display(HTML(result_html))
139
+ if display_handle:
140
+ return display_handle.update(HTML(result_html))
141
+ else:
142
+ return display(HTML(result_html), display_id=display_id)
137
143
 
138
144
 
139
145
  def make_html_report(
@@ -24,49 +24,83 @@ def define_task(
24
24
  ) -> ModelTaskType:
25
25
  if logger is None:
26
26
  logger = logging.getLogger()
27
+
28
+ # Replace inf and -inf with NaN to handle extreme values correctly
29
+ y = y.replace([np.inf, -np.inf], np.nan, inplace=False)
30
+
31
+ # Drop NaN values from the target
27
32
  target = y.dropna()
33
+
34
+ # Check if target is numeric and finite
28
35
  if is_numeric_dtype(target):
29
36
  target = target.loc[np.isfinite(target)]
30
37
  else:
38
+ # If not numeric, drop empty strings as well
31
39
  target = target.loc[target != ""]
40
+
41
+ # Raise error if there are no valid values left in the target
32
42
  if len(target) == 0:
33
43
  raise ValidationError(bundle.get("empty_target"))
44
+
45
+ # Count unique values in the target
34
46
  target_items = target.nunique()
47
+
48
+ # Raise error if all target values are the same
35
49
  if target_items == 1:
36
50
  raise ValidationError(bundle.get("dataset_constant_target"))
51
+
52
+ reason = "" # Will store the reason for selecting the task type
53
+
54
+ # Binary classification case: exactly two unique values
37
55
  if target_items == 2:
38
56
  task = ModelTaskType.BINARY
57
+ reason = bundle.get("binary_target_reason")
39
58
  else:
59
+ # Attempt to convert target to numeric
40
60
  try:
41
61
  target = pd.to_numeric(target)
42
62
  is_numeric = True
43
63
  except Exception:
44
64
  is_numeric = False
45
65
 
46
- # If any value is non numeric - multiclass
66
+ # If target cannot be converted to numeric, assume multiclass classification
47
67
  if not is_numeric:
48
68
  task = ModelTaskType.MULTICLASS
69
+ reason = bundle.get("non_numeric_multiclass_reason")
49
70
  else:
71
+ # Multiclass classification: few unique values and integer encoding
50
72
  if target.nunique() <= 50 and is_int_encoding(target.unique()):
51
73
  task = ModelTaskType.MULTICLASS
74
+ reason = bundle.get("few_unique_label_multiclass_reason")
75
+ # Regression case: if there is date, assume regression
52
76
  elif has_date:
53
77
  task = ModelTaskType.REGRESSION
78
+ reason = bundle.get("date_search_key_regression_reason")
54
79
  else:
80
+ # Remove zero values and recalculate unique ratio
55
81
  non_zero_target = target[target != 0]
56
82
  target_items = non_zero_target.nunique()
57
83
  target_ratio = target_items / len(non_zero_target)
84
+
85
+ # Use unique_ratio to determine whether to classify as regression or multiclass
58
86
  if (
59
- (target.dtype.kind == "f" and np.any(target != target.astype(int))) # any non integer
87
+ (target.dtype.kind == "f" and np.any(target != target.astype(int))) # Non-integer float values
60
88
  or target_items > 50
61
- or target_ratio > 0.2
89
+ or target_ratio > 0.2 # If non-zero values have high ratio of uniqueness
62
90
  ):
63
91
  task = ModelTaskType.REGRESSION
92
+ reason = bundle.get("many_unique_label_regression_reason")
64
93
  else:
65
94
  task = ModelTaskType.MULTICLASS
95
+ reason = bundle.get("limited_int_multiclass_reason")
66
96
 
67
- logger.info(f"Detected task type: {task}")
97
+ # Log or print the reason for the selected task type
98
+ logger.info(f"Detected task type: {task} (Reason: {reason})")
99
+
100
+ # Print task type and reason if silent mode is off
68
101
  if not silent:
69
- print(bundle.get("target_type_detected").format(task))
102
+ print(bundle.get("target_type_detected").format(task, reason))
103
+
70
104
  return task
71
105
 
72
106
 
@@ -81,8 +115,8 @@ def balance_undersample(
81
115
  target_column: str,
82
116
  task_type: ModelTaskType,
83
117
  random_state: int,
84
- imbalance_threshold: int = 0.2,
85
- min_sample_threshold: int = 5000,
118
+ binary_min_sample_threshold: int = 5000,
119
+ multiclass_min_sample_threshold: int = 25000,
86
120
  binary_bootstrap_loops: int = 5,
87
121
  multiclass_bootstrap_loops: int = 2,
88
122
  logger: Optional[logging.Logger] = None,
@@ -96,52 +130,60 @@ def balance_undersample(
96
130
  if SYSTEM_RECORD_ID not in df.columns:
97
131
  raise Exception("System record id must be presented for undersampling")
98
132
 
99
- count = len(df)
133
+ # count = len(df)
100
134
  target = df[target_column].copy()
101
- target_classes_count = target.nunique()
135
+ # target_classes_count = target.nunique()
102
136
 
103
137
  vc = target.value_counts()
104
138
  max_class_value = vc.index[0]
105
139
  min_class_value = vc.index[len(vc) - 1]
106
140
  max_class_count = vc[max_class_value]
107
141
  min_class_count = vc[min_class_value]
142
+ num_classes = len(vc)
108
143
 
109
- min_class_percent = imbalance_threshold / target_classes_count
110
- min_class_threshold = int(min_class_percent * count)
144
+ # min_class_percent = imbalance_threshold / target_classes_count
145
+ # min_class_threshold = int(min_class_percent * count)
111
146
 
112
147
  resampled_data = df
113
148
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
114
149
  if task_type == ModelTaskType.MULTICLASS:
115
- # Sort classes by rows count and find 25% quantile class
116
- classes = vc.index
117
- quantile25_idx = int(0.75 * len(classes)) - 1
118
- quantile25_class = classes[quantile25_idx]
119
- quantile25_class_cnt = vc[quantile25_class]
120
-
121
- if max_class_count > (quantile25_class_cnt * multiclass_bootstrap_loops):
122
- msg = bundle.get("imbalance_multiclass").format(quantile25_class, quantile25_class_cnt)
150
+ if len(df) > multiclass_min_sample_threshold and max_class_count > (
151
+ min_class_count * multiclass_bootstrap_loops
152
+ ):
153
+
154
+ # msg = bundle.get("imbalance_multiclass").format(min_class_value, min_class_count)
155
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
123
156
  logger.warning(msg)
124
157
  print(msg)
125
158
  if warning_counter:
126
159
  warning_counter.increment()
127
160
 
128
- # 25% and lower classes will stay as is. Higher classes will be downsampled
129
161
  sample_strategy = dict()
130
- for class_idx in range(quantile25_idx):
131
- # compare class count with count_of_quantile25_class * 2
132
- class_value = classes[class_idx]
162
+ for class_value in vc.index:
163
+ if class_value == min_class_value:
164
+ continue
133
165
  class_count = vc[class_value]
134
- sample_strategy[class_value] = min(class_count, quantile25_class_cnt * multiclass_bootstrap_loops)
166
+ sample_size = min(
167
+ class_count,
168
+ multiclass_bootstrap_loops
169
+ * (
170
+ min_class_count
171
+ + max((multiclass_min_sample_threshold - num_classes * min_class_count) / (num_classes - 1), 0)
172
+ ),
173
+ )
174
+ sample_strategy[class_value] = int(sample_size)
175
+ logger.info(f"Rebalance sample strategy: {sample_strategy}. Min class count: {min_class_count}")
135
176
  sampler = RandomUnderSampler(sampling_strategy=sample_strategy, random_state=random_state)
136
177
  X = df[SYSTEM_RECORD_ID]
137
178
  X = X.to_frame(SYSTEM_RECORD_ID)
138
179
  new_x, _ = sampler.fit_resample(X, target) # type: ignore
139
180
 
140
181
  resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
141
- elif len(df) > min_sample_threshold and min_class_count < min_sample_threshold / 2:
142
- msg = bundle.get("dataset_rarest_class_less_threshold").format(
143
- min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
144
- )
182
+ elif len(df) > binary_min_sample_threshold:
183
+ # msg = bundle.get("dataset_rarest_class_less_threshold").format(
184
+ # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
185
+ # )
186
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
145
187
  logger.warning(msg)
146
188
  print(msg)
147
189
  if warning_counter:
@@ -150,30 +192,38 @@ def balance_undersample(
150
192
  # fill up to min_sample_threshold by majority class
151
193
  minority_class = df[df[target_column] == min_class_value]
152
194
  majority_class = df[df[target_column] != min_class_value]
153
- sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
195
+ # sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
196
+ sample_size = min(
197
+ max_class_count,
198
+ binary_bootstrap_loops * (min_class_count + max(binary_min_sample_threshold - 2 * min_class_count, 0)),
199
+ )
200
+ logger.info(
201
+ f"Min class count: {min_class_count}. Max class count: {max_class_count}."
202
+ f" Rebalance sample size: {sample_size}"
203
+ )
154
204
  sampled_majority_class = majority_class.sample(n=sample_size, random_state=random_state)
155
205
  resampled_data = df[
156
206
  (df[SYSTEM_RECORD_ID].isin(minority_class[SYSTEM_RECORD_ID]))
157
207
  | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
158
208
  ]
159
209
 
160
- elif max_class_count > min_class_count * binary_bootstrap_loops:
161
- msg = bundle.get("dataset_rarest_class_less_threshold").format(
162
- min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
163
- )
164
- logger.warning(msg)
165
- print(msg)
166
- if warning_counter:
167
- warning_counter.increment()
168
-
169
- sampler = RandomUnderSampler(
170
- sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
171
- )
172
- X = df[SYSTEM_RECORD_ID]
173
- X = X.to_frame(SYSTEM_RECORD_ID)
174
- new_x, _ = sampler.fit_resample(X, target) # type: ignore
175
-
176
- resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
210
+ # elif max_class_count > min_class_count * binary_bootstrap_loops:
211
+ # msg = bundle.get("dataset_rarest_class_less_threshold").format(
212
+ # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
213
+ # )
214
+ # logger.warning(msg)
215
+ # print(msg)
216
+ # if warning_counter:
217
+ # warning_counter.increment()
218
+
219
+ # sampler = RandomUnderSampler(
220
+ # sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
221
+ # )
222
+ # X = df[SYSTEM_RECORD_ID]
223
+ # X = X.to_frame(SYSTEM_RECORD_ID)
224
+ # new_x, _ = sampler.fit_resample(X, target) # type: ignore
225
+
226
+ # resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
177
227
 
178
228
  logger.info(f"Shape after rebalance resampling: {resampled_data}")
179
229
  return resampled_data
@@ -1 +0,0 @@
1
- __version__ = "1.2.14a3616.dev3"
File without changes
File without changes