upgini 1.2.146a2__tar.gz → 1.2.146a9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {upgini-1.2.146a2 → upgini-1.2.146a9}/PKG-INFO +1 -1
- upgini-1.2.146a9/src/upgini/__about__.py +1 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/dataset.py +5 -4
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/features_enricher.py +187 -105
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/normalizer/normalize_utils.py +15 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings.properties +4 -2
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/datetime_utils.py +4 -4
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/features_validator.py +9 -50
- upgini-1.2.146a9/src/upgini/utils/one_hot_encoder.py +215 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/target_utils.py +26 -3
- upgini-1.2.146a2/src/upgini/__about__.py +0 -1
- {upgini-1.2.146a2 → upgini-1.2.146a9}/.gitignore +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/LICENSE +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/README.md +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/pyproject.toml +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads_management/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads_management/ads_manager.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/all_operators.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/binary.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/date.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/feature.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/groupby.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/operator.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/base.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/cross.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/delta.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/lag.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/roll.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/trend.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/volatility.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/unary.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/vector.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/data_source/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/data_source/data_source_publisher.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/errors.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/http.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/mdc/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/mdc/context.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/metadata.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/metrics.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/normalizer/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/exceptions.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings_widget.properties +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/base.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/random_under_sampler.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/search_task.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/spinner.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/Roboto-Regular.ttf +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/__init__.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/base_search_key_detector.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/blocked_time_series.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/config.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/country_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/custom_loss_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/cv_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/deduplicate_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/display_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/email_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/fallback_progress_bar.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/feature_info.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/format.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/hash_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/ip_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/mstats.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/phone_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/postal_code_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/progress_bar.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/psi.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sample_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sklearn_ext.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sort.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/track_info.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/ts_utils.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/warning_counter.py +0 -0
- {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/version_validator.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.2.146a9"
|
|
@@ -58,7 +58,7 @@ class Dataset:
|
|
|
58
58
|
MAX_ROWS = 3_000_000
|
|
59
59
|
MIN_SUPPORTED_DATE_TS = 946684800000 # 2000-01-01
|
|
60
60
|
MAX_FEATURES_COUNT = 3500
|
|
61
|
-
MAX_UPLOADING_FILE_SIZE =
|
|
61
|
+
MAX_UPLOADING_FILE_SIZE = 536_870_912 # 512 Mb
|
|
62
62
|
MAX_STRING_FEATURE_LENGTH = 24573
|
|
63
63
|
FORCE_SAMPLE_SIZE = 7_000
|
|
64
64
|
|
|
@@ -304,10 +304,11 @@ class Dataset:
|
|
|
304
304
|
):
|
|
305
305
|
keys_to_validate.remove(ipv4_column)
|
|
306
306
|
|
|
307
|
-
mandatory_columns =
|
|
307
|
+
mandatory_columns = {target} if target is not None else set()
|
|
308
308
|
columns_to_validate = mandatory_columns.copy()
|
|
309
|
-
columns_to_validate.
|
|
310
|
-
|
|
309
|
+
columns_to_validate.update(keys_to_validate)
|
|
310
|
+
if len(columns_to_validate) == 0:
|
|
311
|
+
return
|
|
311
312
|
|
|
312
313
|
nrows = len(self.data)
|
|
313
314
|
validation_stats = {}
|
|
@@ -11,6 +11,7 @@ import uuid
|
|
|
11
11
|
from collections import Counter
|
|
12
12
|
from copy import deepcopy
|
|
13
13
|
from dataclasses import dataclass
|
|
14
|
+
from pathlib import Path
|
|
14
15
|
from threading import Thread
|
|
15
16
|
from typing import Any, Callable
|
|
16
17
|
|
|
@@ -277,6 +278,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
277
278
|
self.autodetected_search_keys: dict[str, SearchKey] | None = None
|
|
278
279
|
self.imbalanced = False
|
|
279
280
|
self.fit_select_features = True
|
|
281
|
+
self.true_one_hot_groups: dict[str, list[str]] | None = None
|
|
282
|
+
self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
|
|
280
283
|
self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
|
|
281
284
|
dict()
|
|
282
285
|
)
|
|
@@ -679,9 +682,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
679
682
|
self.__set_select_features(select_features)
|
|
680
683
|
self.dump_input(X, y, self.eval_set)
|
|
681
684
|
|
|
682
|
-
if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
|
|
683
|
-
raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
|
|
684
|
-
|
|
685
685
|
self.__inner_fit(
|
|
686
686
|
X,
|
|
687
687
|
y,
|
|
@@ -2049,6 +2049,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2049
2049
|
generated_features.extend(generator.generated_features)
|
|
2050
2050
|
|
|
2051
2051
|
normalizer = Normalizer(self.bundle, self.logger)
|
|
2052
|
+
# TODO restore these properties from the server
|
|
2053
|
+
normalizer.true_one_hot_groups = self.true_one_hot_groups
|
|
2054
|
+
normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
|
|
2052
2055
|
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2053
2056
|
columns_renaming = normalizer.columns_renaming
|
|
2054
2057
|
|
|
@@ -2664,6 +2667,9 @@ if response.status_code == 200:
|
|
|
2664
2667
|
generated_features.extend(generator.generated_features)
|
|
2665
2668
|
|
|
2666
2669
|
normalizer = Normalizer(self.bundle, self.logger)
|
|
2670
|
+
# TODO restore these properties from the server
|
|
2671
|
+
normalizer.true_one_hot_groups = self.true_one_hot_groups
|
|
2672
|
+
normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
|
|
2667
2673
|
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2668
2674
|
columns_renaming = normalizer.columns_renaming
|
|
2669
2675
|
|
|
@@ -2831,85 +2837,103 @@ if response.status_code == 200:
|
|
|
2831
2837
|
del df
|
|
2832
2838
|
gc.collect()
|
|
2833
2839
|
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
2842
|
-
|
|
2843
|
-
|
|
2844
|
-
|
|
2845
|
-
|
|
2846
|
-
|
|
2847
|
-
|
|
2848
|
-
|
|
2849
|
-
|
|
2840
|
+
def invoke_validation(df: pd.DataFrame):
|
|
2841
|
+
|
|
2842
|
+
dataset = Dataset(
|
|
2843
|
+
"sample_" + str(uuid.uuid4()),
|
|
2844
|
+
df=df,
|
|
2845
|
+
meaning_types=meaning_types,
|
|
2846
|
+
search_keys=combined_search_keys,
|
|
2847
|
+
unnest_search_keys=unnest_search_keys,
|
|
2848
|
+
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2849
|
+
date_column=self._get_date_column(search_keys),
|
|
2850
|
+
date_format=self.date_format,
|
|
2851
|
+
sample_config=self.sample_config,
|
|
2852
|
+
rest_client=self.rest_client,
|
|
2853
|
+
logger=self.logger,
|
|
2854
|
+
bundle=self.bundle,
|
|
2855
|
+
warning_callback=self.__log_warning,
|
|
2856
|
+
)
|
|
2857
|
+
dataset.columns_renaming = columns_renaming
|
|
2858
|
+
|
|
2859
|
+
validation_task = self._search_task.validation(
|
|
2860
|
+
self._get_trace_id(),
|
|
2861
|
+
dataset,
|
|
2862
|
+
start_time=start_time,
|
|
2863
|
+
extract_features=True,
|
|
2864
|
+
runtime_parameters=runtime_parameters,
|
|
2865
|
+
exclude_features_sources=exclude_features_sources,
|
|
2866
|
+
metrics_calculation=metrics_calculation,
|
|
2867
|
+
silent_mode=silent_mode,
|
|
2868
|
+
progress_bar=progress_bar,
|
|
2869
|
+
progress_callback=progress_callback,
|
|
2870
|
+
)
|
|
2850
2871
|
|
|
2851
|
-
|
|
2852
|
-
|
|
2853
|
-
dataset,
|
|
2854
|
-
start_time=start_time,
|
|
2855
|
-
extract_features=True,
|
|
2856
|
-
runtime_parameters=runtime_parameters,
|
|
2857
|
-
exclude_features_sources=exclude_features_sources,
|
|
2858
|
-
metrics_calculation=metrics_calculation,
|
|
2859
|
-
silent_mode=silent_mode,
|
|
2860
|
-
progress_bar=progress_bar,
|
|
2861
|
-
progress_callback=progress_callback,
|
|
2862
|
-
)
|
|
2872
|
+
del df, dataset
|
|
2873
|
+
gc.collect()
|
|
2863
2874
|
|
|
2864
|
-
|
|
2865
|
-
|
|
2875
|
+
if not silent_mode:
|
|
2876
|
+
print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
|
|
2877
|
+
if not self.__is_registered:
|
|
2878
|
+
print(self.bundle.get("polling_unregister_information"))
|
|
2866
2879
|
|
|
2867
|
-
|
|
2868
|
-
|
|
2869
|
-
if not
|
|
2870
|
-
|
|
2880
|
+
progress = self.get_progress(validation_task)
|
|
2881
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2882
|
+
if progress_bar is not None:
|
|
2883
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2884
|
+
if progress_callback is not None:
|
|
2885
|
+
progress_callback(progress)
|
|
2886
|
+
prev_progress: SearchProgress | None = None
|
|
2887
|
+
polling_period_seconds = 1
|
|
2888
|
+
try:
|
|
2889
|
+
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2890
|
+
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2891
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2892
|
+
else:
|
|
2893
|
+
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2894
|
+
prev_progress = progress
|
|
2895
|
+
if progress_bar is not None:
|
|
2896
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2897
|
+
if progress_callback is not None:
|
|
2898
|
+
progress_callback(progress)
|
|
2899
|
+
if progress.stage == ProgressStage.FAILED.value:
|
|
2900
|
+
raise Exception(progress.error_message)
|
|
2901
|
+
time.sleep(polling_period_seconds)
|
|
2902
|
+
progress = self.get_progress(validation_task)
|
|
2903
|
+
except KeyboardInterrupt as e:
|
|
2904
|
+
print(self.bundle.get("search_stopping"))
|
|
2905
|
+
self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
|
|
2906
|
+
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2907
|
+
print(self.bundle.get("search_stopped"))
|
|
2908
|
+
raise e
|
|
2871
2909
|
|
|
2872
|
-
|
|
2873
|
-
progress.recalculate_eta(time.time() - start_time)
|
|
2874
|
-
if progress_bar is not None:
|
|
2875
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2876
|
-
if progress_callback is not None:
|
|
2877
|
-
progress_callback(progress)
|
|
2878
|
-
prev_progress: SearchProgress | None = None
|
|
2879
|
-
polling_period_seconds = 1
|
|
2880
|
-
try:
|
|
2881
|
-
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2882
|
-
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2883
|
-
progress.recalculate_eta(time.time() - start_time)
|
|
2884
|
-
else:
|
|
2885
|
-
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2886
|
-
prev_progress = progress
|
|
2887
|
-
if progress_bar is not None:
|
|
2888
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2889
|
-
if progress_callback is not None:
|
|
2890
|
-
progress_callback(progress)
|
|
2891
|
-
if progress.stage == ProgressStage.FAILED.value:
|
|
2892
|
-
raise Exception(progress.error_message)
|
|
2893
|
-
time.sleep(polling_period_seconds)
|
|
2894
|
-
progress = self.get_progress(validation_task)
|
|
2895
|
-
except KeyboardInterrupt as e:
|
|
2896
|
-
print(self.bundle.get("search_stopping"))
|
|
2897
|
-
self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
|
|
2898
|
-
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2899
|
-
print(self.bundle.get("search_stopped"))
|
|
2900
|
-
raise e
|
|
2901
|
-
|
|
2902
|
-
validation_task.poll_result(self._get_trace_id(), quiet=True)
|
|
2903
|
-
|
|
2904
|
-
seconds_left = time.time() - start_time
|
|
2905
|
-
progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
|
|
2906
|
-
if progress_bar is not None:
|
|
2907
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2908
|
-
if progress_callback is not None:
|
|
2909
|
-
progress_callback(progress)
|
|
2910
|
+
validation_task.poll_result(self._get_trace_id(), quiet=True)
|
|
2910
2911
|
|
|
2911
|
-
|
|
2912
|
-
|
|
2912
|
+
seconds_left = time.time() - start_time
|
|
2913
|
+
progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
|
|
2914
|
+
if progress_bar is not None:
|
|
2915
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2916
|
+
if progress_callback is not None:
|
|
2917
|
+
progress_callback(progress)
|
|
2918
|
+
|
|
2919
|
+
if not silent_mode:
|
|
2920
|
+
print(self.bundle.get("transform_start"))
|
|
2921
|
+
|
|
2922
|
+
return validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
|
|
2923
|
+
|
|
2924
|
+
if len(df_without_features) <= Dataset.MAX_ROWS:
|
|
2925
|
+
result_features = invoke_validation(df_without_features)
|
|
2926
|
+
else:
|
|
2927
|
+
self.logger.warning(
|
|
2928
|
+
f"Dataset has more than {Dataset.MAX_ROWS} rows: {len(df_without_features)}, "
|
|
2929
|
+
f"splitting into chunks of {Dataset.MAX_ROWS} rows"
|
|
2930
|
+
)
|
|
2931
|
+
result_features_list = []
|
|
2932
|
+
|
|
2933
|
+
for i in range(0, len(df_without_features), Dataset.MAX_ROWS):
|
|
2934
|
+
chunk = df_without_features.iloc[i : i + Dataset.MAX_ROWS]
|
|
2935
|
+
result_features_list.append(invoke_validation(chunk))
|
|
2936
|
+
result_features = pd.concat(result_features_list)
|
|
2913
2937
|
|
|
2914
2938
|
# Prepare input DataFrame for __enrich by concatenating generated ids and client features
|
|
2915
2939
|
df_before_explode = df_before_explode.rename(columns=columns_renaming)
|
|
@@ -2922,8 +2946,6 @@ if response.status_code == 200:
|
|
|
2922
2946
|
axis=1,
|
|
2923
2947
|
).set_index(validated_Xy.index)
|
|
2924
2948
|
|
|
2925
|
-
result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
|
|
2926
|
-
|
|
2927
2949
|
result = self.__enrich(
|
|
2928
2950
|
combined_df,
|
|
2929
2951
|
result_features,
|
|
@@ -2974,12 +2996,38 @@ if response.status_code == 200:
|
|
|
2974
2996
|
fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
|
|
2975
2997
|
fit_input_columns = [c.originalName for c in file_meta.columns]
|
|
2976
2998
|
original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
|
|
2999
|
+
true_one_hot_features = (
|
|
3000
|
+
[f for group in self.true_one_hot_groups.values() for f in group] if self.true_one_hot_groups else []
|
|
3001
|
+
)
|
|
2977
3002
|
new_columns_on_transform = [
|
|
2978
|
-
c
|
|
3003
|
+
c
|
|
3004
|
+
for c in validated_Xy.columns
|
|
3005
|
+
if c not in fit_input_columns and c not in original_dropped_features and c not in true_one_hot_features
|
|
2979
3006
|
]
|
|
2980
3007
|
fit_original_search_keys = self._get_fit_search_keys_with_original_names()
|
|
2981
3008
|
|
|
2982
3009
|
selected_generated_features = [c for c in generated_features if c in self.feature_names_]
|
|
3010
|
+
selected_true_one_hot_features = (
|
|
3011
|
+
[
|
|
3012
|
+
c
|
|
3013
|
+
for cat_feature, group in self.true_one_hot_groups.items()
|
|
3014
|
+
for c in group
|
|
3015
|
+
if cat_feature in self.feature_names_
|
|
3016
|
+
]
|
|
3017
|
+
if self.true_one_hot_groups
|
|
3018
|
+
else []
|
|
3019
|
+
)
|
|
3020
|
+
selected_pseudo_one_hot_features = (
|
|
3021
|
+
[
|
|
3022
|
+
feature
|
|
3023
|
+
for group in self.pseudo_one_hot_groups.values()
|
|
3024
|
+
if any(f in self.feature_names_ for f in group)
|
|
3025
|
+
for feature in group
|
|
3026
|
+
]
|
|
3027
|
+
if self.pseudo_one_hot_groups
|
|
3028
|
+
else []
|
|
3029
|
+
)
|
|
3030
|
+
|
|
2983
3031
|
if keep_input is True:
|
|
2984
3032
|
selected_input_columns = [
|
|
2985
3033
|
c
|
|
@@ -2998,11 +3046,14 @@ if response.status_code == 200:
|
|
|
2998
3046
|
if DEFAULT_INDEX in selected_input_columns:
|
|
2999
3047
|
selected_input_columns.remove(DEFAULT_INDEX)
|
|
3000
3048
|
|
|
3001
|
-
return
|
|
3049
|
+
return (
|
|
3050
|
+
selected_input_columns
|
|
3051
|
+
+ selected_generated_features
|
|
3052
|
+
+ selected_true_one_hot_features
|
|
3053
|
+
+ selected_pseudo_one_hot_features
|
|
3054
|
+
)
|
|
3002
3055
|
|
|
3003
|
-
def _validate_empty_search_keys(
|
|
3004
|
-
self, search_keys: dict[str, SearchKey], is_transform: bool = False
|
|
3005
|
-
):
|
|
3056
|
+
def _validate_empty_search_keys(self, search_keys: dict[str, SearchKey], is_transform: bool = False):
|
|
3006
3057
|
if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
|
|
3007
3058
|
if is_transform:
|
|
3008
3059
|
self.logger.debug("Transform started without search_keys")
|
|
@@ -3169,7 +3220,7 @@ if response.status_code == 200:
|
|
|
3169
3220
|
else:
|
|
3170
3221
|
only_train_df = df
|
|
3171
3222
|
|
|
3172
|
-
self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle)
|
|
3223
|
+
self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle, self.__log_warning)
|
|
3173
3224
|
if self.imbalanced:
|
|
3174
3225
|
# Exclude eval sets from fit because they will be transformed before metrics calculation
|
|
3175
3226
|
df = only_train_df
|
|
@@ -3242,6 +3293,8 @@ if response.status_code == 200:
|
|
|
3242
3293
|
df, self.fit_search_keys, self.fit_generated_features
|
|
3243
3294
|
)
|
|
3244
3295
|
self.fit_columns_renaming = normalizer.columns_renaming
|
|
3296
|
+
self.true_one_hot_groups = normalizer.true_one_hot_groups
|
|
3297
|
+
self.pseudo_one_hot_groups = normalizer.pseudo_one_hot_groups
|
|
3245
3298
|
if normalizer.removed_datetime_features:
|
|
3246
3299
|
self.fit_dropped_features.update(normalizer.removed_datetime_features)
|
|
3247
3300
|
original_removed_datetime_features = [
|
|
@@ -3259,7 +3312,11 @@ if response.status_code == 200:
|
|
|
3259
3312
|
features_columns = [c for c in df.columns if c not in non_feature_columns]
|
|
3260
3313
|
|
|
3261
3314
|
features_to_drop, feature_validator_warnings = FeaturesValidator(self.logger).validate(
|
|
3262
|
-
df,
|
|
3315
|
+
df,
|
|
3316
|
+
features_columns,
|
|
3317
|
+
self.generate_features,
|
|
3318
|
+
self.fit_columns_renaming,
|
|
3319
|
+
[f for group in self.pseudo_one_hot_groups.values() for f in group] if self.pseudo_one_hot_groups else [],
|
|
3263
3320
|
)
|
|
3264
3321
|
if feature_validator_warnings:
|
|
3265
3322
|
for warning in feature_validator_warnings:
|
|
@@ -3822,8 +3879,7 @@ if response.status_code == 200:
|
|
|
3822
3879
|
elif self.columns_for_online_api:
|
|
3823
3880
|
msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
|
|
3824
3881
|
if msg:
|
|
3825
|
-
|
|
3826
|
-
self.logger.warning(msg)
|
|
3882
|
+
self.__log_warning(msg)
|
|
3827
3883
|
df = df[df[EVAL_SET_INDEX] != eval_set_index]
|
|
3828
3884
|
return df
|
|
3829
3885
|
|
|
@@ -4768,7 +4824,7 @@ if response.status_code == 200:
|
|
|
4768
4824
|
elif self.autodetect_search_keys:
|
|
4769
4825
|
valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
|
|
4770
4826
|
|
|
4771
|
-
if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
|
|
4827
|
+
if len(valid_search_keys) > 0 and all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
|
|
4772
4828
|
if self.__is_registered:
|
|
4773
4829
|
msg = self.bundle.get("only_custom_keys")
|
|
4774
4830
|
else:
|
|
@@ -5027,37 +5083,55 @@ if response.status_code == 200:
|
|
|
5027
5083
|
X_ = X_.to_frame()
|
|
5028
5084
|
|
|
5029
5085
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
5030
|
-
|
|
5031
|
-
|
|
5086
|
+
x_file_name = f"{tmp_dir}/x.parquet"
|
|
5087
|
+
X_.to_parquet(x_file_name, compression="zstd")
|
|
5088
|
+
uploading_file_size = Path(x_file_name).stat().st_size
|
|
5089
|
+
if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
|
|
5090
|
+
self.logger.warning(
|
|
5091
|
+
f"Uploading file x.parquet is too large: {uploading_file_size} bytes. Skip it"
|
|
5092
|
+
)
|
|
5093
|
+
return
|
|
5094
|
+
x_digest_sha256 = file_hash(x_file_name)
|
|
5032
5095
|
if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
|
|
5033
5096
|
self.logger.info(
|
|
5034
5097
|
f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
|
|
5035
5098
|
)
|
|
5036
5099
|
else:
|
|
5037
|
-
self.rest_client.dump_input_file(
|
|
5038
|
-
trace_id_, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
|
|
5039
|
-
)
|
|
5100
|
+
self.rest_client.dump_input_file(trace_id_, x_file_name, "x.parquet", x_digest_sha256)
|
|
5040
5101
|
|
|
5041
5102
|
if y_ is not None:
|
|
5042
5103
|
if isinstance(y_, pd.Series):
|
|
5043
5104
|
y_ = y_.to_frame()
|
|
5044
|
-
|
|
5045
|
-
|
|
5105
|
+
y_file_name = f"{tmp_dir}/y.parquet"
|
|
5106
|
+
y_.to_parquet(y_file_name, compression="zstd")
|
|
5107
|
+
uploading_file_size = Path(y_file_name).stat().st_size
|
|
5108
|
+
if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
|
|
5109
|
+
self.logger.warning(
|
|
5110
|
+
f"Uploading file y.parquet is too large: {uploading_file_size} bytes. Skip it"
|
|
5111
|
+
)
|
|
5112
|
+
return
|
|
5113
|
+
y_digest_sha256 = file_hash(y_file_name)
|
|
5046
5114
|
if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
|
|
5047
5115
|
self.logger.info(
|
|
5048
5116
|
f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
|
|
5049
5117
|
)
|
|
5050
5118
|
else:
|
|
5051
|
-
self.rest_client.dump_input_file(
|
|
5052
|
-
trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
|
|
5053
|
-
)
|
|
5119
|
+
self.rest_client.dump_input_file(trace_id_, y_file_name, "y.parquet", y_digest_sha256)
|
|
5054
5120
|
|
|
5055
5121
|
if eval_set_ is not None and len(eval_set_) > 0:
|
|
5056
5122
|
for idx, (eval_x_, eval_y_) in enumerate(eval_set_):
|
|
5057
5123
|
if isinstance(eval_x_, pd.Series):
|
|
5058
5124
|
eval_x_ = eval_x_.to_frame()
|
|
5059
|
-
|
|
5060
|
-
|
|
5125
|
+
eval_x_file_name = f"{tmp_dir}/eval_x_{idx}.parquet"
|
|
5126
|
+
eval_x_.to_parquet(eval_x_file_name, compression="zstd")
|
|
5127
|
+
uploading_file_size = Path(eval_x_file_name).stat().st_size
|
|
5128
|
+
if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
|
|
5129
|
+
self.logger.warning(
|
|
5130
|
+
f"Uploading file eval_x_{idx}.parquet is too large: "
|
|
5131
|
+
f"{uploading_file_size} bytes. Skip it"
|
|
5132
|
+
)
|
|
5133
|
+
return
|
|
5134
|
+
eval_x_digest_sha256 = file_hash(eval_x_file_name)
|
|
5061
5135
|
if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
|
|
5062
5136
|
self.logger.info(
|
|
5063
5137
|
f"File eval_x_{idx}.parquet was already uploaded with"
|
|
@@ -5066,15 +5140,23 @@ if response.status_code == 200:
|
|
|
5066
5140
|
else:
|
|
5067
5141
|
self.rest_client.dump_input_file(
|
|
5068
5142
|
trace_id_,
|
|
5069
|
-
|
|
5143
|
+
eval_x_file_name,
|
|
5070
5144
|
f"eval_x_{idx}.parquet",
|
|
5071
5145
|
eval_x_digest_sha256,
|
|
5072
5146
|
)
|
|
5073
5147
|
|
|
5074
5148
|
if isinstance(eval_y_, pd.Series):
|
|
5075
5149
|
eval_y_ = eval_y_.to_frame()
|
|
5076
|
-
|
|
5077
|
-
|
|
5150
|
+
eval_y_file_name = f"{tmp_dir}/eval_y_{idx}.parquet"
|
|
5151
|
+
eval_y_.to_parquet(eval_y_file_name, compression="zstd")
|
|
5152
|
+
uploading_file_size = Path(eval_y_file_name).stat().st_size
|
|
5153
|
+
if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
|
|
5154
|
+
self.logger.warning(
|
|
5155
|
+
f"Uploading file eval_y_{idx}.parquet is too large: "
|
|
5156
|
+
f"{uploading_file_size} bytes. Skip it"
|
|
5157
|
+
)
|
|
5158
|
+
return
|
|
5159
|
+
eval_y_digest_sha256 = file_hash(eval_y_file_name)
|
|
5078
5160
|
if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
|
|
5079
5161
|
self.logger.info(
|
|
5080
5162
|
f"File eval_y_{idx}.parquet was already uploaded"
|
|
@@ -5083,7 +5165,7 @@ if response.status_code == 200:
|
|
|
5083
5165
|
else:
|
|
5084
5166
|
self.rest_client.dump_input_file(
|
|
5085
5167
|
trace_id_,
|
|
5086
|
-
|
|
5168
|
+
eval_y_file_name,
|
|
5087
5169
|
f"eval_y_{idx}.parquet",
|
|
5088
5170
|
eval_y_digest_sha256,
|
|
5089
5171
|
)
|
|
@@ -26,6 +26,7 @@ from upgini.utils import find_numbers_with_decimal_comma
|
|
|
26
26
|
from upgini.utils.country_utils import CountrySearchKeyConverter
|
|
27
27
|
from upgini.utils.datetime_utils import DateTimeConverter
|
|
28
28
|
from upgini.utils.ip_utils import IpSearchKeyConverter
|
|
29
|
+
from upgini.utils.one_hot_encoder import OneHotDecoder
|
|
29
30
|
from upgini.utils.phone_utils import PhoneSearchKeyConverter
|
|
30
31
|
from upgini.utils.postal_code_utils import PostalCodeSearchKeyConverter
|
|
31
32
|
|
|
@@ -45,6 +46,8 @@ class Normalizer:
|
|
|
45
46
|
self.search_keys = {}
|
|
46
47
|
self.generated_features = []
|
|
47
48
|
self.removed_datetime_features = []
|
|
49
|
+
self.true_one_hot_groups: dict[str, list[str]] | None = None
|
|
50
|
+
self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
|
|
48
51
|
|
|
49
52
|
def normalize(
|
|
50
53
|
self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], generated_features: List[str]
|
|
@@ -53,6 +56,9 @@ class Normalizer:
|
|
|
53
56
|
self.generated_features = generated_features.copy()
|
|
54
57
|
|
|
55
58
|
df = df.copy()
|
|
59
|
+
|
|
60
|
+
df = self._convert_one_hot_encoded_columns(df)
|
|
61
|
+
|
|
56
62
|
df = self._rename_columns(df)
|
|
57
63
|
|
|
58
64
|
df = self._remove_dates_from_features(df)
|
|
@@ -77,6 +83,15 @@ class Normalizer:
|
|
|
77
83
|
|
|
78
84
|
return df, self.search_keys, self.generated_features
|
|
79
85
|
|
|
86
|
+
def _convert_one_hot_encoded_columns(self, df: pd.DataFrame):
|
|
87
|
+
if self.true_one_hot_groups is not None or self.pseudo_one_hot_groups is not None:
|
|
88
|
+
df = OneHotDecoder.decode_with_cached_groups(
|
|
89
|
+
df, self.true_one_hot_groups, self.pseudo_one_hot_groups
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
df, self.true_one_hot_groups, self.pseudo_one_hot_groups = OneHotDecoder.decode(df)
|
|
93
|
+
return df
|
|
94
|
+
|
|
80
95
|
def _rename_columns(self, df: pd.DataFrame):
|
|
81
96
|
# logger.info("Replace restricted symbols in column names")
|
|
82
97
|
new_columns = []
|
|
@@ -176,7 +176,8 @@ dataset_invalid_multiclass_target=Unexpected dtype of target for multiclass task
|
|
|
176
176
|
dataset_invalid_regression_target=Unexpected dtype of target for regression task type: {}. Expected float
|
|
177
177
|
dataset_invalid_timeseries_target=Unexpected dtype of target for timeseries task type: {}. Expected float
|
|
178
178
|
dataset_to_many_multiclass_targets=The number of target classes {} exceeds the allowed threshold: {}. Please, correct your data and try again
|
|
179
|
-
dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class
|
|
179
|
+
dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class
|
|
180
|
+
#\nPlease, remove rows with rarest class from your dataframe
|
|
180
181
|
dataset_rarest_class_less_threshold=Target is imbalanced and will be undersampled to the rarest class. Frequency of the rarest class `{}` is {}\nMinimum number of observations for each class to avoid undersampling {} ({}%)
|
|
181
182
|
dataset_date_features=Columns {} is a datetime or period type but not used as a search key, removed from X
|
|
182
183
|
dataset_too_many_features=Too many features. Maximum number of features is {}
|
|
@@ -231,7 +232,8 @@ limited_int_multiclass_reason=integer-like values with limited unique values obs
|
|
|
231
232
|
all_ok_community_invite=❓ Support request
|
|
232
233
|
too_small_for_metrics=Your train dataset or one of eval datasets contains less than 500 rows. For such dataset Upgini will not calculate accuracy metrics. Please increase the number of rows in the training dataset to calculate accuracy metrics
|
|
233
234
|
imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} records in train dataset). \nDownsample classes with records more than {1}.
|
|
234
|
-
|
|
235
|
+
rare_target_classes_drop=Drop rare target classes with <0.01% freq: {}
|
|
236
|
+
imbalanced_target=Target is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
|
|
235
237
|
loss_selection_info=Using loss `{}` for feature selection
|
|
236
238
|
loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
|
|
237
239
|
forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
|
|
@@ -8,7 +8,7 @@ from dateutil.relativedelta import relativedelta
|
|
|
8
8
|
from pandas.api.types import is_numeric_dtype
|
|
9
9
|
|
|
10
10
|
from upgini.errors import ValidationError
|
|
11
|
-
from upgini.metadata import EVAL_SET_INDEX, SearchKey
|
|
11
|
+
from upgini.metadata import CURRENT_DATE_COL, EVAL_SET_INDEX, SearchKey
|
|
12
12
|
from upgini.resource_bundle import ResourceBundle, get_custom_bundle
|
|
13
13
|
from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
|
|
14
14
|
|
|
@@ -418,12 +418,12 @@ def is_dates_distribution_valid(
|
|
|
418
418
|
except Exception:
|
|
419
419
|
pass
|
|
420
420
|
|
|
421
|
-
if maybe_date_col is None:
|
|
422
|
-
return
|
|
421
|
+
if maybe_date_col is None or maybe_date_col == CURRENT_DATE_COL:
|
|
422
|
+
return True
|
|
423
423
|
|
|
424
424
|
# Don't check if date column is constant
|
|
425
425
|
if X[maybe_date_col].nunique() <= 1:
|
|
426
|
-
return
|
|
426
|
+
return True
|
|
427
427
|
|
|
428
428
|
if isinstance(X[maybe_date_col].dtype, pd.PeriodDtype):
|
|
429
429
|
dates = X[maybe_date_col].dt.to_timestamp().dt.date
|
|
@@ -23,12 +23,18 @@ class FeaturesValidator:
|
|
|
23
23
|
features: List[str],
|
|
24
24
|
features_for_generate: Optional[List[str]] = None,
|
|
25
25
|
columns_renaming: Optional[Dict[str, str]] = None,
|
|
26
|
+
pseudo_one_hot_encoded_features: Optional[List[str]] = None,
|
|
26
27
|
) -> Tuple[List[str], List[str]]:
|
|
27
|
-
one_hot_encoded_features = []
|
|
28
28
|
empty_or_constant_features = []
|
|
29
29
|
high_cardinality_features = []
|
|
30
30
|
warnings = []
|
|
31
31
|
|
|
32
|
+
pseudo_one_hot_encoded_features = [
|
|
33
|
+
renamed
|
|
34
|
+
for renamed, original in columns_renaming.items()
|
|
35
|
+
if original in pseudo_one_hot_encoded_features or []
|
|
36
|
+
]
|
|
37
|
+
|
|
32
38
|
for f in features:
|
|
33
39
|
column = df[f]
|
|
34
40
|
if is_object_dtype(column):
|
|
@@ -38,20 +44,11 @@ class FeaturesValidator:
|
|
|
38
44
|
|
|
39
45
|
if len(value_counts) == 1:
|
|
40
46
|
empty_or_constant_features.append(f)
|
|
41
|
-
elif most_frequent_percent >= 0.99:
|
|
42
|
-
|
|
43
|
-
one_hot_encoded_features.append(f)
|
|
44
|
-
else:
|
|
45
|
-
empty_or_constant_features.append(f)
|
|
47
|
+
elif most_frequent_percent >= 0.99 and f not in pseudo_one_hot_encoded_features:
|
|
48
|
+
empty_or_constant_features.append(f)
|
|
46
49
|
|
|
47
50
|
columns_renaming = columns_renaming or {}
|
|
48
51
|
|
|
49
|
-
if one_hot_encoded_features and len(one_hot_encoded_features) > 1:
|
|
50
|
-
msg = bundle.get("one_hot_encoded_features").format(
|
|
51
|
-
[columns_renaming.get(f, f) for f in one_hot_encoded_features]
|
|
52
|
-
)
|
|
53
|
-
warnings.append(msg)
|
|
54
|
-
|
|
55
52
|
if empty_or_constant_features:
|
|
56
53
|
msg = bundle.get("empty_or_contant_features").format(
|
|
57
54
|
[columns_renaming.get(f, f) for f in empty_or_constant_features]
|
|
@@ -98,41 +95,3 @@ class FeaturesValidator:
|
|
|
98
95
|
@staticmethod
|
|
99
96
|
def find_constant_features(df: pd.DataFrame) -> List[str]:
|
|
100
97
|
return [i for i in df if df[i].nunique() <= 1]
|
|
101
|
-
|
|
102
|
-
@staticmethod
|
|
103
|
-
def is_one_hot_encoded(series: pd.Series) -> bool:
|
|
104
|
-
try:
|
|
105
|
-
# All rows should be the same type
|
|
106
|
-
if series.apply(lambda x: type(x)).nunique() != 1:
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
# First, handle string representations of True/False
|
|
110
|
-
series_copy = series.copy()
|
|
111
|
-
if series_copy.dtype == "object" or series_copy.dtype == "string":
|
|
112
|
-
# Convert string representations of boolean values to numeric
|
|
113
|
-
series_copy = series_copy.astype(str).str.strip().str.lower()
|
|
114
|
-
series_copy = series_copy.replace({"true": "1", "false": "0"})
|
|
115
|
-
|
|
116
|
-
# Column contains only 0 and 1 (as strings or numbers or booleans)
|
|
117
|
-
series_copy = series_copy.astype(float)
|
|
118
|
-
if set(series_copy.unique()) != {0.0, 1.0}:
|
|
119
|
-
return False
|
|
120
|
-
|
|
121
|
-
series_copy = series_copy.astype(int)
|
|
122
|
-
|
|
123
|
-
# Column doesn't contain any NaN, np.NaN, space, null, etc.
|
|
124
|
-
if not (series_copy.isin([0, 1])).all():
|
|
125
|
-
return False
|
|
126
|
-
|
|
127
|
-
vc = series_copy.value_counts()
|
|
128
|
-
# Column should contain both 0 and 1
|
|
129
|
-
if len(vc) != 2:
|
|
130
|
-
return False
|
|
131
|
-
|
|
132
|
-
# Minority class is 1
|
|
133
|
-
if vc[1] >= vc[0]:
|
|
134
|
-
return False
|
|
135
|
-
|
|
136
|
-
return True
|
|
137
|
-
except ValueError:
|
|
138
|
-
return False
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class OneHotDecoder:
|
|
6
|
+
|
|
7
|
+
def encode(df: pd.DataFrame, category_columns: list[str]) -> pd.DataFrame:
|
|
8
|
+
"""
|
|
9
|
+
Encode categorical columns into one-hot encoded columns.
|
|
10
|
+
"""
|
|
11
|
+
return pd.get_dummies(df, columns=category_columns, prefix_sep="")
|
|
12
|
+
|
|
13
|
+
def decode(df: pd.DataFrame) -> (pd.DataFrame, dict[str, list[str]], dict[str, list[str]]):
|
|
14
|
+
"""
|
|
15
|
+
Detect one-hot encoded column groups and collapse each group into a single
|
|
16
|
+
categorical column. For each row, all active bits in the group are
|
|
17
|
+
encoded into a unique category using a bitmask over the group's columns
|
|
18
|
+
(ordered by numeric suffix). Rows with zero active bits are set to NA.
|
|
19
|
+
|
|
20
|
+
Returns a new DataFrame with transformed columns.
|
|
21
|
+
"""
|
|
22
|
+
one_hot_candidate_groups = OneHotDecoder._group_one_hot_fast(df.columns)
|
|
23
|
+
true_one_hot_groups: dict[str, list[str]] = {}
|
|
24
|
+
|
|
25
|
+
# 1) Detect valid one-hot groups (filter candidates by column-level checks)
|
|
26
|
+
for group_name, column_candidates in one_hot_candidate_groups.items():
|
|
27
|
+
group_columns: list[str] = []
|
|
28
|
+
for column in column_candidates:
|
|
29
|
+
value_counts = df[column].value_counts(dropna=False, normalize=True)
|
|
30
|
+
most_frequent_percent = value_counts.iloc[0]
|
|
31
|
+
if most_frequent_percent >= 0.6 and OneHotDecoder._is_one_hot_encoded(df[column]):
|
|
32
|
+
group_columns.append(column)
|
|
33
|
+
if len(group_columns) > 1:
|
|
34
|
+
true_one_hot_groups[group_name] = group_columns
|
|
35
|
+
|
|
36
|
+
# 2) Transform: replace each detected group with one categorical column
|
|
37
|
+
if not true_one_hot_groups:
|
|
38
|
+
return df, {}, {}
|
|
39
|
+
|
|
40
|
+
result_df = df.copy()
|
|
41
|
+
pseudo_one_hot_groups: dict[str, list[str]] = {}
|
|
42
|
+
for group_name, group_columns in true_one_hot_groups.items():
|
|
43
|
+
sub = result_df[group_columns].copy()
|
|
44
|
+
for c in group_columns:
|
|
45
|
+
s = sub[c]
|
|
46
|
+
if s.dtype == "object" or s.dtype == "string":
|
|
47
|
+
s = s.astype(str).str.strip().str.lower()
|
|
48
|
+
s = s.replace({"true": "1", "false": "0"})
|
|
49
|
+
s = pd.to_numeric(s, errors="coerce")
|
|
50
|
+
sub[c] = s
|
|
51
|
+
|
|
52
|
+
# 3) Find pseudo one-hot encoded columns when there are multiple ones in one row
|
|
53
|
+
if any(sub.sum(axis=1) > 1):
|
|
54
|
+
pseudo_one_hot_groups[group_name] = group_columns
|
|
55
|
+
result_df[group_columns] = result_df[group_columns].astype("string")
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
# Coerce values to numeric 0/1 handling common textual forms
|
|
59
|
+
sub = sub.fillna(0.0)
|
|
60
|
+
# Binarize strictly to 0/1
|
|
61
|
+
bin_values = (sub.to_numpy() > 0.5).astype(np.int64)
|
|
62
|
+
# Map single active bit to exact numeric suffix from column name
|
|
63
|
+
row_sums = bin_values.sum(axis=1)
|
|
64
|
+
argmax_idx = bin_values.argmax(axis=1)
|
|
65
|
+
suffix_arr = np.array(
|
|
66
|
+
[int(OneHotDecoder._split_prefix_numeric_suffix(col)[1]) for col in group_columns], dtype=np.int64
|
|
67
|
+
)
|
|
68
|
+
codes = suffix_arr[argmax_idx]
|
|
69
|
+
categorical_series = pd.Series(codes, index=sub.index)
|
|
70
|
+
# Keep only rows with exactly one active bit; else set NA
|
|
71
|
+
categorical_series = categorical_series.where(row_sums == 1, other=pd.NA)
|
|
72
|
+
# Use pandas nullable integer dtype to keep NA with integer codes
|
|
73
|
+
result_df[group_name] = categorical_series.astype("Int64").astype("string")
|
|
74
|
+
|
|
75
|
+
# Drop original one-hot columns of the group
|
|
76
|
+
result_df = result_df.drop(columns=group_columns)
|
|
77
|
+
|
|
78
|
+
for group_name in pseudo_one_hot_groups:
|
|
79
|
+
del true_one_hot_groups[group_name]
|
|
80
|
+
|
|
81
|
+
return result_df, true_one_hot_groups, pseudo_one_hot_groups
|
|
82
|
+
|
|
83
|
+
def decode_with_cached_groups(
|
|
84
|
+
df: pd.DataFrame, true_one_hot_groups: dict[str, list[str]], pseudo_one_hot_groups: dict[str, list[str]]
|
|
85
|
+
) -> pd.DataFrame:
|
|
86
|
+
"""
|
|
87
|
+
Decode one-hot encoded columns with cached groups.
|
|
88
|
+
"""
|
|
89
|
+
result_df = df.copy()
|
|
90
|
+
# 1. Transform regular one-hot groups back to categorical
|
|
91
|
+
if true_one_hot_groups:
|
|
92
|
+
for group_name, group_columns in true_one_hot_groups.items():
|
|
93
|
+
sub = result_df[group_columns].copy()
|
|
94
|
+
for c in group_columns:
|
|
95
|
+
s = sub[c]
|
|
96
|
+
if s.dtype == "object" or s.dtype == "string":
|
|
97
|
+
s = s.astype(str).str.strip().str.lower()
|
|
98
|
+
s = s.replace({"true": "1", "false": "0"})
|
|
99
|
+
s = pd.to_numeric(s, errors="coerce")
|
|
100
|
+
sub[c] = s
|
|
101
|
+
sub = sub.fillna(0.0)
|
|
102
|
+
bin_values = (sub.to_numpy() > 0.5).astype(np.int64)
|
|
103
|
+
row_sums = bin_values.sum(axis=1)
|
|
104
|
+
argmax_idx = bin_values.argmax(axis=1)
|
|
105
|
+
suffix_arr = np.array(
|
|
106
|
+
[int(OneHotDecoder._split_prefix_numeric_suffix(col)[1]) for col in group_columns], dtype=np.int64
|
|
107
|
+
)
|
|
108
|
+
codes = suffix_arr[argmax_idx]
|
|
109
|
+
categorical_series = pd.Series(codes, index=sub.index)
|
|
110
|
+
categorical_series = categorical_series.where(row_sums == 1, other=pd.NA)
|
|
111
|
+
result_df[group_name] = categorical_series.astype("Int64").astype("string")
|
|
112
|
+
result_df = result_df.drop(columns=group_columns)
|
|
113
|
+
# 2. Convert pseudo-one-hot features to string
|
|
114
|
+
if pseudo_one_hot_groups:
|
|
115
|
+
for _, group_columns in pseudo_one_hot_groups.items():
|
|
116
|
+
result_df[group_columns] = result_df[group_columns].astype("string")
|
|
117
|
+
return result_df
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _is_ascii_digit(c: str) -> bool:
|
|
121
|
+
return "0" <= c <= "9"
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def _split_prefix_numeric_suffix(name: str) -> tuple[str, str] | None:
|
|
125
|
+
"""
|
|
126
|
+
Return (prefix, numeric_suffix) if name ends with ASCII digits and isn't all digits.
|
|
127
|
+
Otherwise None.
|
|
128
|
+
"""
|
|
129
|
+
if not name or not OneHotDecoder._is_ascii_digit(name[-1]):
|
|
130
|
+
return None
|
|
131
|
+
i = len(name) - 1
|
|
132
|
+
while i >= 0 and OneHotDecoder._is_ascii_digit(name[i]):
|
|
133
|
+
i -= 1
|
|
134
|
+
if i < 0:
|
|
135
|
+
# Entire string is digits -> reject
|
|
136
|
+
return None
|
|
137
|
+
return name[: i + 1], name[i + 1 :] # prefix, suffix
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _group_one_hot_fast(
|
|
141
|
+
candidates: list[str], min_group_size: int = 2, require_consecutive: bool = True
|
|
142
|
+
) -> dict[str, list[str]]:
|
|
143
|
+
"""
|
|
144
|
+
Group OHE-like columns by (prefix, numeric_suffix).
|
|
145
|
+
- Only keeps groups with size >= min_group_size (default: 2).
|
|
146
|
+
- Each group's columns are sorted by numeric suffix (int).
|
|
147
|
+
Returns: {prefix: [col_names_sorted]}.
|
|
148
|
+
"""
|
|
149
|
+
if min_group_size < 2:
|
|
150
|
+
raise ValueError("min_group_size must be >= 2.")
|
|
151
|
+
|
|
152
|
+
# 1) Collect by prefix with parsed numeric suffix
|
|
153
|
+
groups: dict[str, list[(int, str)]] = {}
|
|
154
|
+
for s in candidates:
|
|
155
|
+
sp = OneHotDecoder._split_prefix_numeric_suffix(s)
|
|
156
|
+
if sp is None:
|
|
157
|
+
continue
|
|
158
|
+
prefix, sfx = sp
|
|
159
|
+
groups.setdefault(prefix, []).append((int(sfx), s))
|
|
160
|
+
|
|
161
|
+
# 2) Filter and finalize
|
|
162
|
+
out: dict[str, list[str]] = {}
|
|
163
|
+
for prefix, pairs in groups.items():
|
|
164
|
+
if len(pairs) < min_group_size:
|
|
165
|
+
continue
|
|
166
|
+
pairs.sort(key=lambda t: t[0]) # sort by numeric suffix
|
|
167
|
+
if require_consecutive:
|
|
168
|
+
suffixes = [num for num, _ in pairs]
|
|
169
|
+
# no duplicates
|
|
170
|
+
if len(suffixes) != len(set(suffixes)):
|
|
171
|
+
continue
|
|
172
|
+
# strictly consecutive run with step=1
|
|
173
|
+
start = suffixes[0]
|
|
174
|
+
if any(suffixes[i] != start + i for i in range(len(suffixes))):
|
|
175
|
+
continue
|
|
176
|
+
out[prefix] = [name for _, name in pairs]
|
|
177
|
+
|
|
178
|
+
return out
|
|
179
|
+
|
|
180
|
+
def _is_one_hot_encoded(series: pd.Series) -> bool:
|
|
181
|
+
try:
|
|
182
|
+
# All rows should be the same type
|
|
183
|
+
if series.apply(lambda x: type(x)).nunique() != 1:
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
# First, handle string representations of True/False
|
|
187
|
+
series_copy = series.copy()
|
|
188
|
+
if series_copy.dtype == "object" or series_copy.dtype == "string":
|
|
189
|
+
# Convert string representations of boolean values to numeric
|
|
190
|
+
series_copy = series_copy.astype(str).str.strip().str.lower()
|
|
191
|
+
series_copy = series_copy.replace({"true": "1", "false": "0"})
|
|
192
|
+
|
|
193
|
+
# Column contains only 0 and 1 (as strings or numbers or booleans)
|
|
194
|
+
series_copy = series_copy.astype(float)
|
|
195
|
+
if set(series_copy.unique()) != {0.0, 1.0}:
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
series_copy = series_copy.astype(int)
|
|
199
|
+
|
|
200
|
+
# Column doesn't contain any NaN, np.NaN, space, null, etc.
|
|
201
|
+
if not (series_copy.isin([0, 1])).all():
|
|
202
|
+
return False
|
|
203
|
+
|
|
204
|
+
vc = series_copy.value_counts()
|
|
205
|
+
# Column should contain both 0 and 1
|
|
206
|
+
if len(vc) != 2:
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
# Minority class is 1
|
|
210
|
+
if vc[1] >= vc[0]:
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
return True
|
|
214
|
+
except ValueError:
|
|
215
|
+
return False
|
|
@@ -117,6 +117,7 @@ def is_imbalanced(
|
|
|
117
117
|
task_type: ModelTaskType,
|
|
118
118
|
sample_config: SampleConfig,
|
|
119
119
|
bundle: ResourceBundle,
|
|
120
|
+
warning_callback: Optional[Callable] = None,
|
|
120
121
|
) -> bool:
|
|
121
122
|
if task_type is None or not task_type.is_classification():
|
|
122
123
|
return False
|
|
@@ -144,7 +145,8 @@ def is_imbalanced(
|
|
|
144
145
|
msg = bundle.get("dataset_rarest_class_less_min").format(
|
|
145
146
|
min_class_value, min_class_count, MIN_TARGET_CLASS_ROWS
|
|
146
147
|
)
|
|
147
|
-
|
|
148
|
+
if warning_callback is not None:
|
|
149
|
+
warning_callback(msg)
|
|
148
150
|
|
|
149
151
|
min_class_percent = IMBALANCE_THESHOLD / target_classes_count
|
|
150
152
|
min_class_threshold = min_class_percent * count
|
|
@@ -196,14 +198,34 @@ def balance_undersample(
|
|
|
196
198
|
resampled_data = df
|
|
197
199
|
df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
|
|
198
200
|
if task_type == ModelTaskType.MULTICLASS:
|
|
201
|
+
# Remove rare classes which have <0.01% of samples
|
|
202
|
+
total_count = len(df)
|
|
203
|
+
# Always preserve two most frequent classes, even if they are rare
|
|
204
|
+
top_two_classes = list(vc.index[:2])
|
|
205
|
+
rare_classes_all = [cls for cls, cnt in vc.items() if cnt / total_count < 0.0001]
|
|
206
|
+
rare_classes = [cls for cls in rare_classes_all if cls not in top_two_classes]
|
|
207
|
+
if rare_classes:
|
|
208
|
+
msg = bundle.get("rare_target_classes_drop").format(rare_classes)
|
|
209
|
+
logger.warning(msg)
|
|
210
|
+
warning_callback(msg)
|
|
211
|
+
df = df[~df[target_column].isin(rare_classes)]
|
|
212
|
+
target = df[target_column].copy()
|
|
213
|
+
vc = target.value_counts()
|
|
214
|
+
max_class_value = vc.index[0]
|
|
215
|
+
min_class_value = vc.index[len(vc) - 1]
|
|
216
|
+
max_class_count = vc[max_class_value]
|
|
217
|
+
min_class_count = vc[min_class_value]
|
|
218
|
+
num_classes = len(vc)
|
|
219
|
+
|
|
199
220
|
if len(df) > multiclass_min_sample_threshold and max_class_count > (
|
|
200
221
|
min_class_count * multiclass_bootstrap_loops
|
|
201
222
|
):
|
|
202
223
|
|
|
203
224
|
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
|
204
|
-
logger.warning(msg)
|
|
205
225
|
if warning_callback is not None:
|
|
206
226
|
warning_callback(msg)
|
|
227
|
+
else:
|
|
228
|
+
logger.warning(msg)
|
|
207
229
|
|
|
208
230
|
sample_strategy = dict()
|
|
209
231
|
for class_value in vc.index:
|
|
@@ -228,9 +250,10 @@ def balance_undersample(
|
|
|
228
250
|
resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
|
|
229
251
|
elif len(df) > binary_min_sample_threshold:
|
|
230
252
|
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
|
231
|
-
logger.warning(msg)
|
|
232
253
|
if warning_callback is not None:
|
|
233
254
|
warning_callback(msg)
|
|
255
|
+
else:
|
|
256
|
+
logger.warning(msg)
|
|
234
257
|
|
|
235
258
|
# fill up to min_sample_threshold by majority class
|
|
236
259
|
minority_class = df[df[target_column] == min_class_value]
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "1.2.146a2"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|