upgini 1.2.31a2__tar.gz → 1.2.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {upgini-1.2.31a2 → upgini-1.2.32}/PKG-INFO +1 -1
- upgini-1.2.32/src/upgini/__about__.py +1 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/dataset.py +24 -10
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/features_enricher.py +21 -2
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/resource_bundle/strings.properties +1 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/target_utils.py +70 -35
- upgini-1.2.31a2/src/upgini/__about__.py +0 -1
- {upgini-1.2.31a2 → upgini-1.2.32}/.gitignore +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/LICENSE +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/README.md +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/pyproject.toml +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/ads.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/ads_management/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/ads_management/ads_manager.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/all_operands.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/binary.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/date.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/feature.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/groupby.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/operand.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/unary.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/autofe/vector.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/data_source/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/data_source/data_source_publisher.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/errors.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/http.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/lazy_import.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/mdc/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/mdc/context.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/metadata.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/metrics.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/normalizer/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/normalizer/normalize_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/resource_bundle/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/resource_bundle/exceptions.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/resource_bundle/strings_widget.properties +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/sampler/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/sampler/base.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/sampler/random_under_sampler.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/sampler/utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/search_task.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/spinner.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/Roboto-Regular.ttf +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/__init__.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/base_search_key_detector.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/blocked_time_series.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/country_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/custom_loss_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/cv_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/datetime_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/deduplicate_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/display_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/email_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/fallback_progress_bar.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/feature_info.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/features_validator.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/format.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/ip_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/phone_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/postal_code_utils.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/progress_bar.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/sklearn_ext.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/track_info.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/utils/warning_counter.py +0 -0
- {upgini-1.2.31a2 → upgini-1.2.32}/src/upgini/version_validator.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.2.32"
|
|
@@ -36,15 +36,13 @@ from upgini.metadata import (
|
|
|
36
36
|
from upgini.resource_bundle import ResourceBundle, get_custom_bundle
|
|
37
37
|
from upgini.search_task import SearchTask
|
|
38
38
|
from upgini.utils.email_utils import EmailSearchKeyConverter
|
|
39
|
-
from upgini.utils.target_utils import balance_undersample
|
|
39
|
+
from upgini.utils.target_utils import balance_undersample, balance_undersample_forced
|
|
40
40
|
|
|
41
41
|
try:
|
|
42
42
|
from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
|
|
43
43
|
except Exception:
|
|
44
44
|
from upgini.utils.fallback_progress_bar import CustomFallbackProgressBar as ProgressBar
|
|
45
45
|
|
|
46
|
-
from upgini.utils.warning_counter import WarningCounter
|
|
47
|
-
|
|
48
46
|
|
|
49
47
|
class Dataset: # (pd.DataFrame):
|
|
50
48
|
MIN_ROWS_COUNT = 100
|
|
@@ -64,6 +62,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
64
62
|
MAX_FEATURES_COUNT = 3500
|
|
65
63
|
MAX_UPLOADING_FILE_SIZE = 268435456 # 256 Mb
|
|
66
64
|
MAX_STRING_FEATURE_LENGTH = 24573
|
|
65
|
+
FORCE_SAMPLE_SIZE = 7_000
|
|
67
66
|
|
|
68
67
|
def __init__(
|
|
69
68
|
self,
|
|
@@ -78,8 +77,8 @@ class Dataset: # (pd.DataFrame):
|
|
|
78
77
|
random_state: Optional[int] = None,
|
|
79
78
|
rest_client: Optional[_RestClient] = None,
|
|
80
79
|
logger: Optional[logging.Logger] = None,
|
|
81
|
-
warning_counter: Optional[WarningCounter] = None,
|
|
82
80
|
bundle: Optional[ResourceBundle] = None,
|
|
81
|
+
warning_callback: Optional[Callable] = None,
|
|
83
82
|
**kwargs,
|
|
84
83
|
):
|
|
85
84
|
self.bundle = bundle or get_custom_bundle()
|
|
@@ -122,7 +121,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
122
121
|
else:
|
|
123
122
|
self.logger = logging.getLogger()
|
|
124
123
|
self.logger.setLevel("FATAL")
|
|
125
|
-
self.
|
|
124
|
+
self.warning_callback = warning_callback
|
|
126
125
|
|
|
127
126
|
def __len__(self):
|
|
128
127
|
return len(self.data) if self.data is not None else None
|
|
@@ -217,9 +216,23 @@ class Dataset: # (pd.DataFrame):
|
|
|
217
216
|
self.logger.exception("Failed to cast target to float for timeseries task type")
|
|
218
217
|
raise ValidationError(self.bundle.get("dataset_invalid_timeseries_target").format(target.dtype))
|
|
219
218
|
|
|
220
|
-
def __resample(self):
|
|
219
|
+
def __resample(self, force_downsampling=False):
|
|
221
220
|
# self.logger.info("Resampling etalon")
|
|
222
221
|
# Resample imbalanced target. Only train segment (without eval_set)
|
|
222
|
+
if force_downsampling:
|
|
223
|
+
target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
|
|
224
|
+
self.data = balance_undersample_forced(
|
|
225
|
+
df=self.data,
|
|
226
|
+
target_column=target_column,
|
|
227
|
+
task_type=self.task_type,
|
|
228
|
+
random_state=self.random_state,
|
|
229
|
+
sample_size=self.FORCE_SAMPLE_SIZE,
|
|
230
|
+
logger=self.logger,
|
|
231
|
+
bundle=self.bundle,
|
|
232
|
+
warning_callback=self.warning_callback,
|
|
233
|
+
)
|
|
234
|
+
return
|
|
235
|
+
|
|
223
236
|
if EVAL_SET_INDEX in self.data.columns:
|
|
224
237
|
train_segment = self.data[self.data[EVAL_SET_INDEX] == 0]
|
|
225
238
|
else:
|
|
@@ -268,7 +281,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
268
281
|
multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
|
|
269
282
|
logger=self.logger,
|
|
270
283
|
bundle=self.bundle,
|
|
271
|
-
|
|
284
|
+
warning_callback=self.warning_callback,
|
|
272
285
|
)
|
|
273
286
|
|
|
274
287
|
# Resample over fit threshold
|
|
@@ -418,13 +431,13 @@ class Dataset: # (pd.DataFrame):
|
|
|
418
431
|
if len(self.data) == 0:
|
|
419
432
|
raise ValidationError(self.bundle.get("all_search_keys_invalid"))
|
|
420
433
|
|
|
421
|
-
def validate(self, validate_target: bool = True, silent_mode: bool = False):
|
|
434
|
+
def validate(self, validate_target: bool = True, silent_mode: bool = False, force_downsampling: bool = False):
|
|
422
435
|
self.__validate_dataset(validate_target, silent_mode)
|
|
423
436
|
|
|
424
437
|
if validate_target:
|
|
425
438
|
self.__validate_target()
|
|
426
439
|
|
|
427
|
-
self.__resample()
|
|
440
|
+
self.__resample(force_downsampling)
|
|
428
441
|
|
|
429
442
|
self.__validate_min_rows_count()
|
|
430
443
|
|
|
@@ -573,9 +586,10 @@ class Dataset: # (pd.DataFrame):
|
|
|
573
586
|
max_features: Optional[int] = None, # deprecated
|
|
574
587
|
filter_features: Optional[dict] = None, # deprecated
|
|
575
588
|
runtime_parameters: Optional[RuntimeParameters] = None,
|
|
589
|
+
force_downsampling: bool = False,
|
|
576
590
|
) -> SearchTask:
|
|
577
591
|
if self.etalon_def is None:
|
|
578
|
-
self.validate()
|
|
592
|
+
self.validate(force_downsampling=force_downsampling)
|
|
579
593
|
file_metrics = FileMetrics()
|
|
580
594
|
|
|
581
595
|
runtime_parameters = self._rename_generate_features(runtime_parameters)
|
|
@@ -231,6 +231,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
231
231
|
custom_bundle_config: Optional[str] = None,
|
|
232
232
|
add_date_if_missing: bool = True,
|
|
233
233
|
select_features: bool = False,
|
|
234
|
+
disable_force_downsampling: bool = False,
|
|
234
235
|
**kwargs,
|
|
235
236
|
):
|
|
236
237
|
self.bundle = get_custom_bundle(custom_bundle_config)
|
|
@@ -288,6 +289,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
288
289
|
self.feature_importances_ = []
|
|
289
290
|
self.search_id = search_id
|
|
290
291
|
self.select_features = select_features
|
|
292
|
+
self.disable_force_downsampling = disable_force_downsampling
|
|
291
293
|
|
|
292
294
|
if search_id:
|
|
293
295
|
search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
|
|
@@ -2251,6 +2253,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2251
2253
|
date_format=self.date_format,
|
|
2252
2254
|
rest_client=self.rest_client,
|
|
2253
2255
|
logger=self.logger,
|
|
2256
|
+
bundle=self.bundle,
|
|
2257
|
+
warning_callback=self.__log_warning,
|
|
2254
2258
|
)
|
|
2255
2259
|
dataset.columns_renaming = columns_renaming
|
|
2256
2260
|
|
|
@@ -2696,6 +2700,18 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2696
2700
|
|
|
2697
2701
|
combined_search_keys = combine_search_keys(self.fit_search_keys.keys())
|
|
2698
2702
|
|
|
2703
|
+
runtime_parameters = self._get_copy_of_runtime_parameters()
|
|
2704
|
+
|
|
2705
|
+
# Force downsampling to 7000 for API features generation
|
|
2706
|
+
force_downsampling = (
|
|
2707
|
+
not self.disable_force_downsampling
|
|
2708
|
+
and self.generate_features is not None
|
|
2709
|
+
and phone_column is not None
|
|
2710
|
+
and self.fit_columns_renaming[phone_column] in self.generate_features
|
|
2711
|
+
)
|
|
2712
|
+
if force_downsampling and len(df) > Dataset.FORCE_SAMPLE_SIZE:
|
|
2713
|
+
runtime_parameters.properties["fast_fit"] = True
|
|
2714
|
+
|
|
2699
2715
|
dataset = Dataset(
|
|
2700
2716
|
"tds_" + str(uuid.uuid4()),
|
|
2701
2717
|
df=df,
|
|
@@ -2707,6 +2723,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2707
2723
|
random_state=self.random_state,
|
|
2708
2724
|
rest_client=self.rest_client,
|
|
2709
2725
|
logger=self.logger,
|
|
2726
|
+
bundle=self.bundle,
|
|
2727
|
+
warning_callback=self.__log_warning,
|
|
2710
2728
|
)
|
|
2711
2729
|
dataset.columns_renaming = self.fit_columns_renaming
|
|
2712
2730
|
|
|
@@ -2720,8 +2738,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2720
2738
|
start_time=start_time,
|
|
2721
2739
|
progress_callback=progress_callback,
|
|
2722
2740
|
extract_features=True,
|
|
2723
|
-
runtime_parameters=
|
|
2741
|
+
runtime_parameters=runtime_parameters,
|
|
2724
2742
|
exclude_features_sources=exclude_features_sources,
|
|
2743
|
+
force_downsampling=force_downsampling,
|
|
2725
2744
|
)
|
|
2726
2745
|
|
|
2727
2746
|
if search_id_callback is not None:
|
|
@@ -3521,7 +3540,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3521
3540
|
return result_train, result_eval_sets
|
|
3522
3541
|
|
|
3523
3542
|
def __prepare_feature_importances(
|
|
3524
|
-
|
|
3543
|
+
self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
|
|
3525
3544
|
):
|
|
3526
3545
|
if self._search_task is None:
|
|
3527
3546
|
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
@@ -215,6 +215,7 @@ imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} r
|
|
|
215
215
|
imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
|
|
216
216
|
loss_selection_info=Using loss `{}` for feature selection
|
|
217
217
|
loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
|
|
218
|
+
forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
|
|
218
219
|
|
|
219
220
|
# Validation table
|
|
220
221
|
validation_column_name_header=Column name
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Optional, Union
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
@@ -9,7 +9,6 @@ from upgini.errors import ValidationError
|
|
|
9
9
|
from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
|
|
10
10
|
from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
|
|
11
11
|
from upgini.sampler.random_under_sampler import RandomUnderSampler
|
|
12
|
-
from upgini.utils.warning_counter import WarningCounter
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def correct_string_target(y: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
|
|
@@ -121,7 +120,7 @@ def balance_undersample(
|
|
|
121
120
|
multiclass_bootstrap_loops: int = 2,
|
|
122
121
|
logger: Optional[logging.Logger] = None,
|
|
123
122
|
bundle: Optional[ResourceBundle] = None,
|
|
124
|
-
|
|
123
|
+
warning_callback: Optional[Callable] = None,
|
|
125
124
|
) -> pd.DataFrame:
|
|
126
125
|
if logger is None:
|
|
127
126
|
logger = logging.getLogger("muted_logger")
|
|
@@ -130,9 +129,7 @@ def balance_undersample(
|
|
|
130
129
|
if SYSTEM_RECORD_ID not in df.columns:
|
|
131
130
|
raise Exception("System record id must be presented for undersampling")
|
|
132
131
|
|
|
133
|
-
# count = len(df)
|
|
134
132
|
target = df[target_column].copy()
|
|
135
|
-
# target_classes_count = target.nunique()
|
|
136
133
|
|
|
137
134
|
vc = target.value_counts()
|
|
138
135
|
max_class_value = vc.index[0]
|
|
@@ -141,9 +138,6 @@ def balance_undersample(
|
|
|
141
138
|
min_class_count = vc[min_class_value]
|
|
142
139
|
num_classes = len(vc)
|
|
143
140
|
|
|
144
|
-
# min_class_percent = imbalance_threshold / target_classes_count
|
|
145
|
-
# min_class_threshold = int(min_class_percent * count)
|
|
146
|
-
|
|
147
141
|
resampled_data = df
|
|
148
142
|
df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
|
|
149
143
|
if task_type == ModelTaskType.MULTICLASS:
|
|
@@ -151,12 +145,10 @@ def balance_undersample(
|
|
|
151
145
|
min_class_count * multiclass_bootstrap_loops
|
|
152
146
|
):
|
|
153
147
|
|
|
154
|
-
# msg = bundle.get("imbalance_multiclass").format(min_class_value, min_class_count)
|
|
155
148
|
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
|
156
149
|
logger.warning(msg)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
warning_counter.increment()
|
|
150
|
+
if warning_callback is not None:
|
|
151
|
+
warning_callback(msg)
|
|
160
152
|
|
|
161
153
|
sample_strategy = dict()
|
|
162
154
|
for class_value in vc.index:
|
|
@@ -180,19 +172,14 @@ def balance_undersample(
|
|
|
180
172
|
|
|
181
173
|
resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
|
|
182
174
|
elif len(df) > binary_min_sample_threshold:
|
|
183
|
-
# msg = bundle.get("dataset_rarest_class_less_threshold").format(
|
|
184
|
-
# min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
|
|
185
|
-
# )
|
|
186
175
|
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
|
187
176
|
logger.warning(msg)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
warning_counter.increment()
|
|
177
|
+
if warning_callback is not None:
|
|
178
|
+
warning_callback(msg)
|
|
191
179
|
|
|
192
180
|
# fill up to min_sample_threshold by majority class
|
|
193
181
|
minority_class = df[df[target_column] == min_class_value]
|
|
194
182
|
majority_class = df[df[target_column] != min_class_value]
|
|
195
|
-
# sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
|
|
196
183
|
sample_size = min(
|
|
197
184
|
max_class_count,
|
|
198
185
|
binary_bootstrap_loops * (min_class_count + max(binary_min_sample_threshold - 2 * min_class_count, 0)),
|
|
@@ -207,25 +194,73 @@ def balance_undersample(
|
|
|
207
194
|
| (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
|
|
208
195
|
]
|
|
209
196
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# min_class_value, min_class_count, min_class_threshold, min_class_percent * 100
|
|
213
|
-
# )
|
|
214
|
-
# logger.warning(msg)
|
|
215
|
-
# print(msg)
|
|
216
|
-
# if warning_counter:
|
|
217
|
-
# warning_counter.increment()
|
|
197
|
+
logger.info(f"Shape after rebalance resampling: {resampled_data}")
|
|
198
|
+
return resampled_data
|
|
218
199
|
|
|
219
|
-
# sampler = RandomUnderSampler(
|
|
220
|
-
# sampling_strategy={max_class_value: binary_bootstrap_loops * min_class_count}, random_state=random_state
|
|
221
|
-
# )
|
|
222
|
-
# X = df[SYSTEM_RECORD_ID]
|
|
223
|
-
# X = X.to_frame(SYSTEM_RECORD_ID)
|
|
224
|
-
# new_x, _ = sampler.fit_resample(X, target) # type: ignore
|
|
225
200
|
|
|
226
|
-
|
|
201
|
+
def balance_undersample_forced(
|
|
202
|
+
df: pd.DataFrame,
|
|
203
|
+
target_column: str,
|
|
204
|
+
task_type: ModelTaskType,
|
|
205
|
+
random_state: int,
|
|
206
|
+
sample_size: int = 7000,
|
|
207
|
+
logger: Optional[logging.Logger] = None,
|
|
208
|
+
bundle: Optional[ResourceBundle] = None,
|
|
209
|
+
warning_callback: Optional[Callable] = None,
|
|
210
|
+
):
|
|
211
|
+
if len(df) <= sample_size:
|
|
212
|
+
return df
|
|
227
213
|
|
|
228
|
-
logger
|
|
214
|
+
if logger is None:
|
|
215
|
+
logger = logging.getLogger("muted_logger")
|
|
216
|
+
logger.setLevel("FATAL")
|
|
217
|
+
bundle = bundle or get_custom_bundle()
|
|
218
|
+
if SYSTEM_RECORD_ID not in df.columns:
|
|
219
|
+
raise Exception("System record id must be presented for undersampling")
|
|
220
|
+
|
|
221
|
+
msg = bundle.get("forced_balance_undersample")
|
|
222
|
+
logger.info(msg)
|
|
223
|
+
if warning_callback is not None:
|
|
224
|
+
warning_callback(msg)
|
|
225
|
+
|
|
226
|
+
target = df[target_column].copy()
|
|
227
|
+
|
|
228
|
+
vc = target.value_counts()
|
|
229
|
+
max_class_value = vc.index[0]
|
|
230
|
+
min_class_value = vc.index[len(vc) - 1]
|
|
231
|
+
max_class_count = vc[max_class_value]
|
|
232
|
+
min_class_count = vc[min_class_value]
|
|
233
|
+
|
|
234
|
+
resampled_data = df
|
|
235
|
+
df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
|
|
236
|
+
if task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION, ModelTaskType.TIMESERIES]:
|
|
237
|
+
logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
|
|
238
|
+
resampled_data = df.sample(n=sample_size, random_state=random_state)
|
|
239
|
+
else:
|
|
240
|
+
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
|
241
|
+
logger.warning(msg)
|
|
242
|
+
|
|
243
|
+
# fill up to min_sample_threshold by majority class
|
|
244
|
+
minority_class = df[df[target_column] == min_class_value]
|
|
245
|
+
majority_class = df[df[target_column] != min_class_value]
|
|
246
|
+
logger.info(
|
|
247
|
+
f"Min class count: {min_class_count}. Max class count: {max_class_count}."
|
|
248
|
+
f" Rebalance sample size: {sample_size}"
|
|
249
|
+
)
|
|
250
|
+
if len(minority_class) > (sample_size / 2):
|
|
251
|
+
sampled_minority_class = minority_class.sample(n=int(sample_size / 2), random_state=random_state)
|
|
252
|
+
else:
|
|
253
|
+
sampled_minority_class = minority_class
|
|
254
|
+
|
|
255
|
+
if len(majority_class) > (sample_size) / 2:
|
|
256
|
+
sampled_majority_class = majority_class.sample(n=int(sample_size / 2), random_state=random_state)
|
|
257
|
+
|
|
258
|
+
resampled_data = df[
|
|
259
|
+
(df[SYSTEM_RECORD_ID].isin(sampled_minority_class[SYSTEM_RECORD_ID]))
|
|
260
|
+
| (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
logger.info(f"Shape after forced rebalance resampling: {resampled_data}")
|
|
229
264
|
return resampled_data
|
|
230
265
|
|
|
231
266
|
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "1.2.31a2"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|