upgini 1.2.146a2__tar.gz → 1.2.146a9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. {upgini-1.2.146a2 → upgini-1.2.146a9}/PKG-INFO +1 -1
  2. upgini-1.2.146a9/src/upgini/__about__.py +1 -0
  3. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/dataset.py +5 -4
  4. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/features_enricher.py +187 -105
  5. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/normalizer/normalize_utils.py +15 -0
  6. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings.properties +4 -2
  7. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/datetime_utils.py +4 -4
  8. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/features_validator.py +9 -50
  9. upgini-1.2.146a9/src/upgini/utils/one_hot_encoder.py +215 -0
  10. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/target_utils.py +26 -3
  11. upgini-1.2.146a2/src/upgini/__about__.py +0 -1
  12. {upgini-1.2.146a2 → upgini-1.2.146a9}/.gitignore +0 -0
  13. {upgini-1.2.146a2 → upgini-1.2.146a9}/LICENSE +0 -0
  14. {upgini-1.2.146a2 → upgini-1.2.146a9}/README.md +0 -0
  15. {upgini-1.2.146a2 → upgini-1.2.146a9}/pyproject.toml +0 -0
  16. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/__init__.py +0 -0
  17. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads.py +0 -0
  18. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads_management/__init__.py +0 -0
  19. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/ads_management/ads_manager.py +0 -0
  20. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/__init__.py +0 -0
  21. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/all_operators.py +0 -0
  22. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/binary.py +0 -0
  23. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/date.py +0 -0
  24. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/feature.py +0 -0
  25. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/groupby.py +0 -0
  26. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/operator.py +0 -0
  27. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/__init__.py +0 -0
  28. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/base.py +0 -0
  29. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/cross.py +0 -0
  30. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/delta.py +0 -0
  31. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/lag.py +0 -0
  32. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/roll.py +0 -0
  33. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/trend.py +0 -0
  34. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/volatility.py +0 -0
  35. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/unary.py +0 -0
  36. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/utils.py +0 -0
  37. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/autofe/vector.py +0 -0
  38. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/data_source/__init__.py +0 -0
  39. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/data_source/data_source_publisher.py +0 -0
  40. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/errors.py +0 -0
  41. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/http.py +0 -0
  42. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/mdc/__init__.py +0 -0
  43. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/mdc/context.py +0 -0
  44. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/metadata.py +0 -0
  45. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/metrics.py +0 -0
  46. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/normalizer/__init__.py +0 -0
  47. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/__init__.py +0 -0
  48. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/exceptions.py +0 -0
  49. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  50. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/__init__.py +0 -0
  51. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/base.py +0 -0
  52. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/random_under_sampler.py +0 -0
  53. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/sampler/utils.py +0 -0
  54. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/search_task.py +0 -0
  55. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/spinner.py +0 -0
  56. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/Roboto-Regular.ttf +0 -0
  57. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/__init__.py +0 -0
  58. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/base_search_key_detector.py +0 -0
  59. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/blocked_time_series.py +0 -0
  60. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/config.py +0 -0
  61. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/country_utils.py +0 -0
  62. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/custom_loss_utils.py +0 -0
  63. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/cv_utils.py +0 -0
  64. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/deduplicate_utils.py +0 -0
  65. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/display_utils.py +0 -0
  66. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/email_utils.py +0 -0
  67. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/fallback_progress_bar.py +0 -0
  68. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/feature_info.py +0 -0
  69. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/format.py +0 -0
  70. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/hash_utils.py +0 -0
  71. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/ip_utils.py +0 -0
  72. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/mstats.py +0 -0
  73. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/phone_utils.py +0 -0
  74. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/postal_code_utils.py +0 -0
  75. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/progress_bar.py +0 -0
  76. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/psi.py +0 -0
  77. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sample_utils.py +0 -0
  78. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sklearn_ext.py +0 -0
  79. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/sort.py +0 -0
  80. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/track_info.py +0 -0
  81. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/ts_utils.py +0 -0
  82. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/utils/warning_counter.py +0 -0
  83. {upgini-1.2.146a2 → upgini-1.2.146a9}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: upgini
3
- Version: 1.2.146a2
3
+ Version: 1.2.146a9
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -0,0 +1 @@
1
+ __version__ = "1.2.146a9"
@@ -58,7 +58,7 @@ class Dataset:
58
58
  MAX_ROWS = 3_000_000
59
59
  MIN_SUPPORTED_DATE_TS = 946684800000 # 2000-01-01
60
60
  MAX_FEATURES_COUNT = 3500
61
- MAX_UPLOADING_FILE_SIZE = 268435456 # 256 Mb
61
+ MAX_UPLOADING_FILE_SIZE = 536_870_912 # 512 Mb
62
62
  MAX_STRING_FEATURE_LENGTH = 24573
63
63
  FORCE_SAMPLE_SIZE = 7_000
64
64
 
@@ -304,10 +304,11 @@ class Dataset:
304
304
  ):
305
305
  keys_to_validate.remove(ipv4_column)
306
306
 
307
- mandatory_columns = [target]
307
+ mandatory_columns = {target} if target is not None else set()
308
308
  columns_to_validate = mandatory_columns.copy()
309
- columns_to_validate.extend(keys_to_validate)
310
- columns_to_validate = set([i for i in columns_to_validate if i is not None])
309
+ columns_to_validate.update(keys_to_validate)
310
+ if len(columns_to_validate) == 0:
311
+ return
311
312
 
312
313
  nrows = len(self.data)
313
314
  validation_stats = {}
@@ -11,6 +11,7 @@ import uuid
11
11
  from collections import Counter
12
12
  from copy import deepcopy
13
13
  from dataclasses import dataclass
14
+ from pathlib import Path
14
15
  from threading import Thread
15
16
  from typing import Any, Callable
16
17
 
@@ -277,6 +278,8 @@ class FeaturesEnricher(TransformerMixin):
277
278
  self.autodetected_search_keys: dict[str, SearchKey] | None = None
278
279
  self.imbalanced = False
279
280
  self.fit_select_features = True
281
+ self.true_one_hot_groups: dict[str, list[str]] | None = None
282
+ self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
280
283
  self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
281
284
  dict()
282
285
  )
@@ -679,9 +682,6 @@ class FeaturesEnricher(TransformerMixin):
679
682
  self.__set_select_features(select_features)
680
683
  self.dump_input(X, y, self.eval_set)
681
684
 
682
- if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
683
- raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
684
-
685
685
  self.__inner_fit(
686
686
  X,
687
687
  y,
@@ -2049,6 +2049,9 @@ class FeaturesEnricher(TransformerMixin):
2049
2049
  generated_features.extend(generator.generated_features)
2050
2050
 
2051
2051
  normalizer = Normalizer(self.bundle, self.logger)
2052
+ # TODO restore these properties from the server
2053
+ normalizer.true_one_hot_groups = self.true_one_hot_groups
2054
+ normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
2052
2055
  df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2053
2056
  columns_renaming = normalizer.columns_renaming
2054
2057
 
@@ -2664,6 +2667,9 @@ if response.status_code == 200:
2664
2667
  generated_features.extend(generator.generated_features)
2665
2668
 
2666
2669
  normalizer = Normalizer(self.bundle, self.logger)
2670
+ # TODO restore these properties from the server
2671
+ normalizer.true_one_hot_groups = self.true_one_hot_groups
2672
+ normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
2667
2673
  df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2668
2674
  columns_renaming = normalizer.columns_renaming
2669
2675
 
@@ -2831,85 +2837,103 @@ if response.status_code == 200:
2831
2837
  del df
2832
2838
  gc.collect()
2833
2839
 
2834
- dataset = Dataset(
2835
- "sample_" + str(uuid.uuid4()),
2836
- df=df_without_features,
2837
- meaning_types=meaning_types,
2838
- search_keys=combined_search_keys,
2839
- unnest_search_keys=unnest_search_keys,
2840
- id_columns=self.__get_renamed_id_columns(columns_renaming),
2841
- date_column=self._get_date_column(search_keys),
2842
- date_format=self.date_format,
2843
- sample_config=self.sample_config,
2844
- rest_client=self.rest_client,
2845
- logger=self.logger,
2846
- bundle=self.bundle,
2847
- warning_callback=self.__log_warning,
2848
- )
2849
- dataset.columns_renaming = columns_renaming
2840
+ def invoke_validation(df: pd.DataFrame):
2841
+
2842
+ dataset = Dataset(
2843
+ "sample_" + str(uuid.uuid4()),
2844
+ df=df,
2845
+ meaning_types=meaning_types,
2846
+ search_keys=combined_search_keys,
2847
+ unnest_search_keys=unnest_search_keys,
2848
+ id_columns=self.__get_renamed_id_columns(columns_renaming),
2849
+ date_column=self._get_date_column(search_keys),
2850
+ date_format=self.date_format,
2851
+ sample_config=self.sample_config,
2852
+ rest_client=self.rest_client,
2853
+ logger=self.logger,
2854
+ bundle=self.bundle,
2855
+ warning_callback=self.__log_warning,
2856
+ )
2857
+ dataset.columns_renaming = columns_renaming
2858
+
2859
+ validation_task = self._search_task.validation(
2860
+ self._get_trace_id(),
2861
+ dataset,
2862
+ start_time=start_time,
2863
+ extract_features=True,
2864
+ runtime_parameters=runtime_parameters,
2865
+ exclude_features_sources=exclude_features_sources,
2866
+ metrics_calculation=metrics_calculation,
2867
+ silent_mode=silent_mode,
2868
+ progress_bar=progress_bar,
2869
+ progress_callback=progress_callback,
2870
+ )
2850
2871
 
2851
- validation_task = self._search_task.validation(
2852
- self._get_trace_id(),
2853
- dataset,
2854
- start_time=start_time,
2855
- extract_features=True,
2856
- runtime_parameters=runtime_parameters,
2857
- exclude_features_sources=exclude_features_sources,
2858
- metrics_calculation=metrics_calculation,
2859
- silent_mode=silent_mode,
2860
- progress_bar=progress_bar,
2861
- progress_callback=progress_callback,
2862
- )
2872
+ del df, dataset
2873
+ gc.collect()
2863
2874
 
2864
- del df_without_features, dataset
2865
- gc.collect()
2875
+ if not silent_mode:
2876
+ print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2877
+ if not self.__is_registered:
2878
+ print(self.bundle.get("polling_unregister_information"))
2866
2879
 
2867
- if not silent_mode:
2868
- print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2869
- if not self.__is_registered:
2870
- print(self.bundle.get("polling_unregister_information"))
2880
+ progress = self.get_progress(validation_task)
2881
+ progress.recalculate_eta(time.time() - start_time)
2882
+ if progress_bar is not None:
2883
+ progress_bar.progress = progress.to_progress_bar()
2884
+ if progress_callback is not None:
2885
+ progress_callback(progress)
2886
+ prev_progress: SearchProgress | None = None
2887
+ polling_period_seconds = 1
2888
+ try:
2889
+ while progress.stage != ProgressStage.DOWNLOADING.value:
2890
+ if prev_progress is None or prev_progress.percent != progress.percent:
2891
+ progress.recalculate_eta(time.time() - start_time)
2892
+ else:
2893
+ progress.update_eta(prev_progress.eta - polling_period_seconds)
2894
+ prev_progress = progress
2895
+ if progress_bar is not None:
2896
+ progress_bar.progress = progress.to_progress_bar()
2897
+ if progress_callback is not None:
2898
+ progress_callback(progress)
2899
+ if progress.stage == ProgressStage.FAILED.value:
2900
+ raise Exception(progress.error_message)
2901
+ time.sleep(polling_period_seconds)
2902
+ progress = self.get_progress(validation_task)
2903
+ except KeyboardInterrupt as e:
2904
+ print(self.bundle.get("search_stopping"))
2905
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2906
+ self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2907
+ print(self.bundle.get("search_stopped"))
2908
+ raise e
2871
2909
 
2872
- progress = self.get_progress(validation_task)
2873
- progress.recalculate_eta(time.time() - start_time)
2874
- if progress_bar is not None:
2875
- progress_bar.progress = progress.to_progress_bar()
2876
- if progress_callback is not None:
2877
- progress_callback(progress)
2878
- prev_progress: SearchProgress | None = None
2879
- polling_period_seconds = 1
2880
- try:
2881
- while progress.stage != ProgressStage.DOWNLOADING.value:
2882
- if prev_progress is None or prev_progress.percent != progress.percent:
2883
- progress.recalculate_eta(time.time() - start_time)
2884
- else:
2885
- progress.update_eta(prev_progress.eta - polling_period_seconds)
2886
- prev_progress = progress
2887
- if progress_bar is not None:
2888
- progress_bar.progress = progress.to_progress_bar()
2889
- if progress_callback is not None:
2890
- progress_callback(progress)
2891
- if progress.stage == ProgressStage.FAILED.value:
2892
- raise Exception(progress.error_message)
2893
- time.sleep(polling_period_seconds)
2894
- progress = self.get_progress(validation_task)
2895
- except KeyboardInterrupt as e:
2896
- print(self.bundle.get("search_stopping"))
2897
- self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2898
- self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2899
- print(self.bundle.get("search_stopped"))
2900
- raise e
2901
-
2902
- validation_task.poll_result(self._get_trace_id(), quiet=True)
2903
-
2904
- seconds_left = time.time() - start_time
2905
- progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2906
- if progress_bar is not None:
2907
- progress_bar.progress = progress.to_progress_bar()
2908
- if progress_callback is not None:
2909
- progress_callback(progress)
2910
+ validation_task.poll_result(self._get_trace_id(), quiet=True)
2910
2911
 
2911
- if not silent_mode:
2912
- print(self.bundle.get("transform_start"))
2912
+ seconds_left = time.time() - start_time
2913
+ progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2914
+ if progress_bar is not None:
2915
+ progress_bar.progress = progress.to_progress_bar()
2916
+ if progress_callback is not None:
2917
+ progress_callback(progress)
2918
+
2919
+ if not silent_mode:
2920
+ print(self.bundle.get("transform_start"))
2921
+
2922
+ return validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2923
+
2924
+ if len(df_without_features) <= Dataset.MAX_ROWS:
2925
+ result_features = invoke_validation(df_without_features)
2926
+ else:
2927
+ self.logger.warning(
2928
+ f"Dataset has more than {Dataset.MAX_ROWS} rows: {len(df_without_features)}, "
2929
+ f"splitting into chunks of {Dataset.MAX_ROWS} rows"
2930
+ )
2931
+ result_features_list = []
2932
+
2933
+ for i in range(0, len(df_without_features), Dataset.MAX_ROWS):
2934
+ chunk = df_without_features.iloc[i : i + Dataset.MAX_ROWS]
2935
+ result_features_list.append(invoke_validation(chunk))
2936
+ result_features = pd.concat(result_features_list)
2913
2937
 
2914
2938
  # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2915
2939
  df_before_explode = df_before_explode.rename(columns=columns_renaming)
@@ -2922,8 +2946,6 @@ if response.status_code == 200:
2922
2946
  axis=1,
2923
2947
  ).set_index(validated_Xy.index)
2924
2948
 
2925
- result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2926
-
2927
2949
  result = self.__enrich(
2928
2950
  combined_df,
2929
2951
  result_features,
@@ -2974,12 +2996,38 @@ if response.status_code == 200:
2974
2996
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
2975
2997
  fit_input_columns = [c.originalName for c in file_meta.columns]
2976
2998
  original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
2999
+ true_one_hot_features = (
3000
+ [f for group in self.true_one_hot_groups.values() for f in group] if self.true_one_hot_groups else []
3001
+ )
2977
3002
  new_columns_on_transform = [
2978
- c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
3003
+ c
3004
+ for c in validated_Xy.columns
3005
+ if c not in fit_input_columns and c not in original_dropped_features and c not in true_one_hot_features
2979
3006
  ]
2980
3007
  fit_original_search_keys = self._get_fit_search_keys_with_original_names()
2981
3008
 
2982
3009
  selected_generated_features = [c for c in generated_features if c in self.feature_names_]
3010
+ selected_true_one_hot_features = (
3011
+ [
3012
+ c
3013
+ for cat_feature, group in self.true_one_hot_groups.items()
3014
+ for c in group
3015
+ if cat_feature in self.feature_names_
3016
+ ]
3017
+ if self.true_one_hot_groups
3018
+ else []
3019
+ )
3020
+ selected_pseudo_one_hot_features = (
3021
+ [
3022
+ feature
3023
+ for group in self.pseudo_one_hot_groups.values()
3024
+ if any(f in self.feature_names_ for f in group)
3025
+ for feature in group
3026
+ ]
3027
+ if self.pseudo_one_hot_groups
3028
+ else []
3029
+ )
3030
+
2983
3031
  if keep_input is True:
2984
3032
  selected_input_columns = [
2985
3033
  c
@@ -2998,11 +3046,14 @@ if response.status_code == 200:
2998
3046
  if DEFAULT_INDEX in selected_input_columns:
2999
3047
  selected_input_columns.remove(DEFAULT_INDEX)
3000
3048
 
3001
- return selected_input_columns + selected_generated_features
3049
+ return (
3050
+ selected_input_columns
3051
+ + selected_generated_features
3052
+ + selected_true_one_hot_features
3053
+ + selected_pseudo_one_hot_features
3054
+ )
3002
3055
 
3003
- def _validate_empty_search_keys(
3004
- self, search_keys: dict[str, SearchKey], is_transform: bool = False
3005
- ):
3056
+ def _validate_empty_search_keys(self, search_keys: dict[str, SearchKey], is_transform: bool = False):
3006
3057
  if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
3007
3058
  if is_transform:
3008
3059
  self.logger.debug("Transform started without search_keys")
@@ -3169,7 +3220,7 @@ if response.status_code == 200:
3169
3220
  else:
3170
3221
  only_train_df = df
3171
3222
 
3172
- self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle)
3223
+ self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle, self.__log_warning)
3173
3224
  if self.imbalanced:
3174
3225
  # Exclude eval sets from fit because they will be transformed before metrics calculation
3175
3226
  df = only_train_df
@@ -3242,6 +3293,8 @@ if response.status_code == 200:
3242
3293
  df, self.fit_search_keys, self.fit_generated_features
3243
3294
  )
3244
3295
  self.fit_columns_renaming = normalizer.columns_renaming
3296
+ self.true_one_hot_groups = normalizer.true_one_hot_groups
3297
+ self.pseudo_one_hot_groups = normalizer.pseudo_one_hot_groups
3245
3298
  if normalizer.removed_datetime_features:
3246
3299
  self.fit_dropped_features.update(normalizer.removed_datetime_features)
3247
3300
  original_removed_datetime_features = [
@@ -3259,7 +3312,11 @@ if response.status_code == 200:
3259
3312
  features_columns = [c for c in df.columns if c not in non_feature_columns]
3260
3313
 
3261
3314
  features_to_drop, feature_validator_warnings = FeaturesValidator(self.logger).validate(
3262
- df, features_columns, self.generate_features, self.fit_columns_renaming
3315
+ df,
3316
+ features_columns,
3317
+ self.generate_features,
3318
+ self.fit_columns_renaming,
3319
+ [f for group in self.pseudo_one_hot_groups.values() for f in group] if self.pseudo_one_hot_groups else [],
3263
3320
  )
3264
3321
  if feature_validator_warnings:
3265
3322
  for warning in feature_validator_warnings:
@@ -3822,8 +3879,7 @@ if response.status_code == 200:
3822
3879
  elif self.columns_for_online_api:
3823
3880
  msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
3824
3881
  if msg:
3825
- print(msg)
3826
- self.logger.warning(msg)
3882
+ self.__log_warning(msg)
3827
3883
  df = df[df[EVAL_SET_INDEX] != eval_set_index]
3828
3884
  return df
3829
3885
 
@@ -4768,7 +4824,7 @@ if response.status_code == 200:
4768
4824
  elif self.autodetect_search_keys:
4769
4825
  valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
4770
4826
 
4771
- if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4827
+ if len(valid_search_keys) > 0 and all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4772
4828
  if self.__is_registered:
4773
4829
  msg = self.bundle.get("only_custom_keys")
4774
4830
  else:
@@ -5027,37 +5083,55 @@ if response.status_code == 200:
5027
5083
  X_ = X_.to_frame()
5028
5084
 
5029
5085
  with tempfile.TemporaryDirectory() as tmp_dir:
5030
- X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
5031
- x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
5086
+ x_file_name = f"{tmp_dir}/x.parquet"
5087
+ X_.to_parquet(x_file_name, compression="zstd")
5088
+ uploading_file_size = Path(x_file_name).stat().st_size
5089
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5090
+ self.logger.warning(
5091
+ f"Uploading file x.parquet is too large: {uploading_file_size} bytes. Skip it"
5092
+ )
5093
+ return
5094
+ x_digest_sha256 = file_hash(x_file_name)
5032
5095
  if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
5033
5096
  self.logger.info(
5034
5097
  f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
5035
5098
  )
5036
5099
  else:
5037
- self.rest_client.dump_input_file(
5038
- trace_id_, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
5039
- )
5100
+ self.rest_client.dump_input_file(trace_id_, x_file_name, "x.parquet", x_digest_sha256)
5040
5101
 
5041
5102
  if y_ is not None:
5042
5103
  if isinstance(y_, pd.Series):
5043
5104
  y_ = y_.to_frame()
5044
- y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
5045
- y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
5105
+ y_file_name = f"{tmp_dir}/y.parquet"
5106
+ y_.to_parquet(y_file_name, compression="zstd")
5107
+ uploading_file_size = Path(y_file_name).stat().st_size
5108
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5109
+ self.logger.warning(
5110
+ f"Uploading file y.parquet is too large: {uploading_file_size} bytes. Skip it"
5111
+ )
5112
+ return
5113
+ y_digest_sha256 = file_hash(y_file_name)
5046
5114
  if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
5047
5115
  self.logger.info(
5048
5116
  f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
5049
5117
  )
5050
5118
  else:
5051
- self.rest_client.dump_input_file(
5052
- trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5053
- )
5119
+ self.rest_client.dump_input_file(trace_id_, y_file_name, "y.parquet", y_digest_sha256)
5054
5120
 
5055
5121
  if eval_set_ is not None and len(eval_set_) > 0:
5056
5122
  for idx, (eval_x_, eval_y_) in enumerate(eval_set_):
5057
5123
  if isinstance(eval_x_, pd.Series):
5058
5124
  eval_x_ = eval_x_.to_frame()
5059
- eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
5060
- eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
5125
+ eval_x_file_name = f"{tmp_dir}/eval_x_{idx}.parquet"
5126
+ eval_x_.to_parquet(eval_x_file_name, compression="zstd")
5127
+ uploading_file_size = Path(eval_x_file_name).stat().st_size
5128
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5129
+ self.logger.warning(
5130
+ f"Uploading file eval_x_{idx}.parquet is too large: "
5131
+ f"{uploading_file_size} bytes. Skip it"
5132
+ )
5133
+ return
5134
+ eval_x_digest_sha256 = file_hash(eval_x_file_name)
5061
5135
  if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
5062
5136
  self.logger.info(
5063
5137
  f"File eval_x_{idx}.parquet was already uploaded with"
@@ -5066,15 +5140,23 @@ if response.status_code == 200:
5066
5140
  else:
5067
5141
  self.rest_client.dump_input_file(
5068
5142
  trace_id_,
5069
- f"{tmp_dir}/eval_x_{idx}.parquet",
5143
+ eval_x_file_name,
5070
5144
  f"eval_x_{idx}.parquet",
5071
5145
  eval_x_digest_sha256,
5072
5146
  )
5073
5147
 
5074
5148
  if isinstance(eval_y_, pd.Series):
5075
5149
  eval_y_ = eval_y_.to_frame()
5076
- eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
5077
- eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
5150
+ eval_y_file_name = f"{tmp_dir}/eval_y_{idx}.parquet"
5151
+ eval_y_.to_parquet(eval_y_file_name, compression="zstd")
5152
+ uploading_file_size = Path(eval_y_file_name).stat().st_size
5153
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5154
+ self.logger.warning(
5155
+ f"Uploading file eval_y_{idx}.parquet is too large: "
5156
+ f"{uploading_file_size} bytes. Skip it"
5157
+ )
5158
+ return
5159
+ eval_y_digest_sha256 = file_hash(eval_y_file_name)
5078
5160
  if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
5079
5161
  self.logger.info(
5080
5162
  f"File eval_y_{idx}.parquet was already uploaded"
@@ -5083,7 +5165,7 @@ if response.status_code == 200:
5083
5165
  else:
5084
5166
  self.rest_client.dump_input_file(
5085
5167
  trace_id_,
5086
- f"{tmp_dir}/eval_y_{idx}.parquet",
5168
+ eval_y_file_name,
5087
5169
  f"eval_y_{idx}.parquet",
5088
5170
  eval_y_digest_sha256,
5089
5171
  )
@@ -26,6 +26,7 @@ from upgini.utils import find_numbers_with_decimal_comma
26
26
  from upgini.utils.country_utils import CountrySearchKeyConverter
27
27
  from upgini.utils.datetime_utils import DateTimeConverter
28
28
  from upgini.utils.ip_utils import IpSearchKeyConverter
29
+ from upgini.utils.one_hot_encoder import OneHotDecoder
29
30
  from upgini.utils.phone_utils import PhoneSearchKeyConverter
30
31
  from upgini.utils.postal_code_utils import PostalCodeSearchKeyConverter
31
32
 
@@ -45,6 +46,8 @@ class Normalizer:
45
46
  self.search_keys = {}
46
47
  self.generated_features = []
47
48
  self.removed_datetime_features = []
49
+ self.true_one_hot_groups: dict[str, list[str]] | None = None
50
+ self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
48
51
 
49
52
  def normalize(
50
53
  self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], generated_features: List[str]
@@ -53,6 +56,9 @@ class Normalizer:
53
56
  self.generated_features = generated_features.copy()
54
57
 
55
58
  df = df.copy()
59
+
60
+ df = self._convert_one_hot_encoded_columns(df)
61
+
56
62
  df = self._rename_columns(df)
57
63
 
58
64
  df = self._remove_dates_from_features(df)
@@ -77,6 +83,15 @@ class Normalizer:
77
83
 
78
84
  return df, self.search_keys, self.generated_features
79
85
 
86
+ def _convert_one_hot_encoded_columns(self, df: pd.DataFrame):
87
+ if self.true_one_hot_groups is not None or self.pseudo_one_hot_groups is not None:
88
+ df = OneHotDecoder.decode_with_cached_groups(
89
+ df, self.true_one_hot_groups, self.pseudo_one_hot_groups
90
+ )
91
+ else:
92
+ df, self.true_one_hot_groups, self.pseudo_one_hot_groups = OneHotDecoder.decode(df)
93
+ return df
94
+
80
95
  def _rename_columns(self, df: pd.DataFrame):
81
96
  # logger.info("Replace restricted symbols in column names")
82
97
  new_columns = []
@@ -176,7 +176,8 @@ dataset_invalid_multiclass_target=Unexpected dtype of target for multiclass task
176
176
  dataset_invalid_regression_target=Unexpected dtype of target for regression task type: {}. Expected float
177
177
  dataset_invalid_timeseries_target=Unexpected dtype of target for timeseries task type: {}. Expected float
178
178
  dataset_to_many_multiclass_targets=The number of target classes {} exceeds the allowed threshold: {}. Please, correct your data and try again
179
- dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class\nPlease, remove rows with rarest class from your dataframe
179
+ dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class
180
+ #\nPlease, remove rows with rarest class from your dataframe
180
181
  dataset_rarest_class_less_threshold=Target is imbalanced and will be undersampled to the rarest class. Frequency of the rarest class `{}` is {}\nMinimum number of observations for each class to avoid undersampling {} ({}%)
181
182
  dataset_date_features=Columns {} is a datetime or period type but not used as a search key, removed from X
182
183
  dataset_too_many_features=Too many features. Maximum number of features is {}
@@ -231,7 +232,8 @@ limited_int_multiclass_reason=integer-like values with limited unique values obs
231
232
  all_ok_community_invite=❓ Support request
232
233
  too_small_for_metrics=Your train dataset or one of eval datasets contains less than 500 rows. For such dataset Upgini will not calculate accuracy metrics. Please increase the number of rows in the training dataset to calculate accuracy metrics
233
234
  imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} records in train dataset). \nDownsample classes with records more than {1}.
234
- imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
235
+ rare_target_classes_drop=Drop rare target classes with <0.01% freq: {}
236
+ imbalanced_target=Target is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
235
237
  loss_selection_info=Using loss `{}` for feature selection
236
238
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
237
239
  forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
@@ -8,7 +8,7 @@ from dateutil.relativedelta import relativedelta
8
8
  from pandas.api.types import is_numeric_dtype
9
9
 
10
10
  from upgini.errors import ValidationError
11
- from upgini.metadata import EVAL_SET_INDEX, SearchKey
11
+ from upgini.metadata import CURRENT_DATE_COL, EVAL_SET_INDEX, SearchKey
12
12
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
13
13
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
14
14
 
@@ -418,12 +418,12 @@ def is_dates_distribution_valid(
418
418
  except Exception:
419
419
  pass
420
420
 
421
- if maybe_date_col is None:
422
- return
421
+ if maybe_date_col is None or maybe_date_col == CURRENT_DATE_COL:
422
+ return True
423
423
 
424
424
  # Don't check if date column is constant
425
425
  if X[maybe_date_col].nunique() <= 1:
426
- return
426
+ return True
427
427
 
428
428
  if isinstance(X[maybe_date_col].dtype, pd.PeriodDtype):
429
429
  dates = X[maybe_date_col].dt.to_timestamp().dt.date
@@ -23,12 +23,18 @@ class FeaturesValidator:
23
23
  features: List[str],
24
24
  features_for_generate: Optional[List[str]] = None,
25
25
  columns_renaming: Optional[Dict[str, str]] = None,
26
+ pseudo_one_hot_encoded_features: Optional[List[str]] = None,
26
27
  ) -> Tuple[List[str], List[str]]:
27
- one_hot_encoded_features = []
28
28
  empty_or_constant_features = []
29
29
  high_cardinality_features = []
30
30
  warnings = []
31
31
 
32
+ pseudo_one_hot_encoded_features = [
33
+ renamed
34
+ for renamed, original in columns_renaming.items()
35
+ if original in pseudo_one_hot_encoded_features or []
36
+ ]
37
+
32
38
  for f in features:
33
39
  column = df[f]
34
40
  if is_object_dtype(column):
@@ -38,20 +44,11 @@ class FeaturesValidator:
38
44
 
39
45
  if len(value_counts) == 1:
40
46
  empty_or_constant_features.append(f)
41
- elif most_frequent_percent >= 0.99:
42
- if self.is_one_hot_encoded(column):
43
- one_hot_encoded_features.append(f)
44
- else:
45
- empty_or_constant_features.append(f)
47
+ elif most_frequent_percent >= 0.99 and f not in pseudo_one_hot_encoded_features:
48
+ empty_or_constant_features.append(f)
46
49
 
47
50
  columns_renaming = columns_renaming or {}
48
51
 
49
- if one_hot_encoded_features and len(one_hot_encoded_features) > 1:
50
- msg = bundle.get("one_hot_encoded_features").format(
51
- [columns_renaming.get(f, f) for f in one_hot_encoded_features]
52
- )
53
- warnings.append(msg)
54
-
55
52
  if empty_or_constant_features:
56
53
  msg = bundle.get("empty_or_contant_features").format(
57
54
  [columns_renaming.get(f, f) for f in empty_or_constant_features]
@@ -98,41 +95,3 @@ class FeaturesValidator:
98
95
  @staticmethod
99
96
  def find_constant_features(df: pd.DataFrame) -> List[str]:
100
97
  return [i for i in df if df[i].nunique() <= 1]
101
-
102
- @staticmethod
103
- def is_one_hot_encoded(series: pd.Series) -> bool:
104
- try:
105
- # All rows should be the same type
106
- if series.apply(lambda x: type(x)).nunique() != 1:
107
- return False
108
-
109
- # First, handle string representations of True/False
110
- series_copy = series.copy()
111
- if series_copy.dtype == "object" or series_copy.dtype == "string":
112
- # Convert string representations of boolean values to numeric
113
- series_copy = series_copy.astype(str).str.strip().str.lower()
114
- series_copy = series_copy.replace({"true": "1", "false": "0"})
115
-
116
- # Column contains only 0 and 1 (as strings or numbers or booleans)
117
- series_copy = series_copy.astype(float)
118
- if set(series_copy.unique()) != {0.0, 1.0}:
119
- return False
120
-
121
- series_copy = series_copy.astype(int)
122
-
123
- # Column doesn't contain any NaN, np.NaN, space, null, etc.
124
- if not (series_copy.isin([0, 1])).all():
125
- return False
126
-
127
- vc = series_copy.value_counts()
128
- # Column should contain both 0 and 1
129
- if len(vc) != 2:
130
- return False
131
-
132
- # Minority class is 1
133
- if vc[1] >= vc[0]:
134
- return False
135
-
136
- return True
137
- except ValueError:
138
- return False
@@ -0,0 +1,215 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ class OneHotDecoder:
6
+
7
+ def encode(df: pd.DataFrame, category_columns: list[str]) -> pd.DataFrame:
8
+ """
9
+ Encode categorical columns into one-hot encoded columns.
10
+ """
11
+ return pd.get_dummies(df, columns=category_columns, prefix_sep="")
12
+
13
+ def decode(df: pd.DataFrame) -> (pd.DataFrame, dict[str, list[str]], dict[str, list[str]]):
14
+ """
15
+ Detect one-hot encoded column groups and collapse each group into a single
16
+ categorical column. For each row, all active bits in the group are
17
+ encoded into a unique category using a bitmask over the group's columns
18
+ (ordered by numeric suffix). Rows with zero active bits are set to NA.
19
+
20
+ Returns a new DataFrame with transformed columns.
21
+ """
22
+ one_hot_candidate_groups = OneHotDecoder._group_one_hot_fast(df.columns)
23
+ true_one_hot_groups: dict[str, list[str]] = {}
24
+
25
+ # 1) Detect valid one-hot groups (filter candidates by column-level checks)
26
+ for group_name, column_candidates in one_hot_candidate_groups.items():
27
+ group_columns: list[str] = []
28
+ for column in column_candidates:
29
+ value_counts = df[column].value_counts(dropna=False, normalize=True)
30
+ most_frequent_percent = value_counts.iloc[0]
31
+ if most_frequent_percent >= 0.6 and OneHotDecoder._is_one_hot_encoded(df[column]):
32
+ group_columns.append(column)
33
+ if len(group_columns) > 1:
34
+ true_one_hot_groups[group_name] = group_columns
35
+
36
+ # 2) Transform: replace each detected group with one categorical column
37
+ if not true_one_hot_groups:
38
+ return df, {}, {}
39
+
40
+ result_df = df.copy()
41
+ pseudo_one_hot_groups: dict[str, list[str]] = {}
42
+ for group_name, group_columns in true_one_hot_groups.items():
43
+ sub = result_df[group_columns].copy()
44
+ for c in group_columns:
45
+ s = sub[c]
46
+ if s.dtype == "object" or s.dtype == "string":
47
+ s = s.astype(str).str.strip().str.lower()
48
+ s = s.replace({"true": "1", "false": "0"})
49
+ s = pd.to_numeric(s, errors="coerce")
50
+ sub[c] = s
51
+
52
+ # 3) Find pseudo one-hot encoded columns when there are multiple ones in one row
53
+ if any(sub.sum(axis=1) > 1):
54
+ pseudo_one_hot_groups[group_name] = group_columns
55
+ result_df[group_columns] = result_df[group_columns].astype("string")
56
+ continue
57
+
58
+ # Coerce values to numeric 0/1 handling common textual forms
59
+ sub = sub.fillna(0.0)
60
+ # Binarize strictly to 0/1
61
+ bin_values = (sub.to_numpy() > 0.5).astype(np.int64)
62
+ # Map single active bit to exact numeric suffix from column name
63
+ row_sums = bin_values.sum(axis=1)
64
+ argmax_idx = bin_values.argmax(axis=1)
65
+ suffix_arr = np.array(
66
+ [int(OneHotDecoder._split_prefix_numeric_suffix(col)[1]) for col in group_columns], dtype=np.int64
67
+ )
68
+ codes = suffix_arr[argmax_idx]
69
+ categorical_series = pd.Series(codes, index=sub.index)
70
+ # Keep only rows with exactly one active bit; else set NA
71
+ categorical_series = categorical_series.where(row_sums == 1, other=pd.NA)
72
+ # Use pandas nullable integer dtype to keep NA with integer codes
73
+ result_df[group_name] = categorical_series.astype("Int64").astype("string")
74
+
75
+ # Drop original one-hot columns of the group
76
+ result_df = result_df.drop(columns=group_columns)
77
+
78
+ for group_name in pseudo_one_hot_groups:
79
+ del true_one_hot_groups[group_name]
80
+
81
+ return result_df, true_one_hot_groups, pseudo_one_hot_groups
82
+
83
+ def decode_with_cached_groups(
84
+ df: pd.DataFrame, true_one_hot_groups: dict[str, list[str]], pseudo_one_hot_groups: dict[str, list[str]]
85
+ ) -> pd.DataFrame:
86
+ """
87
+ Decode one-hot encoded columns with cached groups.
88
+ """
89
+ result_df = df.copy()
90
+ # 1. Transform regular one-hot groups back to categorical
91
+ if true_one_hot_groups:
92
+ for group_name, group_columns in true_one_hot_groups.items():
93
+ sub = result_df[group_columns].copy()
94
+ for c in group_columns:
95
+ s = sub[c]
96
+ if s.dtype == "object" or s.dtype == "string":
97
+ s = s.astype(str).str.strip().str.lower()
98
+ s = s.replace({"true": "1", "false": "0"})
99
+ s = pd.to_numeric(s, errors="coerce")
100
+ sub[c] = s
101
+ sub = sub.fillna(0.0)
102
+ bin_values = (sub.to_numpy() > 0.5).astype(np.int64)
103
+ row_sums = bin_values.sum(axis=1)
104
+ argmax_idx = bin_values.argmax(axis=1)
105
+ suffix_arr = np.array(
106
+ [int(OneHotDecoder._split_prefix_numeric_suffix(col)[1]) for col in group_columns], dtype=np.int64
107
+ )
108
+ codes = suffix_arr[argmax_idx]
109
+ categorical_series = pd.Series(codes, index=sub.index)
110
+ categorical_series = categorical_series.where(row_sums == 1, other=pd.NA)
111
+ result_df[group_name] = categorical_series.astype("Int64").astype("string")
112
+ result_df = result_df.drop(columns=group_columns)
113
+ # 2. Convert pseudo-one-hot features to string
114
+ if pseudo_one_hot_groups:
115
+ for _, group_columns in pseudo_one_hot_groups.items():
116
+ result_df[group_columns] = result_df[group_columns].astype("string")
117
+ return result_df
118
+
119
+ @staticmethod
120
+ def _is_ascii_digit(c: str) -> bool:
121
+ return "0" <= c <= "9"
122
+
123
+ @staticmethod
124
+ def _split_prefix_numeric_suffix(name: str) -> tuple[str, str] | None:
125
+ """
126
+ Return (prefix, numeric_suffix) if name ends with ASCII digits and isn't all digits.
127
+ Otherwise None.
128
+ """
129
+ if not name or not OneHotDecoder._is_ascii_digit(name[-1]):
130
+ return None
131
+ i = len(name) - 1
132
+ while i >= 0 and OneHotDecoder._is_ascii_digit(name[i]):
133
+ i -= 1
134
+ if i < 0:
135
+ # Entire string is digits -> reject
136
+ return None
137
+ return name[: i + 1], name[i + 1 :] # prefix, suffix
138
+
139
+ @staticmethod
140
+ def _group_one_hot_fast(
141
+ candidates: list[str], min_group_size: int = 2, require_consecutive: bool = True
142
+ ) -> dict[str, list[str]]:
143
+ """
144
+ Group OHE-like columns by (prefix, numeric_suffix).
145
+ - Only keeps groups with size >= min_group_size (default: 2).
146
+ - Each group's columns are sorted by numeric suffix (int).
147
+ Returns: {prefix: [col_names_sorted]}.
148
+ """
149
+ if min_group_size < 2:
150
+ raise ValueError("min_group_size must be >= 2.")
151
+
152
+ # 1) Collect by prefix with parsed numeric suffix
153
+ groups: dict[str, list[(int, str)]] = {}
154
+ for s in candidates:
155
+ sp = OneHotDecoder._split_prefix_numeric_suffix(s)
156
+ if sp is None:
157
+ continue
158
+ prefix, sfx = sp
159
+ groups.setdefault(prefix, []).append((int(sfx), s))
160
+
161
+ # 2) Filter and finalize
162
+ out: dict[str, list[str]] = {}
163
+ for prefix, pairs in groups.items():
164
+ if len(pairs) < min_group_size:
165
+ continue
166
+ pairs.sort(key=lambda t: t[0]) # sort by numeric suffix
167
+ if require_consecutive:
168
+ suffixes = [num for num, _ in pairs]
169
+ # no duplicates
170
+ if len(suffixes) != len(set(suffixes)):
171
+ continue
172
+ # strictly consecutive run with step=1
173
+ start = suffixes[0]
174
+ if any(suffixes[i] != start + i for i in range(len(suffixes))):
175
+ continue
176
+ out[prefix] = [name for _, name in pairs]
177
+
178
+ return out
179
+
180
+ def _is_one_hot_encoded(series: pd.Series) -> bool:
181
+ try:
182
+ # All rows should be the same type
183
+ if series.apply(lambda x: type(x)).nunique() != 1:
184
+ return False
185
+
186
+ # First, handle string representations of True/False
187
+ series_copy = series.copy()
188
+ if series_copy.dtype == "object" or series_copy.dtype == "string":
189
+ # Convert string representations of boolean values to numeric
190
+ series_copy = series_copy.astype(str).str.strip().str.lower()
191
+ series_copy = series_copy.replace({"true": "1", "false": "0"})
192
+
193
+ # Column contains only 0 and 1 (as strings or numbers or booleans)
194
+ series_copy = series_copy.astype(float)
195
+ if set(series_copy.unique()) != {0.0, 1.0}:
196
+ return False
197
+
198
+ series_copy = series_copy.astype(int)
199
+
200
+ # Column doesn't contain any NaN, np.NaN, space, null, etc.
201
+ if not (series_copy.isin([0, 1])).all():
202
+ return False
203
+
204
+ vc = series_copy.value_counts()
205
+ # Column should contain both 0 and 1
206
+ if len(vc) != 2:
207
+ return False
208
+
209
+ # Minority class is 1
210
+ if vc[1] >= vc[0]:
211
+ return False
212
+
213
+ return True
214
+ except ValueError:
215
+ return False
@@ -117,6 +117,7 @@ def is_imbalanced(
117
117
  task_type: ModelTaskType,
118
118
  sample_config: SampleConfig,
119
119
  bundle: ResourceBundle,
120
+ warning_callback: Optional[Callable] = None,
120
121
  ) -> bool:
121
122
  if task_type is None or not task_type.is_classification():
122
123
  return False
@@ -144,7 +145,8 @@ def is_imbalanced(
144
145
  msg = bundle.get("dataset_rarest_class_less_min").format(
145
146
  min_class_value, min_class_count, MIN_TARGET_CLASS_ROWS
146
147
  )
147
- raise ValidationError(msg)
148
+ if warning_callback is not None:
149
+ warning_callback(msg)
148
150
 
149
151
  min_class_percent = IMBALANCE_THESHOLD / target_classes_count
150
152
  min_class_threshold = min_class_percent * count
@@ -196,14 +198,34 @@ def balance_undersample(
196
198
  resampled_data = df
197
199
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
198
200
  if task_type == ModelTaskType.MULTICLASS:
201
+ # Remove rare classes which have <0.01% of samples
202
+ total_count = len(df)
203
+ # Always preserve two most frequent classes, even if they are rare
204
+ top_two_classes = list(vc.index[:2])
205
+ rare_classes_all = [cls for cls, cnt in vc.items() if cnt / total_count < 0.0001]
206
+ rare_classes = [cls for cls in rare_classes_all if cls not in top_two_classes]
207
+ if rare_classes:
208
+ msg = bundle.get("rare_target_classes_drop").format(rare_classes)
209
+ logger.warning(msg)
210
+ warning_callback(msg)
211
+ df = df[~df[target_column].isin(rare_classes)]
212
+ target = df[target_column].copy()
213
+ vc = target.value_counts()
214
+ max_class_value = vc.index[0]
215
+ min_class_value = vc.index[len(vc) - 1]
216
+ max_class_count = vc[max_class_value]
217
+ min_class_count = vc[min_class_value]
218
+ num_classes = len(vc)
219
+
199
220
  if len(df) > multiclass_min_sample_threshold and max_class_count > (
200
221
  min_class_count * multiclass_bootstrap_loops
201
222
  ):
202
223
 
203
224
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
204
- logger.warning(msg)
205
225
  if warning_callback is not None:
206
226
  warning_callback(msg)
227
+ else:
228
+ logger.warning(msg)
207
229
 
208
230
  sample_strategy = dict()
209
231
  for class_value in vc.index:
@@ -228,9 +250,10 @@ def balance_undersample(
228
250
  resampled_data = df[df[SYSTEM_RECORD_ID].isin(new_x[SYSTEM_RECORD_ID])]
229
251
  elif len(df) > binary_min_sample_threshold:
230
252
  msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
231
- logger.warning(msg)
232
253
  if warning_callback is not None:
233
254
  warning_callback(msg)
255
+ else:
256
+ logger.warning(msg)
234
257
 
235
258
  # fill up to min_sample_threshold by majority class
236
259
  minority_class = df[df[target_column] == min_class_value]
@@ -1 +0,0 @@
1
- __version__ = "1.2.146a2"
File without changes
File without changes
File without changes
File without changes
File without changes