upgini 1.2.14__tar.gz → 1.2.14a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {upgini-1.2.14 → upgini-1.2.14a2}/PKG-INFO +1 -1
  2. upgini-1.2.14a2/src/upgini/__about__.py +1 -0
  3. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/dataset.py +6 -3
  4. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/features_enricher.py +25 -22
  5. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/normalizer/normalize_utils.py +15 -22
  6. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/resource_bundle/strings.properties +1 -0
  7. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/target_utils.py +52 -41
  8. upgini-1.2.14/src/upgini/__about__.py +0 -1
  9. {upgini-1.2.14 → upgini-1.2.14a2}/.gitignore +0 -0
  10. {upgini-1.2.14 → upgini-1.2.14a2}/LICENSE +0 -0
  11. {upgini-1.2.14 → upgini-1.2.14a2}/README.md +0 -0
  12. {upgini-1.2.14 → upgini-1.2.14a2}/pyproject.toml +0 -0
  13. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/__init__.py +0 -0
  14. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/ads.py +0 -0
  15. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/ads_management/__init__.py +0 -0
  16. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/ads_management/ads_manager.py +0 -0
  17. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/__init__.py +0 -0
  18. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/all_operands.py +0 -0
  19. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/binary.py +0 -0
  20. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/date.py +0 -0
  21. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/feature.py +0 -0
  22. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/groupby.py +0 -0
  23. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/operand.py +0 -0
  24. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/unary.py +0 -0
  25. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/autofe/vector.py +0 -0
  26. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/data_source/__init__.py +0 -0
  27. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/data_source/data_source_publisher.py +0 -0
  28. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/errors.py +0 -0
  29. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/http.py +0 -0
  30. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/lazy_import.py +0 -0
  31. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/mdc/__init__.py +0 -0
  32. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/mdc/context.py +0 -0
  33. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/metadata.py +0 -0
  34. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/metrics.py +0 -0
  35. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/normalizer/__init__.py +0 -0
  36. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/resource_bundle/__init__.py +0 -0
  37. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/resource_bundle/exceptions.py +0 -0
  38. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  39. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/sampler/__init__.py +0 -0
  40. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/sampler/base.py +0 -0
  41. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/sampler/random_under_sampler.py +0 -0
  42. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/sampler/utils.py +0 -0
  43. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/search_task.py +0 -0
  44. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/spinner.py +0 -0
  45. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/__init__.py +0 -0
  46. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/base_search_key_detector.py +0 -0
  47. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/blocked_time_series.py +0 -0
  48. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/country_utils.py +0 -0
  49. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/custom_loss_utils.py +0 -0
  50. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/cv_utils.py +0 -0
  51. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/datetime_utils.py +0 -0
  52. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/deduplicate_utils.py +0 -0
  53. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/display_utils.py +0 -0
  54. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/email_utils.py +0 -0
  55. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/fallback_progress_bar.py +0 -0
  56. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/features_validator.py +0 -0
  57. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/format.py +0 -0
  58. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/ip_utils.py +0 -0
  59. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/phone_utils.py +0 -0
  60. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/postal_code_utils.py +0 -0
  61. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/progress_bar.py +0 -0
  62. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/sklearn_ext.py +0 -0
  63. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/track_info.py +0 -0
  64. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/utils/warning_counter.py +0 -0
  65. {upgini-1.2.14 → upgini-1.2.14a2}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.14
3
+ Version: 1.2.14a2
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -0,0 +1 @@
1
+ __version__ = "1.2.14a2"
@@ -53,7 +53,8 @@ class Dataset: # (pd.DataFrame):
53
53
  FIT_SAMPLE_THRESHOLD = 200_000
54
54
  FIT_SAMPLE_WITH_EVAL_SET_ROWS = 200_000
55
55
  FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD = 200_000
56
- MIN_SAMPLE_THRESHOLD = 5_000
56
+ BINARY_MIN_SAMPLE_THRESHOLD = 5_000
57
+ MULTICLASS_MIN_SAMPLE_THRESHOLD = 25_000
57
58
  IMBALANCE_THESHOLD = 0.6
58
59
  BINARY_BOOTSTRAP_LOOPS = 5
59
60
  MULTICLASS_BOOTSTRAP_LOOPS = 2
@@ -225,7 +226,7 @@ class Dataset: # (pd.DataFrame):
225
226
  train_segment = self.data
226
227
 
227
228
  if self.task_type == ModelTaskType.MULTICLASS or (
228
- self.task_type == ModelTaskType.BINARY and len(train_segment) > self.MIN_SAMPLE_THRESHOLD
229
+ self.task_type == ModelTaskType.BINARY and len(train_segment) > self.BINARY_MIN_SAMPLE_THRESHOLD
229
230
  ):
230
231
  count = len(train_segment)
231
232
  target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
@@ -253,6 +254,7 @@ class Dataset: # (pd.DataFrame):
253
254
  min_class_percent = self.IMBALANCE_THESHOLD / target_classes_count
254
255
  min_class_threshold = min_class_percent * count
255
256
 
257
+ # If min class count less than 30% for binary or (60 / classes_count)% for multiclass
256
258
  if min_class_count < min_class_threshold:
257
259
  self.imbalanced = True
258
260
  self.data = balance_undersample(
@@ -260,7 +262,8 @@ class Dataset: # (pd.DataFrame):
260
262
  target_column=target_column,
261
263
  task_type=self.task_type,
262
264
  random_state=self.random_state,
263
- imbalance_threshold=self.IMBALANCE_THESHOLD,
265
+ binary_min_sample_threshold=self.BINARY_MIN_SAMPLE_THRESHOLD,
266
+ multiclass_min_sample_threshold=self.MULTICLASS_MIN_SAMPLE_THRESHOLD,
264
267
  binary_bootstrap_loops=self.BINARY_BOOTSTRAP_LOOPS,
265
268
  multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
266
269
  logger=self.logger,
@@ -1577,8 +1577,8 @@ class FeaturesEnricher(TransformerMixin):
1577
1577
  df = generator.generate(df)
1578
1578
  generated_features.extend(generator.generated_features)
1579
1579
 
1580
- normalizer = Normalizer(self.bundle, self.logger, self.warning_counter)
1581
- df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
1580
+ normalizer = Normalizer(search_keys, generated_features, self.bundle, self.logger, self.warning_counter)
1581
+ df = normalizer.normalize(df)
1582
1582
  columns_renaming = normalizer.columns_renaming
1583
1583
 
1584
1584
  df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
@@ -2017,8 +2017,10 @@ class FeaturesEnricher(TransformerMixin):
2017
2017
  df = generator.generate(df)
2018
2018
  generated_features.extend(generator.generated_features)
2019
2019
 
2020
- normalizer = Normalizer(self.bundle, self.logger, self.warning_counter, silent_mode)
2021
- df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2020
+ normalizer = Normalizer(
2021
+ search_keys, generated_features, self.bundle, self.logger, self.warning_counter, silent_mode
2022
+ )
2023
+ df = normalizer.normalize(df)
2022
2024
  columns_renaming = normalizer.columns_renaming
2023
2025
 
2024
2026
  # Don't pass all features in backend on transform
@@ -2447,13 +2449,14 @@ class FeaturesEnricher(TransformerMixin):
2447
2449
  if is_numeric_dtype(df[self.TARGET_NAME]) and has_date:
2448
2450
  self._validate_PSI(df.sort_values(by=maybe_date_column))
2449
2451
 
2450
- normalizer = Normalizer(self.bundle, self.logger, self.warning_counter)
2451
- df, self.fit_search_keys, self.fit_generated_features = normalizer.normalize(
2452
- df, self.fit_search_keys, self.fit_generated_features
2453
- )
2454
- self.fit_columns_renaming = normalizer.columns_renaming
2452
+ self.__adjust_cv(df, maybe_date_column, self.model_task_type)
2455
2453
 
2456
- self.__adjust_cv(df)
2454
+ normalizer = Normalizer(
2455
+ self.fit_search_keys, self.fit_generated_features, self.bundle, self.logger, self.warning_counter
2456
+ )
2457
+ df = normalizer.normalize(df)
2458
+ columns_renaming = normalizer.columns_renaming
2459
+ self.fit_columns_renaming = columns_renaming
2457
2460
 
2458
2461
  df = remove_fintech_duplicates(
2459
2462
  df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
@@ -2467,7 +2470,7 @@ class FeaturesEnricher(TransformerMixin):
2467
2470
  self.df_with_original_index = df.copy()
2468
2471
  # TODO check maybe need to drop _time column from df_with_original_index
2469
2472
 
2470
- df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, self.fit_columns_renaming)
2473
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, columns_renaming)
2471
2474
 
2472
2475
  # Convert EMAIL to HEM after unnesting to do it only with one column
2473
2476
  email_column = self._get_email_column(self.fit_search_keys)
@@ -2477,7 +2480,7 @@ class FeaturesEnricher(TransformerMixin):
2477
2480
  email_column,
2478
2481
  hem_column,
2479
2482
  self.fit_search_keys,
2480
- self.fit_columns_renaming,
2483
+ columns_renaming,
2481
2484
  list(unnest_search_keys.keys()),
2482
2485
  self.logger,
2483
2486
  )
@@ -2488,7 +2491,7 @@ class FeaturesEnricher(TransformerMixin):
2488
2491
  converter = IpSearchKeyConverter(
2489
2492
  ip_column,
2490
2493
  self.fit_search_keys,
2491
- self.fit_columns_renaming,
2494
+ columns_renaming,
2492
2495
  list(unnest_search_keys.keys()),
2493
2496
  self.bundle,
2494
2497
  self.logger,
@@ -2519,7 +2522,7 @@ class FeaturesEnricher(TransformerMixin):
2519
2522
  features_columns = [c for c in df.columns if c not in non_feature_columns]
2520
2523
 
2521
2524
  features_to_drop = FeaturesValidator(self.logger).validate(
2522
- df, features_columns, self.generate_features, self.warning_counter, self.fit_columns_renaming
2525
+ df, features_columns, self.generate_features, self.warning_counter, columns_renaming
2523
2526
  )
2524
2527
  self.fit_dropped_features.update(features_to_drop)
2525
2528
  df = df.drop(columns=features_to_drop)
@@ -2560,7 +2563,7 @@ class FeaturesEnricher(TransformerMixin):
2560
2563
  rest_client=self.rest_client,
2561
2564
  logger=self.logger,
2562
2565
  )
2563
- dataset.columns_renaming = self.fit_columns_renaming
2566
+ dataset.columns_renaming = columns_renaming
2564
2567
 
2565
2568
  self.passed_features = [
2566
2569
  column for column, meaning_type in meaning_types.items() if meaning_type == FileColumnMeaningType.FEATURE
@@ -2707,24 +2710,24 @@ class FeaturesEnricher(TransformerMixin):
2707
2710
  if not self.warning_counter.has_warnings():
2708
2711
  self.__display_support_link(self.bundle.get("all_ok_community_invite"))
2709
2712
 
2710
- def __adjust_cv(self, df: pd.DataFrame):
2711
- date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2713
+ def __adjust_cv(self, df: pd.DataFrame, date_column: pd.Series, model_task_type: ModelTaskType):
2712
2714
  # Check Multivariate time series
2713
2715
  if (
2714
2716
  self.cv is None
2715
2717
  and date_column
2716
- and self.model_task_type == ModelTaskType.REGRESSION
2718
+ and model_task_type == ModelTaskType.REGRESSION
2717
2719
  and len({SearchKey.PHONE, SearchKey.EMAIL, SearchKey.HEM}.intersection(self.fit_search_keys.keys())) == 0
2718
2720
  and is_blocked_time_series(df, date_column, list(self.fit_search_keys.keys()) + [TARGET])
2719
2721
  ):
2720
2722
  msg = self.bundle.get("multivariate_timeseries_detected")
2721
2723
  self.__override_cv(CVType.blocked_time_series, msg, print_warning=False)
2722
- elif self.cv is None and self.model_task_type != ModelTaskType.REGRESSION:
2724
+ elif (
2725
+ self.cv is None
2726
+ and model_task_type != ModelTaskType.REGRESSION
2727
+ and self._get_group_columns(df, self.fit_search_keys)
2728
+ ):
2723
2729
  msg = self.bundle.get("group_k_fold_in_classification")
2724
2730
  self.__override_cv(CVType.group_k_fold, msg, print_warning=self.cv is not None)
2725
- group_columns = self._get_group_columns(df, self.fit_search_keys)
2726
- self.runtime_parameters.properties["cv_params.group_columns"] = ",".join(group_columns)
2727
- self.runtime_parameters.properties["cv_params.shuffle_kfold"] = "True"
2728
2731
 
2729
2732
  def __override_cv(self, cv: CVType, msg: str, print_warning: bool = True):
2730
2733
  if print_warning:
@@ -1,6 +1,6 @@
1
1
  import hashlib
2
2
  from logging import Logger, getLogger
3
- from typing import Dict, List, Tuple
3
+ from typing import Dict, List
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -35,25 +35,22 @@ class Normalizer:
35
35
 
36
36
  def __init__(
37
37
  self,
38
+ search_keys: Dict[str, SearchKey],
39
+ generated_features: List[str],
38
40
  bundle: ResourceBundle = None,
39
41
  logger: Logger = None,
40
42
  warnings_counter: WarningCounter = None,
41
43
  silent_mode=False,
42
44
  ):
45
+ self.search_keys = search_keys
46
+ self.generated_features = generated_features
43
47
  self.bundle = bundle or get_custom_bundle()
44
48
  self.logger = logger or getLogger()
45
49
  self.warnings_counter = warnings_counter or WarningCounter()
46
50
  self.silent_mode = silent_mode
47
51
  self.columns_renaming = {}
48
- self.search_keys = {}
49
- self.generated_features = []
50
-
51
- def normalize(
52
- self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], generated_features: List[str]
53
- ) -> Tuple[pd.DataFrame, Dict[str, SearchKey], List[str]]:
54
- self.search_keys = search_keys.copy()
55
- self.generated_features = generated_features.copy()
56
52
 
53
+ def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
57
54
  df = df.copy()
58
55
  df = self._rename_columns(df)
59
56
 
@@ -71,25 +68,21 @@ class Normalizer:
71
68
 
72
69
  df = self.__convert_features_types(df)
73
70
 
74
- return df, self.search_keys, self.generated_features
71
+ return df
75
72
 
76
73
  def _rename_columns(self, df: pd.DataFrame):
77
74
  # logger.info("Replace restricted symbols in column names")
78
75
  new_columns = []
79
76
  dup_counter = 0
80
77
  for column in df.columns:
81
- if (
82
- column
83
- in [
84
- TARGET,
85
- EVAL_SET_INDEX,
86
- SYSTEM_RECORD_ID,
87
- ENTITY_SYSTEM_RECORD_ID,
88
- SEARCH_KEY_UNNEST,
89
- DateTimeSearchKeyConverter.DATETIME_COL,
90
- ]
91
- + self.generated_features
92
- ):
78
+ if column in [
79
+ TARGET,
80
+ EVAL_SET_INDEX,
81
+ SYSTEM_RECORD_ID,
82
+ ENTITY_SYSTEM_RECORD_ID,
83
+ SEARCH_KEY_UNNEST,
84
+ DateTimeSearchKeyConverter.DATETIME_COL,
85
+ ] + self.generated_features:
93
86
  self.columns_renaming[column] = column
94
87
  new_columns.append(column)
95
88
  continue
@@ -208,6 +208,7 @@ target_type_detected=\nDetected task type: {}\n
208
208
  all_ok_community_invite=❓ Support request
209
209
  too_small_for_metrics=Your train dataset or one of eval datasets contains less than 500 rows. For such dataset Upgini will not calculate accuracy metrics. Please increase the number of rows in the training dataset to calculate accuracy metrics
210
210
  imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} records in train dataset). \nDownsample classes with records more than {1}.
211
+ imbalanced_target=\nWARNING: Target is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
211
212
  loss_selection_info=Using loss `{}` for feature selection
212
213
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
213
214
 
@@ -81,8 +81,8 @@ def balance_undersample(
81
81
  target_column: str,
82
82
  task_type: ModelTaskType,
83
83
  random_state: int,
84
- imbalance_threshold: int = 0.2,
85
- min_sample_threshold: int = 5000,
84
+ binary_min_sample_threshold: int = 5000,
85
+ multiclass_min_sample_threshold: int = 25000,
86
86
  binary_bootstrap_loops: int = 5,
87
87
  multiclass_bootstrap_loops: int = 2,
88
88
  logger: Optional[logging.Logger] = None,
@@ -96,52 +96,59 @@ def balance_undersample(
96
96
  if SYSTEM_RECORD_ID not in df.columns:
97
97
  raise Exception("System record id must be presented for undersampling")
98
98
 
99
- count = len(df)
99
+ # count = len(df)
100
100
  target = df[target_column].copy()
101
- target_classes_count = target.nunique()
101
+ # target_classes_count = target.nunique()
102
102
 
103
103
  vc = target.value_counts()
104
104
  max_class_value = vc.index[0]
105
105
  min_class_value = vc.index[len(vc) - 1]
106
106
  max_class_count = vc[max_class_value]
107
107
  min_class_count = vc[min_class_value]
108
+ num_classes = len(vc)
108
109
 
109
- min_class_percent = imbalance_threshold / target_classes_count
110
- min_class_threshold = int(min_class_percent * count)
110
+ # min_class_percent = imbalance_threshold / target_classes_count
111
+ # min_class_threshold = int(min_class_percent * count)
111
112
 
112
113
  resampled_data = df
113
114
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
114
115
  if task_type == ModelTaskType.MULTICLASS:
115
- # Sort classes by rows count and find 25% quantile class
116
- classes = vc.index
117
- quantile25_idx = int(0.75 * len(classes)) - 1
118
- quantile25_class = classes[quantile25_idx]
119
- quantile25_class_cnt = vc[quantile25_class]
120
-
121
- if max_class_count > (quantile25_class_cnt * multiclass_bootstrap_loops):
122
- msg = bundle.get("imbalance_multiclass").format(quantile25_class, quantile25_class_cnt)
116
+ if len(df) > multiclass_min_sample_threshold and max_class_count > (
117
+ min_class_count * multiclass_bootstrap_loops
118
+ ):
119
+
120
+ # msg = bundle.get("imbalance_multiclass").format(min_class_value, min_class_count)
121
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
123
122
  logger.warning(msg)
124
123
  print(msg)
125
124
  if warning_counter:
126
125
  warning_counter.increment()
127
126
 
128
- # 25% and lower classes will stay as is. Higher classes will be downsampled
129
127
  sample_strategy = dict()
130
- for class_idx in range(quantile25_idx):
131
- # compare class count with count_of_quantile25_class * 2
132
- class_value = classes[class_idx]
128
+ for class_value in vc.index:
129
+ if class_value == min_class_value:
130
+ continue
133
131
  class_count = vc[class_value]
134
- sample_strategy[class_value] = min(class_count, quantile25_class_cnt * multiclass_bootstrap_loops)
132
+ sample_size = min(
133
+ class_count,
134
+ multiclass_bootstrap_loops
135
+ * (
136
+ min_class_count
137
+ + max((multiclass_min_sample_threshold - num_classes * min_class_count) / (num_classes - 1), 0)
138
+ ),
139
+ )
140
+ sample_strategy[class_value] = int(sample_size)
135
141
  sampler = RandomUnderSampler(sampling_strategy=sample_strategy, random_state=random_state)
136
142
  X = df[SYSTEM_RECORD_ID]
137
143
  X = X.to_frame(SYSTEM_RECORD_ID)
138
144
  new_x, _ = sampler.fit_resample(X, target) # type: ignore
139
145
 
140
146
  resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
141
- elif len(df) > min_sample_threshold and min_class_count < min_sample_threshold / 2:
142
- msg = bundle.get("dataset_rarest_class_less_threshold").format(
143
- min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
144
- )
147
+ elif len(df) > binary_min_sample_threshold:
148
+ # msg = bundle.get("dataset_rarest_class_less_threshold").format(
149
+ # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
150
+ # )
151
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
145
152
  logger.warning(msg)
146
153
  print(msg)
147
154
  if warning_counter:
@@ -150,30 +157,34 @@ def balance_undersample(
150
157
  # fill up to min_sample_threshold by majority class
151
158
  minority_class = df[df[target_column] == min_class_value]
152
159
  majority_class = df[df[target_column] != min_class_value]
153
- sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
160
+ # sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
161
+ sample_size = min(
162
+ max_class_count,
163
+ binary_bootstrap_loops * (min_class_count + max(binary_min_sample_threshold - 2 * min_class_count, 0)),
164
+ )
154
165
  sampled_majority_class = majority_class.sample(n=sample_size, random_state=random_state)
155
166
  resampled_data = df[
156
167
  (df[SYSTEM_RECORD_ID].isin(minority_class[SYSTEM_RECORD_ID]))
157
168
  | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
158
169
  ]
159
170
 
160
- elif max_class_count > min_class_count * binary_bootstrap_loops:
161
- msg = bundle.get("dataset_rarest_class_less_threshold").format(
162
- min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
163
- )
164
- logger.warning(msg)
165
- print(msg)
166
- if warning_counter:
167
- warning_counter.increment()
168
-
169
- sampler = RandomUnderSampler(
170
- sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
171
- )
172
- X = df[SYSTEM_RECORD_ID]
173
- X = X.to_frame(SYSTEM_RECORD_ID)
174
- new_x, _ = sampler.fit_resample(X, target) # type: ignore
175
-
176
- resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
171
+ # elif max_class_count > min_class_count * binary_bootstrap_loops:
172
+ # msg = bundle.get("dataset_rarest_class_less_threshold").format(
173
+ # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
174
+ # )
175
+ # logger.warning(msg)
176
+ # print(msg)
177
+ # if warning_counter:
178
+ # warning_counter.increment()
179
+
180
+ # sampler = RandomUnderSampler(
181
+ # sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
182
+ # )
183
+ # X = df[SYSTEM_RECORD_ID]
184
+ # X = X.to_frame(SYSTEM_RECORD_ID)
185
+ # new_x, _ = sampler.fit_resample(X, target) # type: ignore
186
+
187
+ # resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
177
188
 
178
189
  logger.info(f"Shape after rebalance resampling: {resampled_data}")
179
190
  return resampled_data
@@ -1 +0,0 @@
1
- __version__ = "1.2.14"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes