upgini 1.2.91a3884.dev5__py3-none-any.whl → 1.2.91a3906.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/autofe/unary.py +8 -0
- upgini/dataset.py +107 -58
- upgini/features_enricher.py +191 -227
- upgini/metadata.py +0 -3
- upgini/metrics.py +11 -12
- upgini/resource_bundle/strings.properties +0 -2
- upgini/utils/target_utils.py +199 -3
- {upgini-1.2.91a3884.dev5.dist-info → upgini-1.2.91a3906.dev1.dist-info}/METADATA +1 -1
- {upgini-1.2.91a3884.dev5.dist-info → upgini-1.2.91a3906.dev1.dist-info}/RECORD +12 -13
- upgini/utils/sample_utils.py +0 -414
- {upgini-1.2.91a3884.dev5.dist-info → upgini-1.2.91a3906.dev1.dist-info}/WHEEL +0 -0
- {upgini-1.2.91a3884.dev5.dist-info → upgini-1.2.91a3906.dev1.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.
|
1
|
+
__version__ = "1.2.91a3906.dev1"
|
upgini/autofe/unary.py
CHANGED
@@ -190,3 +190,11 @@ class Bin(PandasOperator):
|
|
190
190
|
if isinstance(value, str):
|
191
191
|
return json.loads(value)
|
192
192
|
return value
|
193
|
+
|
194
|
+
|
195
|
+
class Cluster(PandasOperator):
|
196
|
+
name: str = "cluster"
|
197
|
+
is_unary: bool = True
|
198
|
+
input_type: Optional[str] = "vector"
|
199
|
+
output_type: Optional[str] = "category"
|
200
|
+
is_categorical: bool = True
|
upgini/dataset.py
CHANGED
@@ -38,7 +38,11 @@ from upgini.metadata import (
|
|
38
38
|
from upgini.resource_bundle import ResourceBundle, get_custom_bundle
|
39
39
|
from upgini.search_task import SearchTask
|
40
40
|
from upgini.utils.email_utils import EmailSearchKeyConverter
|
41
|
-
from upgini.utils.
|
41
|
+
from upgini.utils.target_utils import (
|
42
|
+
balance_undersample,
|
43
|
+
balance_undersample_forced,
|
44
|
+
balance_undersample_time_series_trunc,
|
45
|
+
)
|
42
46
|
|
43
47
|
try:
|
44
48
|
from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
|
@@ -48,9 +52,17 @@ except Exception:
|
|
48
52
|
)
|
49
53
|
|
50
54
|
|
51
|
-
class Dataset:
|
55
|
+
class Dataset: # (pd.DataFrame):
|
52
56
|
MIN_ROWS_COUNT = 100
|
53
57
|
MAX_ROWS = 200_000
|
58
|
+
FIT_SAMPLE_ROWS = 200_000
|
59
|
+
FIT_SAMPLE_THRESHOLD = 200_000
|
60
|
+
FIT_SAMPLE_WITH_EVAL_SET_ROWS = 200_000
|
61
|
+
FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD = 200_000
|
62
|
+
FIT_SAMPLE_THRESHOLD_TS = 54_000
|
63
|
+
FIT_SAMPLE_ROWS_TS = 54_000
|
64
|
+
BINARY_MIN_SAMPLE_THRESHOLD = 5_000
|
65
|
+
MULTICLASS_MIN_SAMPLE_THRESHOLD = 25_000
|
54
66
|
IMBALANCE_THESHOLD = 0.6
|
55
67
|
BINARY_BOOTSTRAP_LOOPS = 5
|
56
68
|
MULTICLASS_BOOTSTRAP_LOOPS = 2
|
@@ -76,7 +88,6 @@ class Dataset:
|
|
76
88
|
date_column: Optional[str] = None,
|
77
89
|
id_columns: Optional[List[str]] = None,
|
78
90
|
random_state: Optional[int] = None,
|
79
|
-
sample_config: Optional[SampleConfig] = None,
|
80
91
|
rest_client: Optional[_RestClient] = None,
|
81
92
|
logger: Optional[logging.Logger] = None,
|
82
93
|
bundle: Optional[ResourceBundle] = None,
|
@@ -84,7 +95,6 @@ class Dataset:
|
|
84
95
|
**kwargs,
|
85
96
|
):
|
86
97
|
self.bundle = bundle or get_custom_bundle()
|
87
|
-
self.sample_config = sample_config or SampleConfig(force_sample_size=self.FORCE_SAMPLE_SIZE)
|
88
98
|
if df is not None:
|
89
99
|
data = df.copy()
|
90
100
|
elif path is not None:
|
@@ -223,70 +233,109 @@ class Dataset:
|
|
223
233
|
raise ValidationError(self.bundle.get("dataset_invalid_timeseries_target").format(target.dtype))
|
224
234
|
|
225
235
|
def __resample(self, force_downsampling=False):
|
236
|
+
# self.logger.info("Resampling etalon")
|
237
|
+
# Resample imbalanced target. Only train segment (without eval_set)
|
238
|
+
if force_downsampling:
|
239
|
+
target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
|
240
|
+
self.data = balance_undersample_forced(
|
241
|
+
df=self.data,
|
242
|
+
target_column=target_column,
|
243
|
+
task_type=self.task_type,
|
244
|
+
cv_type=self.cv_type,
|
245
|
+
date_column=self.date_column,
|
246
|
+
id_columns=self.id_columns,
|
247
|
+
random_state=self.random_state,
|
248
|
+
sample_size=self.FORCE_SAMPLE_SIZE,
|
249
|
+
logger=self.logger,
|
250
|
+
bundle=self.bundle,
|
251
|
+
warning_callback=self.warning_callback,
|
252
|
+
)
|
253
|
+
return
|
226
254
|
|
227
|
-
if EVAL_SET_INDEX in self.data.columns
|
255
|
+
if EVAL_SET_INDEX in self.data.columns:
|
228
256
|
train_segment = self.data[self.data[EVAL_SET_INDEX] == 0]
|
229
257
|
else:
|
230
258
|
train_segment = self.data
|
231
259
|
|
232
|
-
self.
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
target=
|
238
|
-
|
239
|
-
)
|
240
|
-
|
241
|
-
self.data = sample(
|
242
|
-
train_segment if self.imbalanced else self.data, # for imbalanced data we will be doing transform anyway
|
243
|
-
self.task_type,
|
244
|
-
self.cv_type,
|
245
|
-
self.sample_config,
|
246
|
-
sample_columns,
|
247
|
-
self.random_state,
|
248
|
-
balance=self.imbalanced,
|
249
|
-
force_downsampling=force_downsampling,
|
250
|
-
logger=self.logger,
|
251
|
-
bundle=self.bundle,
|
252
|
-
warning_callback=self.warning_callback,
|
253
|
-
)
|
260
|
+
if self.task_type == ModelTaskType.MULTICLASS or (
|
261
|
+
self.task_type == ModelTaskType.BINARY and len(train_segment) > self.BINARY_MIN_SAMPLE_THRESHOLD
|
262
|
+
):
|
263
|
+
count = len(train_segment)
|
264
|
+
target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, TARGET)
|
265
|
+
target = train_segment[target_column]
|
266
|
+
target_classes_count = target.nunique()
|
254
267
|
|
255
|
-
|
256
|
-
|
257
|
-
|
268
|
+
if target_classes_count > self.MAX_MULTICLASS_CLASS_COUNT:
|
269
|
+
msg = self.bundle.get("dataset_to_many_multiclass_targets").format(
|
270
|
+
target_classes_count, self.MAX_MULTICLASS_CLASS_COUNT
|
271
|
+
)
|
272
|
+
self.logger.warning(msg)
|
273
|
+
raise ValidationError(msg)
|
258
274
|
|
259
|
-
|
260
|
-
|
275
|
+
vc = target.value_counts()
|
276
|
+
min_class_value = vc.index[len(vc) - 1]
|
277
|
+
min_class_count = vc[min_class_value]
|
261
278
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
279
|
+
if min_class_count < self.MIN_TARGET_CLASS_ROWS:
|
280
|
+
msg = self.bundle.get("dataset_rarest_class_less_min").format(
|
281
|
+
min_class_value, min_class_count, self.MIN_TARGET_CLASS_ROWS
|
282
|
+
)
|
283
|
+
self.logger.warning(msg)
|
284
|
+
raise ValidationError(msg)
|
266
285
|
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
)
|
271
|
-
|
272
|
-
|
286
|
+
min_class_percent = self.IMBALANCE_THESHOLD / target_classes_count
|
287
|
+
min_class_threshold = min_class_percent * count
|
288
|
+
|
289
|
+
# If min class count less than 30% for binary or (60 / classes_count)% for multiclass
|
290
|
+
if min_class_count < min_class_threshold:
|
291
|
+
self.imbalanced = True
|
292
|
+
self.data = balance_undersample(
|
293
|
+
df=train_segment,
|
294
|
+
target_column=target_column,
|
295
|
+
task_type=self.task_type,
|
296
|
+
random_state=self.random_state,
|
297
|
+
binary_min_sample_threshold=self.BINARY_MIN_SAMPLE_THRESHOLD,
|
298
|
+
multiclass_min_sample_threshold=self.MULTICLASS_MIN_SAMPLE_THRESHOLD,
|
299
|
+
binary_bootstrap_loops=self.BINARY_BOOTSTRAP_LOOPS,
|
300
|
+
multiclass_bootstrap_loops=self.MULTICLASS_BOOTSTRAP_LOOPS,
|
301
|
+
logger=self.logger,
|
302
|
+
bundle=self.bundle,
|
303
|
+
warning_callback=self.warning_callback,
|
304
|
+
)
|
273
305
|
|
274
|
-
|
275
|
-
|
276
|
-
|
306
|
+
# Resample over fit threshold
|
307
|
+
if self.cv_type is not None and self.cv_type.is_time_series():
|
308
|
+
sample_threshold = self.FIT_SAMPLE_THRESHOLD_TS
|
309
|
+
sample_rows = self.FIT_SAMPLE_ROWS_TS
|
310
|
+
elif not self.imbalanced and EVAL_SET_INDEX in self.data.columns:
|
311
|
+
sample_threshold = self.FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD
|
312
|
+
sample_rows = self.FIT_SAMPLE_WITH_EVAL_SET_ROWS
|
313
|
+
else:
|
314
|
+
sample_threshold = self.FIT_SAMPLE_THRESHOLD
|
315
|
+
sample_rows = self.FIT_SAMPLE_ROWS
|
277
316
|
|
278
|
-
if
|
279
|
-
|
280
|
-
|
317
|
+
if len(self.data) > sample_threshold:
|
318
|
+
self.logger.info(
|
319
|
+
f"Etalon has size {len(self.data)} more than threshold {sample_threshold} "
|
320
|
+
f"and will be downsampled to {sample_rows}"
|
281
321
|
)
|
282
|
-
self.
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
322
|
+
if self.cv_type is not None and self.cv_type.is_time_series():
|
323
|
+
resampled_data = balance_undersample_time_series_trunc(
|
324
|
+
df=self.data,
|
325
|
+
id_columns=self.id_columns,
|
326
|
+
date_column=next(
|
327
|
+
k
|
328
|
+
for k, v in self.meaning_types.items()
|
329
|
+
if v in [FileColumnMeaningType.DATE, FileColumnMeaningType.DATETIME]
|
330
|
+
),
|
331
|
+
sample_size=sample_rows,
|
332
|
+
random_state=self.random_state,
|
333
|
+
logger=self.logger,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
resampled_data = self.data.sample(n=sample_rows, random_state=self.random_state)
|
337
|
+
self.data = resampled_data
|
338
|
+
self.logger.info(f"Shape after threshold resampling: {self.data.shape}")
|
290
339
|
|
291
340
|
def __validate_dataset(self, validate_target: bool, silent_mode: bool):
|
292
341
|
"""Validate DataSet"""
|
@@ -568,8 +617,8 @@ class Dataset:
|
|
568
617
|
def _set_sample_size(self, runtime_parameters: Optional[RuntimeParameters]) -> Optional[RuntimeParameters]:
|
569
618
|
if runtime_parameters is not None and runtime_parameters.properties is not None:
|
570
619
|
if self.cv_type is not None and self.cv_type.is_time_series():
|
571
|
-
runtime_parameters.properties["sample_size"] = self.
|
572
|
-
runtime_parameters.properties["iter0_sample_size"] = self.
|
620
|
+
runtime_parameters.properties["sample_size"] = self.FIT_SAMPLE_ROWS_TS
|
621
|
+
runtime_parameters.properties["iter0_sample_size"] = self.FIT_SAMPLE_ROWS_TS
|
573
622
|
return runtime_parameters
|
574
623
|
|
575
624
|
def _clean_generate_features(self, runtime_parameters: Optional[RuntimeParameters]) -> Optional[RuntimeParameters]:
|