upgini 1.2.31a1__tar.gz → 1.2.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {upgini-1.2.31a1 → upgini-1.2.32}/PKG-INFO +1 -1
  2. {upgini-1.2.31a1 → upgini-1.2.32}/pyproject.toml +3 -3
  3. upgini-1.2.32/src/upgini/__about__.py +1 -0
  4. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/dataset.py +24 -10
  5. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/features_enricher.py +21 -2
  6. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/resource_bundle/strings.properties +1 -0
  7. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/target_utils.py +70 -35
  8. upgini-1.2.31a1/src/upgini/__about__.py +0 -1
  9. {upgini-1.2.31a1 → upgini-1.2.32}/.gitignore +0 -0
  10. {upgini-1.2.31a1 → upgini-1.2.32}/LICENSE +0 -0
  11. {upgini-1.2.31a1 → upgini-1.2.32}/README.md +0 -0
  12. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/__init__.py +0 -0
  13. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/ads.py +0 -0
  14. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/ads_management/__init__.py +0 -0
  15. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/ads_management/ads_manager.py +0 -0
  16. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/__init__.py +0 -0
  17. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/all_operands.py +0 -0
  18. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/binary.py +0 -0
  19. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/date.py +0 -0
  20. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/feature.py +0 -0
  21. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/groupby.py +0 -0
  22. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/operand.py +0 -0
  23. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/unary.py +0 -0
  24. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/autofe/vector.py +0 -0
  25. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/data_source/__init__.py +0 -0
  26. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/data_source/data_source_publisher.py +0 -0
  27. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/errors.py +0 -0
  28. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/http.py +0 -0
  29. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/lazy_import.py +0 -0
  30. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/mdc/__init__.py +0 -0
  31. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/mdc/context.py +0 -0
  32. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/metadata.py +0 -0
  33. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/metrics.py +0 -0
  34. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/normalizer/__init__.py +0 -0
  35. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/normalizer/normalize_utils.py +0 -0
  36. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/resource_bundle/__init__.py +0 -0
  37. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/resource_bundle/exceptions.py +0 -0
  38. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  39. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/sampler/__init__.py +0 -0
  40. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/sampler/base.py +0 -0
  41. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/sampler/random_under_sampler.py +0 -0
  42. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/sampler/utils.py +0 -0
  43. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/search_task.py +0 -0
  44. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/spinner.py +0 -0
  45. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/Roboto-Regular.ttf +0 -0
  46. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/__init__.py +0 -0
  47. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/base_search_key_detector.py +0 -0
  48. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/blocked_time_series.py +0 -0
  49. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/country_utils.py +0 -0
  50. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/custom_loss_utils.py +0 -0
  51. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/cv_utils.py +0 -0
  52. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/datetime_utils.py +0 -0
  53. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/deduplicate_utils.py +0 -0
  54. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/display_utils.py +0 -0
  55. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/email_utils.py +0 -0
  56. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/fallback_progress_bar.py +0 -0
  57. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/feature_info.py +0 -0
  58. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/features_validator.py +0 -0
  59. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/format.py +0 -0
  60. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/ip_utils.py +0 -0
  61. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/phone_utils.py +0 -0
  62. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/postal_code_utils.py +0 -0
  63. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/progress_bar.py +0 -0
  64. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/sklearn_ext.py +0 -0
  65. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/track_info.py +0 -0
  66. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/utils/warning_counter.py +0 -0
  67. {upgini-1.2.31a1 → upgini-1.2.32}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.31a1
3
+ Version: 1.2.32
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -77,7 +77,7 @@ include = [
77
77
 
78
78
  [tool.hatch.envs.default]
79
79
  type = "virtual"
80
- python = "3.10"
80
+ python = "3.11"
81
81
 
82
82
  [tool.hatch.envs.test.scripts]
83
83
  cov = 'pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=upgini --cov=tests'
@@ -90,11 +90,11 @@ python = ["3.8"]
90
90
  pandas = ["1.1.0"]
91
91
 
92
92
  [[tool.hatch.envs.test.matrix]]
93
- python = ["3.8", "3.9", "3.10"]
93
+ python = ["3.8", "3.9", "3.10", "3.11"]
94
94
  pandas = ["1.2.0", "1.3.0", "1.4.0", "1.5.0", "2.0.0"]
95
95
 
96
96
  [[tool.hatch.envs.test.matrix]]
97
- python = ["3.9", "3.10"]
97
+ python = ["3.9", "3.10", "3.11"]
98
98
  pandas = ["2.1.0", "2.2.0"]
99
99
 
100
100
  # from versions: 0.1, 0.2, 0.3.0, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.5.0, 0.6.0, 0.6.1, 0.7.0, 0.7.1, 0.7.2, 0.7.3, 0.8.0, 0.8.1, 0.9.0, 0.9.1, 0.10.0, 0.10.1, 0.11.0, 0.12.0, 0.13.0, 0.13.1, 0.14.0, 0.14.1, 0.15.0, 0.15.1, 0.15.2, 0.16.0, 0.16.1, 0.16.2, 0.17.0, 0.17.1, 0.18.0, 0.18.1, 0.19.0, 0.19.1, 0.19.2, 0.20.0, 0.20.1, 0.20.2, 0.20.3, 0.21.0, 0.21.1, 0.22.0, 0.23.0, 0.23.1, 0.23.2, 0.23.3, 0.23.4, 0.24.0, 0.24.1, 0.24.2, 0.25.0, 0.25.1, 0.25.2, 0.25.3, 1.0.0, 1.0.1, 1.0.2, 1.0.3, 1.0.4, 1.0.5, 1.1.0, 1.1.1, 1.1.2, 1.1.3, 1.1.4, 1.1.5, 1.2.0, 1.2.1, 1.2.2, 1.2.3, 1.2.4, 1.2.5, 1.3.0, 1.3.1, 1.3.2, 1.3.3, 1.3.4, 1.3.5, 1.4.0rc0, 1.4.0, 1.4.1, 1.4.2, 1.4.3, 1.4.4, 1.5.0rc0, 1.5.0, 1.5.1, 1.5.2, 1.5.3, 2.0.0rc0, 2.0.0rc1, 2.0.0, 2.0.1, 2.0.2, 2.0.3
@@ -0,0 +1 @@
1
+ __version__ = "1.2.32"
@@ -36,15 +36,13 @@ from upgini.metadata import (
36
36
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
37
37
  from upgini.search_task import SearchTask
38
38
  from upgini.utils.email_utils import EmailSearchKeyConverter
39
- from upgini.utils.target_utils import balance_undersample
39
+ from upgini.utils.target_utils import balance_undersample, balance_undersample_forced
40
40
 
41
41
  try:
42
42
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
43
43
  except Exception:
44
44
  from upgini.utils.fallback_progress_bar import CustomFallbackProgressBar as ProgressBar
45
45
 
46
- from upgini.utils.warning_counter import WarningCounter
47
-
48
46
 
49
47
  class Dataset: # (pd.DataFrame):
50
48
  MIN_ROWS_COUNT = 100
@@ -64,6 +62,7 @@ class Dataset: # (pd.DataFrame):
64
62
  MAX_FEATURES_COUNT = 3500
65
63
  MAX_UPLOADING_FILE_SIZE = 268435456 # 256 Mb
66
64
  MAX_STRING_FEATURE_LENGTH = 24573
65
+ FORCE_SAMPLE_SIZE = 7_000
67
66
 
68
67
  def __init__(
69
68
  self,
@@ -78,8 +77,8 @@ class Dataset: # (pd.DataFrame):
78
77
  random_state: Optional[int] = None,
79
78
  rest_client: Optional[_RestClient] = None,
80
79
  logger: Optional[logging.Logger] = None,
81
- warning_counter: Optional[WarningCounter] = None,
82
80
  bundle: Optional[ResourceBundle] = None,
81
+ warning_callback: Optional[Callable] = None,
83
82
  **kwargs,
84
83
  ):
85
84
  self.bundle = bundle or get_custom_bundle()
@@ -122,7 +121,7 @@ class Dataset: # (pd.DataFrame):
122
121
  else:
123
122
  self.logger = logging.getLogger()
124
123
  self.logger.setLevel("FATAL")
125
- self.warning_counter = warning_counter or WarningCounter()
124
+ self.warning_callback = warning_callback
126
125
 
127
126
  def __len__(self):
128
127
  return len(self.data) if self.data is not None else None
@@ -217,9 +216,23 @@ class Dataset: # (pd.DataFrame):
217
216
  self.logger.exception("Failed to cast target to float for timeseries task type")
218
217
  raise ValidationError(self.bundle.get("dataset_invalid_timeseries_target").format(target.dtype))
219
218
 
220
- def __resample(self):
219
+ def __resample(self, force_downsampling=False):
221
220
  # self.logger.info("Resampling etalon")
222
221
  # Resample imbalanced target. Only train segment (without eval_set)
222
+ if force_downsampling:
223
+ target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
224
+ self.data = balance_undersample_forced(
225
+ df=self.data,
226
+ target_column=target_column,
227
+ task_type=self.task_type,
228
+ random_state=self.random_state,
229
+ sample_size=self.FORCE_SAMPLE_SIZE,
230
+ logger=self.logger,
231
+ bundle=self.bundle,
232
+ warning_callback=self.warning_callback,
233
+ )
234
+ return
235
+
223
236
  if EVAL_SET_INDEX in self.data.columns:
224
237
  train_segment = self.data[self.data[EVAL_SET_INDEX] == 0]
225
238
  else:
@@ -268,7 +281,7 @@ class Dataset: # (pd.DataFrame):
268
281
  multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
269
282
  logger=self.logger,
270
283
  bundle=self.bundle,
271
- warning_counter=self.warning_counter,
284
+ warning_callback=self.warning_callback,
272
285
  )
273
286
 
274
287
  # Resample over fit threshold
@@ -418,13 +431,13 @@ class Dataset: # (pd.DataFrame):
418
431
  if len(self.data) == 0:
419
432
  raise ValidationError(self.bundle.get("all_search_keys_invalid"))
420
433
 
421
- def validate(self, validate_target: bool = True, silent_mode: bool = False):
434
+ def validate(self, validate_target: bool = True, silent_mode: bool = False, force_downsampling: bool = False):
422
435
  self.__validate_dataset(validate_target, silent_mode)
423
436
 
424
437
  if validate_target:
425
438
  self.__validate_target()
426
439
 
427
- self.__resample()
440
+ self.__resample(force_downsampling)
428
441
 
429
442
  self.__validate_min_rows_count()
430
443
 
@@ -573,9 +586,10 @@ class Dataset: # (pd.DataFrame):
573
586
  max_features: Optional[int] = None, # deprecated
574
587
  filter_features: Optional[dict] = None, # deprecated
575
588
  runtime_parameters: Optional[RuntimeParameters] = None,
589
+ force_downsampling: bool = False,
576
590
  ) -> SearchTask:
577
591
  if self.etalon_def is None:
578
- self.validate()
592
+ self.validate(force_downsampling=force_downsampling)
579
593
  file_metrics = FileMetrics()
580
594
 
581
595
  runtime_parameters = self._rename_generate_features(runtime_parameters)
@@ -231,6 +231,7 @@ class FeaturesEnricher(TransformerMixin):
231
231
  custom_bundle_config: Optional[str] = None,
232
232
  add_date_if_missing: bool = True,
233
233
  select_features: bool = False,
234
+ disable_force_downsampling: bool = False,
234
235
  **kwargs,
235
236
  ):
236
237
  self.bundle = get_custom_bundle(custom_bundle_config)
@@ -288,6 +289,7 @@ class FeaturesEnricher(TransformerMixin):
288
289
  self.feature_importances_ = []
289
290
  self.search_id = search_id
290
291
  self.select_features = select_features
292
+ self.disable_force_downsampling = disable_force_downsampling
291
293
 
292
294
  if search_id:
293
295
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
@@ -2251,6 +2253,8 @@ class FeaturesEnricher(TransformerMixin):
2251
2253
  date_format=self.date_format,
2252
2254
  rest_client=self.rest_client,
2253
2255
  logger=self.logger,
2256
+ bundle=self.bundle,
2257
+ warning_callback=self.__log_warning,
2254
2258
  )
2255
2259
  dataset.columns_renaming = columns_renaming
2256
2260
 
@@ -2696,6 +2700,18 @@ class FeaturesEnricher(TransformerMixin):
2696
2700
 
2697
2701
  combined_search_keys = combine_search_keys(self.fit_search_keys.keys())
2698
2702
 
2703
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2704
+
2705
+ # Force downsampling to 7000 for API features generation
2706
+ force_downsampling = (
2707
+ not self.disable_force_downsampling
2708
+ and self.generate_features is not None
2709
+ and phone_column is not None
2710
+ and self.fit_columns_renaming[phone_column] in self.generate_features
2711
+ )
2712
+ if force_downsampling and len(df) > Dataset.FORCE_SAMPLE_SIZE:
2713
+ runtime_parameters.properties["fast_fit"] = True
2714
+
2699
2715
  dataset = Dataset(
2700
2716
  "tds_" + str(uuid.uuid4()),
2701
2717
  df=df,
@@ -2707,6 +2723,8 @@ class FeaturesEnricher(TransformerMixin):
2707
2723
  random_state=self.random_state,
2708
2724
  rest_client=self.rest_client,
2709
2725
  logger=self.logger,
2726
+ bundle=self.bundle,
2727
+ warning_callback=self.__log_warning,
2710
2728
  )
2711
2729
  dataset.columns_renaming = self.fit_columns_renaming
2712
2730
 
@@ -2720,8 +2738,9 @@ class FeaturesEnricher(TransformerMixin):
2720
2738
  start_time=start_time,
2721
2739
  progress_callback=progress_callback,
2722
2740
  extract_features=True,
2723
- runtime_parameters=self._get_copy_of_runtime_parameters(),
2741
+ runtime_parameters=runtime_parameters,
2724
2742
  exclude_features_sources=exclude_features_sources,
2743
+ force_downsampling=force_downsampling,
2725
2744
  )
2726
2745
 
2727
2746
  if search_id_callback is not None:
@@ -3521,7 +3540,7 @@ class FeaturesEnricher(TransformerMixin):
3521
3540
  return result_train, result_eval_sets
3522
3541
 
3523
3542
  def __prepare_feature_importances(
3524
- self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3543
+ self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3525
3544
  ):
3526
3545
  if self._search_task is None:
3527
3546
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
@@ -215,6 +215,7 @@ imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} r
215
215
  imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
216
216
  loss_selection_info=Using loss `{}` for feature selection
217
217
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
218
+ forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
218
219
 
219
220
  # Validation table
220
221
  validation_column_name_header=Column name
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Optional, Union
2
+ from typing import Callable, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -9,7 +9,6 @@ from upgini.errors import ValidationError
9
9
  from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
10
10
  from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
11
11
  from upgini.sampler.random_under_sampler import RandomUnderSampler
12
- from upgini.utils.warning_counter import WarningCounter
13
12
 
14
13
 
15
14
  def correct_string_target(y: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
@@ -121,7 +120,7 @@ def balance_undersample(
121
120
  multiclass_bootstrap_loops: int = 2,
122
121
  logger: Optional[logging.Logger] = None,
123
122
  bundle: Optional[ResourceBundle] = None,
124
- warning_counter: Optional[WarningCounter] = None,
123
+ warning_callback: Optional[Callable] = None,
125
124
  ) -> pd.DataFrame:
126
125
  if logger is None:
127
126
  logger = logging.getLogger("muted_logger")
@@ -130,9 +129,7 @@ def balance_undersample(
130
129
  if SYSTEM_RECORD_ID not in df.columns:
131
130
  raise Exception("System record id must be presented for undersampling")
132
131
 
133
- # count = len(df)
134
132
  target = df[target_column].copy()
135
- # target_classes_count = target.nunique()
136
133
 
137
134
  vc = target.value_counts()
138
135
  max_class_value = vc.index[0]
@@ -141,9 +138,6 @@ def balance_undersample(
141
138
  min_class_count = vc[min_class_value]
142
139
  num_classes = len(vc)
143
140
 
144
- # min_class_percent = imbalance_threshold / target_classes_count
145
- # min_class_threshold = int(min_class_percent * count)
146
-
147
141
  resampled_data = df
148
142
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
149
143
  if task_type == ModelTaskType.MULTICLASS:
@@ -151,12 +145,10 @@ def balance_undersample(
151
145
  min_class_count * multiclass_bootstrap_loops
152
146
  ):
153
147
 
154
- # msg = bundle.get("imbalance_multiclass").format(min_class_value, min_class_count)
155
148
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
156
149
  logger.warning(msg)
157
- print(msg)
158
- if warning_counter:
159
- warning_counter.increment()
150
+ if warning_callback is not None:
151
+ warning_callback(msg)
160
152
 
161
153
  sample_strategy = dict()
162
154
  for class_value in vc.index:
@@ -180,19 +172,14 @@ def balance_undersample(
180
172
 
181
173
  resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
182
174
  elif len(df) > binary_min_sample_threshold:
183
- # msg = bundle.get("dataset_rarest_class_less_threshold").format(
184
- # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
185
- # )
186
175
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
187
176
  logger.warning(msg)
188
- print(msg)
189
- if warning_counter:
190
- warning_counter.increment()
177
+ if warning_callback is not None:
178
+ warning_callback(msg)
191
179
 
192
180
  # fill up to min_sample_threshold by majority class
193
181
  minority_class = df[df[target_column] == min_class_value]
194
182
  majority_class = df[df[target_column] != min_class_value]
195
- # sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
196
183
  sample_size = min(
197
184
  max_class_count,
198
185
  binary_bootstrap_loops * (min_class_count + max(binary_min_sample_threshold - 2 * min_class_count, 0)),
@@ -207,25 +194,73 @@ def balance_undersample(
207
194
  | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
208
195
  ]
209
196
 
210
- # elif max_class_count > min_class_count * binary_bootstrap_loops:
211
- # msg = bundle.get("dataset_rarest_class_less_threshold").format(
212
- # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
213
- # )
214
- # logger.warning(msg)
215
- # print(msg)
216
- # if warning_counter:
217
- # warning_counter.increment()
197
+ logger.info(f"Shape after rebalance resampling: {resampled_data}")
198
+ return resampled_data
218
199
 
219
- # sampler = RandomUnderSampler(
220
- # sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
221
- # )
222
- # X = df[SYSTEM_RECORD_ID]
223
- # X = X.to_frame(SYSTEM_RECORD_ID)
224
- # new_x, _ = sampler.fit_resample(X, target) # type: ignore
225
200
 
226
- # resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
201
+ def balance_undersample_forced(
202
+ df: pd.DataFrame,
203
+ target_column: str,
204
+ task_type: ModelTaskType,
205
+ random_state: int,
206
+ sample_size: int = 7000,
207
+ logger: Optional[logging.Logger] = None,
208
+ bundle: Optional[ResourceBundle] = None,
209
+ warning_callback: Optional[Callable] = None,
210
+ ):
211
+ if len(df) <= sample_size:
212
+ return df
227
213
 
228
- logger.info(f"Shape after rebalance resampling: {resampled_data}")
214
+ if logger is None:
215
+ logger = logging.getLogger("muted_logger")
216
+ logger.setLevel("FATAL")
217
+ bundle = bundle or get_custom_bundle()
218
+ if SYSTEM_RECORD_ID not in df.columns:
219
+ raise Exception("System record id must be presented for undersampling")
220
+
221
+ msg = bundle.get("forced_balance_undersample")
222
+ logger.info(msg)
223
+ if warning_callback is not None:
224
+ warning_callback(msg)
225
+
226
+ target = df[target_column].copy()
227
+
228
+ vc = target.value_counts()
229
+ max_class_value = vc.index[0]
230
+ min_class_value = vc.index[len(vc) - 1]
231
+ max_class_count = vc[max_class_value]
232
+ min_class_count = vc[min_class_value]
233
+
234
+ resampled_data = df
235
+ df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
236
+ if task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION, ModelTaskType.TIMESERIES]:
237
+ logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
238
+ resampled_data = df.sample(n=sample_size, random_state=random_state)
239
+ else:
240
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
241
+ logger.warning(msg)
242
+
243
+ # fill up to min_sample_threshold by majority class
244
+ minority_class = df[df[target_column] == min_class_value]
245
+ majority_class = df[df[target_column] != min_class_value]
246
+ logger.info(
247
+ f"Min class count: {min_class_count}. Max class count: {max_class_count}."
248
+ f" Rebalance sample size: {sample_size}"
249
+ )
250
+ if len(minority_class) > (sample_size / 2):
251
+ sampled_minority_class = minority_class.sample(n=int(sample_size / 2), random_state=random_state)
252
+ else:
253
+ sampled_minority_class = minority_class
254
+
255
+ if len(majority_class) > (sample_size) / 2:
256
+ sampled_majority_class = majority_class.sample(n=int(sample_size / 2), random_state=random_state)
257
+
258
+ resampled_data = df[
259
+ (df[SYSTEM_RECORD_ID].isin(sampled_minority_class[SYSTEM_RECORD_ID]))
260
+ | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
261
+ ]
262
+
263
+ logger.info(f"Shape after forced rebalance resampling: {resampled_data}")
229
264
  return resampled_data
230
265
 
231
266
 
@@ -1 +0,0 @@
1
- __version__ = "1.2.31a1"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes