upgini 1.2.145a4065.dev1__tar.gz → 1.2.146a9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/PKG-INFO +2 -2
  2. upgini-1.2.146a9/src/upgini/__about__.py +1 -0
  3. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/dataset.py +9 -5
  4. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/features_enricher.py +196 -113
  5. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/normalizer/normalize_utils.py +15 -0
  6. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings.properties +5 -3
  7. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/datetime_utils.py +4 -4
  8. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/features_validator.py +9 -50
  9. upgini-1.2.146a9/src/upgini/utils/one_hot_encoder.py +215 -0
  10. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/target_utils.py +26 -3
  11. upgini-1.2.145a4065.dev1/src/upgini/__about__.py +0 -1
  12. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/.gitignore +0 -0
  13. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/LICENSE +0 -0
  14. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/README.md +0 -0
  15. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/pyproject.toml +0 -0
  16. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/__init__.py +0 -0
  17. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/ads.py +0 -0
  18. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/ads_management/__init__.py +0 -0
  19. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/ads_management/ads_manager.py +0 -0
  20. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/__init__.py +0 -0
  21. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/all_operators.py +0 -0
  22. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/binary.py +0 -0
  23. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/date.py +0 -0
  24. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/feature.py +0 -0
  25. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/groupby.py +0 -0
  26. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/operator.py +0 -0
  27. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/__init__.py +0 -0
  28. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/base.py +0 -0
  29. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/cross.py +0 -0
  30. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/delta.py +0 -0
  31. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/lag.py +0 -0
  32. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/roll.py +0 -0
  33. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/trend.py +0 -0
  34. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/timeseries/volatility.py +0 -0
  35. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/unary.py +0 -0
  36. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/utils.py +0 -0
  37. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/autofe/vector.py +0 -0
  38. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/data_source/__init__.py +0 -0
  39. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/data_source/data_source_publisher.py +0 -0
  40. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/errors.py +0 -0
  41. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/http.py +0 -0
  42. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/mdc/__init__.py +0 -0
  43. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/mdc/context.py +0 -0
  44. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/metadata.py +0 -0
  45. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/metrics.py +0 -0
  46. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/normalizer/__init__.py +0 -0
  47. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/resource_bundle/__init__.py +0 -0
  48. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/resource_bundle/exceptions.py +0 -0
  49. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/resource_bundle/strings_widget.properties +0 -0
  50. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/sampler/__init__.py +0 -0
  51. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/sampler/base.py +0 -0
  52. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/sampler/random_under_sampler.py +0 -0
  53. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/sampler/utils.py +0 -0
  54. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/search_task.py +0 -0
  55. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/spinner.py +0 -0
  56. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/Roboto-Regular.ttf +0 -0
  57. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/__init__.py +0 -0
  58. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/base_search_key_detector.py +0 -0
  59. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/blocked_time_series.py +0 -0
  60. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/config.py +0 -0
  61. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/country_utils.py +0 -0
  62. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/custom_loss_utils.py +0 -0
  63. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/cv_utils.py +0 -0
  64. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/deduplicate_utils.py +0 -0
  65. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/display_utils.py +0 -0
  66. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/email_utils.py +0 -0
  67. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/fallback_progress_bar.py +0 -0
  68. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/feature_info.py +0 -0
  69. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/format.py +0 -0
  70. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/hash_utils.py +0 -0
  71. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/ip_utils.py +0 -0
  72. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/mstats.py +0 -0
  73. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/phone_utils.py +0 -0
  74. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/postal_code_utils.py +0 -0
  75. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/progress_bar.py +0 -0
  76. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/psi.py +0 -0
  77. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/sample_utils.py +0 -0
  78. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/sklearn_ext.py +0 -0
  79. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/sort.py +0 -0
  80. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/track_info.py +0 -0
  81. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/ts_utils.py +0 -0
  82. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/utils/warning_counter.py +0 -0
  83. {upgini-1.2.145a4065.dev1 → upgini-1.2.146a9}/src/upgini/version_validator.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: upgini
3
- Version: 1.2.145a4065.dev1
3
+ Version: 1.2.146a9
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -0,0 +1 @@
1
+ __version__ = "1.2.146a9"
@@ -58,7 +58,7 @@ class Dataset:
58
58
  MAX_ROWS = 3_000_000
59
59
  MIN_SUPPORTED_DATE_TS = 946684800000 # 2000-01-01
60
60
  MAX_FEATURES_COUNT = 3500
61
- MAX_UPLOADING_FILE_SIZE = 268435456 # 256 Mb
61
+ MAX_UPLOADING_FILE_SIZE = 536_870_912 # 512 Mb
62
62
  MAX_STRING_FEATURE_LENGTH = 24573
63
63
  FORCE_SAMPLE_SIZE = 7_000
64
64
 
@@ -304,10 +304,11 @@ class Dataset:
304
304
  ):
305
305
  keys_to_validate.remove(ipv4_column)
306
306
 
307
- mandatory_columns = [target]
307
+ mandatory_columns = {target} if target is not None else set()
308
308
  columns_to_validate = mandatory_columns.copy()
309
- columns_to_validate.extend(keys_to_validate)
310
- columns_to_validate = set([i for i in columns_to_validate if i is not None])
309
+ columns_to_validate.update(keys_to_validate)
310
+ if len(columns_to_validate) == 0:
311
+ return
311
312
 
312
313
  nrows = len(self.data)
313
314
  validation_stats = {}
@@ -370,7 +371,10 @@ class Dataset:
370
371
  self.data["valid_keys"] = self.data["valid_keys"] + self.data[f"{col}_is_valid"]
371
372
  self.data.drop(columns=f"{col}_is_valid", inplace=True)
372
373
 
373
- self.data["is_valid"] = self.data["valid_keys"] > 0
374
+ if len(keys_to_validate) > 0:
375
+ self.data["is_valid"] = self.data["valid_keys"] > 0
376
+ else:
377
+ self.data["is_valid"] = True
374
378
  self.data["is_valid"] = self.data["is_valid"] & self.data["valid_mandatory"]
375
379
  self.data.drop(columns=["valid_keys", "valid_mandatory"], inplace=True)
376
380
 
@@ -11,6 +11,7 @@ import uuid
11
11
  from collections import Counter
12
12
  from copy import deepcopy
13
13
  from dataclasses import dataclass
14
+ from pathlib import Path
14
15
  from threading import Thread
15
16
  from typing import Any, Callable
16
17
 
@@ -277,6 +278,8 @@ class FeaturesEnricher(TransformerMixin):
277
278
  self.autodetected_search_keys: dict[str, SearchKey] | None = None
278
279
  self.imbalanced = False
279
280
  self.fit_select_features = True
281
+ self.true_one_hot_groups: dict[str, list[str]] | None = None
282
+ self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
280
283
  self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
281
284
  dict()
282
285
  )
@@ -679,9 +682,6 @@ class FeaturesEnricher(TransformerMixin):
679
682
  self.__set_select_features(select_features)
680
683
  self.dump_input(X, y, self.eval_set)
681
684
 
682
- if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
683
- raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
684
-
685
685
  self.__inner_fit(
686
686
  X,
687
687
  y,
@@ -2049,6 +2049,9 @@ class FeaturesEnricher(TransformerMixin):
2049
2049
  generated_features.extend(generator.generated_features)
2050
2050
 
2051
2051
  normalizer = Normalizer(self.bundle, self.logger)
2052
+ # TODO restore these properties from the server
2053
+ normalizer.true_one_hot_groups = self.true_one_hot_groups
2054
+ normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
2052
2055
  df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2053
2056
  columns_renaming = normalizer.columns_renaming
2054
2057
 
@@ -2664,6 +2667,9 @@ if response.status_code == 200:
2664
2667
  generated_features.extend(generator.generated_features)
2665
2668
 
2666
2669
  normalizer = Normalizer(self.bundle, self.logger)
2670
+ # TODO restore these properties from the server
2671
+ normalizer.true_one_hot_groups = self.true_one_hot_groups
2672
+ normalizer.pseudo_one_hot_groups = self.pseudo_one_hot_groups
2667
2673
  df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2668
2674
  columns_renaming = normalizer.columns_renaming
2669
2675
 
@@ -2831,85 +2837,103 @@ if response.status_code == 200:
2831
2837
  del df
2832
2838
  gc.collect()
2833
2839
 
2834
- dataset = Dataset(
2835
- "sample_" + str(uuid.uuid4()),
2836
- df=df_without_features,
2837
- meaning_types=meaning_types,
2838
- search_keys=combined_search_keys,
2839
- unnest_search_keys=unnest_search_keys,
2840
- id_columns=self.__get_renamed_id_columns(columns_renaming),
2841
- date_column=self._get_date_column(search_keys),
2842
- date_format=self.date_format,
2843
- sample_config=self.sample_config,
2844
- rest_client=self.rest_client,
2845
- logger=self.logger,
2846
- bundle=self.bundle,
2847
- warning_callback=self.__log_warning,
2848
- )
2849
- dataset.columns_renaming = columns_renaming
2840
+ def invoke_validation(df: pd.DataFrame):
2841
+
2842
+ dataset = Dataset(
2843
+ "sample_" + str(uuid.uuid4()),
2844
+ df=df,
2845
+ meaning_types=meaning_types,
2846
+ search_keys=combined_search_keys,
2847
+ unnest_search_keys=unnest_search_keys,
2848
+ id_columns=self.__get_renamed_id_columns(columns_renaming),
2849
+ date_column=self._get_date_column(search_keys),
2850
+ date_format=self.date_format,
2851
+ sample_config=self.sample_config,
2852
+ rest_client=self.rest_client,
2853
+ logger=self.logger,
2854
+ bundle=self.bundle,
2855
+ warning_callback=self.__log_warning,
2856
+ )
2857
+ dataset.columns_renaming = columns_renaming
2858
+
2859
+ validation_task = self._search_task.validation(
2860
+ self._get_trace_id(),
2861
+ dataset,
2862
+ start_time=start_time,
2863
+ extract_features=True,
2864
+ runtime_parameters=runtime_parameters,
2865
+ exclude_features_sources=exclude_features_sources,
2866
+ metrics_calculation=metrics_calculation,
2867
+ silent_mode=silent_mode,
2868
+ progress_bar=progress_bar,
2869
+ progress_callback=progress_callback,
2870
+ )
2850
2871
 
2851
- validation_task = self._search_task.validation(
2852
- self._get_trace_id(),
2853
- dataset,
2854
- start_time=start_time,
2855
- extract_features=True,
2856
- runtime_parameters=runtime_parameters,
2857
- exclude_features_sources=exclude_features_sources,
2858
- metrics_calculation=metrics_calculation,
2859
- silent_mode=silent_mode,
2860
- progress_bar=progress_bar,
2861
- progress_callback=progress_callback,
2862
- )
2872
+ del df, dataset
2873
+ gc.collect()
2863
2874
 
2864
- del df_without_features, dataset
2865
- gc.collect()
2875
+ if not silent_mode:
2876
+ print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2877
+ if not self.__is_registered:
2878
+ print(self.bundle.get("polling_unregister_information"))
2866
2879
 
2867
- if not silent_mode:
2868
- print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2869
- if not self.__is_registered:
2870
- print(self.bundle.get("polling_unregister_information"))
2880
+ progress = self.get_progress(validation_task)
2881
+ progress.recalculate_eta(time.time() - start_time)
2882
+ if progress_bar is not None:
2883
+ progress_bar.progress = progress.to_progress_bar()
2884
+ if progress_callback is not None:
2885
+ progress_callback(progress)
2886
+ prev_progress: SearchProgress | None = None
2887
+ polling_period_seconds = 1
2888
+ try:
2889
+ while progress.stage != ProgressStage.DOWNLOADING.value:
2890
+ if prev_progress is None or prev_progress.percent != progress.percent:
2891
+ progress.recalculate_eta(time.time() - start_time)
2892
+ else:
2893
+ progress.update_eta(prev_progress.eta - polling_period_seconds)
2894
+ prev_progress = progress
2895
+ if progress_bar is not None:
2896
+ progress_bar.progress = progress.to_progress_bar()
2897
+ if progress_callback is not None:
2898
+ progress_callback(progress)
2899
+ if progress.stage == ProgressStage.FAILED.value:
2900
+ raise Exception(progress.error_message)
2901
+ time.sleep(polling_period_seconds)
2902
+ progress = self.get_progress(validation_task)
2903
+ except KeyboardInterrupt as e:
2904
+ print(self.bundle.get("search_stopping"))
2905
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2906
+ self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2907
+ print(self.bundle.get("search_stopped"))
2908
+ raise e
2871
2909
 
2872
- progress = self.get_progress(validation_task)
2873
- progress.recalculate_eta(time.time() - start_time)
2874
- if progress_bar is not None:
2875
- progress_bar.progress = progress.to_progress_bar()
2876
- if progress_callback is not None:
2877
- progress_callback(progress)
2878
- prev_progress: SearchProgress | None = None
2879
- polling_period_seconds = 1
2880
- try:
2881
- while progress.stage != ProgressStage.DOWNLOADING.value:
2882
- if prev_progress is None or prev_progress.percent != progress.percent:
2883
- progress.recalculate_eta(time.time() - start_time)
2884
- else:
2885
- progress.update_eta(prev_progress.eta - polling_period_seconds)
2886
- prev_progress = progress
2887
- if progress_bar is not None:
2888
- progress_bar.progress = progress.to_progress_bar()
2889
- if progress_callback is not None:
2890
- progress_callback(progress)
2891
- if progress.stage == ProgressStage.FAILED.value:
2892
- raise Exception(progress.error_message)
2893
- time.sleep(polling_period_seconds)
2894
- progress = self.get_progress(validation_task)
2895
- except KeyboardInterrupt as e:
2896
- print(self.bundle.get("search_stopping"))
2897
- self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2898
- self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2899
- print(self.bundle.get("search_stopped"))
2900
- raise e
2901
-
2902
- validation_task.poll_result(self._get_trace_id(), quiet=True)
2903
-
2904
- seconds_left = time.time() - start_time
2905
- progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2906
- if progress_bar is not None:
2907
- progress_bar.progress = progress.to_progress_bar()
2908
- if progress_callback is not None:
2909
- progress_callback(progress)
2910
+ validation_task.poll_result(self._get_trace_id(), quiet=True)
2911
+
2912
+ seconds_left = time.time() - start_time
2913
+ progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2914
+ if progress_bar is not None:
2915
+ progress_bar.progress = progress.to_progress_bar()
2916
+ if progress_callback is not None:
2917
+ progress_callback(progress)
2918
+
2919
+ if not silent_mode:
2920
+ print(self.bundle.get("transform_start"))
2910
2921
 
2911
- if not silent_mode:
2912
- print(self.bundle.get("transform_start"))
2922
+ return validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2923
+
2924
+ if len(df_without_features) <= Dataset.MAX_ROWS:
2925
+ result_features = invoke_validation(df_without_features)
2926
+ else:
2927
+ self.logger.warning(
2928
+ f"Dataset has more than {Dataset.MAX_ROWS} rows: {len(df_without_features)}, "
2929
+ f"splitting into chunks of {Dataset.MAX_ROWS} rows"
2930
+ )
2931
+ result_features_list = []
2932
+
2933
+ for i in range(0, len(df_without_features), Dataset.MAX_ROWS):
2934
+ chunk = df_without_features.iloc[i : i + Dataset.MAX_ROWS]
2935
+ result_features_list.append(invoke_validation(chunk))
2936
+ result_features = pd.concat(result_features_list)
2913
2937
 
2914
2938
  # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2915
2939
  df_before_explode = df_before_explode.rename(columns=columns_renaming)
@@ -2922,8 +2946,6 @@ if response.status_code == 200:
2922
2946
  axis=1,
2923
2947
  ).set_index(validated_Xy.index)
2924
2948
 
2925
- result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2926
-
2927
2949
  result = self.__enrich(
2928
2950
  combined_df,
2929
2951
  result_features,
@@ -2974,12 +2996,38 @@ if response.status_code == 200:
2974
2996
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
2975
2997
  fit_input_columns = [c.originalName for c in file_meta.columns]
2976
2998
  original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
2999
+ true_one_hot_features = (
3000
+ [f for group in self.true_one_hot_groups.values() for f in group] if self.true_one_hot_groups else []
3001
+ )
2977
3002
  new_columns_on_transform = [
2978
- c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
3003
+ c
3004
+ for c in validated_Xy.columns
3005
+ if c not in fit_input_columns and c not in original_dropped_features and c not in true_one_hot_features
2979
3006
  ]
2980
3007
  fit_original_search_keys = self._get_fit_search_keys_with_original_names()
2981
3008
 
2982
3009
  selected_generated_features = [c for c in generated_features if c in self.feature_names_]
3010
+ selected_true_one_hot_features = (
3011
+ [
3012
+ c
3013
+ for cat_feature, group in self.true_one_hot_groups.items()
3014
+ for c in group
3015
+ if cat_feature in self.feature_names_
3016
+ ]
3017
+ if self.true_one_hot_groups
3018
+ else []
3019
+ )
3020
+ selected_pseudo_one_hot_features = (
3021
+ [
3022
+ feature
3023
+ for group in self.pseudo_one_hot_groups.values()
3024
+ if any(f in self.feature_names_ for f in group)
3025
+ for feature in group
3026
+ ]
3027
+ if self.pseudo_one_hot_groups
3028
+ else []
3029
+ )
3030
+
2983
3031
  if keep_input is True:
2984
3032
  selected_input_columns = [
2985
3033
  c
@@ -2998,17 +3046,23 @@ if response.status_code == 200:
2998
3046
  if DEFAULT_INDEX in selected_input_columns:
2999
3047
  selected_input_columns.remove(DEFAULT_INDEX)
3000
3048
 
3001
- return selected_input_columns + selected_generated_features
3049
+ return (
3050
+ selected_input_columns
3051
+ + selected_generated_features
3052
+ + selected_true_one_hot_features
3053
+ + selected_pseudo_one_hot_features
3054
+ )
3002
3055
 
3003
- def __validate_search_keys(self, search_keys: dict[str, SearchKey], search_id: str | None = None):
3056
+ def _validate_empty_search_keys(self, search_keys: dict[str, SearchKey], is_transform: bool = False):
3004
3057
  if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
3005
- if search_id:
3006
- self.logger.debug(f"search_id {search_id} provided without search_keys")
3007
- return
3008
- # else:
3009
- # self.logger.warning("search_keys not provided")
3010
- # raise ValidationError(self.bundle.get("empty_search_keys"))
3058
+ if is_transform:
3059
+ self.logger.debug("Transform started without search_keys")
3060
+ # return
3061
+ else:
3062
+ self.logger.warning("search_keys not provided")
3063
+ # raise ValidationError(self.bundle.get("empty_search_keys"))
3011
3064
 
3065
+ def __validate_search_keys(self, search_keys: dict[str, SearchKey], search_id: str | None = None):
3012
3066
  key_types = search_keys.values()
3013
3067
 
3014
3068
  # Multiple search keys allowed only for PHONE, IP, POSTAL_CODE, EMAIL, HEM
@@ -3166,7 +3220,7 @@ if response.status_code == 200:
3166
3220
  else:
3167
3221
  only_train_df = df
3168
3222
 
3169
- self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle)
3223
+ self.imbalanced = is_imbalanced(only_train_df, self.model_task_type, self.sample_config, self.bundle, self.__log_warning)
3170
3224
  if self.imbalanced:
3171
3225
  # Exclude eval sets from fit because they will be transformed before metrics calculation
3172
3226
  df = only_train_df
@@ -3239,6 +3293,8 @@ if response.status_code == 200:
3239
3293
  df, self.fit_search_keys, self.fit_generated_features
3240
3294
  )
3241
3295
  self.fit_columns_renaming = normalizer.columns_renaming
3296
+ self.true_one_hot_groups = normalizer.true_one_hot_groups
3297
+ self.pseudo_one_hot_groups = normalizer.pseudo_one_hot_groups
3242
3298
  if normalizer.removed_datetime_features:
3243
3299
  self.fit_dropped_features.update(normalizer.removed_datetime_features)
3244
3300
  original_removed_datetime_features = [
@@ -3256,7 +3312,11 @@ if response.status_code == 200:
3256
3312
  features_columns = [c for c in df.columns if c not in non_feature_columns]
3257
3313
 
3258
3314
  features_to_drop, feature_validator_warnings = FeaturesValidator(self.logger).validate(
3259
- df, features_columns, self.generate_features, self.fit_columns_renaming
3315
+ df,
3316
+ features_columns,
3317
+ self.generate_features,
3318
+ self.fit_columns_renaming,
3319
+ [f for group in self.pseudo_one_hot_groups.values() for f in group] if self.pseudo_one_hot_groups else [],
3260
3320
  )
3261
3321
  if feature_validator_warnings:
3262
3322
  for warning in feature_validator_warnings:
@@ -3819,8 +3879,7 @@ if response.status_code == 200:
3819
3879
  elif self.columns_for_online_api:
3820
3880
  msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
3821
3881
  if msg:
3822
- print(msg)
3823
- self.logger.warning(msg)
3882
+ self.__log_warning(msg)
3824
3883
  df = df[df[EVAL_SET_INDEX] != eval_set_index]
3825
3884
  return df
3826
3885
 
@@ -4765,7 +4824,7 @@ if response.status_code == 200:
4765
4824
  elif self.autodetect_search_keys:
4766
4825
  valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
4767
4826
 
4768
- if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4827
+ if len(valid_search_keys) > 0 and all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4769
4828
  if self.__is_registered:
4770
4829
  msg = self.bundle.get("only_custom_keys")
4771
4830
  else:
@@ -4801,6 +4860,8 @@ if response.status_code == 200:
4801
4860
 
4802
4861
  self.logger.info(f"Prepared search keys: {valid_search_keys}")
4803
4862
 
4863
+ # x = self._validate_empty_search_keys(x, valid_search_keys, is_transform=is_transform)
4864
+
4804
4865
  return valid_search_keys
4805
4866
 
4806
4867
  def __show_metrics(
@@ -4941,10 +5002,6 @@ if response.status_code == 200:
4941
5002
  )
4942
5003
  self.__log_warning(self.bundle.get("phone_detected_not_registered"))
4943
5004
 
4944
- if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
4945
- self.logger.warning("search_keys not provided")
4946
- raise ValidationError(self.bundle.get("empty_search_keys"))
4947
-
4948
5005
  return search_keys
4949
5006
 
4950
5007
  def _validate_binary_observations(self, y, task_type: ModelTaskType):
@@ -5026,37 +5083,55 @@ if response.status_code == 200:
5026
5083
  X_ = X_.to_frame()
5027
5084
 
5028
5085
  with tempfile.TemporaryDirectory() as tmp_dir:
5029
- X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
5030
- x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
5086
+ x_file_name = f"{tmp_dir}/x.parquet"
5087
+ X_.to_parquet(x_file_name, compression="zstd")
5088
+ uploading_file_size = Path(x_file_name).stat().st_size
5089
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5090
+ self.logger.warning(
5091
+ f"Uploading file x.parquet is too large: {uploading_file_size} bytes. Skip it"
5092
+ )
5093
+ return
5094
+ x_digest_sha256 = file_hash(x_file_name)
5031
5095
  if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
5032
5096
  self.logger.info(
5033
5097
  f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
5034
5098
  )
5035
5099
  else:
5036
- self.rest_client.dump_input_file(
5037
- trace_id_, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
5038
- )
5100
+ self.rest_client.dump_input_file(trace_id_, x_file_name, "x.parquet", x_digest_sha256)
5039
5101
 
5040
5102
  if y_ is not None:
5041
5103
  if isinstance(y_, pd.Series):
5042
5104
  y_ = y_.to_frame()
5043
- y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
5044
- y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
5105
+ y_file_name = f"{tmp_dir}/y.parquet"
5106
+ y_.to_parquet(y_file_name, compression="zstd")
5107
+ uploading_file_size = Path(y_file_name).stat().st_size
5108
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5109
+ self.logger.warning(
5110
+ f"Uploading file y.parquet is too large: {uploading_file_size} bytes. Skip it"
5111
+ )
5112
+ return
5113
+ y_digest_sha256 = file_hash(y_file_name)
5045
5114
  if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
5046
5115
  self.logger.info(
5047
5116
  f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
5048
5117
  )
5049
5118
  else:
5050
- self.rest_client.dump_input_file(
5051
- trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5052
- )
5119
+ self.rest_client.dump_input_file(trace_id_, y_file_name, "y.parquet", y_digest_sha256)
5053
5120
 
5054
5121
  if eval_set_ is not None and len(eval_set_) > 0:
5055
5122
  for idx, (eval_x_, eval_y_) in enumerate(eval_set_):
5056
5123
  if isinstance(eval_x_, pd.Series):
5057
5124
  eval_x_ = eval_x_.to_frame()
5058
- eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
5059
- eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
5125
+ eval_x_file_name = f"{tmp_dir}/eval_x_{idx}.parquet"
5126
+ eval_x_.to_parquet(eval_x_file_name, compression="zstd")
5127
+ uploading_file_size = Path(eval_x_file_name).stat().st_size
5128
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5129
+ self.logger.warning(
5130
+ f"Uploading file eval_x_{idx}.parquet is too large: "
5131
+ f"{uploading_file_size} bytes. Skip it"
5132
+ )
5133
+ return
5134
+ eval_x_digest_sha256 = file_hash(eval_x_file_name)
5060
5135
  if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
5061
5136
  self.logger.info(
5062
5137
  f"File eval_x_{idx}.parquet was already uploaded with"
@@ -5065,15 +5140,23 @@ if response.status_code == 200:
5065
5140
  else:
5066
5141
  self.rest_client.dump_input_file(
5067
5142
  trace_id_,
5068
- f"{tmp_dir}/eval_x_{idx}.parquet",
5143
+ eval_x_file_name,
5069
5144
  f"eval_x_{idx}.parquet",
5070
5145
  eval_x_digest_sha256,
5071
5146
  )
5072
5147
 
5073
5148
  if isinstance(eval_y_, pd.Series):
5074
5149
  eval_y_ = eval_y_.to_frame()
5075
- eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
5076
- eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
5150
+ eval_y_file_name = f"{tmp_dir}/eval_y_{idx}.parquet"
5151
+ eval_y_.to_parquet(eval_y_file_name, compression="zstd")
5152
+ uploading_file_size = Path(eval_y_file_name).stat().st_size
5153
+ if uploading_file_size > Dataset.MAX_UPLOADING_FILE_SIZE:
5154
+ self.logger.warning(
5155
+ f"Uploading file eval_y_{idx}.parquet is too large: "
5156
+ f"{uploading_file_size} bytes. Skip it"
5157
+ )
5158
+ return
5159
+ eval_y_digest_sha256 = file_hash(eval_y_file_name)
5077
5160
  if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
5078
5161
  self.logger.info(
5079
5162
  f"File eval_y_{idx}.parquet was already uploaded"
@@ -5082,7 +5165,7 @@ if response.status_code == 200:
5082
5165
  else:
5083
5166
  self.rest_client.dump_input_file(
5084
5167
  trace_id_,
5085
- f"{tmp_dir}/eval_y_{idx}.parquet",
5168
+ eval_y_file_name,
5086
5169
  f"eval_y_{idx}.parquet",
5087
5170
  eval_y_digest_sha256,
5088
5171
  )
@@ -26,6 +26,7 @@ from upgini.utils import find_numbers_with_decimal_comma
26
26
  from upgini.utils.country_utils import CountrySearchKeyConverter
27
27
  from upgini.utils.datetime_utils import DateTimeConverter
28
28
  from upgini.utils.ip_utils import IpSearchKeyConverter
29
+ from upgini.utils.one_hot_encoder import OneHotDecoder
29
30
  from upgini.utils.phone_utils import PhoneSearchKeyConverter
30
31
  from upgini.utils.postal_code_utils import PostalCodeSearchKeyConverter
31
32
 
@@ -45,6 +46,8 @@ class Normalizer:
45
46
  self.search_keys = {}
46
47
  self.generated_features = []
47
48
  self.removed_datetime_features = []
49
+ self.true_one_hot_groups: dict[str, list[str]] | None = None
50
+ self.pseudo_one_hot_groups: dict[str, list[str]] | None = None
48
51
 
49
52
  def normalize(
50
53
  self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], generated_features: List[str]
@@ -53,6 +56,9 @@ class Normalizer:
53
56
  self.generated_features = generated_features.copy()
54
57
 
55
58
  df = df.copy()
59
+
60
+ df = self._convert_one_hot_encoded_columns(df)
61
+
56
62
  df = self._rename_columns(df)
57
63
 
58
64
  df = self._remove_dates_from_features(df)
@@ -77,6 +83,15 @@ class Normalizer:
77
83
 
78
84
  return df, self.search_keys, self.generated_features
79
85
 
86
+ def _convert_one_hot_encoded_columns(self, df: pd.DataFrame):
87
+ if self.true_one_hot_groups is not None or self.pseudo_one_hot_groups is not None:
88
+ df = OneHotDecoder.decode_with_cached_groups(
89
+ df, self.true_one_hot_groups, self.pseudo_one_hot_groups
90
+ )
91
+ else:
92
+ df, self.true_one_hot_groups, self.pseudo_one_hot_groups = OneHotDecoder.decode(df)
93
+ return df
94
+
80
95
  def _rename_columns(self, df: pd.DataFrame):
81
96
  # logger.info("Replace restricted symbols in column names")
82
97
  new_columns = []
@@ -176,7 +176,8 @@ dataset_invalid_multiclass_target=Unexpected dtype of target for multiclass task
176
176
  dataset_invalid_regression_target=Unexpected dtype of target for regression task type: {}. Expected float
177
177
  dataset_invalid_timeseries_target=Unexpected dtype of target for timeseries task type: {}. Expected float
178
178
  dataset_to_many_multiclass_targets=The number of target classes {} exceeds the allowed threshold: {}. Please, correct your data and try again
179
- dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class\nPlease, remove rows with rarest class from your dataframe
179
+ dataset_rarest_class_less_min=Count of rows with the rarest class `{}` is {}, minimum count must be > {} for each class
180
+ #\nPlease, remove rows with rarest class from your dataframe
180
181
  dataset_rarest_class_less_threshold=Target is imbalanced and will be undersampled to the rarest class. Frequency of the rarest class `{}` is {}\nMinimum number of observations for each class to avoid undersampling {} ({}%)
181
182
  dataset_date_features=Columns {} is a datetime or period type but not used as a search key, removed from X
182
183
  dataset_too_many_features=Too many features. Maximum number of features is {}
@@ -220,7 +221,7 @@ email_detected=Emails detected in column `{}`. It will be used as a search key\n
220
221
  email_detected_not_registered=Emails detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns\n
221
222
  phone_detected=Phone numbers detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns\n
222
223
  phone_detected_not_registered=Phone numbers detected in column `{}`. It can be used only with api_key from profile.upgini.com\nSee docs to turn off the automatic detection: https://github.com/upgini/upgini/blob/main/README.md#turn-off-autodetection-for-search-key-columns\n
223
- target_type_detected=\nDetected task type: {}. Reason: {}\nYou can set task type manually with argument `model_task_type` of FeaturesEnricher constructor if task type detected incorrectly\n
224
+ target_type_detected=Detected task type: {}. Reason: {}\nYou can set task type manually with argument `model_task_type` of FeaturesEnricher constructor if task type detected incorrectly\n
224
225
  binary_target_reason=only two unique label-values observed
225
226
  non_numeric_multiclass_reason=non-numeric label values observed
226
227
  few_unique_label_multiclass_reason=few unique label-values observed and can be considered as categorical
@@ -231,7 +232,8 @@ limited_int_multiclass_reason=integer-like values with limited unique values obs
231
232
  all_ok_community_invite=❓ Support request
232
233
  too_small_for_metrics=Your train dataset or one of eval datasets contains less than 500 rows. For such dataset Upgini will not calculate accuracy metrics. Please increase the number of rows in the training dataset to calculate accuracy metrics
233
234
  imbalance_multiclass=Class {0} is on 25% quantile of classes distribution ({1} records in train dataset). \nDownsample classes with records more than {1}.
234
- imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
235
+ rare_target_classes_drop=Drop rare target classes with <0.01% freq: {}
236
+ imbalanced_target=Target is imbalanced and will be undersampled. Frequency of the rarest class `{}` is {}
235
237
  loss_selection_info=Using loss `{}` for feature selection
236
238
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
237
239
  forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
@@ -8,7 +8,7 @@ from dateutil.relativedelta import relativedelta
8
8
  from pandas.api.types import is_numeric_dtype
9
9
 
10
10
  from upgini.errors import ValidationError
11
- from upgini.metadata import EVAL_SET_INDEX, SearchKey
11
+ from upgini.metadata import CURRENT_DATE_COL, EVAL_SET_INDEX, SearchKey
12
12
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
13
13
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
14
14
 
@@ -418,12 +418,12 @@ def is_dates_distribution_valid(
418
418
  except Exception:
419
419
  pass
420
420
 
421
- if maybe_date_col is None:
422
- return
421
+ if maybe_date_col is None or maybe_date_col == CURRENT_DATE_COL:
422
+ return True
423
423
 
424
424
  # Don't check if date column is constant
425
425
  if X[maybe_date_col].nunique() <= 1:
426
- return
426
+ return True
427
427
 
428
428
  if isinstance(X[maybe_date_col].dtype, pd.PeriodDtype):
429
429
  dates = X[maybe_date_col].dt.to_timestamp().dt.date