upgini 1.2.90__py3-none-any.whl → 1.2.91a3884.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/__about__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.2.90"
1
+ __version__ = "1.2.91a3884.dev1"
upgini/dataset.py CHANGED
@@ -38,11 +38,7 @@ from upgini.metadata import (
38
38
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
39
39
  from upgini.search_task import SearchTask
40
40
  from upgini.utils.email_utils import EmailSearchKeyConverter
41
- from upgini.utils.target_utils import (
42
- balance_undersample,
43
- balance_undersample_forced,
44
- balance_undersample_time_series_trunc,
45
- )
41
+ from upgini.utils.sample_utils import SampleColumns, SampleConfig, sample
46
42
 
47
43
  try:
48
44
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
@@ -88,6 +84,7 @@ class Dataset: # (pd.DataFrame):
88
84
  date_column: Optional[str] = None,
89
85
  id_columns: Optional[List[str]] = None,
90
86
  random_state: Optional[int] = None,
87
+ sample_config: Optional[SampleConfig] = None,
91
88
  rest_client: Optional[_RestClient] = None,
92
89
  logger: Optional[logging.Logger] = None,
93
90
  bundle: Optional[ResourceBundle] = None,
@@ -95,6 +92,7 @@ class Dataset: # (pd.DataFrame):
95
92
  **kwargs,
96
93
  ):
97
94
  self.bundle = bundle or get_custom_bundle()
95
+ self.sample_config = sample_config or SampleConfig(force_sample_size=self.FORCE_SAMPLE_SIZE)
98
96
  if df is not None:
99
97
  data = df.copy()
100
98
  elif path is not None:
@@ -233,109 +231,70 @@ class Dataset: # (pd.DataFrame):
233
231
  raise ValidationError(self.bundle.get("dataset_invalid_timeseries_target").format(target.dtype))
234
232
 
235
233
  def __resample(self, force_downsampling=False):
236
- # self.logger.info("Resampling etalon")
237
- # Resample imbalanced target. Only train segment (without eval_set)
238
- if force_downsampling:
239
- target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
240
- self.data = balance_undersample_forced(
241
- df=self.data,
242
- target_column=target_column,
243
- task_type=self.task_type,
244
- cv_type=self.cv_type,
245
- date_column=self.date_column,
246
- id_columns=self.id_columns,
247
- random_state=self.random_state,
248
- sample_size=self.FORCE_SAMPLE_SIZE,
249
- logger=self.logger,
250
- bundle=self.bundle,
251
- warning_callback=self.warning_callback,
252
- )
253
- return
254
234
 
255
- if EVAL_SET_INDEX in self.data.columns:
235
+ if EVAL_SET_INDEX in self.data.columns and not force_downsampling:
256
236
  train_segment = self.data[self.data[EVAL_SET_INDEX] == 0]
257
237
  else:
258
238
  train_segment = self.data
259
239
 
260
- if self.task_type == ModelTaskType.MULTICLASS or (
261
- self.task_type == ModelTaskType.BINARY and len(train_segment) > self.BINARY_MIN_SAMPLE_THRESHOLD
262
- ):
263
- count = len(train_segment)
264
- target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
265
- target = train_segment[target_column]
266
- target_classes_count = target.nunique()
240
+ self.imbalanced = self.__is_imbalanced(train_segment)
267
241
 
268
- if target_classes_count > self.MAX_MULTICLASS_CLASS_COUNT:
269
- msg = self.bundle.get("dataset_to_many_multiclass_targets").format(
270
- target_classes_count, self.MAX_MULTICLASS_CLASS_COUNT
271
- )
272
- self.logger.warning(msg)
273
- raise ValidationError(msg)
242
+ sample_columns = SampleColumns(
243
+ ids=self.id_columns,
244
+ date=self.date_column,
245
+ target=self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET),
246
+ eval_set_index=EVAL_SET_INDEX,
247
+ )
274
248
 
275
- vc = target.value_counts()
276
- min_class_value = vc.index[len(vc) - 1]
277
- min_class_count = vc[min_class_value]
249
+ self.data = sample(
250
+ train_segment if self.imbalanced else self.data, # for imbalanced data we will be doing transform anyway
251
+ self.task_type,
252
+ self.cv_type,
253
+ self.sample_config,
254
+ sample_columns,
255
+ self.random_state,
256
+ balance=self.imbalanced,
257
+ force_downsampling=force_downsampling,
258
+ logger=self.logger,
259
+ bundle=self.bundle,
260
+ warning_callback=self.warning_callback,
261
+ )
278
262
 
279
- if min_class_count < self.MIN_TARGET_CLASS_ROWS:
280
- msg = self.bundle.get("dataset_rarest_class_less_min").format(
281
- min_class_value, min_class_count, self.MIN_TARGET_CLASS_ROWS
282
- )
283
- self.logger.warning(msg)
284
- raise ValidationError(msg)
263
+ def __is_imbalanced(self, data: pd.DataFrame) -> bool:
264
+ if self.task_type is None or not self.task_type.is_classification():
265
+ return False
285
266
 
286
- min_class_percent = self.IMBALANCE_THESHOLD / target_classes_count
287
- min_class_threshold = min_class_percent * count
288
-
289
- # If min class count less than 30% for binary or (60 / classes_count)% for multiclass
290
- if min_class_count < min_class_threshold:
291
- self.imbalanced = True
292
- self.data = balance_undersample(
293
- df=train_segment,
294
- target_column=target_column,
295
- task_type=self.task_type,
296
- random_state=self.random_state,
297
- binary_min_sample_threshold=self.BINARY_MIN_SAMPLE_THRESHOLD,
298
- multiclass_min_sample_threshold=self.MULTICLASS_MIN_SAMPLE_THRESHOLD,
299
- binary_bootstrap_loops=self.BINARY_BOOTSTRAP_LOOPS,
300
- multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
301
- logger=self.logger,
302
- bundle=self.bundle,
303
- warning_callback=self.warning_callback,
304
- )
267
+ if self.task_type == ModelTaskType.BINARY and len(data) <= self.sample_config.binary_min_sample_threshold:
268
+ return False
305
269
 
306
- # Resample over fit threshold
307
- if self.cv_type is not None and self.cv_type.is_time_series():
308
- sample_threshold = self.FIT_SAMPLE_THRESHOLD_TS
309
- sample_rows = self.FIT_SAMPLE_ROWS_TS
310
- elif not self.imbalanced and EVAL_SET_INDEX in self.data.columns:
311
- sample_threshold = self.FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD
312
- sample_rows = self.FIT_SAMPLE_WITH_EVAL_SET_ROWS
313
- else:
314
- sample_threshold = self.FIT_SAMPLE_THRESHOLD
315
- sample_rows = self.FIT_SAMPLE_ROWS
270
+ count = len(data)
271
+ target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
272
+ target = data[target_column]
273
+ target_classes_count = target.nunique()
316
274
 
317
- if len(self.data) > sample_threshold:
318
- self.logger.info(
319
- f"Etalon has size {len(self.data)} more than threshold {sample_threshold} "
320
- f"and will be downsampled to {sample_rows}"
275
+ if target_classes_count > self.MAX_MULTICLASS_CLASS_COUNT:
276
+ msg = self.bundle.get("dataset_to_many_multiclass_targets").format(
277
+ target_classes_count, self.MAX_MULTICLASS_CLASS_COUNT
321
278
  )
322
- if self.cv_type is not None and self.cv_type.is_time_series():
323
- resampled_data = balance_undersample_time_series_trunc(
324
- df=self.data,
325
- id_columns=self.id_columns,
326
- date_column=next(
327
- k
328
- for k, v in self.meaning_types.items()
329
- if v in [FileColumnMeaningType.DATE, FileColumnMeaningType.DATETIME]
330
- ),
331
- sample_size=sample_rows,
332
- random_state=self.random_state,
333
- logger=self.logger,
334
- )
335
- else:
336
- resampled_data = self.data.sample(n=sample_rows, random_state=self.random_state)
337
- self.data = resampled_data
338
- self.logger.info(f"Shape after threshold resampling: {self.data.shape}")
279
+ self.logger.warning(msg)
280
+ raise ValidationError(msg)
281
+
282
+ vc = target.value_counts()
283
+ min_class_value = vc.index[len(vc) - 1]
284
+ min_class_count = vc[min_class_value]
285
+
286
+ if min_class_count < self.MIN_TARGET_CLASS_ROWS:
287
+ msg = self.bundle.get("dataset_rarest_class_less_min").format(
288
+ min_class_value, min_class_count, self.MIN_TARGET_CLASS_ROWS
289
+ )
290
+ self.logger.warning(msg)
291
+ raise ValidationError(msg)
292
+
293
+ min_class_percent = self.IMBALANCE_THESHOLD / target_classes_count
294
+ min_class_threshold = min_class_percent * count
295
+
296
+ # If min class count less than 30% for binary or (60 / classes_count)% for multiclass
297
+ return bool(min_class_count < min_class_threshold)
339
298
 
340
299
  def __validate_dataset(self, validate_target: bool, silent_mode: bool):
341
300
  """Validate DataSet"""