upgini 1.2.31a2__tar.gz → 1.2.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

Files changed (67) hide show
  1. {upgini-1.2.31a2 → upgini-1.2.33}/PKG-INFO +2 -2
  2. {upgini-1.2.31a2 → upgini-1.2.33}/README.md +1 -1
  3. upgini-1.2.33/src/upgini/__about__.py +1 -0
  4. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/dataset.py +24 -10
  5. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/features_enricher.py +67 -7
  6. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/metadata.py +1 -0
  7. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/resource_bundle/strings.properties +1 -0
  8. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/target_utils.py +70 -35
  9. upgini-1.2.31a2/src/upgini/__about__.py +0 -1
  10. {upgini-1.2.31a2 → upgini-1.2.33}/.gitignore +0 -0
  11. {upgini-1.2.31a2 → upgini-1.2.33}/LICENSE +0 -0
  12. {upgini-1.2.31a2 → upgini-1.2.33}/pyproject.toml +0 -0
  13. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/__init__.py +0 -0
  14. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/ads.py +0 -0
  15. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/ads_management/__init__.py +0 -0
  16. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/ads_management/ads_manager.py +0 -0
  17. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/__init__.py +0 -0
  18. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/all_operands.py +0 -0
  19. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/binary.py +0 -0
  20. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/date.py +0 -0
  21. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/feature.py +0 -0
  22. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/groupby.py +0 -0
  23. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/operand.py +0 -0
  24. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/unary.py +0 -0
  25. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/autofe/vector.py +0 -0
  26. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/data_source/__init__.py +0 -0
  27. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/data_source/data_source_publisher.py +0 -0
  28. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/errors.py +0 -0
  29. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/http.py +0 -0
  30. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/lazy_import.py +0 -0
  31. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/mdc/__init__.py +0 -0
  32. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/mdc/context.py +0 -0
  33. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/metrics.py +0 -0
  34. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/normalizer/__init__.py +0 -0
  35. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/normalizer/normalize_utils.py +0 -0
  36. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/resource_bundle/__init__.py +0 -0
  37. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/resource_bundle/exceptions.py +0 -0
  38. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  39. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/sampler/__init__.py +0 -0
  40. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/sampler/base.py +0 -0
  41. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/sampler/random_under_sampler.py +0 -0
  42. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/sampler/utils.py +0 -0
  43. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/search_task.py +0 -0
  44. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/spinner.py +0 -0
  45. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/Roboto-Regular.ttf +0 -0
  46. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/__init__.py +0 -0
  47. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/base_search_key_detector.py +0 -0
  48. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/blocked_time_series.py +0 -0
  49. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/country_utils.py +0 -0
  50. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/custom_loss_utils.py +0 -0
  51. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/cv_utils.py +0 -0
  52. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/datetime_utils.py +0 -0
  53. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/deduplicate_utils.py +0 -0
  54. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/display_utils.py +0 -0
  55. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/email_utils.py +0 -0
  56. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/fallback_progress_bar.py +0 -0
  57. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/feature_info.py +0 -0
  58. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/features_validator.py +0 -0
  59. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/format.py +0 -0
  60. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/ip_utils.py +0 -0
  61. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/phone_utils.py +0 -0
  62. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/postal_code_utils.py +0 -0
  63. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/progress_bar.py +0 -0
  64. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/sklearn_ext.py +0 -0
  65. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/track_info.py +0 -0
  66. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/utils/warning_counter.py +0 -0
  67. {upgini-1.2.31a2 → upgini-1.2.33}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.31a2
3
+ Version: 1.2.33
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -110,7 +110,7 @@ Description-Content-Type: text/markdown
110
110
  </tr>
111
111
  </table>
112
112
 
113
- ⭐️ [Simple Drag & Drop Search UI](https://upgini.com/upgini-widget):
113
+ ⭐️ [Simple Drag & Drop Search UI](https://www.upgini.com/data-search-widget):
114
114
  <a href="https://upgini.com/upgini-widget">
115
115
  <img width="710" alt="Drag & Drop Search UI" src="https://github.com/upgini/upgini/assets/95645411/36b6460c-51f3-400e-9f04-445b938bf45e">
116
116
  </a>
@@ -68,7 +68,7 @@
68
68
  </tr>
69
69
  </table>
70
70
 
71
- ⭐️ [Simple Drag & Drop Search UI](https://upgini.com/upgini-widget):
71
+ ⭐️ [Simple Drag & Drop Search UI](https://www.upgini.com/data-search-widget):
72
72
  <a href="https://upgini.com/upgini-widget">
73
73
  <img width="710" alt="Drag & Drop Search UI" src="https://github.com/upgini/upgini/assets/95645411/36b6460c-51f3-400e-9f04-445b938bf45e">
74
74
  </a>
@@ -0,0 +1 @@
1
+ __version__ = "1.2.33"
@@ -36,15 +36,13 @@ from upgini.metadata import (
36
36
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
37
37
  from upgini.search_task import SearchTask
38
38
  from upgini.utils.email_utils import EmailSearchKeyConverter
39
- from upgini.utils.target_utils import balance_undersample
39
+ from upgini.utils.target_utils import balance_undersample, balance_undersample_forced
40
40
 
41
41
  try:
42
42
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
43
43
  except Exception:
44
44
  from upgini.utils.fallback_progress_bar import CustomFallbackProgressBar as ProgressBar
45
45
 
46
- from upgini.utils.warning_counter import WarningCounter
47
-
48
46
 
49
47
  class Dataset: # (pd.DataFrame):
50
48
  MIN_ROWS_COUNT = 100
@@ -64,6 +62,7 @@ class Dataset: # (pd.DataFrame):
64
62
  MAX_FEATURES_COUNT = 3500
65
63
  MAX_UPLOADING_FILE_SIZE = 268435456 # 256 Mb
66
64
  MAX_STRING_FEATURE_LENGTH = 24573
65
+ FORCE_SAMPLE_SIZE = 7_000
67
66
 
68
67
  def __init__(
69
68
  self,
@@ -78,8 +77,8 @@ class Dataset: # (pd.DataFrame):
78
77
  random_state: Optional[int] = None,
79
78
  rest_client: Optional[_RestClient] = None,
80
79
  logger: Optional[logging.Logger] = None,
81
- warning_counter: Optional[WarningCounter] = None,
82
80
  bundle: Optional[ResourceBundle] = None,
81
+ warning_callback: Optional[Callable] = None,
83
82
  **kwargs,
84
83
  ):
85
84
  self.bundle = bundle or get_custom_bundle()
@@ -122,7 +121,7 @@ class Dataset: # (pd.DataFrame):
122
121
  else:
123
122
  self.logger = logging.getLogger()
124
123
  self.logger.setLevel("FATAL")
125
- self.warning_counter = warning_counter or WarningCounter()
124
+ self.warning_callback = warning_callback
126
125
 
127
126
  def __len__(self):
128
127
  return len(self.data) if self.data is not None else None
@@ -217,9 +216,23 @@ class Dataset: # (pd.DataFrame):
217
216
  self.logger.exception("Failed to cast target to float for timeseries task type")
218
217
  raise ValidationError(self.bundle.get("dataset_invalid_timeseries_target").format(target.dtype))
219
218
 
220
- def __resample(self):
219
+ def __resample(self, force_downsampling=False):
221
220
  # self.logger.info("Resampling etalon")
222
221
  # Resample imbalanced target. Only train segment (without eval_set)
222
+ if force_downsampling:
223
+ target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
224
+ self.data = balance_undersample_forced(
225
+ df=self.data,
226
+ target_column=target_column,
227
+ task_type=self.task_type,
228
+ random_state=self.random_state,
229
+ sample_size=self.FORCE_SAMPLE_SIZE,
230
+ logger=self.logger,
231
+ bundle=self.bundle,
232
+ warning_callback=self.warning_callback,
233
+ )
234
+ return
235
+
223
236
  if EVAL_SET_INDEX in self.data.columns:
224
237
  train_segment = self.data[self.data[EVAL_SET_INDEX] == 0]
225
238
  else:
@@ -268,7 +281,7 @@ class Dataset: # (pd.DataFrame):
268
281
  multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
269
282
  logger=self.logger,
270
283
  bundle=self.bundle,
271
- warning_counter=self.warning_counter,
284
+ warning_callback=self.warning_callback,
272
285
  )
273
286
 
274
287
  # Resample over fit threshold
@@ -418,13 +431,13 @@ class Dataset: # (pd.DataFrame):
418
431
  if len(self.data) == 0:
419
432
  raise ValidationError(self.bundle.get("all_search_keys_invalid"))
420
433
 
421
- def validate(self, validate_target: bool = True, silent_mode: bool = False):
434
+ def validate(self, validate_target: bool = True, silent_mode: bool = False, force_downsampling: bool = False):
422
435
  self.__validate_dataset(validate_target, silent_mode)
423
436
 
424
437
  if validate_target:
425
438
  self.__validate_target()
426
439
 
427
- self.__resample()
440
+ self.__resample(force_downsampling)
428
441
 
429
442
  self.__validate_min_rows_count()
430
443
 
@@ -573,9 +586,10 @@ class Dataset: # (pd.DataFrame):
573
586
  max_features: Optional[int] = None, # deprecated
574
587
  filter_features: Optional[dict] = None, # deprecated
575
588
  runtime_parameters: Optional[RuntimeParameters] = None,
589
+ force_downsampling: bool = False,
576
590
  ) -> SearchTask:
577
591
  if self.etalon_def is None:
578
- self.validate()
592
+ self.validate(force_downsampling=force_downsampling)
579
593
  file_metrics = FileMetrics()
580
594
 
581
595
  runtime_parameters = self._rename_generate_features(runtime_parameters)
@@ -111,7 +111,11 @@ try:
111
111
  except Exception:
112
112
  from upgini.utils.fallback_progress_bar import CustomFallbackProgressBar as ProgressBar
113
113
 
114
- from upgini.utils.target_utils import calculate_psi, define_task
114
+ from upgini.utils.target_utils import (
115
+ balance_undersample_forced,
116
+ calculate_psi,
117
+ define_task,
118
+ )
115
119
  from upgini.utils.warning_counter import WarningCounter
116
120
  from upgini.version_validator import validate_version
117
121
 
@@ -231,6 +235,7 @@ class FeaturesEnricher(TransformerMixin):
231
235
  custom_bundle_config: Optional[str] = None,
232
236
  add_date_if_missing: bool = True,
233
237
  select_features: bool = False,
238
+ disable_force_downsampling: bool = False,
234
239
  **kwargs,
235
240
  ):
236
241
  self.bundle = get_custom_bundle(custom_bundle_config)
@@ -288,6 +293,7 @@ class FeaturesEnricher(TransformerMixin):
288
293
  self.feature_importances_ = []
289
294
  self.search_id = search_id
290
295
  self.select_features = select_features
296
+ self.disable_force_downsampling = disable_force_downsampling
291
297
 
292
298
  if search_id:
293
299
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
@@ -965,6 +971,13 @@ class FeaturesEnricher(TransformerMixin):
965
971
  self.__log_warning(self.bundle.get("metrics_no_important_free_features"))
966
972
  return None
967
973
 
974
+ maybe_phone_column = self._get_phone_column(self.search_keys)
975
+ text_features = (
976
+ [f for f in self.generate_features if f != maybe_phone_column]
977
+ if self.generate_features is not None
978
+ else None
979
+ )
980
+
968
981
  print(self.bundle.get("metrics_start"))
969
982
  with Spinner():
970
983
  self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
@@ -980,7 +993,7 @@ class FeaturesEnricher(TransformerMixin):
980
993
  fitting_enriched_X,
981
994
  scoring,
982
995
  groups=groups,
983
- text_features=self.generate_features,
996
+ text_features=text_features,
984
997
  has_date=has_date,
985
998
  )
986
999
  metric = wrapper.metric_name
@@ -1007,7 +1020,7 @@ class FeaturesEnricher(TransformerMixin):
1007
1020
  cat_features,
1008
1021
  add_params=custom_loss_add_params,
1009
1022
  groups=groups,
1010
- text_features=self.generate_features,
1023
+ text_features=text_features,
1011
1024
  has_date=has_date,
1012
1025
  )
1013
1026
  etalon_cv_result = baseline_estimator.cross_val_predict(
@@ -1042,7 +1055,7 @@ class FeaturesEnricher(TransformerMixin):
1042
1055
  cat_features,
1043
1056
  add_params=custom_loss_add_params,
1044
1057
  groups=groups,
1045
- text_features=self.generate_features,
1058
+ text_features=text_features,
1046
1059
  has_date=has_date,
1047
1060
  )
1048
1061
  enriched_cv_result = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
@@ -1825,7 +1838,27 @@ class FeaturesEnricher(TransformerMixin):
1825
1838
 
1826
1839
  # downsample if need to eval_set threshold
1827
1840
  num_samples = _num_samples(df)
1828
- if num_samples > Dataset.FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD:
1841
+ phone_column = self._get_phone_column(self.search_keys)
1842
+ force_downsampling = (
1843
+ not self.disable_force_downsampling
1844
+ and self.generate_features is not None
1845
+ and phone_column is not None
1846
+ and self.fit_columns_renaming[phone_column] in self.generate_features
1847
+ and num_samples > Dataset.FORCE_SAMPLE_SIZE
1848
+ )
1849
+ if force_downsampling:
1850
+ self.logger.info(f"Force downsampling from {num_samples} to {Dataset.FORCE_SAMPLE_SIZE}")
1851
+ df = balance_undersample_forced(
1852
+ df=df,
1853
+ target_column=TARGET,
1854
+ task_type=self.model_task_type,
1855
+ random_state=self.random_state,
1856
+ sample_size=Dataset.FORCE_SAMPLE_SIZE,
1857
+ logger=self.logger,
1858
+ bundle=self.bundle,
1859
+ warning_callback=self.__log_warning,
1860
+ )
1861
+ elif num_samples > Dataset.FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD:
1829
1862
  self.logger.info(f"Downsampling from {num_samples} to {Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS}")
1830
1863
  df = df.sample(n=Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS, random_state=self.random_state)
1831
1864
 
@@ -2061,6 +2094,15 @@ class FeaturesEnricher(TransformerMixin):
2061
2094
  self.__display_support_link(msg)
2062
2095
  return None, {c: c for c in X.columns}, []
2063
2096
 
2097
+ features_meta = self._search_task.get_all_features_metadata_v2()
2098
+ online_api_features = [fm.name for fm in features_meta if fm.from_online_api]
2099
+ if len(online_api_features) > 0:
2100
+ self.logger.warning(
2101
+ f"There are important features for transform, that generated by online API: {online_api_features}"
2102
+ )
2103
+ # TODO
2104
+ raise Exception("There are features selected that are paid. Contact support (sales@upgini.com)")
2105
+
2064
2106
  if not metrics_calculation:
2065
2107
  transform_usage = self.rest_client.get_current_transform_usage(trace_id)
2066
2108
  self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
@@ -2251,6 +2293,8 @@ class FeaturesEnricher(TransformerMixin):
2251
2293
  date_format=self.date_format,
2252
2294
  rest_client=self.rest_client,
2253
2295
  logger=self.logger,
2296
+ bundle=self.bundle,
2297
+ warning_callback=self.__log_warning,
2254
2298
  )
2255
2299
  dataset.columns_renaming = columns_renaming
2256
2300
 
@@ -2696,6 +2740,19 @@ class FeaturesEnricher(TransformerMixin):
2696
2740
 
2697
2741
  combined_search_keys = combine_search_keys(self.fit_search_keys.keys())
2698
2742
 
2743
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2744
+
2745
+ # Force downsampling to 7000 for API features generation
2746
+ force_downsampling = (
2747
+ not self.disable_force_downsampling
2748
+ and self.generate_features is not None
2749
+ and phone_column is not None
2750
+ and self.fit_columns_renaming[phone_column] in self.generate_features
2751
+ and len(df) > Dataset.FORCE_SAMPLE_SIZE
2752
+ )
2753
+ if force_downsampling:
2754
+ runtime_parameters.properties["fast_fit"] = True
2755
+
2699
2756
  dataset = Dataset(
2700
2757
  "tds_" + str(uuid.uuid4()),
2701
2758
  df=df,
@@ -2707,6 +2764,8 @@ class FeaturesEnricher(TransformerMixin):
2707
2764
  random_state=self.random_state,
2708
2765
  rest_client=self.rest_client,
2709
2766
  logger=self.logger,
2767
+ bundle=self.bundle,
2768
+ warning_callback=self.__log_warning,
2710
2769
  )
2711
2770
  dataset.columns_renaming = self.fit_columns_renaming
2712
2771
 
@@ -2720,8 +2779,9 @@ class FeaturesEnricher(TransformerMixin):
2720
2779
  start_time=start_time,
2721
2780
  progress_callback=progress_callback,
2722
2781
  extract_features=True,
2723
- runtime_parameters=self._get_copy_of_runtime_parameters(),
2782
+ runtime_parameters=runtime_parameters,
2724
2783
  exclude_features_sources=exclude_features_sources,
2784
+ force_downsampling=force_downsampling,
2725
2785
  )
2726
2786
 
2727
2787
  if search_id_callback is not None:
@@ -3521,7 +3581,7 @@ class FeaturesEnricher(TransformerMixin):
3521
3581
  return result_train, result_eval_sets
3522
3582
 
3523
3583
  def __prepare_feature_importances(
3524
- self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3584
+ self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3525
3585
  ):
3526
3586
  if self._search_task is None:
3527
3587
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
@@ -255,6 +255,7 @@ class FeaturesMetadataV2(BaseModel):
255
255
  data_source_links: Optional[List[str]] = None
256
256
  doc_link: Optional[str] = None
257
257
  update_frequency: Optional[str] = None
258
+ from_online_api: Optional[bool] = None
258
259
 
259
260
 
260
261
  class HitRateMetrics(BaseModel):
@@ -215,6 +215,7 @@ imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} r
215
215
  imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
216
216
  loss_selection_info=Using loss `{}` for feature selection
217
217
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
218
+ forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
218
219
 
219
220
  # Validation table
220
221
  validation_column_name_header=Column name
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Optional, Union
2
+ from typing import Callable, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -9,7 +9,6 @@ from upgini.errors import ValidationError
9
9
  from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
10
10
  from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
11
11
  from upgini.sampler.random_under_sampler import RandomUnderSampler
12
- from upgini.utils.warning_counter import WarningCounter
13
12
 
14
13
 
15
14
  def correct_string_target(y: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
@@ -121,7 +120,7 @@ def balance_undersample(
121
120
  multiclass_bootstrap_loops: int = 2,
122
121
  logger: Optional[logging.Logger] = None,
123
122
  bundle: Optional[ResourceBundle] = None,
124
- warning_counter: Optional[WarningCounter] = None,
123
+ warning_callback: Optional[Callable] = None,
125
124
  ) -> pd.DataFrame:
126
125
  if logger is None:
127
126
  logger = logging.getLogger("muted_logger")
@@ -130,9 +129,7 @@ def balance_undersample(
130
129
  if SYSTEM_RECORD_ID not in df.columns:
131
130
  raise Exception("System record id must be presented for undersampling")
132
131
 
133
- # count = len(df)
134
132
  target = df[target_column].copy()
135
- # target_classes_count = target.nunique()
136
133
 
137
134
  vc = target.value_counts()
138
135
  max_class_value = vc.index[0]
@@ -141,9 +138,6 @@ def balance_undersample(
141
138
  min_class_count = vc[min_class_value]
142
139
  num_classes = len(vc)
143
140
 
144
- # min_class_percent = imbalance_threshold / target_classes_count
145
- # min_class_threshold = int(min_class_percent * count)
146
-
147
141
  resampled_data = df
148
142
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
149
143
  if task_type == ModelTaskType.MULTICLASS:
@@ -151,12 +145,10 @@ def balance_undersample(
151
145
  min_class_count * multiclass_bootstrap_loops
152
146
  ):
153
147
 
154
- # msg = bundle.get("imbalance_multiclass").format(min_class_value, min_class_count)
155
148
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
156
149
  logger.warning(msg)
157
- print(msg)
158
- if warning_counter:
159
- warning_counter.increment()
150
+ if warning_callback is not None:
151
+ warning_callback(msg)
160
152
 
161
153
  sample_strategy = dict()
162
154
  for class_value in vc.index:
@@ -180,19 +172,14 @@ def balance_undersample(
180
172
 
181
173
  resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
182
174
  elif len(df) > binary_min_sample_threshold:
183
- # msg = bundle.get("dataset_rarest_class_less_threshold").format(
184
- # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
185
- # )
186
175
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
187
176
  logger.warning(msg)
188
- print(msg)
189
- if warning_counter:
190
- warning_counter.increment()
177
+ if warning_callback is not None:
178
+ warning_callback(msg)
191
179
 
192
180
  # fill up to min_sample_threshold by majority class
193
181
  minority_class = df[df[target_column] == min_class_value]
194
182
  majority_class = df[df[target_column] != min_class_value]
195
- # sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
196
183
  sample_size = min(
197
184
  max_class_count,
198
185
  binary_bootstrap_loops * (min_class_count + max(binary_min_sample_threshold - 2 * min_class_count, 0)),
@@ -207,25 +194,73 @@ def balance_undersample(
207
194
  | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
208
195
  ]
209
196
 
210
- # elif max_class_count > min_class_count * binary_bootstrap_loops:
211
- # msg = bundle.get("dataset_rarest_class_less_threshold").format(
212
- # min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
213
- # )
214
- # logger.warning(msg)
215
- # print(msg)
216
- # if warning_counter:
217
- # warning_counter.increment()
197
+ logger.info(f"Shape after rebalance resampling: {resampled_data}")
198
+ return resampled_data
218
199
 
219
- # sampler = RandomUnderSampler(
220
- # sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
221
- # )
222
- # X = df[SYSTEM_RECORD_ID]
223
- # X = X.to_frame(SYSTEM_RECORD_ID)
224
- # new_x, _ = sampler.fit_resample(X, target) # type: ignore
225
200
 
226
- # resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
201
+ def balance_undersample_forced(
202
+ df: pd.DataFrame,
203
+ target_column: str,
204
+ task_type: ModelTaskType,
205
+ random_state: int,
206
+ sample_size: int = 7000,
207
+ logger: Optional[logging.Logger] = None,
208
+ bundle: Optional[ResourceBundle] = None,
209
+ warning_callback: Optional[Callable] = None,
210
+ ):
211
+ if len(df) <= sample_size:
212
+ return df
227
213
 
228
- logger.info(f"Shape after rebalance resampling: {resampled_data}")
214
+ if logger is None:
215
+ logger = logging.getLogger("muted_logger")
216
+ logger.setLevel("FATAL")
217
+ bundle = bundle or get_custom_bundle()
218
+ if SYSTEM_RECORD_ID not in df.columns:
219
+ raise Exception("System record id must be presented for undersampling")
220
+
221
+ msg = bundle.get("forced_balance_undersample")
222
+ logger.info(msg)
223
+ if warning_callback is not None:
224
+ warning_callback(msg)
225
+
226
+ target = df[target_column].copy()
227
+
228
+ vc = target.value_counts()
229
+ max_class_value = vc.index[0]
230
+ min_class_value = vc.index[len(vc) - 1]
231
+ max_class_count = vc[max_class_value]
232
+ min_class_count = vc[min_class_value]
233
+
234
+ resampled_data = df
235
+ df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
236
+ if task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION, ModelTaskType.TIMESERIES]:
237
+ logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
238
+ resampled_data = df.sample(n=sample_size, random_state=random_state)
239
+ else:
240
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
241
+ logger.warning(msg)
242
+
243
+ # fill up to min_sample_threshold by majority class
244
+ minority_class = df[df[target_column] == min_class_value]
245
+ majority_class = df[df[target_column] != min_class_value]
246
+ logger.info(
247
+ f"Min class count: {min_class_count}. Max class count: {max_class_count}."
248
+ f" Rebalance sample size: {sample_size}"
249
+ )
250
+ if len(minority_class) > (sample_size / 2):
251
+ sampled_minority_class = minority_class.sample(n=int(sample_size / 2), random_state=random_state)
252
+ else:
253
+ sampled_minority_class = minority_class
254
+
255
+ if len(majority_class) > (sample_size) / 2:
256
+ sampled_majority_class = majority_class.sample(n=int(sample_size / 2), random_state=random_state)
257
+
258
+ resampled_data = df[
259
+ (df[SYSTEM_RECORD_ID].isin(sampled_minority_class[SYSTEM_RECORD_ID]))
260
+ | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
261
+ ]
262
+
263
+ logger.info(f"Shape after forced rebalance resampling: {resampled_data}")
229
264
  return resampled_data
230
265
 
231
266
 
@@ -1 +0,0 @@
1
- __version__ = "1.2.31a2"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes