upgini 1.2.91a3906.dev1__py3-none-any.whl → 1.2.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/metadata.py CHANGED
@@ -159,6 +159,9 @@ class ModelTaskType(Enum):
159
159
  REGRESSION = "REGRESSION"
160
160
  TIMESERIES = "TIMESERIES"
161
161
 
162
+ def is_classification(self) -> bool:
163
+ return self in [ModelTaskType.BINARY, ModelTaskType.MULTICLASS]
164
+
162
165
 
163
166
  class ModelLabelType(Enum):
164
167
  GINI = "gini"
upgini/metrics.py CHANGED
@@ -332,7 +332,7 @@ class EstimatorWrapper:
332
332
  self.groups = groups
333
333
  self.text_features = text_features
334
334
  self.logger = logger or logging.getLogger()
335
- self.droped_features = []
335
+ self.dropped_features = []
336
336
  self.converted_to_int = []
337
337
  self.converted_to_str = []
338
338
  self.converted_to_numeric = []
@@ -381,10 +381,11 @@ class EstimatorWrapper:
381
381
  x, y, groups = self._prepare_data(x, y, groups=self.groups)
382
382
 
383
383
  self.logger.info(f"Before preparing data columns: {x.columns.to_list()}")
384
- self.droped_features = []
384
+ self.dropped_features = []
385
385
  self.converted_to_int = []
386
386
  self.converted_to_str = []
387
387
  self.converted_to_numeric = []
388
+
388
389
  for c in x.columns:
389
390
 
390
391
  if _get_unique_count(x[c]) < 2:
@@ -392,7 +393,7 @@ class EstimatorWrapper:
392
393
  if c in self.cat_features:
393
394
  self.cat_features.remove(c)
394
395
  x.drop(columns=[c], inplace=True)
395
- self.droped_features.append(c)
396
+ self.dropped_features.append(c)
396
397
  elif self.text_features is not None and c in self.text_features:
397
398
  x[c] = x[c].astype(str)
398
399
  self.converted_to_str.append(c)
@@ -427,16 +428,16 @@ class EstimatorWrapper:
427
428
  except (ValueError, TypeError):
428
429
  self.logger.warning(f"Remove feature {c} because it is not numeric and not in cat_features")
429
430
  x.drop(columns=[c], inplace=True)
430
- self.droped_features.append(c)
431
+ self.dropped_features.append(c)
431
432
 
432
433
  return x, y, groups, {}
433
434
 
434
435
  def _prepare_to_calculate(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
435
436
  x, y, _ = self._prepare_data(x, y)
436
437
 
437
- if self.droped_features:
438
- self.logger.info(f"Drop features on calculate metrics: {self.droped_features}")
439
- x = x.drop(columns=self.droped_features)
438
+ if self.dropped_features:
439
+ self.logger.info(f"Drop features on calculate metrics: {self.dropped_features}")
440
+ x = x.drop(columns=self.dropped_features)
440
441
 
441
442
  if self.converted_to_int:
442
443
  self.logger.info(f"Convert to int features on calculate metrics: {self.converted_to_int}")
@@ -797,7 +798,7 @@ class CatBoostWrapper(EstimatorWrapper):
797
798
  )
798
799
  for f in high_cardinality_features:
799
800
  self.text_features.remove(f)
800
- self.droped_features.append(f)
801
+ self.dropped_features.append(f)
801
802
  x = x.drop(columns=f, errors="ignore")
802
803
  return super().cross_val_predict(x, y, baseline_score_column)
803
804
  else:
@@ -814,7 +815,7 @@ class CatBoostWrapper(EstimatorWrapper):
814
815
  else:
815
816
  encoded = cat_encoder.transform(x[self.cat_features])
816
817
  cat_features = encoded.columns.to_list()
817
- x[self.cat_features] = encoded
818
+ x.loc[:, self.cat_features] = encoded
818
819
  else:
819
820
  cat_features = self.cat_features
820
821
 
@@ -897,7 +898,7 @@ class LightGBMWrapper(EstimatorWrapper):
897
898
  for c in x.columns:
898
899
  if x[c].dtype not in ["category", "int64", "float64", "bool"]:
899
900
  self.logger.warning(f"Feature {c} is not numeric and will be dropped")
900
- self.droped_features.append(c)
901
+ self.dropped_features.append(c)
901
902
  x = x.drop(columns=c, errors="ignore")
902
903
  return x, y_numpy, groups, params
903
904
 
@@ -988,7 +989,7 @@ class OtherEstimatorWrapper(EstimatorWrapper):
988
989
  for c in x.columns:
989
990
  if x[c].dtype not in ["category", "int64", "float64", "bool"]:
990
991
  self.logger.warning(f"Feature {c} is not numeric and will be dropped")
991
- self.droped_features.append(c)
992
+ self.dropped_features.append(c)
992
993
  x = x.drop(columns=c, errors="ignore")
993
994
  return x, y_numpy, groups, params
994
995
 
@@ -144,6 +144,7 @@ baseline_score_column_has_na=baseline_score_column contains NaN. Clear it and an
144
144
  missing_features_for_transform=Missing some features for transform that were presented on fit: {}
145
145
  missing_target_for_transform=Search contains features on target. Please add y to the call and try again
146
146
  missing_id_column=Id column {} not found in X: {}
147
+ unknown_id_column_value_in_eval_set=Unknown values in id columns: {}
147
148
  # target validation
148
149
  empty_target=Target is empty in all rows
149
150
  # non_numeric_target=Binary target should be numerical type
@@ -195,6 +196,7 @@ timeseries_invalid_test_size_type=test_size={} should be a float in the (0, 1) r
195
196
  timeseries_splits_more_than_samples=Number of splits={} can't be more than number of samples={}
196
197
  timeseries_invalid_test_size=Wrong number of samples in a test fold: (test_size * n_samples / n_splits) <= 1
197
198
  date_and_id_columns_duplicates=Found {} duplicate rows by date and id_columns. Please remove them and try again
199
+ missing_ids_in_eval_set=Following ids are present in eval set but not in sampled train set: {}. They will be removed from eval set.
198
200
  # Upload ads validation
199
201
  ads_upload_too_few_rows=At least 1000 records per sample are needed. Increase the sample size for evaluation and resubmit the data
200
202
  ads_upload_search_key_not_found=Search key {} wasn't found in dataframe columns
@@ -155,7 +155,7 @@ def _get_internal_source(feature_meta: FeaturesMetadataV2, is_client_feature: bo
155
155
  and not feature_meta.name.endswith("_postal_code")
156
156
  and not is_client_feature
157
157
  else ""
158
- )
158
+ )
159
159
 
160
160
 
161
161
  def _list_or_single(lst: List[str], single: str):
@@ -179,7 +179,7 @@ def _make_links(names: List[str], links: List[str]) -> str:
179
179
 
180
180
 
181
181
  def _round_shap_value(shap: float) -> float:
182
- if shap > 0.0 and shap < 0.0001:
182
+ if shap >= 0.0 and shap < 0.0001:
183
183
  return 0.0001
184
184
  else:
185
185
  return round(shap, 4)
@@ -0,0 +1,416 @@
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+ import numbers
4
+ from typing import Callable, List, Optional
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from upgini.metadata import SYSTEM_RECORD_ID, CVType, ModelTaskType
9
+ from upgini.resource_bundle import ResourceBundle, get_custom_bundle
10
+ from upgini.utils.target_utils import balance_undersample
11
+ from upgini.utils.ts_utils import get_most_frequent_time_unit, trunc_datetime
12
+
13
+
14
+ TS_MIN_DIFFERENT_IDS_RATIO = 0.2
15
+ TS_DEFAULT_HIGH_FREQ_TRUNC_LENGTHS = [pd.DateOffset(years=2, months=6), pd.DateOffset(years=2, days=7)]
16
+ TS_DEFAULT_LOW_FREQ_TRUNC_LENGTHS = [pd.DateOffset(years=7), pd.DateOffset(years=5)]
17
+ TS_DEFAULT_TIME_UNIT_THRESHOLD = pd.Timedelta(weeks=4)
18
+ FIT_SAMPLE_ROWS_TS = 100_000
19
+
20
+ BINARY_MIN_SAMPLE_THRESHOLD = 5_000
21
+ MULTICLASS_MIN_SAMPLE_THRESHOLD = 25_000
22
+ BINARY_BOOTSTRAP_LOOPS = 5
23
+ MULTICLASS_BOOTSTRAP_LOOPS = 2
24
+
25
+ FIT_SAMPLE_THRESHOLD = 200_000
26
+ FIT_SAMPLE_ROWS = 200_000
27
+ FIT_SAMPLE_ROWS_WITH_EVAL_SET = 200_000
28
+ FIT_SAMPLE_THRESHOLD_WITH_EVAL_SET = 200_000
29
+
30
+
31
+ @dataclass
32
+ class SampleConfig:
33
+ force_sample_size: int = 7000
34
+ ts_min_different_ids_ratio: float = TS_MIN_DIFFERENT_IDS_RATIO
35
+ ts_default_high_freq_trunc_lengths: List[pd.DateOffset] = field(
36
+ default_factory=TS_DEFAULT_HIGH_FREQ_TRUNC_LENGTHS.copy
37
+ )
38
+ ts_default_low_freq_trunc_lengths: List[pd.DateOffset] = field(
39
+ default_factory=TS_DEFAULT_LOW_FREQ_TRUNC_LENGTHS.copy
40
+ )
41
+ ts_default_time_unit_threshold: pd.Timedelta = TS_DEFAULT_TIME_UNIT_THRESHOLD
42
+ binary_min_sample_threshold: int = BINARY_MIN_SAMPLE_THRESHOLD
43
+ multiclass_min_sample_threshold: int = MULTICLASS_MIN_SAMPLE_THRESHOLD
44
+ binary_bootstrap_loops: int = BINARY_BOOTSTRAP_LOOPS
45
+ multiclass_bootstrap_loops: int = MULTICLASS_BOOTSTRAP_LOOPS
46
+ fit_sample_threshold: int = FIT_SAMPLE_THRESHOLD
47
+ fit_sample_rows: int = FIT_SAMPLE_ROWS
48
+ fit_sample_rows_with_eval_set: int = FIT_SAMPLE_ROWS_WITH_EVAL_SET
49
+ fit_sample_threshold_with_eval_set: int = FIT_SAMPLE_THRESHOLD_WITH_EVAL_SET
50
+ fit_sample_rows_ts: int = FIT_SAMPLE_ROWS_TS
51
+
52
+
53
+ @dataclass
54
+ class SampleColumns:
55
+ date: str
56
+ target: str
57
+ ids: Optional[List[str]] = None
58
+ eval_set_index: Optional[str] = None
59
+
60
+
61
+ def sample(
62
+ df: pd.DataFrame,
63
+ task_type: Optional[ModelTaskType],
64
+ cv_type: Optional[CVType],
65
+ sample_config: SampleConfig,
66
+ sample_columns: SampleColumns,
67
+ random_state: int = 42,
68
+ balance: bool = True,
69
+ force_downsampling: bool = False,
70
+ logger: Optional[logging.Logger] = None,
71
+ **kwargs,
72
+ ) -> pd.DataFrame:
73
+ if force_downsampling:
74
+ return balance_undersample_forced(
75
+ df,
76
+ sample_columns.target,
77
+ sample_columns.ids,
78
+ sample_columns.date,
79
+ task_type,
80
+ cv_type,
81
+ random_state,
82
+ sample_config.force_sample_size,
83
+ logger=logger,
84
+ **kwargs,
85
+ )
86
+
87
+ if sample_columns.eval_set_index in df.columns:
88
+ fit_sample_threshold = sample_config.fit_sample_threshold_with_eval_set
89
+ fit_sample_rows = sample_config.fit_sample_rows_with_eval_set
90
+ else:
91
+ fit_sample_threshold = sample_config.fit_sample_threshold
92
+ fit_sample_rows = sample_config.fit_sample_rows
93
+
94
+ if cv_type is not None and cv_type.is_time_series():
95
+ return sample_time_series_train_eval(
96
+ df,
97
+ sample_columns,
98
+ sample_config.fit_sample_rows_ts,
99
+ trim_threshold=fit_sample_threshold,
100
+ max_rows=fit_sample_rows,
101
+ random_state=random_state,
102
+ logger=logger,
103
+ **kwargs,
104
+ )
105
+
106
+ if task_type is not None and task_type.is_classification() and balance:
107
+ df = balance_undersample(
108
+ df=df,
109
+ target_column=sample_columns.target,
110
+ task_type=task_type,
111
+ random_state=random_state,
112
+ binary_min_sample_threshold=sample_config.binary_min_sample_threshold,
113
+ multiclass_min_sample_threshold=sample_config.multiclass_min_sample_threshold,
114
+ binary_bootstrap_loops=sample_config.binary_bootstrap_loops,
115
+ multiclass_bootstrap_loops=sample_config.multiclass_bootstrap_loops,
116
+ logger=logger,
117
+ **kwargs,
118
+ )
119
+
120
+ num_samples = _num_samples(df)
121
+ if num_samples > fit_sample_threshold:
122
+ logger.info(
123
+ f"Etalon has size {num_samples} more than threshold {fit_sample_threshold} "
124
+ f"and will be downsampled to {fit_sample_rows}"
125
+ )
126
+ df = df.sample(n=fit_sample_rows, random_state=random_state)
127
+ logger.info(f"Shape after threshold resampling: {df.shape}")
128
+
129
+ return df
130
+
131
+
132
+ def sample_time_series_train_eval(
133
+ df: pd.DataFrame,
134
+ sample_columns: SampleColumns,
135
+ sample_size: int,
136
+ trim_threshold: int,
137
+ max_rows: int,
138
+ random_state: int = 42,
139
+ logger: Optional[logging.Logger] = None,
140
+ bundle: Optional[ResourceBundle] = None,
141
+ **kwargs,
142
+ ):
143
+ if sample_columns.eval_set_index in df.columns:
144
+ train_df = df[df[sample_columns.eval_set_index] == 0]
145
+ eval_df = df[df[sample_columns.eval_set_index] > 0]
146
+ else:
147
+ train_df = df
148
+ eval_df = None
149
+
150
+ train_df = sample_time_series_trunc(
151
+ train_df, sample_columns.ids, sample_columns.date, sample_size, random_state, logger=logger, **kwargs
152
+ )
153
+ if sample_columns.ids and eval_df is not None:
154
+ missing_ids = (
155
+ eval_df[~eval_df[sample_columns.ids].isin(np.unique(train_df[sample_columns.ids]))][sample_columns.ids]
156
+ .dropna()
157
+ .drop_duplicates()
158
+ .values.tolist()
159
+ )
160
+ if missing_ids:
161
+ bundle = bundle or get_custom_bundle()
162
+ print(bundle.get("missing_ids_in_eval_set").format(missing_ids))
163
+ eval_df = eval_df.merge(train_df[sample_columns.ids].drop_duplicates())
164
+
165
+ if eval_df is not None:
166
+ if len(eval_df) > trim_threshold - len(train_df):
167
+ eval_df = sample_time_series_trunc(
168
+ eval_df,
169
+ sample_columns.ids,
170
+ sample_columns.date,
171
+ max_rows - len(train_df),
172
+ random_state,
173
+ logger=logger,
174
+ **kwargs,
175
+ )
176
+ if logger is not None:
177
+ logger.info(f"Eval set size: {len(eval_df)}")
178
+ df = pd.concat([train_df, eval_df])
179
+
180
+ elif len(train_df) > max_rows:
181
+ df = sample_time_series_trunc(
182
+ train_df,
183
+ sample_columns.ids,
184
+ sample_columns.date,
185
+ max_rows,
186
+ random_state,
187
+ logger=logger,
188
+ **kwargs,
189
+ )
190
+ else:
191
+ df = train_df
192
+
193
+ if logger is not None:
194
+ logger.info(f"Train set size: {len(df)}")
195
+
196
+ return df
197
+
198
+
199
+ def sample_time_series_trunc(
200
+ df: pd.DataFrame,
201
+ id_columns: Optional[List[str]],
202
+ date_column: str,
203
+ sample_size: int,
204
+ random_state: int = 42,
205
+ logger: Optional[logging.Logger] = None,
206
+ highfreq_trunc_lengths: List[pd.DateOffset] = TS_DEFAULT_HIGH_FREQ_TRUNC_LENGTHS,
207
+ lowfreq_trunc_lengths: List[pd.DateOffset] = TS_DEFAULT_LOW_FREQ_TRUNC_LENGTHS,
208
+ time_unit_threshold: pd.Timedelta = TS_DEFAULT_TIME_UNIT_THRESHOLD,
209
+ **kwargs,
210
+ ):
211
+ if id_columns is None:
212
+ id_columns = []
213
+ # Convert date column to datetime
214
+ dates_df = df[id_columns + [date_column]].copy().reset_index(drop=True)
215
+ if pd.api.types.is_numeric_dtype(dates_df[date_column]):
216
+ dates_df[date_column] = pd.to_datetime(dates_df[date_column], unit="ms")
217
+ else:
218
+ dates_df[date_column] = pd.to_datetime(dates_df[date_column])
219
+
220
+ time_unit = get_most_frequent_time_unit(dates_df, id_columns, date_column)
221
+ if logger is not None:
222
+ logger.info(f"Time unit: {time_unit}")
223
+
224
+ if time_unit is None:
225
+ if logger is not None:
226
+ logger.info("Cannot detect time unit, returning original dataset")
227
+ return df
228
+
229
+ if time_unit < time_unit_threshold:
230
+ for trunc_length in highfreq_trunc_lengths:
231
+ sampled_df = trunc_datetime(dates_df, id_columns, date_column, trunc_length, logger=logger)
232
+ if len(sampled_df) <= sample_size:
233
+ break
234
+ if len(sampled_df) > sample_size:
235
+ sampled_df = sample_time_series(
236
+ sampled_df, id_columns, date_column, sample_size, random_state, logger=logger, **kwargs
237
+ )
238
+ else:
239
+ for trunc_length in lowfreq_trunc_lengths:
240
+ sampled_df = trunc_datetime(dates_df, id_columns, date_column, trunc_length, logger=logger)
241
+ if len(sampled_df) <= sample_size:
242
+ break
243
+ if len(sampled_df) > sample_size:
244
+ sampled_df = sample_time_series(
245
+ sampled_df, id_columns, date_column, sample_size, random_state, logger=logger, **kwargs
246
+ )
247
+
248
+ return df.iloc[sampled_df.index]
249
+
250
+
251
+ def sample_time_series(
252
+ df: pd.DataFrame,
253
+ id_columns: List[str],
254
+ date_column: str,
255
+ sample_size: int,
256
+ random_state: int = 42,
257
+ min_different_ids_ratio: float = TS_MIN_DIFFERENT_IDS_RATIO,
258
+ prefer_recent_dates: bool = True,
259
+ logger: Optional[logging.Logger] = None,
260
+ **kwargs,
261
+ ):
262
+ def ensure_tuple(x):
263
+ return tuple([x]) if not isinstance(x, tuple) else x
264
+
265
+ random_state = np.random.RandomState(random_state)
266
+
267
+ if not id_columns:
268
+ id_columns = [date_column]
269
+ ids_sort = df.groupby(id_columns)[date_column].aggregate(["max", "count"]).T.to_dict()
270
+ ids_sort = {
271
+ ensure_tuple(k): (
272
+ (v["max"], v["count"], random_state.rand()) if prefer_recent_dates else (v["count"], random_state.rand())
273
+ )
274
+ for k, v in ids_sort.items()
275
+ }
276
+ id_counts = df[id_columns].value_counts()
277
+ id_counts.index = [ensure_tuple(i) for i in id_counts.index]
278
+ id_counts = id_counts.sort_index(key=lambda x: [ids_sort[y] for y in x], ascending=False).cumsum()
279
+ id_counts = id_counts[id_counts <= sample_size]
280
+ min_different_ids = max(int(len(df[id_columns].drop_duplicates()) * min_different_ids_ratio), 1)
281
+
282
+ def id_mask(sample_index: pd.Index) -> pd.Index:
283
+ if isinstance(sample_index, pd.MultiIndex):
284
+ return pd.MultiIndex.from_frame(df[id_columns]).isin(sample_index)
285
+ else:
286
+ return df[id_columns[0]].isin(sample_index)
287
+
288
+ if len(id_counts) < min_different_ids:
289
+ if logger is not None:
290
+ logger.info(
291
+ f"Different ids count {len(id_counts)} for sample size {sample_size}"
292
+ f" is less than min different ids {min_different_ids}, sampling time window"
293
+ )
294
+ date_counts = df.groupby(id_columns)[date_column].nunique().sort_values(ascending=False)
295
+ ids_to_sample = date_counts.index[:min_different_ids] if len(id_counts) > 0 else date_counts.index
296
+ mask = id_mask(ids_to_sample)
297
+ df = df[mask]
298
+ sample_date_counts = df[date_column].value_counts().sort_index(ascending=False).cumsum()
299
+ sample_date_counts = sample_date_counts[sample_date_counts <= sample_size]
300
+ df = df[df[date_column].isin(sample_date_counts.index)]
301
+ else:
302
+ if len(id_columns) > 1:
303
+ id_counts.index = pd.MultiIndex.from_tuples(id_counts.index)
304
+ else:
305
+ id_counts.index = [i[0] for i in id_counts.index]
306
+ mask = id_mask(id_counts.index)
307
+ df = df[mask]
308
+
309
+ return df
310
+
311
+
312
+ def balance_undersample_forced(
313
+ df: pd.DataFrame,
314
+ sample_columns: SampleColumns,
315
+ task_type: ModelTaskType,
316
+ cv_type: Optional[CVType],
317
+ random_state: int,
318
+ sample_size: int = 7000,
319
+ logger: Optional[logging.Logger] = None,
320
+ bundle: Optional[ResourceBundle] = None,
321
+ warning_callback: Optional[Callable] = None,
322
+ ):
323
+ if len(df) <= sample_size:
324
+ return df
325
+
326
+ if logger is None:
327
+ logger = logging.getLogger("muted_logger")
328
+ logger.setLevel("FATAL")
329
+ bundle = bundle or get_custom_bundle()
330
+ if SYSTEM_RECORD_ID not in df.columns:
331
+ raise Exception("System record id must be presented for undersampling")
332
+
333
+ msg = bundle.get("forced_balance_undersample")
334
+ logger.info(msg)
335
+ if warning_callback is not None:
336
+ warning_callback(msg)
337
+
338
+ target = df[sample_columns.target].copy()
339
+
340
+ vc = target.value_counts()
341
+ max_class_value = vc.index[0]
342
+ min_class_value = vc.index[len(vc) - 1]
343
+ max_class_count = vc[max_class_value]
344
+ min_class_count = vc[min_class_value]
345
+
346
+ resampled_data = df
347
+ df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
348
+ if cv_type is not None and cv_type.is_time_series():
349
+ logger.warning(f"Sampling time series dataset from {len(df)} to {sample_size}")
350
+ resampled_data = sample_time_series_train_eval(
351
+ df,
352
+ sample_columns=sample_columns,
353
+ sample_size=sample_size,
354
+ trim_threshold=sample_size,
355
+ max_rows=sample_size,
356
+ random_state=random_state,
357
+ logger=logger,
358
+ )
359
+ elif task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION]:
360
+ logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
361
+ resampled_data = df.sample(n=sample_size, random_state=random_state)
362
+ else:
363
+ msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
364
+ logger.warning(msg)
365
+
366
+ # fill up to min_sample_threshold by majority class
367
+ minority_class = df[df[sample_columns.target] == min_class_value]
368
+ majority_class = df[df[sample_columns.target] != min_class_value]
369
+ logger.info(
370
+ f"Min class count: {min_class_count}. Max class count: {max_class_count}."
371
+ f" Rebalance sample size: {sample_size}"
372
+ )
373
+ if len(minority_class) > (sample_size / 2):
374
+ sampled_minority_class = minority_class.sample(n=int(sample_size / 2), random_state=random_state)
375
+ else:
376
+ sampled_minority_class = minority_class
377
+
378
+ if len(majority_class) > (sample_size) / 2:
379
+ sampled_majority_class = majority_class.sample(n=int(sample_size / 2), random_state=random_state)
380
+
381
+ resampled_data = df[
382
+ (df[SYSTEM_RECORD_ID].isin(sampled_minority_class[SYSTEM_RECORD_ID]))
383
+ | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
384
+ ]
385
+
386
+ logger.info(f"Shape after forced rebalance resampling: {resampled_data}")
387
+ return resampled_data
388
+
389
+
390
+ def _num_samples(x):
391
+ """Return number of samples in array-like x."""
392
+ if x is None:
393
+ return 0
394
+ message = "Expected sequence or array-like, got %s" % type(x)
395
+ if hasattr(x, "fit") and callable(x.fit):
396
+ # Don't get num_samples from an ensembles length!
397
+ raise TypeError(message)
398
+
399
+ if not hasattr(x, "__len__") and not hasattr(x, "shape"):
400
+ if hasattr(x, "__array__"):
401
+ x = np.asarray(x)
402
+ else:
403
+ raise TypeError(message)
404
+
405
+ if hasattr(x, "shape") and x.shape is not None:
406
+ if len(x.shape) == 0:
407
+ raise TypeError("Singleton array %r cannot be considered a valid collection." % x)
408
+ # Check that shape is returning an integer or default to len
409
+ # Dask dataframes may not return numeric shape[0] value
410
+ if isinstance(x.shape[0], numbers.Integral):
411
+ return x.shape[0]
412
+
413
+ try:
414
+ return len(x)
415
+ except TypeError as type_error:
416
+ raise TypeError(message) from type_error