upgini 1.1.280.dev0__py3-none-any.whl → 1.2.31a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

Files changed (43) hide show
  1. upgini/__about__.py +1 -1
  2. upgini/__init__.py +4 -20
  3. upgini/autofe/all_operands.py +39 -9
  4. upgini/autofe/binary.py +148 -45
  5. upgini/autofe/date.py +197 -26
  6. upgini/autofe/feature.py +102 -19
  7. upgini/autofe/groupby.py +22 -22
  8. upgini/autofe/operand.py +9 -6
  9. upgini/autofe/unary.py +83 -41
  10. upgini/autofe/vector.py +8 -8
  11. upgini/data_source/data_source_publisher.py +128 -5
  12. upgini/dataset.py +50 -386
  13. upgini/features_enricher.py +931 -542
  14. upgini/http.py +27 -16
  15. upgini/lazy_import.py +35 -0
  16. upgini/metadata.py +84 -59
  17. upgini/metrics.py +164 -34
  18. upgini/normalizer/normalize_utils.py +197 -0
  19. upgini/resource_bundle/strings.properties +66 -51
  20. upgini/search_task.py +10 -4
  21. upgini/utils/Roboto-Regular.ttf +0 -0
  22. upgini/utils/base_search_key_detector.py +14 -12
  23. upgini/utils/country_utils.py +16 -0
  24. upgini/utils/custom_loss_utils.py +39 -36
  25. upgini/utils/datetime_utils.py +98 -45
  26. upgini/utils/deduplicate_utils.py +135 -112
  27. upgini/utils/display_utils.py +46 -15
  28. upgini/utils/email_utils.py +54 -16
  29. upgini/utils/feature_info.py +172 -0
  30. upgini/utils/features_validator.py +34 -20
  31. upgini/utils/ip_utils.py +100 -1
  32. upgini/utils/phone_utils.py +343 -0
  33. upgini/utils/postal_code_utils.py +34 -0
  34. upgini/utils/sklearn_ext.py +28 -19
  35. upgini/utils/target_utils.py +113 -57
  36. upgini/utils/warning_counter.py +1 -0
  37. upgini/version_validator.py +8 -4
  38. {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a1.dist-info}/METADATA +31 -16
  39. upgini-1.2.31a1.dist-info/RECORD +65 -0
  40. upgini/normalizer/phone_normalizer.py +0 -340
  41. upgini-1.1.280.dev0.dist-info/RECORD +0 -62
  42. {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a1.dist-info}/WHEEL +0 -0
  43. {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a1.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@ import sys
11
11
  import tempfile
12
12
  import time
13
13
  import uuid
14
+ from collections import Counter
14
15
  from dataclasses import dataclass
15
16
  from threading import Thread
16
17
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -22,7 +23,6 @@ from pandas.api.types import (
22
23
  is_datetime64_any_dtype,
23
24
  is_numeric_dtype,
24
25
  is_object_dtype,
25
- is_period_dtype,
26
26
  is_string_dtype,
27
27
  )
28
28
  from scipy.stats import ks_2samp
@@ -45,24 +45,31 @@ from upgini.mdc import MDC
45
45
  from upgini.metadata import (
46
46
  COUNTRY,
47
47
  DEFAULT_INDEX,
48
+ ENTITY_SYSTEM_RECORD_ID,
48
49
  EVAL_SET_INDEX,
49
50
  ORIGINAL_INDEX,
50
51
  RENAMED_INDEX,
52
+ SEARCH_KEY_UNNEST,
51
53
  SORT_ID,
52
54
  SYSTEM_RECORD_ID,
53
55
  TARGET,
54
56
  CVType,
57
+ FeaturesMetadataV2,
55
58
  FileColumnMeaningType,
56
59
  ModelTaskType,
57
60
  RuntimeParameters,
58
61
  SearchKey,
59
62
  )
60
63
  from upgini.metrics import EstimatorWrapper, validate_scoring_argument
64
+ from upgini.normalizer.normalize_utils import Normalizer
61
65
  from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
62
66
  from upgini.search_task import SearchTask
63
67
  from upgini.spinner import Spinner
64
68
  from upgini.utils import combine_search_keys, find_numbers_with_decimal_comma
65
- from upgini.utils.country_utils import CountrySearchKeyDetector
69
+ from upgini.utils.country_utils import (
70
+ CountrySearchKeyConverter,
71
+ CountrySearchKeyDetector,
72
+ )
66
73
  from upgini.utils.custom_loss_utils import (
67
74
  get_additional_params_custom_loss,
68
75
  get_runtime_params_custom_loss,
@@ -71,8 +78,8 @@ from upgini.utils.cv_utils import CVConfig, get_groups
71
78
  from upgini.utils.datetime_utils import (
72
79
  DateTimeSearchKeyConverter,
73
80
  is_blocked_time_series,
81
+ is_dates_distribution_valid,
74
82
  is_time_series,
75
- validate_dates_distribution,
76
83
  )
77
84
  from upgini.utils.deduplicate_utils import (
78
85
  clean_full_duplicates,
@@ -84,12 +91,20 @@ from upgini.utils.display_utils import (
84
91
  prepare_and_show_report,
85
92
  show_request_quote_button,
86
93
  )
87
- from upgini.utils.email_utils import EmailSearchKeyConverter, EmailSearchKeyDetector
94
+ from upgini.utils.email_utils import (
95
+ EmailDomainGenerator,
96
+ EmailSearchKeyConverter,
97
+ EmailSearchKeyDetector,
98
+ )
99
+ from upgini.utils.feature_info import FeatureInfo, _round_shap_value
88
100
  from upgini.utils.features_validator import FeaturesValidator
89
101
  from upgini.utils.format import Format
90
- from upgini.utils.ip_utils import IpToCountrySearchKeyConverter
91
- from upgini.utils.phone_utils import PhoneSearchKeyDetector
92
- from upgini.utils.postal_code_utils import PostalCodeSearchKeyDetector
102
+ from upgini.utils.ip_utils import IpSearchKeyConverter
103
+ from upgini.utils.phone_utils import PhoneSearchKeyConverter, PhoneSearchKeyDetector
104
+ from upgini.utils.postal_code_utils import (
105
+ PostalCodeSearchKeyConverter,
106
+ PostalCodeSearchKeyDetector,
107
+ )
93
108
 
94
109
  try:
95
110
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
@@ -145,6 +160,10 @@ class FeaturesEnricher(TransformerMixin):
145
160
 
146
161
  shared_datasets: list of str, optional (default=None)
147
162
  List of private shared dataset ids for custom search
163
+
164
+ select_features: bool, optional (default=False)
165
+ If True, return only selected features both from input and data sources.
166
+ Otherwise, return all features from input and only selected features from data sources.
148
167
  """
149
168
 
150
169
  TARGET_NAME = "target"
@@ -211,11 +230,12 @@ class FeaturesEnricher(TransformerMixin):
211
230
  client_visitorid: Optional[str] = None,
212
231
  custom_bundle_config: Optional[str] = None,
213
232
  add_date_if_missing: bool = True,
233
+ select_features: bool = False,
214
234
  **kwargs,
215
235
  ):
216
236
  self.bundle = get_custom_bundle(custom_bundle_config)
217
237
  self._api_key = api_key or os.environ.get(UPGINI_API_KEY)
218
- if api_key is not None and not isinstance(api_key, str):
238
+ if self._api_key is not None and not isinstance(self._api_key, str):
219
239
  raise ValidationError(f"api_key should be `string`, but passed: `{api_key}`")
220
240
  self.rest_client = get_rest_client(endpoint, self._api_key, client_ip, client_visitorid)
221
241
  self.client_ip = client_ip
@@ -235,6 +255,7 @@ class FeaturesEnricher(TransformerMixin):
235
255
 
236
256
  self.passed_features: List[str] = []
237
257
  self.df_with_original_index: Optional[pd.DataFrame] = None
258
+ self.fit_columns_renaming: Optional[Dict[str, str]] = None
238
259
  self.country_added = False
239
260
  self.fit_generated_features: List[str] = []
240
261
  self.fit_dropped_features: Set[str] = set()
@@ -245,10 +266,12 @@ class FeaturesEnricher(TransformerMixin):
245
266
  self.eval_set: Optional[List[Tuple]] = None
246
267
  self.autodetected_search_keys: Dict[str, SearchKey] = {}
247
268
  self.imbalanced = False
248
- self.__cached_sampled_datasets: Optional[Tuple[pd.DataFrame, pd.DataFrame, pd.Series, Dict, Dict]] = None
269
+ self.__cached_sampled_datasets: Dict[str, Tuple[pd.DataFrame, pd.DataFrame, pd.Series, Dict, Dict, Dict]] = (
270
+ dict()
271
+ )
249
272
 
250
- validate_version(self.logger)
251
- self.search_keys = search_keys or dict()
273
+ validate_version(self.logger, self.__log_warning)
274
+ self.search_keys = search_keys or {}
252
275
  self.country_code = country_code
253
276
  self.__validate_search_keys(search_keys, search_id)
254
277
  self.model_task_type = model_task_type
@@ -261,8 +284,11 @@ class FeaturesEnricher(TransformerMixin):
261
284
  self._relevant_data_sources_wo_links: pd.DataFrame = self.EMPTY_DATA_SOURCES
262
285
  self.metrics: Optional[pd.DataFrame] = None
263
286
  self.feature_names_ = []
287
+ self.dropped_client_feature_names_ = []
264
288
  self.feature_importances_ = []
265
289
  self.search_id = search_id
290
+ self.select_features = select_features
291
+
266
292
  if search_id:
267
293
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
268
294
 
@@ -322,6 +348,10 @@ class FeaturesEnricher(TransformerMixin):
322
348
  self.exclude_columns = exclude_columns
323
349
  self.baseline_score_column = baseline_score_column
324
350
  self.add_date_if_missing = add_date_if_missing
351
+ self.features_info_display_handle = None
352
+ self.data_sources_display_handle = None
353
+ self.autofe_features_display_handle = None
354
+ self.report_button_handle = None
325
355
 
326
356
  def _get_api_key(self):
327
357
  return self._api_key
@@ -423,7 +453,7 @@ class FeaturesEnricher(TransformerMixin):
423
453
 
424
454
  self.logger.info("Start fit")
425
455
 
426
- self.__validate_search_keys(self.search_keys, self.search_id)
456
+ self.__validate_search_keys(self.search_keys)
427
457
 
428
458
  # Validate client estimator params
429
459
  self._get_client_cat_features(estimator, X, self.search_keys)
@@ -557,7 +587,7 @@ class FeaturesEnricher(TransformerMixin):
557
587
 
558
588
  self.logger.info("Start fit_transform")
559
589
 
560
- self.__validate_search_keys(self.search_keys, self.search_id)
590
+ self.__validate_search_keys(self.search_keys)
561
591
 
562
592
  search_progress = SearchProgress(0.0, ProgressStage.START_FIT)
563
593
  if progress_callback is not None:
@@ -704,7 +734,7 @@ class FeaturesEnricher(TransformerMixin):
704
734
 
705
735
  start_time = time.time()
706
736
  try:
707
- result = self.__inner_transform(
737
+ result, _, _ = self.__inner_transform(
708
738
  trace_id,
709
739
  X,
710
740
  exclude_features_sources=exclude_features_sources,
@@ -831,17 +861,44 @@ class FeaturesEnricher(TransformerMixin):
831
861
  self.logger.warning(msg)
832
862
  print(msg)
833
863
 
864
+ if X is not None and y is None:
865
+ raise ValidationError("X passed without y")
866
+
834
867
  self.__validate_search_keys(self.search_keys, self.search_id)
835
868
  effective_X = X if X is not None else self.X
836
869
  effective_y = y if y is not None else self.y
837
870
  effective_eval_set = eval_set if eval_set is not None else self.eval_set
838
871
  effective_eval_set = self._check_eval_set(effective_eval_set, effective_X, self.bundle)
839
872
 
873
+ if (
874
+ self._search_task is None
875
+ or self._search_task.provider_metadata_v2 is None
876
+ or len(self._search_task.provider_metadata_v2) == 0
877
+ or effective_X is None
878
+ or effective_y is None
879
+ ):
880
+ raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
881
+
882
+ validated_X = self._validate_X(effective_X)
883
+ validated_y = self._validate_y(validated_X, effective_y)
884
+ validated_eval_set = (
885
+ [self._validate_eval_set_pair(validated_X, eval_pair) for eval_pair in effective_eval_set]
886
+ if effective_eval_set is not None
887
+ else None
888
+ )
889
+
890
+ if self.X is None:
891
+ self.X = X
892
+ if self.y is None:
893
+ self.y = y
894
+ if self.eval_set is None:
895
+ self.eval_set = effective_eval_set
896
+
840
897
  try:
841
898
  self.__log_debug_information(
842
- effective_X,
843
- effective_y,
844
- effective_eval_set,
899
+ validated_X,
900
+ validated_y,
901
+ validated_eval_set,
845
902
  exclude_features_sources=exclude_features_sources,
846
903
  cv=cv if cv is not None else self.cv,
847
904
  importance_threshold=importance_threshold,
@@ -851,21 +908,9 @@ class FeaturesEnricher(TransformerMixin):
851
908
  remove_outliers_calc_metrics=remove_outliers_calc_metrics,
852
909
  )
853
910
 
854
- if (
855
- self._search_task is None
856
- or self._search_task.provider_metadata_v2 is None
857
- or len(self._search_task.provider_metadata_v2) == 0
858
- or effective_X is None
859
- or effective_y is None
860
- ):
861
- raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
862
-
863
- if X is not None and y is None:
864
- raise ValidationError("X passed without y")
865
-
866
911
  validate_scoring_argument(scoring)
867
912
 
868
- self._validate_baseline_score(effective_X, effective_eval_set)
913
+ self._validate_baseline_score(validated_X, validated_eval_set)
869
914
 
870
915
  if self._has_paid_features(exclude_features_sources):
871
916
  msg = self.bundle.get("metrics_with_paid_features")
@@ -874,14 +919,14 @@ class FeaturesEnricher(TransformerMixin):
874
919
  return None
875
920
 
876
921
  cat_features, search_keys_for_metrics = self._get_client_cat_features(
877
- estimator, effective_X, self.search_keys
922
+ estimator, validated_X, self.search_keys
878
923
  )
879
924
 
880
925
  prepared_data = self._prepare_data_for_metrics(
881
926
  trace_id=trace_id,
882
- X=effective_X,
883
- y=effective_y,
884
- eval_set=effective_eval_set,
927
+ X=X,
928
+ y=y,
929
+ eval_set=eval_set,
885
930
  exclude_features_sources=exclude_features_sources,
886
931
  importance_threshold=importance_threshold,
887
932
  max_features=max_features,
@@ -904,21 +949,27 @@ class FeaturesEnricher(TransformerMixin):
904
949
  search_keys,
905
950
  groups,
906
951
  _cv,
952
+ columns_renaming,
907
953
  ) = prepared_data
908
954
 
955
+ # rename cat_features
956
+ if cat_features:
957
+ for new_c, old_c in columns_renaming.items():
958
+ if old_c in cat_features:
959
+ cat_features.remove(old_c)
960
+ cat_features.append(new_c)
961
+
909
962
  gc.collect()
910
963
 
964
+ if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
965
+ self.__log_warning(self.bundle.get("metrics_no_important_free_features"))
966
+ return None
967
+
911
968
  print(self.bundle.get("metrics_start"))
912
969
  with Spinner():
913
- if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
914
- print(self.bundle.get("metrics_no_important_free_features"))
915
- self.logger.warning("No client or free relevant ADS features found to calculate metrics")
916
- self.warning_counter.increment()
917
- return None
918
-
919
970
  self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
920
971
 
921
- has_date = self._get_date_column(search_keys) is not None
972
+ has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
922
973
  model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
923
974
 
924
975
  wrapper = EstimatorWrapper.create(
@@ -930,11 +981,12 @@ class FeaturesEnricher(TransformerMixin):
930
981
  scoring,
931
982
  groups=groups,
932
983
  text_features=self.generate_features,
984
+ has_date=has_date,
933
985
  )
934
986
  metric = wrapper.metric_name
935
987
  multiplier = wrapper.multiplier
936
988
 
937
- # 1 If client features are presented - fit and predict with KFold CatBoost model
989
+ # 1 If client features are presented - fit and predict with KFold estimator
938
990
  # on etalon features and calculate baseline metric
939
991
  etalon_metric = None
940
992
  baseline_estimator = None
@@ -956,14 +1008,24 @@ class FeaturesEnricher(TransformerMixin):
956
1008
  add_params=custom_loss_add_params,
957
1009
  groups=groups,
958
1010
  text_features=self.generate_features,
1011
+ has_date=has_date,
959
1012
  )
960
- etalon_metric = baseline_estimator.cross_val_predict(
1013
+ etalon_cv_result = baseline_estimator.cross_val_predict(
961
1014
  fitting_X, y_sorted, self.baseline_score_column
962
1015
  )
963
- self.logger.info(f"Baseline {metric} on train client features: {etalon_metric}")
1016
+ etalon_metric = etalon_cv_result.get_display_metric()
1017
+ if etalon_metric is None:
1018
+ self.logger.info(
1019
+ f"Baseline {metric} on train client features is None (maybe all features was removed)"
1020
+ )
1021
+ baseline_estimator = None
1022
+ else:
1023
+ self.logger.info(f"Baseline {metric} on train client features: {etalon_metric}")
964
1024
 
965
- # 2 Fit and predict with KFold Catboost model on enriched tds
1025
+ # 2 Fit and predict with KFold estimator on enriched tds
966
1026
  # and calculate final metric (and uplift)
1027
+ enriched_metric = None
1028
+ uplift = None
967
1029
  enriched_estimator = None
968
1030
  if set(fitting_X.columns) != set(fitting_enriched_X.columns):
969
1031
  self.logger.info(
@@ -981,16 +1043,24 @@ class FeaturesEnricher(TransformerMixin):
981
1043
  add_params=custom_loss_add_params,
982
1044
  groups=groups,
983
1045
  text_features=self.generate_features,
1046
+ has_date=has_date,
984
1047
  )
985
- enriched_metric = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
986
- self.logger.info(f"Enriched {metric} on train combined features: {enriched_metric}")
987
- if etalon_metric is not None:
988
- uplift = (enriched_metric - etalon_metric) * multiplier
1048
+ enriched_cv_result = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
1049
+ enriched_metric = enriched_cv_result.get_display_metric()
1050
+ enriched_shaps = enriched_cv_result.shap_values
1051
+
1052
+ if enriched_shaps is not None:
1053
+ self._update_shap_values(trace_id, validated_X.columns.to_list(), enriched_shaps)
1054
+
1055
+ if enriched_metric is None:
1056
+ self.logger.warning(
1057
+ f"Enriched {metric} on train combined features is None (maybe all features was removed)"
1058
+ )
1059
+ enriched_estimator = None
989
1060
  else:
990
- uplift = None
991
- else:
992
- enriched_metric = None
993
- uplift = None
1061
+ self.logger.info(f"Enriched {metric} on train combined features: {enriched_metric}")
1062
+ if etalon_metric is not None and enriched_metric is not None:
1063
+ uplift = (enriched_cv_result.metric - etalon_cv_result.metric) * multiplier
994
1064
 
995
1065
  train_metrics = {
996
1066
  self.bundle.get("quality_metrics_segment_header"): self.bundle.get(
@@ -999,10 +1069,10 @@ class FeaturesEnricher(TransformerMixin):
999
1069
  self.bundle.get("quality_metrics_rows_header"): _num_samples(effective_X),
1000
1070
  }
1001
1071
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1002
- y_sorted
1072
+ validated_y
1003
1073
  ):
1004
1074
  train_metrics[self.bundle.get("quality_metrics_mean_target_header")] = round(
1005
- np.mean(effective_y), 4
1075
+ np.mean(validated_y), 4
1006
1076
  )
1007
1077
  if etalon_metric is not None:
1008
1078
  train_metrics[self.bundle.get("quality_metrics_baseline_header").format(metric)] = etalon_metric
@@ -1033,9 +1103,10 @@ class FeaturesEnricher(TransformerMixin):
1033
1103
  f"Calculate baseline {metric} on eval set {idx + 1} "
1034
1104
  f"on client features: {eval_X_sorted.columns.to_list()}"
1035
1105
  )
1036
- etalon_eval_metric = baseline_estimator.calculate_metric(
1106
+ etalon_eval_results = baseline_estimator.calculate_metric(
1037
1107
  eval_X_sorted, eval_y_sorted, self.baseline_score_column
1038
1108
  )
1109
+ etalon_eval_metric = etalon_eval_results.get_display_metric()
1039
1110
  self.logger.info(
1040
1111
  f"Baseline {metric} on eval set {idx + 1} client features: {etalon_eval_metric}"
1041
1112
  )
@@ -1047,9 +1118,10 @@ class FeaturesEnricher(TransformerMixin):
1047
1118
  f"Calculate enriched {metric} on eval set {idx + 1} "
1048
1119
  f"on combined features: {enriched_eval_X_sorted.columns.to_list()}"
1049
1120
  )
1050
- enriched_eval_metric = enriched_estimator.calculate_metric(
1121
+ enriched_eval_results = enriched_estimator.calculate_metric(
1051
1122
  enriched_eval_X_sorted, enriched_eval_y_sorted
1052
1123
  )
1124
+ enriched_eval_metric = enriched_eval_results.get_display_metric()
1053
1125
  self.logger.info(
1054
1126
  f"Enriched {metric} on eval set {idx + 1} combined features: {enriched_eval_metric}"
1055
1127
  )
@@ -1057,11 +1129,11 @@ class FeaturesEnricher(TransformerMixin):
1057
1129
  enriched_eval_metric = None
1058
1130
 
1059
1131
  if etalon_eval_metric is not None and enriched_eval_metric is not None:
1060
- eval_uplift = (enriched_eval_metric - etalon_eval_metric) * multiplier
1132
+ eval_uplift = (enriched_eval_results.metric - etalon_eval_results.metric) * multiplier
1061
1133
  else:
1062
1134
  eval_uplift = None
1063
1135
 
1064
- effective_eval_set = eval_set if eval_set is not None else self.eval_set
1136
+ # effective_eval_set = eval_set if eval_set is not None else self.eval_set
1065
1137
  eval_metrics = {
1066
1138
  self.bundle.get("quality_metrics_segment_header"): self.bundle.get(
1067
1139
  "quality_metrics_eval_segment"
@@ -1072,10 +1144,10 @@ class FeaturesEnricher(TransformerMixin):
1072
1144
  # self.bundle.get("quality_metrics_match_rate_header"): eval_hit_rate,
1073
1145
  }
1074
1146
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1075
- eval_y_sorted
1147
+ validated_eval_set[idx][1]
1076
1148
  ):
1077
1149
  eval_metrics[self.bundle.get("quality_metrics_mean_target_header")] = round(
1078
- np.mean(effective_eval_set[idx][1]), 4
1150
+ np.mean(validated_eval_set[idx][1]), 4
1079
1151
  )
1080
1152
  if etalon_eval_metric is not None:
1081
1153
  eval_metrics[self.bundle.get("quality_metrics_baseline_header").format(metric)] = (
@@ -1099,7 +1171,7 @@ class FeaturesEnricher(TransformerMixin):
1099
1171
  )
1100
1172
 
1101
1173
  uplift_col = self.bundle.get("quality_metrics_uplift_header")
1102
- date_column = self._get_date_column(search_keys)
1174
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1103
1175
  if (
1104
1176
  uplift_col in metrics_df.columns
1105
1177
  and (metrics_df[uplift_col] < 0).any()
@@ -1138,6 +1210,57 @@ class FeaturesEnricher(TransformerMixin):
1138
1210
  finally:
1139
1211
  self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
1140
1212
 
1213
+ def _update_shap_values(self, trace_id: str, x_columns: List[str], new_shaps: Dict[str, float]):
1214
+ new_shaps = {
1215
+ feature: _round_shap_value(shap) for feature, shap in new_shaps.items() if feature in self.feature_names_
1216
+ }
1217
+ self.__prepare_feature_importances(trace_id, x_columns, new_shaps, silent=True)
1218
+
1219
+ if self.features_info_display_handle is not None:
1220
+ try:
1221
+ _ = get_ipython() # type: ignore
1222
+
1223
+ display_html_dataframe(
1224
+ self.features_info,
1225
+ self._features_info_without_links,
1226
+ self.bundle.get("relevant_features_header"),
1227
+ display_handle=self.features_info_display_handle,
1228
+ )
1229
+ except (ImportError, NameError):
1230
+ pass
1231
+ if self.data_sources_display_handle is not None:
1232
+ try:
1233
+ _ = get_ipython() # type: ignore
1234
+
1235
+ display_html_dataframe(
1236
+ self.relevant_data_sources,
1237
+ self._relevant_data_sources_wo_links,
1238
+ self.bundle.get("relevant_data_sources_header"),
1239
+ display_handle=self.data_sources_display_handle,
1240
+ )
1241
+ except (ImportError, NameError):
1242
+ pass
1243
+ if self.autofe_features_display_handle is not None:
1244
+ try:
1245
+ _ = get_ipython() # type: ignore
1246
+ autofe_descriptions_df = self.get_autofe_features_description()
1247
+ if autofe_descriptions_df is not None:
1248
+ display_html_dataframe(
1249
+ df=autofe_descriptions_df,
1250
+ internal_df=autofe_descriptions_df,
1251
+ header=self.bundle.get("autofe_descriptions_header"),
1252
+ display_handle=self.autofe_features_display_handle,
1253
+ )
1254
+ except (ImportError, NameError):
1255
+ pass
1256
+ if self.report_button_handle is not None:
1257
+ try:
1258
+ _ = get_ipython() # type: ignore
1259
+
1260
+ self.__show_report_button(display_handle=self.report_button_handle)
1261
+ except (ImportError, NameError):
1262
+ pass
1263
+
1141
1264
  def _check_train_and_eval_target_distribution(self, y, eval_set_dict):
1142
1265
  uneven_distribution = False
1143
1266
  for eval_set in eval_set_dict.values():
@@ -1174,34 +1297,6 @@ class FeaturesEnricher(TransformerMixin):
1174
1297
  def _has_paid_features(self, exclude_features_sources: Optional[List[str]]) -> bool:
1175
1298
  return self._has_features_with_commercial_schema(CommercialSchema.PAID.value, exclude_features_sources)
1176
1299
 
1177
- def _extend_x(self, x: pd.DataFrame, is_demo_dataset: bool) -> Tuple[pd.DataFrame, Dict[str, SearchKey]]:
1178
- search_keys = self.search_keys.copy()
1179
- search_keys = self.__prepare_search_keys(x, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
1180
-
1181
- extended_X = x.copy()
1182
- generated_features = []
1183
- date_column = self._get_date_column(search_keys)
1184
- if date_column is not None:
1185
- converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
1186
- extended_X = converter.convert(extended_X, keep_time=True)
1187
- generated_features.extend(converter.generated_features)
1188
- email_column = self._get_email_column(search_keys)
1189
- hem_column = self._get_hem_column(search_keys)
1190
- if email_column:
1191
- converter = EmailSearchKeyConverter(email_column, hem_column, search_keys, self.logger)
1192
- extended_X = converter.convert(extended_X)
1193
- generated_features.extend(converter.generated_features)
1194
- if (
1195
- self.detect_missing_search_keys
1196
- and list(search_keys.values()) == [SearchKey.DATE]
1197
- and self.country_code is None
1198
- ):
1199
- converter = IpToCountrySearchKeyConverter(search_keys, self.logger)
1200
- extended_X = converter.convert(extended_X)
1201
- generated_features = [f for f in generated_features if f in self.fit_generated_features]
1202
-
1203
- return extended_X, search_keys
1204
-
1205
1300
  def _is_input_same_as_fit(
1206
1301
  self,
1207
1302
  X: Union[pd.DataFrame, pd.Series, np.ndarray, None] = None,
@@ -1245,7 +1340,7 @@ class FeaturesEnricher(TransformerMixin):
1245
1340
  groups = None
1246
1341
 
1247
1342
  if not isinstance(_cv, BaseCrossValidator):
1248
- date_column = self._get_date_column(search_keys)
1343
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1249
1344
  date_series = X[date_column] if date_column is not None else None
1250
1345
  _cv, groups = CVConfig(
1251
1346
  _cv, date_series, self.random_state, self._search_task.get_shuffle_kfold(), group_columns=group_columns
@@ -1268,7 +1363,7 @@ class FeaturesEnricher(TransformerMixin):
1268
1363
 
1269
1364
  def _get_client_cat_features(
1270
1365
  self, estimator: Optional[Any], X: pd.DataFrame, search_keys: Dict[str, SearchKey]
1271
- ) -> Optional[List[str]]:
1366
+ ) -> Tuple[Optional[List[str]], List[str]]:
1272
1367
  cat_features = None
1273
1368
  search_keys_for_metrics = []
1274
1369
  if (
@@ -1328,30 +1423,38 @@ class FeaturesEnricher(TransformerMixin):
1328
1423
  progress_bar,
1329
1424
  progress_callback,
1330
1425
  )
1331
- X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys = dataclasses.astuple(sampled_data)
1426
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming = dataclasses.astuple(
1427
+ sampled_data
1428
+ )
1332
1429
 
1333
1430
  excluding_search_keys = list(search_keys.keys())
1334
1431
  if search_keys_for_metrics is not None and len(search_keys_for_metrics) > 0:
1335
- excluding_search_keys = [sk for sk in excluding_search_keys if sk not in search_keys_for_metrics]
1336
- meta = self._search_task.get_all_features_metadata_v2()
1337
- zero_importance_client_features = [m for m in meta if m.source == "etalon" and m.shap_value == 0.0]
1432
+ for sk in excluding_search_keys:
1433
+ if columns_renaming.get(sk) in search_keys_for_metrics:
1434
+ excluding_search_keys.remove(sk)
1338
1435
 
1339
1436
  client_features = [
1340
1437
  c
1341
1438
  for c in X_sampled.columns.to_list()
1342
- if c
1439
+ if (
1440
+ not self.select_features
1441
+ or c in self.feature_names_
1442
+ or (self.fit_columns_renaming is not None and self.fit_columns_renaming.get(c) in self.feature_names_)
1443
+ )
1444
+ and c
1343
1445
  not in (
1344
1446
  excluding_search_keys
1345
1447
  + list(self.fit_dropped_features)
1346
- + [DateTimeSearchKeyConverter.DATETIME_COL, SYSTEM_RECORD_ID]
1347
- + zero_importance_client_features
1448
+ + [DateTimeSearchKeyConverter.DATETIME_COL, SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID]
1348
1449
  )
1349
1450
  ]
1451
+ self.logger.info(f"Client features column on prepare data for metrics: {client_features}")
1350
1452
 
1351
1453
  filtered_enriched_features = self.__filtered_enriched_features(
1352
1454
  importance_threshold,
1353
1455
  max_features,
1354
1456
  )
1457
+ filtered_enriched_features = [c for c in filtered_enriched_features if c not in client_features]
1355
1458
 
1356
1459
  X_sorted, y_sorted = self._sort_by_system_record_id(X_sampled, y_sampled, self.cv)
1357
1460
  enriched_X_sorted, enriched_y_sorted = self._sort_by_system_record_id(enriched_X, y_sampled, self.cv)
@@ -1381,9 +1484,12 @@ class FeaturesEnricher(TransformerMixin):
1381
1484
  fitting_X = fitting_X.drop(columns=constant_columns, errors="ignore")
1382
1485
  fitting_enriched_X = fitting_enriched_X.drop(columns=constant_columns, errors="ignore")
1383
1486
 
1487
+ # TODO maybe there is no more need for these convertions
1384
1488
  # Remove datetime features
1385
1489
  datetime_features = [
1386
- f for f in fitting_X.columns if is_datetime64_any_dtype(fitting_X[f]) or is_period_dtype(fitting_X[f])
1490
+ f
1491
+ for f in fitting_X.columns
1492
+ if is_datetime64_any_dtype(fitting_X[f]) or isinstance(fitting_X[f].dtype, pd.PeriodDtype)
1387
1493
  ]
1388
1494
  if len(datetime_features) > 0:
1389
1495
  self.logger.warning(self.bundle.get("dataset_date_features").format(datetime_features))
@@ -1403,37 +1509,25 @@ class FeaturesEnricher(TransformerMixin):
1403
1509
  if len(decimal_columns_to_fix) > 0:
1404
1510
  self.logger.warning(f"Convert strings with decimal comma to float: {decimal_columns_to_fix}")
1405
1511
  for col in decimal_columns_to_fix:
1406
- fitting_X[col] = fitting_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1512
+ fitting_X[col] = fitting_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1407
1513
  fitting_enriched_X[col] = (
1408
- fitting_enriched_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1514
+ fitting_enriched_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1409
1515
  )
1410
1516
 
1411
- fitting_eval_set_dict = dict()
1517
+ fitting_eval_set_dict = {}
1518
+ fitting_x_columns = fitting_X.columns.to_list()
1519
+ self.logger.info(f"Final list of fitting X columns: {fitting_x_columns}")
1520
+ fitting_enriched_x_columns = fitting_enriched_X.columns.to_list()
1521
+ self.logger.info(f"Final list of fitting enriched X columns: {fitting_enriched_x_columns}")
1412
1522
  for idx, eval_tuple in eval_set_sampled_dict.items():
1413
1523
  eval_X_sampled, enriched_eval_X, eval_y_sampled = eval_tuple
1414
1524
  eval_X_sorted, eval_y_sorted = self._sort_by_system_record_id(eval_X_sampled, eval_y_sampled, self.cv)
1415
1525
  enriched_eval_X_sorted, enriched_eval_y_sorted = self._sort_by_system_record_id(
1416
1526
  enriched_eval_X, eval_y_sampled, self.cv
1417
1527
  )
1418
- fitting_eval_X = eval_X_sorted[client_features].copy()
1419
- fitting_enriched_eval_X = enriched_eval_X_sorted[
1420
- client_features + existing_filtered_enriched_features
1421
- ].copy()
1422
-
1423
- # # Drop high cardinality features in eval set
1424
- if len(columns_with_high_cardinality) > 0:
1425
- fitting_eval_X = fitting_eval_X.drop(columns=columns_with_high_cardinality, errors="ignore")
1426
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(
1427
- columns=columns_with_high_cardinality, errors="ignore"
1428
- )
1429
- # Drop constant features in eval_set
1430
- if len(constant_columns) > 0:
1431
- fitting_eval_X = fitting_eval_X.drop(columns=constant_columns, errors="ignore")
1432
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(columns=constant_columns, errors="ignore")
1433
- # Drop datetime features in eval_set
1434
- if len(datetime_features) > 0:
1435
- fitting_eval_X = fitting_eval_X.drop(columns=datetime_features, errors="ignore")
1436
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(columns=datetime_features, errors="ignore")
1528
+ fitting_eval_X = eval_X_sorted[fitting_x_columns].copy()
1529
+ fitting_enriched_eval_X = enriched_eval_X_sorted[fitting_enriched_x_columns].copy()
1530
+
1437
1531
  # Convert bool to string in eval_set
1438
1532
  if len(bool_columns) > 0:
1439
1533
  fitting_eval_X[col] = fitting_eval_X[col].astype(str)
@@ -1441,9 +1535,14 @@ class FeaturesEnricher(TransformerMixin):
1441
1535
  # Correct string features with decimal commas
1442
1536
  if len(decimal_columns_to_fix) > 0:
1443
1537
  for col in decimal_columns_to_fix:
1444
- fitting_eval_X[col] = fitting_eval_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1538
+ fitting_eval_X[col] = (
1539
+ fitting_eval_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1540
+ )
1445
1541
  fitting_enriched_eval_X[col] = (
1446
- fitting_enriched_eval_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1542
+ fitting_enriched_eval_X[col]
1543
+ .astype("string")
1544
+ .str.replace(",", ".", regex=False)
1545
+ .astype(np.float64)
1447
1546
  )
1448
1547
 
1449
1548
  fitting_eval_set_dict[idx] = (
@@ -1463,6 +1562,7 @@ class FeaturesEnricher(TransformerMixin):
1463
1562
  search_keys,
1464
1563
  groups,
1465
1564
  cv,
1565
+ columns_renaming,
1466
1566
  )
1467
1567
 
1468
1568
  @dataclass
@@ -1472,6 +1572,7 @@ class FeaturesEnricher(TransformerMixin):
1472
1572
  enriched_X: pd.DataFrame
1473
1573
  eval_set_sampled_dict: Dict[int, Tuple[pd.DataFrame, pd.Series]]
1474
1574
  search_keys: Dict[str, SearchKey]
1575
+ columns_renaming: Dict[str, str]
1475
1576
 
1476
1577
  def _sample_data_for_metrics(
1477
1578
  self,
@@ -1486,18 +1587,28 @@ class FeaturesEnricher(TransformerMixin):
1486
1587
  progress_bar: Optional[ProgressBar],
1487
1588
  progress_callback: Optional[Callable[[SearchProgress], Any]],
1488
1589
  ) -> _SampledDataForMetrics:
1489
- if self.__cached_sampled_datasets is not None and is_input_same_as_fit and remove_outliers_calc_metrics is None:
1590
+ datasets_hash = hash_input(validated_X, validated_y, eval_set)
1591
+ cached_sampled_datasets = self.__cached_sampled_datasets.get(datasets_hash)
1592
+ if cached_sampled_datasets is not None and is_input_same_as_fit and remove_outliers_calc_metrics is None:
1490
1593
  self.logger.info("Cached enriched dataset found - use it")
1491
- return self.__get_sampled_cached_enriched(exclude_features_sources)
1594
+ return self.__get_sampled_cached_enriched(datasets_hash, exclude_features_sources)
1492
1595
  elif len(self.feature_importances_) == 0:
1493
1596
  self.logger.info("No external features selected. So use only input datasets for metrics calculation")
1494
1597
  return self.__sample_only_input(validated_X, validated_y, eval_set, is_demo_dataset)
1495
1598
  # TODO save and check if dataset was deduplicated - use imbalance branch for such case
1496
- elif not self.imbalanced and not exclude_features_sources and is_input_same_as_fit:
1599
+ elif (
1600
+ not self.imbalanced
1601
+ and not exclude_features_sources
1602
+ and is_input_same_as_fit
1603
+ and self.df_with_original_index is not None
1604
+ ):
1497
1605
  self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
1498
1606
  return self.__sample_balanced(eval_set, trace_id, remove_outliers_calc_metrics)
1499
1607
  else:
1500
- self.logger.info("Dataset is imbalanced or exclude_features_sources or X was passed. Run transform")
1608
+ self.logger.info(
1609
+ "Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
1610
+ " Run transform"
1611
+ )
1501
1612
  print(self.bundle.get("prepare_data_for_metrics"))
1502
1613
  return self.__sample_imbalanced(
1503
1614
  validated_X,
@@ -1510,17 +1621,23 @@ class FeaturesEnricher(TransformerMixin):
1510
1621
  progress_callback,
1511
1622
  )
1512
1623
 
1513
- def __get_sampled_cached_enriched(self, exclude_features_sources: Optional[List[str]]) -> _SampledDataForMetrics:
1514
- X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys = self.__cached_sampled_datasets
1624
+ def __get_sampled_cached_enriched(
1625
+ self, datasets_hash: str, exclude_features_sources: Optional[List[str]]
1626
+ ) -> _SampledDataForMetrics:
1627
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming = (
1628
+ self.__cached_sampled_datasets[datasets_hash]
1629
+ )
1515
1630
  if exclude_features_sources:
1516
1631
  enriched_X = enriched_X.drop(columns=exclude_features_sources, errors="ignore")
1517
1632
 
1518
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1633
+ return self.__mk_sampled_data_tuple(
1634
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming
1635
+ )
1519
1636
 
1520
1637
  def __sample_only_input(
1521
1638
  self, validated_X: pd.DataFrame, validated_y: pd.Series, eval_set: Optional[List[tuple]], is_demo_dataset: bool
1522
1639
  ) -> _SampledDataForMetrics:
1523
- eval_set_sampled_dict = dict()
1640
+ eval_set_sampled_dict = {}
1524
1641
 
1525
1642
  df = validated_X.copy()
1526
1643
  df[TARGET] = validated_y
@@ -1533,7 +1650,31 @@ class FeaturesEnricher(TransformerMixin):
1533
1650
  eval_xy[EVAL_SET_INDEX] = idx + 1
1534
1651
  df = pd.concat([df, eval_xy])
1535
1652
 
1536
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1653
+ search_keys = self.search_keys.copy()
1654
+ search_keys = self.__prepare_search_keys(df, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
1655
+
1656
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1657
+ generated_features = []
1658
+ if date_column is not None:
1659
+ converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
1660
+ # Leave original date column values
1661
+ df_with_date_features = converter.convert(df, keep_time=True)
1662
+ df_with_date_features[date_column] = df[date_column]
1663
+ df = df_with_date_features
1664
+ generated_features = converter.generated_features
1665
+
1666
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
1667
+ if email_columns:
1668
+ generator = EmailDomainGenerator(email_columns)
1669
+ df = generator.generate(df)
1670
+ generated_features.extend(generator.generated_features)
1671
+
1672
+ # normalizer = Normalizer(self.bundle, self.logger)
1673
+ # df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
1674
+ # columns_renaming = normalizer.columns_renaming
1675
+ columns_renaming = {c: c for c in df.columns}
1676
+
1677
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1537
1678
 
1538
1679
  num_samples = _num_samples(df)
1539
1680
  sample_threshold, sample_rows = (
@@ -1545,24 +1686,36 @@ class FeaturesEnricher(TransformerMixin):
1545
1686
  self.logger.info(f"Downsampling from {num_samples} to {sample_rows}")
1546
1687
  df = df.sample(n=sample_rows, random_state=self.random_state)
1547
1688
 
1548
- df_extended, search_keys = self._extend_x(df, is_demo_dataset)
1549
- df_extended = self.__add_fit_system_record_id(df_extended, dict(), search_keys)
1689
+ df = self.__add_fit_system_record_id(df, search_keys, SYSTEM_RECORD_ID)
1690
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
1691
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
1550
1692
 
1551
- train_df = df_extended.query(f"{EVAL_SET_INDEX} == 0") if eval_set is not None else df_extended
1693
+ train_df = df.query(f"{EVAL_SET_INDEX} == 0") if eval_set is not None else df
1552
1694
  X_sampled = train_df.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1553
1695
  y_sampled = train_df[TARGET].copy()
1554
1696
  enriched_X = X_sampled
1555
1697
 
1556
1698
  if eval_set is not None:
1557
1699
  for idx in range(len(eval_set)):
1558
- eval_xy_sampled = df_extended.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1700
+ eval_xy_sampled = df.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1559
1701
  eval_X_sampled = eval_xy_sampled.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1560
1702
  eval_y_sampled = eval_xy_sampled[TARGET].copy()
1561
1703
  enriched_eval_X = eval_X_sampled
1562
1704
  eval_set_sampled_dict[idx] = (eval_X_sampled, enriched_eval_X, eval_y_sampled)
1563
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1564
1705
 
1565
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1706
+ datasets_hash = hash_input(X_sampled, y_sampled, eval_set_sampled_dict)
1707
+ self.__cached_sampled_datasets[datasets_hash] = (
1708
+ X_sampled,
1709
+ y_sampled,
1710
+ enriched_X,
1711
+ eval_set_sampled_dict,
1712
+ search_keys,
1713
+ columns_renaming,
1714
+ )
1715
+
1716
+ return self.__mk_sampled_data_tuple(
1717
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming
1718
+ )
1566
1719
 
1567
1720
  def __sample_balanced(
1568
1721
  self,
@@ -1570,22 +1723,21 @@ class FeaturesEnricher(TransformerMixin):
1570
1723
  trace_id: str,
1571
1724
  remove_outliers_calc_metrics: Optional[bool],
1572
1725
  ) -> _SampledDataForMetrics:
1573
- eval_set_sampled_dict = dict()
1726
+ eval_set_sampled_dict = {}
1574
1727
  search_keys = self.fit_search_keys
1575
1728
 
1576
1729
  rows_to_drop = None
1577
- has_date = self._get_date_column(search_keys) is not None
1578
- task_type = self.model_task_type or define_task(
1730
+ has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
1731
+ self.model_task_type = self.model_task_type or define_task(
1579
1732
  self.df_with_original_index[TARGET], has_date, self.logger, silent=True
1580
1733
  )
1581
- if task_type == ModelTaskType.REGRESSION:
1734
+ if self.model_task_type == ModelTaskType.REGRESSION:
1582
1735
  target_outliers_df = self._search_task.get_target_outliers(trace_id)
1583
1736
  if target_outliers_df is not None and len(target_outliers_df) > 0:
1584
1737
  outliers = pd.merge(
1585
1738
  self.df_with_original_index,
1586
1739
  target_outliers_df,
1587
- left_on=SYSTEM_RECORD_ID,
1588
- right_on=SYSTEM_RECORD_ID,
1740
+ on=ENTITY_SYSTEM_RECORD_ID,
1589
1741
  how="inner",
1590
1742
  )
1591
1743
  top_outliers = outliers.sort_values(by=TARGET, ascending=False)[TARGET].head(3)
@@ -1612,6 +1764,7 @@ class FeaturesEnricher(TransformerMixin):
1612
1764
  X_sampled = enriched_Xy[x_columns].copy()
1613
1765
  y_sampled = enriched_Xy[TARGET].copy()
1614
1766
  enriched_X = enriched_Xy.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1767
+ enriched_X_columns = enriched_X.columns.to_list()
1615
1768
 
1616
1769
  self.logger.info(f"Shape of enriched_X: {enriched_X.shape}")
1617
1770
  self.logger.info(f"Shape of X after sampling: {X_sampled.shape}")
@@ -1626,12 +1779,22 @@ class FeaturesEnricher(TransformerMixin):
1626
1779
  for idx in range(len(eval_set)):
1627
1780
  eval_X_sampled = enriched_eval_sets[idx + 1][x_columns].copy()
1628
1781
  eval_y_sampled = enriched_eval_sets[idx + 1][TARGET].copy()
1629
- enriched_eval_X = enriched_eval_sets[idx + 1].drop(columns=[TARGET, EVAL_SET_INDEX])
1782
+ enriched_eval_X = enriched_eval_sets[idx + 1][enriched_X_columns].copy()
1630
1783
  eval_set_sampled_dict[idx] = (eval_X_sampled, enriched_eval_X, eval_y_sampled)
1631
1784
 
1632
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1785
+ datasets_hash = hash_input(self.X, self.y, self.eval_set)
1786
+ self.__cached_sampled_datasets[datasets_hash] = (
1787
+ X_sampled,
1788
+ y_sampled,
1789
+ enriched_X,
1790
+ eval_set_sampled_dict,
1791
+ search_keys,
1792
+ self.fit_columns_renaming,
1793
+ )
1633
1794
 
1634
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1795
+ return self.__mk_sampled_data_tuple(
1796
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, self.fit_columns_renaming
1797
+ )
1635
1798
 
1636
1799
  def __sample_imbalanced(
1637
1800
  self,
@@ -1644,7 +1807,7 @@ class FeaturesEnricher(TransformerMixin):
1644
1807
  progress_bar: Optional[ProgressBar],
1645
1808
  progress_callback: Optional[Callable[[SearchProgress], Any]],
1646
1809
  ) -> _SampledDataForMetrics:
1647
- eval_set_sampled_dict = dict()
1810
+ eval_set_sampled_dict = {}
1648
1811
  if eval_set is not None:
1649
1812
  self.logger.info("Transform with eval_set")
1650
1813
  # concatenate X and eval_set with eval_set_index
@@ -1658,7 +1821,7 @@ class FeaturesEnricher(TransformerMixin):
1658
1821
  eval_df_with_index[EVAL_SET_INDEX] = idx + 1
1659
1822
  df = pd.concat([df, eval_df_with_index])
1660
1823
 
1661
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1824
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1662
1825
 
1663
1826
  # downsample if need to eval_set threshold
1664
1827
  num_samples = _num_samples(df)
@@ -1666,12 +1829,12 @@ class FeaturesEnricher(TransformerMixin):
1666
1829
  self.logger.info(f"Downsampling from {num_samples} to {Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS}")
1667
1830
  df = df.sample(n=Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS, random_state=self.random_state)
1668
1831
 
1669
- eval_set_sampled_dict = dict()
1832
+ eval_set_sampled_dict = {}
1670
1833
 
1671
1834
  tmp_target_name = "__target"
1672
1835
  df = df.rename(columns={TARGET: tmp_target_name})
1673
1836
 
1674
- enriched_df = self.__inner_transform(
1837
+ enriched_df, columns_renaming, generated_features = self.__inner_transform(
1675
1838
  trace_id,
1676
1839
  df,
1677
1840
  exclude_features_sources=exclude_features_sources,
@@ -1688,7 +1851,7 @@ class FeaturesEnricher(TransformerMixin):
1688
1851
 
1689
1852
  x_columns = [
1690
1853
  c
1691
- for c in (validated_X.columns.tolist() + self.fit_generated_features + [SYSTEM_RECORD_ID])
1854
+ for c in (validated_X.columns.tolist() + generated_features + [SYSTEM_RECORD_ID])
1692
1855
  if c in enriched_df.columns
1693
1856
  ]
1694
1857
 
@@ -1696,12 +1859,13 @@ class FeaturesEnricher(TransformerMixin):
1696
1859
  X_sampled = enriched_Xy[x_columns].copy()
1697
1860
  y_sampled = enriched_Xy[TARGET].copy()
1698
1861
  enriched_X = enriched_Xy.drop(columns=[TARGET, EVAL_SET_INDEX])
1862
+ enriched_X_columns = enriched_X.columns.tolist()
1699
1863
 
1700
1864
  for idx in range(len(eval_set)):
1701
1865
  enriched_eval_xy = enriched_df.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1702
1866
  eval_x_sampled = enriched_eval_xy[x_columns].copy()
1703
1867
  eval_y_sampled = enriched_eval_xy[TARGET].copy()
1704
- enriched_eval_x = enriched_eval_xy.drop(columns=[TARGET, EVAL_SET_INDEX])
1868
+ enriched_eval_x = enriched_eval_xy[enriched_X_columns].copy()
1705
1869
  eval_set_sampled_dict[idx] = (eval_x_sampled, enriched_eval_x, eval_y_sampled)
1706
1870
  else:
1707
1871
  self.logger.info("Transform without eval_set")
@@ -1709,7 +1873,7 @@ class FeaturesEnricher(TransformerMixin):
1709
1873
 
1710
1874
  df[TARGET] = validated_y
1711
1875
 
1712
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1876
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1713
1877
 
1714
1878
  num_samples = _num_samples(df)
1715
1879
  if num_samples > Dataset.FIT_SAMPLE_THRESHOLD:
@@ -1719,7 +1883,7 @@ class FeaturesEnricher(TransformerMixin):
1719
1883
  tmp_target_name = "__target"
1720
1884
  df = df.rename(columns={TARGET: tmp_target_name})
1721
1885
 
1722
- enriched_Xy = self.__inner_transform(
1886
+ enriched_Xy, columns_renaming, generated_features = self.__inner_transform(
1723
1887
  trace_id,
1724
1888
  df,
1725
1889
  exclude_features_sources=exclude_features_sources,
@@ -1736,7 +1900,7 @@ class FeaturesEnricher(TransformerMixin):
1736
1900
 
1737
1901
  x_columns = [
1738
1902
  c
1739
- for c in (validated_X.columns.tolist() + self.fit_generated_features + [SYSTEM_RECORD_ID])
1903
+ for c in (validated_X.columns.tolist() + generated_features + [SYSTEM_RECORD_ID])
1740
1904
  if c in enriched_Xy.columns
1741
1905
  ]
1742
1906
 
@@ -1744,9 +1908,19 @@ class FeaturesEnricher(TransformerMixin):
1744
1908
  y_sampled = enriched_Xy[TARGET].copy()
1745
1909
  enriched_X = enriched_Xy.drop(columns=TARGET)
1746
1910
 
1747
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys)
1911
+ datasets_hash = hash_input(validated_X, validated_y, eval_set)
1912
+ self.__cached_sampled_datasets[datasets_hash] = (
1913
+ X_sampled,
1914
+ y_sampled,
1915
+ enriched_X,
1916
+ eval_set_sampled_dict,
1917
+ self.search_keys,
1918
+ columns_renaming,
1919
+ )
1748
1920
 
1749
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys)
1921
+ return self.__mk_sampled_data_tuple(
1922
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys, columns_renaming
1923
+ )
1750
1924
 
1751
1925
  def __mk_sampled_data_tuple(
1752
1926
  self,
@@ -1755,6 +1929,7 @@ class FeaturesEnricher(TransformerMixin):
1755
1929
  enriched_X: pd.DataFrame,
1756
1930
  eval_set_sampled_dict: Dict,
1757
1931
  search_keys: Dict,
1932
+ columns_renaming: Dict[str, str],
1758
1933
  ):
1759
1934
  search_keys = {k: v for k, v in search_keys.items() if k in X_sampled.columns.to_list()}
1760
1935
  return FeaturesEnricher._SampledDataForMetrics(
@@ -1763,6 +1938,7 @@ class FeaturesEnricher(TransformerMixin):
1763
1938
  enriched_X=enriched_X,
1764
1939
  eval_set_sampled_dict=eval_set_sampled_dict,
1765
1940
  search_keys=search_keys,
1941
+ columns_renaming=columns_renaming,
1766
1942
  )
1767
1943
 
1768
1944
  def get_search_id(self) -> Optional[str]:
@@ -1812,9 +1988,19 @@ class FeaturesEnricher(TransformerMixin):
1812
1988
  file_metadata = self._search_task.get_file_metadata(str(uuid.uuid4()))
1813
1989
  search_keys = file_metadata.search_types()
1814
1990
  if SearchKey.IPV6_ADDRESS in search_keys:
1815
- search_keys.remove(SearchKey.IPV6_ADDRESS)
1991
+ # search_keys.remove(SearchKey.IPV6_ADDRESS)
1992
+ search_keys.pop(SearchKey.IPV6_ADDRESS, None)
1816
1993
 
1817
- keys = "{" + ", ".join([f'"{key.name}": "{key_example(key)}"' for key in search_keys]) + "}"
1994
+ keys = (
1995
+ "{"
1996
+ + ", ".join(
1997
+ [
1998
+ f'"{key.name}": {{"name": "{name}", "value": "{key_example(key)}"}}'
1999
+ for key, name in search_keys.items()
2000
+ ]
2001
+ )
2002
+ + "}"
2003
+ )
1818
2004
  features_for_transform = self._search_task.get_features_for_transform()
1819
2005
  if features_for_transform:
1820
2006
  original_features_for_transform = [
@@ -1851,37 +2037,41 @@ class FeaturesEnricher(TransformerMixin):
1851
2037
  progress_bar: Optional[ProgressBar] = None,
1852
2038
  progress_callback: Optional[Callable[[SearchProgress], Any]] = None,
1853
2039
  add_fit_system_record_id: bool = False,
1854
- ) -> pd.DataFrame:
2040
+ ) -> Tuple[pd.DataFrame, Dict[str, str], List[str]]:
1855
2041
  if self._search_task is None:
1856
2042
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
1857
2043
 
1858
2044
  start_time = time.time()
1859
2045
  with MDC(trace_id=trace_id):
1860
2046
  self.logger.info("Start transform")
1861
- self.__log_debug_information(X, exclude_features_sources=exclude_features_sources)
2047
+
2048
+ validated_X = self._validate_X(X, is_transform=True)
2049
+
2050
+ self.__log_debug_information(validated_X, exclude_features_sources=exclude_features_sources)
1862
2051
 
1863
2052
  self.__validate_search_keys(self.search_keys, self.search_id)
1864
2053
 
1865
2054
  if len(self.feature_names_) == 0:
1866
2055
  self.logger.warning(self.bundle.get("no_important_features_for_transform"))
1867
- return X
2056
+ return X, {c: c for c in X.columns}, []
1868
2057
 
1869
2058
  if self._has_paid_features(exclude_features_sources):
1870
2059
  msg = self.bundle.get("transform_with_paid_features")
1871
2060
  self.logger.warning(msg)
1872
2061
  self.__display_support_link(msg)
1873
- return None
2062
+ return None, {c: c for c in X.columns}, []
1874
2063
 
1875
2064
  if not metrics_calculation:
1876
2065
  transform_usage = self.rest_client.get_current_transform_usage(trace_id)
1877
2066
  self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
1878
2067
  if transform_usage.has_limit:
1879
2068
  if len(X) > transform_usage.rest_rows:
1880
- msg = self.bundle.get("transform_usage_warning").format(len(X), transform_usage.rest_rows)
2069
+ rest_rows = max(transform_usage.rest_rows, 0)
2070
+ msg = self.bundle.get("transform_usage_warning").format(len(X), rest_rows)
1881
2071
  self.logger.warning(msg)
1882
2072
  print(msg)
1883
2073
  show_request_quote_button()
1884
- return None
2074
+ return None, {c: c for c in X.columns}, []
1885
2075
  else:
1886
2076
  msg = self.bundle.get("transform_usage_info").format(
1887
2077
  transform_usage.limit, transform_usage.transformed_rows
@@ -1889,11 +2079,11 @@ class FeaturesEnricher(TransformerMixin):
1889
2079
  self.logger.info(msg)
1890
2080
  print(msg)
1891
2081
 
1892
- validated_X = self._validate_X(X, is_transform=True)
1893
-
1894
2082
  is_demo_dataset = hash_input(validated_X) in DEMO_DATASET_HASHES
1895
2083
 
1896
- columns_to_drop = [c for c in validated_X.columns if c in self.feature_names_]
2084
+ columns_to_drop = [
2085
+ c for c in validated_X.columns if c in self.feature_names_ and c in self.dropped_client_feature_names_
2086
+ ]
1897
2087
  if len(columns_to_drop) > 0:
1898
2088
  msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
1899
2089
  self.logger.warning(msg)
@@ -1919,79 +2109,135 @@ class FeaturesEnricher(TransformerMixin):
1919
2109
  df = self.__add_country_code(df, search_keys)
1920
2110
 
1921
2111
  generated_features = []
1922
- date_column = self._get_date_column(search_keys)
2112
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1923
2113
  if date_column is not None:
1924
2114
  converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
1925
- df = converter.convert(df)
2115
+ df = converter.convert(df, keep_time=True)
1926
2116
  self.logger.info(f"Date column after convertion: {df[date_column]}")
1927
2117
  generated_features.extend(converter.generated_features)
1928
2118
  else:
1929
2119
  self.logger.info("Input dataset hasn't date column")
1930
2120
  if self.add_date_if_missing:
1931
2121
  df = self._add_current_date_as_key(df, search_keys, self.logger, self.bundle)
2122
+
2123
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2124
+ if email_columns:
2125
+ generator = EmailDomainGenerator(email_columns)
2126
+ df = generator.generate(df)
2127
+ generated_features.extend(generator.generated_features)
2128
+
2129
+ normalizer = Normalizer(self.bundle, self.logger)
2130
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2131
+ columns_renaming = normalizer.columns_renaming
2132
+
2133
+ # Don't pass all features in backend on transform
2134
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2135
+ features_for_transform = self._search_task.get_features_for_transform() or []
2136
+ if len(features_for_transform) > 0:
2137
+ missing_features_for_transform = [
2138
+ columns_renaming.get(f) for f in features_for_transform if f not in df.columns
2139
+ ]
2140
+ if len(missing_features_for_transform) > 0:
2141
+ raise ValidationError(
2142
+ self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2143
+ )
2144
+ runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_transform)
2145
+
2146
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2147
+
2148
+ df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(
2149
+ df[columns_for_system_record_id], index=False
2150
+ ).astype("float64")
2151
+
2152
+ # Explode multiple search keys
2153
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2154
+
1932
2155
  email_column = self._get_email_column(search_keys)
1933
2156
  hem_column = self._get_hem_column(search_keys)
1934
- email_converted_to_hem = False
1935
2157
  if email_column:
1936
- converter = EmailSearchKeyConverter(email_column, hem_column, search_keys, self.logger)
2158
+ converter = EmailSearchKeyConverter(
2159
+ email_column,
2160
+ hem_column,
2161
+ search_keys,
2162
+ columns_renaming,
2163
+ list(unnest_search_keys.keys()),
2164
+ self.logger,
2165
+ )
1937
2166
  df = converter.convert(df)
1938
- generated_features.extend(converter.generated_features)
1939
- email_converted_to_hem = converter.email_converted_to_hem
1940
- if (
1941
- self.detect_missing_search_keys
1942
- and list(search_keys.values()) == [SearchKey.DATE]
1943
- and self.country_code is None
1944
- ):
1945
- converter = IpToCountrySearchKeyConverter(search_keys, self.logger)
2167
+
2168
+ ip_column = self._get_ip_column(search_keys)
2169
+ if ip_column:
2170
+ converter = IpSearchKeyConverter(
2171
+ ip_column,
2172
+ search_keys,
2173
+ columns_renaming,
2174
+ list(unnest_search_keys.keys()),
2175
+ self.bundle,
2176
+ self.logger,
2177
+ )
1946
2178
  df = converter.convert(df)
1947
- generated_features = [f for f in generated_features if f in self.fit_generated_features]
1948
2179
 
1949
- meaning_types = {col: key.value for col, key in search_keys.items()}
1950
- non_keys_columns = [column for column in df.columns if column not in search_keys.keys()]
2180
+ phone_column = self._get_phone_column(search_keys)
2181
+ country_column = self._get_country_column(search_keys)
2182
+ if phone_column:
2183
+ converter = PhoneSearchKeyConverter(phone_column, country_column)
2184
+ df = converter.convert(df)
2185
+
2186
+ if country_column:
2187
+ converter = CountrySearchKeyConverter(country_column)
2188
+ df = converter.convert(df)
1951
2189
 
1952
- if email_converted_to_hem:
1953
- non_keys_columns.append(email_column)
2190
+ postal_code = self._get_postal_column(search_keys)
2191
+ if postal_code:
2192
+ converter = PostalCodeSearchKeyConverter(postal_code)
2193
+ df = converter.convert(df)
1954
2194
 
1955
- # Don't pass features in backend on transform
1956
- original_features_for_transform = None
1957
- runtime_parameters = self._get_copy_of_runtime_parameters()
1958
- if len(non_keys_columns) > 0:
1959
- # Pass only features that need for transform
1960
- features_for_transform = self._search_task.get_features_for_transform()
1961
- if features_for_transform is not None and len(features_for_transform) > 0:
1962
- file_metadata = self._search_task.get_file_metadata(trace_id)
1963
- original_features_for_transform = [
1964
- c.originalName or c.name for c in file_metadata.columns if c.name in features_for_transform
1965
- ]
1966
- non_keys_columns = [c for c in non_keys_columns if c not in original_features_for_transform]
2195
+ # generated_features = [f for f in generated_features if f in self.fit_generated_features]
1967
2196
 
1968
- runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_transform)
2197
+ meaning_types = {col: key.value for col, key in search_keys.items()}
2198
+ for col in features_for_transform:
2199
+ meaning_types[col] = FileColumnMeaningType.FEATURE
2200
+ features_not_to_pass = [
2201
+ c
2202
+ for c in df.columns
2203
+ if c not in search_keys.keys()
2204
+ and c not in features_for_transform
2205
+ and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2206
+ ]
1969
2207
 
1970
2208
  if add_fit_system_record_id:
1971
- df = self.__add_fit_system_record_id(df, dict(), search_keys)
2209
+ df = self.__add_fit_system_record_id(df, search_keys, SYSTEM_RECORD_ID)
1972
2210
  df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
1973
- non_keys_columns.append(SORT_ID)
2211
+ features_not_to_pass.append(SORT_ID)
1974
2212
 
1975
- columns_for_system_record_id = sorted(list(search_keys.keys()) + (original_features_for_transform or []))
2213
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2214
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
1976
2215
 
2216
+ # search keys might be changed after explode
2217
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
1977
2218
  df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
1978
- "Float64"
2219
+ "float64"
1979
2220
  )
1980
2221
  meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2222
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2223
+ if SEARCH_KEY_UNNEST in df.columns:
2224
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
1981
2225
 
1982
2226
  df = df.reset_index(drop=True)
1983
- system_columns_with_original_index = [SYSTEM_RECORD_ID] + generated_features
2227
+ system_columns_with_original_index = [SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID] + generated_features
1984
2228
  if add_fit_system_record_id:
1985
2229
  system_columns_with_original_index.append(SORT_ID)
1986
2230
  df_with_original_index = df[system_columns_with_original_index].copy()
1987
2231
 
1988
2232
  combined_search_keys = combine_search_keys(search_keys.keys())
1989
2233
 
1990
- df_without_features = df.drop(columns=non_keys_columns)
2234
+ df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
1991
2235
 
1992
- df_without_features = clean_full_duplicates(
1993
- df_without_features, self.logger, silent=silent_mode, bundle=self.bundle
2236
+ df_without_features, full_duplicates_warning = clean_full_duplicates(
2237
+ df_without_features, self.logger, bundle=self.bundle
1994
2238
  )
2239
+ if not silent_mode and full_duplicates_warning:
2240
+ self.__log_warning(full_duplicates_warning)
1995
2241
 
1996
2242
  del df
1997
2243
  gc.collect()
@@ -1999,14 +2245,14 @@ class FeaturesEnricher(TransformerMixin):
1999
2245
  dataset = Dataset(
2000
2246
  "sample_" + str(uuid.uuid4()),
2001
2247
  df=df_without_features,
2248
+ meaning_types=meaning_types,
2249
+ search_keys=combined_search_keys,
2250
+ unnest_search_keys=unnest_search_keys,
2002
2251
  date_format=self.date_format,
2003
2252
  rest_client=self.rest_client,
2004
2253
  logger=self.logger,
2005
2254
  )
2006
- dataset.meaning_types = meaning_types
2007
- dataset.search_keys = combined_search_keys
2008
- if email_converted_to_hem:
2009
- dataset.ignore_columns = [email_column]
2255
+ dataset.columns_renaming = columns_renaming
2010
2256
 
2011
2257
  if max_features is not None or importance_threshold is not None:
2012
2258
  exclude_features_sources = list(
@@ -2094,9 +2340,15 @@ class FeaturesEnricher(TransformerMixin):
2094
2340
  else:
2095
2341
  result = enrich()
2096
2342
 
2343
+ selecting_columns = [
2344
+ c
2345
+ for c in itertools.chain(validated_X.columns.tolist(), generated_features)
2346
+ if c not in self.dropped_client_feature_names_
2347
+ ]
2097
2348
  filtered_columns = self.__filtered_enriched_features(importance_threshold, max_features)
2098
- existing_filtered_columns = [c for c in filtered_columns if c in result.columns]
2099
- selecting_columns = validated_X.columns.tolist() + generated_features + existing_filtered_columns
2349
+ selecting_columns.extend(
2350
+ c for c in filtered_columns if c in result.columns and c not in validated_X.columns
2351
+ )
2100
2352
  if add_fit_system_record_id:
2101
2353
  selecting_columns.append(SORT_ID)
2102
2354
 
@@ -2108,7 +2360,7 @@ class FeaturesEnricher(TransformerMixin):
2108
2360
  if add_fit_system_record_id:
2109
2361
  result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2110
2362
 
2111
- return result
2363
+ return result, columns_renaming, generated_features
2112
2364
 
2113
2365
  def _get_excluded_features(self, max_features: Optional[int], importance_threshold: Optional[float]) -> List[str]:
2114
2366
  features_info = self._internal_features_info
@@ -2132,7 +2384,7 @@ class FeaturesEnricher(TransformerMixin):
2132
2384
  ]
2133
2385
  return excluded_features[feature_name_header].values.tolist()
2134
2386
 
2135
- def __validate_search_keys(self, search_keys: Dict[str, SearchKey], search_id: Optional[str]):
2387
+ def __validate_search_keys(self, search_keys: Dict[str, SearchKey], search_id: Optional[str] = None):
2136
2388
  if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
2137
2389
  if search_id:
2138
2390
  self.logger.debug(f"search_id {search_id} provided without search_keys")
@@ -2143,6 +2395,14 @@ class FeaturesEnricher(TransformerMixin):
2143
2395
 
2144
2396
  key_types = search_keys.values()
2145
2397
 
2398
+ # Multiple search keys allowed only for PHONE, IP, POSTAL_CODE, EMAIL, HEM
2399
+ multi_keys = [key for key, count in Counter(key_types).items() if count > 1]
2400
+ for multi_key in multi_keys:
2401
+ if multi_key not in [SearchKey.PHONE, SearchKey.IP, SearchKey.POSTAL_CODE, SearchKey.EMAIL, SearchKey.HEM]:
2402
+ msg = self.bundle.get("unsupported_multi_key").format(multi_key)
2403
+ self.logger.warning(msg)
2404
+ raise ValidationError(msg)
2405
+
2146
2406
  if SearchKey.DATE in key_types and SearchKey.DATETIME in key_types:
2147
2407
  msg = self.bundle.get("date_and_datetime_simultanious")
2148
2408
  self.logger.warning(msg)
@@ -2158,11 +2418,11 @@ class FeaturesEnricher(TransformerMixin):
2158
2418
  self.logger.warning(msg)
2159
2419
  raise ValidationError(msg)
2160
2420
 
2161
- for key_type in SearchKey.__members__.values():
2162
- if key_type != SearchKey.CUSTOM_KEY and list(key_types).count(key_type) > 1:
2163
- msg = self.bundle.get("multiple_search_key").format(key_type)
2164
- self.logger.warning(msg)
2165
- raise ValidationError(msg)
2421
+ # for key_type in SearchKey.__members__.values():
2422
+ # if key_type != SearchKey.CUSTOM_KEY and list(key_types).count(key_type) > 1:
2423
+ # msg = self.bundle.get("multiple_search_key").format(key_type)
2424
+ # self.logger.warning(msg)
2425
+ # raise ValidationError(msg)
2166
2426
 
2167
2427
  # non_personal_keys = set(SearchKey.__members__.values()) - set(SearchKey.personal_keys())
2168
2428
  # if (
@@ -2178,6 +2438,15 @@ class FeaturesEnricher(TransformerMixin):
2178
2438
  def __is_registered(self) -> bool:
2179
2439
  return self.api_key is not None and self.api_key != ""
2180
2440
 
2441
+ def __log_warning(self, message: str, show_support_link: bool = False):
2442
+ warning_num = self.warning_counter.increment()
2443
+ formatted_message = f"WARNING #{warning_num}: {message}\n"
2444
+ if show_support_link:
2445
+ self.__display_support_link(formatted_message)
2446
+ else:
2447
+ print(formatted_message)
2448
+ self.logger.warning(message)
2449
+
2181
2450
  def __inner_fit(
2182
2451
  self,
2183
2452
  trace_id: str,
@@ -2199,8 +2468,11 @@ class FeaturesEnricher(TransformerMixin):
2199
2468
  ):
2200
2469
  self.warning_counter.reset()
2201
2470
  self.df_with_original_index = None
2202
- self.__cached_sampled_datasets = None
2471
+ self.__cached_sampled_datasets = dict()
2203
2472
  self.metrics = None
2473
+ self.fit_columns_renaming = None
2474
+ self.fit_dropped_features = set()
2475
+ self.fit_generated_features = []
2204
2476
 
2205
2477
  validated_X = self._validate_X(X)
2206
2478
  validated_y = self._validate_y(validated_X, y)
@@ -2221,9 +2493,7 @@ class FeaturesEnricher(TransformerMixin):
2221
2493
  checked_generate_features = []
2222
2494
  for gen_feature in self.generate_features:
2223
2495
  if gen_feature not in x_columns:
2224
- msg = self.bundle.get("missing_generate_feature").format(gen_feature, x_columns)
2225
- print(msg)
2226
- self.logger.warning(msg)
2496
+ self.__log_warning(self.bundle.get("missing_generate_feature").format(gen_feature, x_columns))
2227
2497
  else:
2228
2498
  checked_generate_features.append(gen_feature)
2229
2499
  self.generate_features = checked_generate_features
@@ -2232,9 +2502,9 @@ class FeaturesEnricher(TransformerMixin):
2232
2502
  validate_scoring_argument(scoring)
2233
2503
 
2234
2504
  self.__log_debug_information(
2235
- X,
2236
- y,
2237
- eval_set,
2505
+ validated_X,
2506
+ validated_y,
2507
+ validated_eval_set,
2238
2508
  exclude_features_sources=exclude_features_sources,
2239
2509
  calculate_metrics=calculate_metrics,
2240
2510
  scoring=scoring,
@@ -2244,20 +2514,6 @@ class FeaturesEnricher(TransformerMixin):
2244
2514
 
2245
2515
  df = pd.concat([validated_X, validated_y], axis=1)
2246
2516
 
2247
- self.fit_search_keys = self.search_keys.copy()
2248
- self.fit_search_keys = self.__prepare_search_keys(validated_X, self.fit_search_keys, is_demo_dataset)
2249
-
2250
- validate_dates_distribution(validated_X, self.fit_search_keys, self.logger, self.bundle, self.warning_counter)
2251
-
2252
- maybe_date_column = self._get_date_column(self.fit_search_keys)
2253
- has_date = maybe_date_column is not None
2254
- model_task_type = self.model_task_type or define_task(validated_y, has_date, self.logger)
2255
- self._validate_binary_observations(validated_y, model_task_type)
2256
-
2257
- self.runtime_parameters = get_runtime_params_custom_loss(
2258
- self.loss, model_task_type, self.runtime_parameters, self.logger
2259
- )
2260
-
2261
2517
  if validated_eval_set is not None and len(validated_eval_set) > 0:
2262
2518
  df[EVAL_SET_INDEX] = 0
2263
2519
  for idx, (eval_X, eval_y) in enumerate(validated_eval_set):
@@ -2265,12 +2521,21 @@ class FeaturesEnricher(TransformerMixin):
2265
2521
  eval_df[EVAL_SET_INDEX] = idx + 1
2266
2522
  df = pd.concat([df, eval_df])
2267
2523
 
2268
- df = self.__correct_target(df)
2269
-
2524
+ self.fit_search_keys = self.search_keys.copy()
2270
2525
  df = self.__handle_index_search_keys(df, self.fit_search_keys)
2526
+ self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
2271
2527
 
2272
- if is_numeric_dtype(df[self.TARGET_NAME]) and has_date:
2273
- self._validate_PSI(df.sort_values(by=maybe_date_column))
2528
+ maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2529
+ has_date = maybe_date_column is not None
2530
+ self.model_task_type = self.model_task_type or define_task(validated_y, has_date, self.logger)
2531
+
2532
+ self._validate_binary_observations(validated_y, self.model_task_type)
2533
+
2534
+ self.runtime_parameters = get_runtime_params_custom_loss(
2535
+ self.loss, self.model_task_type, self.runtime_parameters, self.logger
2536
+ )
2537
+
2538
+ df = self.__correct_target(df)
2274
2539
 
2275
2540
  if DEFAULT_INDEX in df.columns:
2276
2541
  msg = self.bundle.get("unsupported_index_column")
@@ -2281,58 +2546,132 @@ class FeaturesEnricher(TransformerMixin):
2281
2546
 
2282
2547
  df = self.__add_country_code(df, self.fit_search_keys)
2283
2548
 
2284
- df = remove_fintech_duplicates(
2285
- df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
2286
- )
2287
- df = clean_full_duplicates(df, self.logger, bundle=self.bundle)
2288
-
2289
- date_column = self._get_date_column(self.fit_search_keys)
2290
- self.__adjust_cv(df, date_column, model_task_type)
2291
-
2292
2549
  self.fit_generated_features = []
2293
2550
 
2294
- if date_column is not None:
2295
- converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
2551
+ if has_date:
2552
+ converter = DateTimeSearchKeyConverter(
2553
+ maybe_date_column,
2554
+ self.date_format,
2555
+ self.logger,
2556
+ bundle=self.bundle,
2557
+ )
2296
2558
  df = converter.convert(df, keep_time=True)
2297
- self.logger.info(f"Date column after convertion: {df[date_column]}")
2559
+ if converter.has_old_dates:
2560
+ self.__log_warning(self.bundle.get("dataset_drop_old_dates"))
2561
+ self.logger.info(f"Date column after convertion: {df[maybe_date_column]}")
2298
2562
  self.fit_generated_features.extend(converter.generated_features)
2299
2563
  else:
2300
2564
  self.logger.info("Input dataset hasn't date column")
2301
2565
  if self.add_date_if_missing:
2302
2566
  df = self._add_current_date_as_key(df, self.fit_search_keys, self.logger, self.bundle)
2567
+
2568
+ email_columns = SearchKey.find_all_keys(self.fit_search_keys, SearchKey.EMAIL)
2569
+ if email_columns:
2570
+ generator = EmailDomainGenerator(email_columns)
2571
+ df = generator.generate(df)
2572
+ self.fit_generated_features.extend(generator.generated_features)
2573
+
2574
+ # Checks that need validated date
2575
+ try:
2576
+ if not is_dates_distribution_valid(df, self.fit_search_keys):
2577
+ self.__log_warning(bundle.get("x_unstable_by_date"))
2578
+ except Exception:
2579
+ self.logger.exception("Failed to check dates distribution validity")
2580
+
2581
+ if (
2582
+ is_numeric_dtype(df[self.TARGET_NAME])
2583
+ and self.model_task_type in [ModelTaskType.BINARY, ModelTaskType.MULTICLASS]
2584
+ and has_date
2585
+ ):
2586
+ self._validate_PSI(df.sort_values(by=maybe_date_column))
2587
+
2588
+ normalizer = Normalizer(self.bundle, self.logger)
2589
+ df, self.fit_search_keys, self.fit_generated_features = normalizer.normalize(
2590
+ df, self.fit_search_keys, self.fit_generated_features
2591
+ )
2592
+ self.fit_columns_renaming = normalizer.columns_renaming
2593
+ if normalizer.removed_features:
2594
+ self.__log_warning(self.bundle.get("dataset_date_features").format(normalizer.removed_features))
2595
+
2596
+ self.__adjust_cv(df)
2597
+
2598
+ df, fintech_warnings = remove_fintech_duplicates(
2599
+ df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
2600
+ )
2601
+ if fintech_warnings:
2602
+ for fintech_warning in fintech_warnings:
2603
+ self.__log_warning(fintech_warning)
2604
+ df, full_duplicates_warning = clean_full_duplicates(df, self.logger, bundle=self.bundle)
2605
+ if full_duplicates_warning:
2606
+ self.__log_warning(full_duplicates_warning)
2607
+
2608
+ # Explode multiple search keys
2609
+ df = self.__add_fit_system_record_id(df, self.fit_search_keys, ENTITY_SYSTEM_RECORD_ID)
2610
+
2611
+ # TODO check that this is correct for enrichment
2612
+ self.df_with_original_index = df.copy()
2613
+ # TODO check maybe need to drop _time column from df_with_original_index
2614
+
2615
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, self.fit_columns_renaming)
2616
+
2617
+ # Convert EMAIL to HEM after unnesting to do it only with one column
2303
2618
  email_column = self._get_email_column(self.fit_search_keys)
2304
2619
  hem_column = self._get_hem_column(self.fit_search_keys)
2305
- email_converted_to_hem = False
2306
2620
  if email_column:
2307
- converter = EmailSearchKeyConverter(email_column, hem_column, self.fit_search_keys, self.logger)
2621
+ converter = EmailSearchKeyConverter(
2622
+ email_column,
2623
+ hem_column,
2624
+ self.fit_search_keys,
2625
+ self.fit_columns_renaming,
2626
+ list(unnest_search_keys.keys()),
2627
+ self.logger,
2628
+ )
2308
2629
  df = converter.convert(df)
2309
- self.fit_generated_features.extend(converter.generated_features)
2310
- email_converted_to_hem = converter.email_converted_to_hem
2311
- if (
2312
- self.detect_missing_search_keys
2313
- and list(self.fit_search_keys.values()) == [SearchKey.DATE]
2314
- and self.country_code is None
2315
- ):
2316
- converter = IpToCountrySearchKeyConverter(self.fit_search_keys, self.logger)
2630
+
2631
+ ip_column = self._get_ip_column(self.fit_search_keys)
2632
+ if ip_column:
2633
+ converter = IpSearchKeyConverter(
2634
+ ip_column,
2635
+ self.fit_search_keys,
2636
+ self.fit_columns_renaming,
2637
+ list(unnest_search_keys.keys()),
2638
+ self.bundle,
2639
+ self.logger,
2640
+ )
2641
+ df = converter.convert(df)
2642
+
2643
+ phone_column = self._get_phone_column(self.fit_search_keys)
2644
+ country_column = self._get_country_column(self.fit_search_keys)
2645
+ if phone_column:
2646
+ converter = PhoneSearchKeyConverter(phone_column, country_column)
2647
+ df = converter.convert(df)
2648
+
2649
+ if country_column:
2650
+ converter = CountrySearchKeyConverter(country_column)
2317
2651
  df = converter.convert(df)
2318
2652
 
2319
- non_feature_columns = [self.TARGET_NAME, EVAL_SET_INDEX] + list(self.fit_search_keys.keys())
2320
- if email_converted_to_hem:
2321
- non_feature_columns.append(email_column)
2653
+ postal_code = self._get_postal_column(self.fit_search_keys)
2654
+ if postal_code:
2655
+ converter = PostalCodeSearchKeyConverter(postal_code)
2656
+ df = converter.convert(df)
2657
+
2658
+ non_feature_columns = [self.TARGET_NAME, EVAL_SET_INDEX, ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST] + list(
2659
+ self.fit_search_keys.keys()
2660
+ )
2322
2661
  if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2323
2662
  non_feature_columns.append(DateTimeSearchKeyConverter.DATETIME_COL)
2324
2663
 
2325
2664
  features_columns = [c for c in df.columns if c not in non_feature_columns]
2326
2665
 
2327
- features_to_drop = FeaturesValidator(self.logger).validate(
2328
- df, features_columns, self.generate_features, self.warning_counter
2666
+ features_to_drop, feature_validator_warnings = FeaturesValidator(self.logger).validate(
2667
+ df, features_columns, self.generate_features, self.fit_columns_renaming
2329
2668
  )
2669
+ if feature_validator_warnings:
2670
+ for warning in feature_validator_warnings:
2671
+ self.__log_warning(warning)
2330
2672
  self.fit_dropped_features.update(features_to_drop)
2331
2673
  df = df.drop(columns=features_to_drop)
2332
2674
 
2333
- if email_converted_to_hem:
2334
- self.fit_dropped_features.add(email_column)
2335
-
2336
2675
  self.fit_generated_features = [f for f in self.fit_generated_features if f not in self.fit_dropped_features]
2337
2676
 
2338
2677
  meaning_types = {
@@ -2340,12 +2679,19 @@ class FeaturesEnricher(TransformerMixin):
2340
2679
  **{str(c): FileColumnMeaningType.FEATURE for c in df.columns if c not in non_feature_columns},
2341
2680
  }
2342
2681
  meaning_types[self.TARGET_NAME] = FileColumnMeaningType.TARGET
2682
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2683
+ if SEARCH_KEY_UNNEST in df.columns:
2684
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2343
2685
  if eval_set is not None and len(eval_set) > 0:
2344
2686
  meaning_types[EVAL_SET_INDEX] = FileColumnMeaningType.EVAL_SET_INDEX
2345
2687
 
2346
- df = self.__add_fit_system_record_id(df, meaning_types, self.fit_search_keys)
2688
+ df = self.__add_fit_system_record_id(df, self.fit_search_keys, SYSTEM_RECORD_ID)
2689
+
2690
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2691
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
2692
+
2693
+ meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2347
2694
 
2348
- self.df_with_original_index = df.copy()
2349
2695
  df = df.reset_index(drop=True).sort_values(by=SYSTEM_RECORD_ID).reset_index(drop=True)
2350
2696
 
2351
2697
  combined_search_keys = combine_search_keys(self.fit_search_keys.keys())
@@ -2353,16 +2699,16 @@ class FeaturesEnricher(TransformerMixin):
2353
2699
  dataset = Dataset(
2354
2700
  "tds_" + str(uuid.uuid4()),
2355
2701
  df=df,
2356
- model_task_type=model_task_type,
2702
+ meaning_types=meaning_types,
2703
+ search_keys=combined_search_keys,
2704
+ unnest_search_keys=unnest_search_keys,
2705
+ model_task_type=self.model_task_type,
2357
2706
  date_format=self.date_format,
2358
2707
  random_state=self.random_state,
2359
2708
  rest_client=self.rest_client,
2360
2709
  logger=self.logger,
2361
2710
  )
2362
- dataset.meaning_types = meaning_types
2363
- dataset.search_keys = combined_search_keys
2364
- if email_converted_to_hem:
2365
- dataset.ignore_columns = [email_column]
2711
+ dataset.columns_renaming = self.fit_columns_renaming
2366
2712
 
2367
2713
  self.passed_features = [
2368
2714
  column for column, meaning_type in meaning_types.items() if meaning_type == FileColumnMeaningType.FEATURE
@@ -2438,9 +2784,7 @@ class FeaturesEnricher(TransformerMixin):
2438
2784
  zero_hit_columns = self.get_columns_by_search_keys(zero_hit_search_keys)
2439
2785
  if zero_hit_columns:
2440
2786
  msg = self.bundle.get("features_info_zero_hit_rate_search_keys").format(zero_hit_columns)
2441
- self.logger.warning(msg)
2442
- self.__display_support_link(msg)
2443
- self.warning_counter.increment()
2787
+ self.__log_warning(msg, show_support_link=True)
2444
2788
 
2445
2789
  if (
2446
2790
  self._search_task.unused_features_for_generation is not None
@@ -2450,9 +2794,7 @@ class FeaturesEnricher(TransformerMixin):
2450
2794
  dataset.columns_renaming.get(col) or col for col in self._search_task.unused_features_for_generation
2451
2795
  ]
2452
2796
  msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
2453
- self.logger.warning(msg)
2454
- print(msg)
2455
- self.warning_counter.increment()
2797
+ self.__log_warning(msg)
2456
2798
 
2457
2799
  self.__prepare_feature_importances(trace_id, validated_X.columns.to_list() + self.fit_generated_features)
2458
2800
 
@@ -2460,7 +2802,13 @@ class FeaturesEnricher(TransformerMixin):
2460
2802
 
2461
2803
  autofe_description = self.get_autofe_features_description()
2462
2804
  if autofe_description is not None:
2463
- display_html_dataframe(autofe_description, autofe_description, "*Description of AutoFE feature names")
2805
+ self.logger.info(f"AutoFE descriptions: {autofe_description}")
2806
+ self.autofe_features_display_handle = display_html_dataframe(
2807
+ df=autofe_description,
2808
+ internal_df=autofe_description,
2809
+ header=self.bundle.get("autofe_descriptions_header"),
2810
+ display_id="autofe_descriptions",
2811
+ )
2464
2812
 
2465
2813
  if self._has_paid_features(exclude_features_sources):
2466
2814
  if calculate_metrics is not None and calculate_metrics:
@@ -2500,32 +2848,32 @@ class FeaturesEnricher(TransformerMixin):
2500
2848
  progress_callback,
2501
2849
  )
2502
2850
  except Exception:
2503
- self.__show_report_button()
2851
+ self.report_button_handle = self.__show_report_button(display_id="report_button")
2504
2852
  raise
2505
2853
 
2506
- self.__show_report_button()
2854
+ self.report_button_handle = self.__show_report_button(display_id="report_button")
2507
2855
 
2508
2856
  if not self.warning_counter.has_warnings():
2509
2857
  self.__display_support_link(self.bundle.get("all_ok_community_invite"))
2510
2858
 
2511
- def __adjust_cv(self, df: pd.DataFrame, date_column: pd.Series, model_task_type: ModelTaskType):
2859
+ def __adjust_cv(self, df: pd.DataFrame):
2860
+ date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2512
2861
  # Check Multivariate time series
2513
2862
  if (
2514
2863
  self.cv is None
2515
2864
  and date_column
2516
- and model_task_type == ModelTaskType.REGRESSION
2865
+ and self.model_task_type == ModelTaskType.REGRESSION
2517
2866
  and len({SearchKey.PHONE, SearchKey.EMAIL, SearchKey.HEM}.intersection(self.fit_search_keys.keys())) == 0
2518
2867
  and is_blocked_time_series(df, date_column, list(self.fit_search_keys.keys()) + [TARGET])
2519
2868
  ):
2520
2869
  msg = self.bundle.get("multivariate_timeseries_detected")
2521
2870
  self.__override_cv(CVType.blocked_time_series, msg, print_warning=False)
2522
- elif (
2523
- self.cv is None
2524
- and model_task_type != ModelTaskType.REGRESSION
2525
- and self._get_group_columns(df, self.fit_search_keys)
2526
- ):
2871
+ elif self.cv is None and self.model_task_type != ModelTaskType.REGRESSION:
2527
2872
  msg = self.bundle.get("group_k_fold_in_classification")
2528
2873
  self.__override_cv(CVType.group_k_fold, msg, print_warning=self.cv is not None)
2874
+ group_columns = self._get_group_columns(df, self.fit_search_keys)
2875
+ self.runtime_parameters.properties["cv_params.group_columns"] = ",".join(group_columns)
2876
+ self.runtime_parameters.properties["cv_params.shuffle_kfold"] = "True"
2529
2877
 
2530
2878
  def __override_cv(self, cv: CVType, msg: str, print_warning: bool = True):
2531
2879
  if print_warning:
@@ -2543,9 +2891,6 @@ class FeaturesEnricher(TransformerMixin):
2543
2891
  return [c for c, v in search_keys_with_autodetection.items() if v.value.value in keys]
2544
2892
 
2545
2893
  def _validate_X(self, X, is_transform=False) -> pd.DataFrame:
2546
- if _num_samples(X) == 0:
2547
- raise ValidationError(self.bundle.get("x_is_empty"))
2548
-
2549
2894
  if isinstance(X, pd.DataFrame):
2550
2895
  if isinstance(X.columns, pd.MultiIndex) or isinstance(X.index, pd.MultiIndex):
2551
2896
  raise ValidationError(self.bundle.get("x_multiindex_unsupported"))
@@ -2559,6 +2904,9 @@ class FeaturesEnricher(TransformerMixin):
2559
2904
  else:
2560
2905
  raise ValidationError(self.bundle.get("unsupported_x_type").format(type(X)))
2561
2906
 
2907
+ if _num_samples(X) == 0:
2908
+ raise ValidationError(self.bundle.get("x_is_empty"))
2909
+
2562
2910
  if len(set(validated_X.columns)) != len(validated_X.columns):
2563
2911
  raise ValidationError(self.bundle.get("x_contains_dup_columns"))
2564
2912
  if not is_transform and not validated_X.index.is_unique:
@@ -2578,13 +2926,12 @@ class FeaturesEnricher(TransformerMixin):
2578
2926
  raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(EVAL_SET_INDEX))
2579
2927
  if SYSTEM_RECORD_ID in validated_X.columns:
2580
2928
  raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(SYSTEM_RECORD_ID))
2929
+ if ENTITY_SYSTEM_RECORD_ID in validated_X.columns:
2930
+ raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(ENTITY_SYSTEM_RECORD_ID))
2581
2931
 
2582
2932
  return validated_X
2583
2933
 
2584
2934
  def _validate_y(self, X: pd.DataFrame, y) -> pd.Series:
2585
- if _num_samples(y) == 0:
2586
- raise ValidationError(self.bundle.get("y_is_empty"))
2587
-
2588
2935
  if (
2589
2936
  not isinstance(y, pd.Series)
2590
2937
  and not isinstance(y, pd.DataFrame)
@@ -2593,6 +2940,9 @@ class FeaturesEnricher(TransformerMixin):
2593
2940
  ):
2594
2941
  raise ValidationError(self.bundle.get("unsupported_y_type").format(type(y)))
2595
2942
 
2943
+ if _num_samples(y) == 0:
2944
+ raise ValidationError(self.bundle.get("y_is_empty"))
2945
+
2596
2946
  if _num_samples(X) != _num_samples(y):
2597
2947
  raise ValidationError(self.bundle.get("x_and_y_diff_size").format(_num_samples(X), _num_samples(y)))
2598
2948
 
@@ -2730,9 +3080,10 @@ class FeaturesEnricher(TransformerMixin):
2730
3080
  X: pd.DataFrame, y: pd.Series, cv: Optional[CVType]
2731
3081
  ) -> Tuple[pd.DataFrame, pd.Series]:
2732
3082
  if cv not in [CVType.time_series, CVType.blocked_time_series]:
3083
+ record_id_column = ENTITY_SYSTEM_RECORD_ID if ENTITY_SYSTEM_RECORD_ID in X else SYSTEM_RECORD_ID
2733
3084
  Xy = X.copy()
2734
3085
  Xy[TARGET] = y
2735
- Xy = Xy.sort_values(by=SYSTEM_RECORD_ID).reset_index(drop=True)
3086
+ Xy = Xy.sort_values(by=record_id_column).reset_index(drop=True)
2736
3087
  X = Xy.drop(columns=TARGET)
2737
3088
  y = Xy[TARGET].copy()
2738
3089
 
@@ -2750,7 +3101,7 @@ class FeaturesEnricher(TransformerMixin):
2750
3101
  if DateTimeSearchKeyConverter.DATETIME_COL in X.columns:
2751
3102
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
2752
3103
  else:
2753
- date_column = FeaturesEnricher._get_date_column(search_keys)
3104
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2754
3105
  sort_columns = [date_column] if date_column is not None else []
2755
3106
 
2756
3107
  # Xy = pd.concat([X, y], axis=1)
@@ -2846,7 +3197,7 @@ class FeaturesEnricher(TransformerMixin):
2846
3197
 
2847
3198
  do_without_pandas_limits(print_datasets_sample)
2848
3199
 
2849
- maybe_date_col = self._get_date_column(self.search_keys)
3200
+ maybe_date_col = SearchKey.find_key(self.search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2850
3201
  if X is not None and maybe_date_col is not None and maybe_date_col in X.columns:
2851
3202
  # TODO cast date column to single dtype
2852
3203
  date_converter = DateTimeSearchKeyConverter(maybe_date_col, self.date_format)
@@ -2856,7 +3207,7 @@ class FeaturesEnricher(TransformerMixin):
2856
3207
  self.logger.info(f"Dates interval is ({min_date}, {max_date})")
2857
3208
 
2858
3209
  except Exception:
2859
- self.logger.exception("Failed to log debug information")
3210
+ self.logger.warning("Failed to log debug information", exc_info=True)
2860
3211
 
2861
3212
  def __handle_index_search_keys(self, df: pd.DataFrame, search_keys: Dict[str, SearchKey]) -> pd.DataFrame:
2862
3213
  index_names = df.index.names if df.index.names != [None] else [DEFAULT_INDEX]
@@ -2876,15 +3227,8 @@ class FeaturesEnricher(TransformerMixin):
2876
3227
 
2877
3228
  return df
2878
3229
 
2879
- @staticmethod
2880
- def _get_date_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2881
- for col, t in search_keys.items():
2882
- if t in [SearchKey.DATE, SearchKey.DATETIME]:
2883
- return col
2884
-
2885
- @staticmethod
2886
3230
  def _add_current_date_as_key(
2887
- df: pd.DataFrame, search_keys: Dict[str, SearchKey], logger: logging.Logger, bundle: ResourceBundle
3231
+ self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], logger: logging.Logger, bundle: ResourceBundle
2888
3232
  ) -> pd.DataFrame:
2889
3233
  if (
2890
3234
  set(search_keys.values()) == {SearchKey.PHONE}
@@ -2892,12 +3236,10 @@ class FeaturesEnricher(TransformerMixin):
2892
3236
  or set(search_keys.values()) == {SearchKey.HEM}
2893
3237
  or set(search_keys.values()) == {SearchKey.COUNTRY, SearchKey.POSTAL_CODE}
2894
3238
  ):
2895
- msg = bundle.get("current_date_added")
2896
- print(msg)
2897
- logger.warning(msg)
3239
+ self.__log_warning(bundle.get("current_date_added"))
2898
3240
  df[FeaturesEnricher.CURRENT_DATE] = datetime.date.today()
2899
3241
  search_keys[FeaturesEnricher.CURRENT_DATE] = SearchKey.DATE
2900
- converter = DateTimeSearchKeyConverter(FeaturesEnricher.CURRENT_DATE, None, logger, bundle)
3242
+ converter = DateTimeSearchKeyConverter(FeaturesEnricher.CURRENT_DATE)
2901
3243
  df = converter.convert(df)
2902
3244
  return df
2903
3245
 
@@ -2911,24 +3253,87 @@ class FeaturesEnricher(TransformerMixin):
2911
3253
 
2912
3254
  @staticmethod
2913
3255
  def _get_email_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3256
+ cols = [col for col, t in search_keys.items() if t == SearchKey.EMAIL]
3257
+ if len(cols) > 1:
3258
+ raise Exception("More than one email column found after unnest")
3259
+ if len(cols) == 1:
3260
+ return cols[0]
3261
+
3262
+ @staticmethod
3263
+ def _get_hem_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3264
+ cols = [col for col, t in search_keys.items() if t == SearchKey.HEM]
3265
+ if len(cols) > 1:
3266
+ raise Exception("More than one hem column found after unnest")
3267
+ if len(cols) == 1:
3268
+ return cols[0]
3269
+
3270
+ @staticmethod
3271
+ def _get_ip_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3272
+ cols = [col for col, t in search_keys.items() if t == SearchKey.IP]
3273
+ if len(cols) > 1:
3274
+ raise Exception("More than one ip column found after unnest")
3275
+ if len(cols) == 1:
3276
+ return cols[0]
3277
+
3278
+ @staticmethod
3279
+ def _get_phone_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2914
3280
  for col, t in search_keys.items():
2915
- if t == SearchKey.EMAIL:
3281
+ if t == SearchKey.PHONE:
2916
3282
  return col
2917
3283
 
2918
3284
  @staticmethod
2919
- def _get_hem_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3285
+ def _get_country_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2920
3286
  for col, t in search_keys.items():
2921
- if t == SearchKey.HEM:
3287
+ if t == SearchKey.COUNTRY:
2922
3288
  return col
2923
3289
 
2924
3290
  @staticmethod
2925
- def _get_phone_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3291
+ def _get_postal_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2926
3292
  for col, t in search_keys.items():
2927
- if t == SearchKey.PHONE:
3293
+ if t == SearchKey.POSTAL_CODE:
2928
3294
  return col
2929
3295
 
3296
+ def _explode_multiple_search_keys(
3297
+ self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], columns_renaming: Dict[str, str]
3298
+ ) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
3299
+ # find groups of multiple search keys
3300
+ search_key_names_by_type: Dict[SearchKey, List[str]] = {}
3301
+ for key_name, key_type in search_keys.items():
3302
+ search_key_names_by_type[key_type] = search_key_names_by_type.get(key_type, []) + [key_name]
3303
+ search_key_names_by_type = {
3304
+ key_type: key_names for key_type, key_names in search_key_names_by_type.items() if len(key_names) > 1
3305
+ }
3306
+ if len(search_key_names_by_type) == 0:
3307
+ return df, {}
3308
+
3309
+ self.logger.info(f"Start exploding dataset by {search_key_names_by_type}. Size before: {len(df)}")
3310
+ multiple_keys_columns = [col for cols in search_key_names_by_type.values() for col in cols]
3311
+ other_columns = [col for col in df.columns if col not in multiple_keys_columns]
3312
+ exploded_dfs = []
3313
+ unnest_search_keys = {}
3314
+
3315
+ for key_type, key_names in search_key_names_by_type.items():
3316
+ new_search_key = f"upgini_{key_type.name.lower()}_unnest"
3317
+ exploded_df = pd.melt(
3318
+ df, id_vars=other_columns, value_vars=key_names, var_name=SEARCH_KEY_UNNEST, value_name=new_search_key
3319
+ )
3320
+ exploded_dfs.append(exploded_df)
3321
+ for old_key in key_names:
3322
+ del search_keys[old_key]
3323
+ search_keys[new_search_key] = key_type
3324
+ unnest_search_keys[new_search_key] = key_names
3325
+ columns_renaming[new_search_key] = new_search_key
3326
+
3327
+ df = pd.concat(exploded_dfs, ignore_index=True)
3328
+ self.logger.info(f"Finished explosion. Size after: {len(df)}")
3329
+ return df, unnest_search_keys
3330
+
2930
3331
  def __add_fit_system_record_id(
2931
- self, df: pd.DataFrame, meaning_types: Dict[str, FileColumnMeaningType], search_keys: Dict[str, SearchKey]
3332
+ self,
3333
+ df: pd.DataFrame,
3334
+ # meaning_types: Dict[str, FileColumnMeaningType],
3335
+ search_keys: Dict[str, SearchKey],
3336
+ id_name: str,
2932
3337
  ) -> pd.DataFrame:
2933
3338
  # save original order or rows
2934
3339
  original_index_name = df.index.name
@@ -2939,52 +3344,61 @@ class FeaturesEnricher(TransformerMixin):
2939
3344
 
2940
3345
  # order by date and idempotent order by other keys
2941
3346
  if self.cv not in [CVType.time_series, CVType.blocked_time_series]:
2942
- sort_exclude_columns = [original_order_name, ORIGINAL_INDEX, EVAL_SET_INDEX, TARGET, "__target"]
3347
+ sort_exclude_columns = [
3348
+ original_order_name,
3349
+ ORIGINAL_INDEX,
3350
+ EVAL_SET_INDEX,
3351
+ TARGET,
3352
+ "__target",
3353
+ ENTITY_SYSTEM_RECORD_ID,
3354
+ ]
2943
3355
  if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2944
3356
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
2945
- sort_exclude_columns.append(self._get_date_column(search_keys))
3357
+ sort_exclude_columns.append(SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]))
2946
3358
  else:
2947
- date_column = self._get_date_column(search_keys)
3359
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2948
3360
  sort_columns = [date_column] if date_column is not None else []
2949
3361
 
3362
+ sorted_other_keys = sorted(search_keys, key=lambda x: str(search_keys.get(x)))
3363
+ sorted_other_keys = [k for k in sorted_other_keys if k not in sort_exclude_columns]
3364
+
2950
3365
  other_columns = sorted(
2951
3366
  [
2952
3367
  c
2953
3368
  for c in df.columns
2954
- if c not in sort_columns and c not in sort_exclude_columns and df[c].nunique() > 1
3369
+ if c not in sort_columns
3370
+ and c not in sorted_other_keys
3371
+ and c not in sort_exclude_columns
3372
+ and df[c].nunique() > 1
2955
3373
  ]
2956
- # [
2957
- # sk
2958
- # for sk, key_type in search_keys.items()
2959
- # if key_type not in [SearchKey.DATE, SearchKey.DATETIME]
2960
- # and sk in df.columns
2961
- # and df[sk].nunique() > 1 # don't use constant keys for hash
2962
- # ]
2963
3374
  )
2964
3375
 
3376
+ all_other_columns = sorted_other_keys + other_columns
3377
+
2965
3378
  search_keys_hash = "search_keys_hash"
2966
- if len(other_columns) > 0:
3379
+ if len(all_other_columns) > 0:
2967
3380
  sort_columns.append(search_keys_hash)
2968
- df[search_keys_hash] = pd.util.hash_pandas_object(df[other_columns], index=False)
3381
+ df[search_keys_hash] = pd.util.hash_pandas_object(df[all_other_columns], index=False)
2969
3382
 
2970
3383
  df = df.sort_values(by=sort_columns)
2971
3384
 
2972
3385
  if search_keys_hash in df.columns:
2973
3386
  df.drop(columns=search_keys_hash, inplace=True)
2974
3387
 
2975
- if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2976
- df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL, inplace=True)
2977
-
2978
3388
  df = df.reset_index(drop=True).reset_index()
2979
3389
  # system_record_id saves correct order for fit
2980
- df = df.rename(columns={DEFAULT_INDEX: SYSTEM_RECORD_ID})
3390
+ df = df.rename(columns={DEFAULT_INDEX: id_name})
2981
3391
 
2982
3392
  # return original order
2983
3393
  df = df.set_index(ORIGINAL_INDEX)
2984
3394
  df.index.name = original_index_name
2985
3395
  df = df.sort_values(by=original_order_name).drop(columns=original_order_name)
2986
3396
 
2987
- meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
3397
+ # meaning_types[id_name] = (
3398
+ # FileColumnMeaningType.SYSTEM_RECORD_ID
3399
+ # if id_name == SYSTEM_RECORD_ID
3400
+ # else FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
3401
+ # )
2988
3402
  return df
2989
3403
 
2990
3404
  def __correct_target(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -3039,7 +3453,11 @@ class FeaturesEnricher(TransformerMixin):
3039
3453
  )
3040
3454
 
3041
3455
  comparing_columns = X.columns if is_transform else df_with_original_index.columns
3042
- dup_features = [c for c in comparing_columns if c in result_features.columns and c != SYSTEM_RECORD_ID]
3456
+ dup_features = [
3457
+ c
3458
+ for c in comparing_columns
3459
+ if c in result_features.columns and c not in [SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID]
3460
+ ]
3043
3461
  if len(dup_features) > 0:
3044
3462
  self.logger.warning(f"X contain columns with same name as returned from backend: {dup_features}")
3045
3463
  raise ValidationError(self.bundle.get("returned_features_same_as_passed").format(dup_features))
@@ -3047,11 +3465,11 @@ class FeaturesEnricher(TransformerMixin):
3047
3465
  # index overrites from result_features
3048
3466
  original_index_name = df_with_original_index.index.name
3049
3467
  df_with_original_index = df_with_original_index.reset_index()
3468
+ # TODO drop system_record_id before merge
3050
3469
  result_features = pd.merge(
3051
3470
  df_with_original_index,
3052
3471
  result_features,
3053
- left_on=SYSTEM_RECORD_ID,
3054
- right_on=SYSTEM_RECORD_ID,
3472
+ on=ENTITY_SYSTEM_RECORD_ID,
3055
3473
  how="left" if is_transform else "inner",
3056
3474
  )
3057
3475
  result_features = result_features.set_index(original_index_name or DEFAULT_INDEX)
@@ -3059,10 +3477,12 @@ class FeaturesEnricher(TransformerMixin):
3059
3477
 
3060
3478
  if rows_to_drop is not None:
3061
3479
  self.logger.info(f"Before dropping target outliers size: {len(result_features)}")
3062
- result_features = result_features[~result_features[SYSTEM_RECORD_ID].isin(rows_to_drop[SYSTEM_RECORD_ID])]
3480
+ result_features = result_features[
3481
+ ~result_features[ENTITY_SYSTEM_RECORD_ID].isin(rows_to_drop[ENTITY_SYSTEM_RECORD_ID])
3482
+ ]
3063
3483
  self.logger.info(f"After dropping target outliers size: {len(result_features)}")
3064
3484
 
3065
- result_eval_sets = dict()
3485
+ result_eval_sets = {}
3066
3486
  if not is_transform and EVAL_SET_INDEX in result_features.columns:
3067
3487
  result_train_features = result_features.loc[result_features[EVAL_SET_INDEX] == 0].copy()
3068
3488
  eval_set_indices = list(result_features[EVAL_SET_INDEX].unique())
@@ -3092,16 +3512,17 @@ class FeaturesEnricher(TransformerMixin):
3092
3512
  result_train = result_train_features
3093
3513
 
3094
3514
  if drop_system_record_id:
3095
- if SYSTEM_RECORD_ID in result_train.columns:
3096
- result_train = result_train.drop(columns=SYSTEM_RECORD_ID)
3515
+ result_train = result_train.drop(columns=[SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID], errors="ignore")
3097
3516
  for eval_set_index in result_eval_sets.keys():
3098
- if SYSTEM_RECORD_ID in result_eval_sets[eval_set_index].columns:
3099
- result_eval_sets[eval_set_index] = result_eval_sets[eval_set_index].drop(columns=SYSTEM_RECORD_ID)
3517
+ result_eval_sets[eval_set_index] = result_eval_sets[eval_set_index].drop(
3518
+ columns=[SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID], errors="ignore"
3519
+ )
3100
3520
 
3101
3521
  return result_train, result_eval_sets
3102
3522
 
3103
- def __prepare_feature_importances(self, trace_id: str, x_columns: List[str], silent=False):
3104
- llm_source = "LLM with external data augmentation"
3523
+ def __prepare_feature_importances(
3524
+ self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3525
+ ):
3105
3526
  if self._search_task is None:
3106
3527
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
3107
3528
  features_meta = self._search_task.get_all_features_metadata_v2()
@@ -3112,122 +3533,44 @@ class FeaturesEnricher(TransformerMixin):
3112
3533
  features_df = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
3113
3534
 
3114
3535
  self.feature_names_ = []
3536
+ self.dropped_client_feature_names_ = []
3115
3537
  self.feature_importances_ = []
3116
3538
  features_info = []
3117
3539
  features_info_without_links = []
3118
3540
  internal_features_info = []
3119
3541
 
3120
- def round_shap_value(shap: float) -> float:
3121
- if shap > 0.0 and shap < 0.0001:
3122
- return 0.0001
3123
- else:
3124
- return round(shap, 4)
3125
-
3126
- def list_or_single(lst: List[str], single: str):
3127
- return lst or ([single] if single else [])
3128
-
3129
- def to_anchor(link: str, value: str) -> str:
3130
- if not value:
3131
- return ""
3132
- elif not link:
3133
- return value
3134
- elif value == llm_source:
3135
- return value
3136
- else:
3137
- return f"<a href='{link}' target='_blank' rel='noopener noreferrer'>{value}</a>"
3138
-
3139
- def make_links(names: List[str], links: List[str]):
3140
- all_links = [to_anchor(link, name) for name, link in itertools.zip_longest(names, links)]
3141
- return ",".join(all_links)
3542
+ if updated_shaps is not None:
3543
+ for fm in features_meta:
3544
+ fm.shap_value = updated_shaps.get(fm.name, 0.0)
3142
3545
 
3143
3546
  features_meta.sort(key=lambda m: (-m.shap_value, m.name))
3144
3547
  for feature_meta in features_meta:
3145
3548
  if feature_meta.name in original_names_dict.keys():
3146
3549
  feature_meta.name = original_names_dict[feature_meta.name]
3147
- # Use only enriched features
3550
+
3551
+ is_client_feature = feature_meta.name in x_columns
3552
+
3553
+ if feature_meta.shap_value == 0.0:
3554
+ if self.select_features:
3555
+ self.dropped_client_feature_names_.append(feature_meta.name)
3556
+ continue
3557
+
3558
+ # Use only important features
3148
3559
  if (
3149
- feature_meta.name in x_columns
3560
+ feature_meta.name in self.fit_generated_features
3150
3561
  or feature_meta.name == COUNTRY
3151
- or feature_meta.shap_value == 0.0
3152
- or feature_meta.name in self.fit_generated_features
3562
+ # In select_features mode we select also from etalon features and need to show them
3563
+ or (not self.select_features and is_client_feature)
3153
3564
  ):
3154
3565
  continue
3155
3566
 
3156
- feature_sample = []
3157
3567
  self.feature_names_.append(feature_meta.name)
3158
- self.feature_importances_.append(round_shap_value(feature_meta.shap_value))
3159
- if feature_meta.name in features_df.columns:
3160
- feature_sample = np.random.choice(features_df[feature_meta.name].dropna().unique(), 3).tolist()
3161
- if len(feature_sample) > 0 and isinstance(feature_sample[0], float):
3162
- feature_sample = [round(f, 4) for f in feature_sample]
3163
- feature_sample = [str(f) for f in feature_sample]
3164
- feature_sample = ", ".join(feature_sample)
3165
- if len(feature_sample) > 30:
3166
- feature_sample = feature_sample[:30] + "..."
3167
-
3168
- internal_provider = feature_meta.data_provider or "Upgini"
3169
- providers = list_or_single(feature_meta.data_providers, feature_meta.data_provider)
3170
- provider_links = list_or_single(feature_meta.data_provider_links, feature_meta.data_provider_link)
3171
- if providers:
3172
- provider = make_links(providers, provider_links)
3173
- else:
3174
- provider = to_anchor("https://upgini.com", "Upgini")
3568
+ self.feature_importances_.append(_round_shap_value(feature_meta.shap_value))
3175
3569
 
3176
- internal_source = feature_meta.data_source or (
3177
- llm_source
3178
- if not feature_meta.name.endswith("_country") and not feature_meta.name.endswith("_postal_code")
3179
- else ""
3180
- )
3181
- sources = list_or_single(feature_meta.data_sources, feature_meta.data_source)
3182
- source_links = list_or_single(feature_meta.data_source_links, feature_meta.data_source_link)
3183
- if sources:
3184
- source = make_links(sources, source_links)
3185
- else:
3186
- source = internal_source
3187
-
3188
- internal_feature_name = feature_meta.name
3189
- if feature_meta.doc_link:
3190
- feature_name = to_anchor(feature_meta.doc_link, feature_meta.name)
3191
- else:
3192
- feature_name = internal_feature_name
3193
-
3194
- features_info.append(
3195
- {
3196
- self.bundle.get("features_info_name"): feature_name,
3197
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3198
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3199
- self.bundle.get("features_info_value_preview"): feature_sample,
3200
- self.bundle.get("features_info_provider"): provider,
3201
- self.bundle.get("features_info_source"): source,
3202
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3203
- }
3204
- )
3205
- features_info_without_links.append(
3206
- {
3207
- self.bundle.get("features_info_name"): internal_feature_name,
3208
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3209
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3210
- self.bundle.get("features_info_value_preview"): feature_sample,
3211
- self.bundle.get("features_info_provider"): internal_provider,
3212
- self.bundle.get("features_info_source"): internal_source,
3213
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3214
- }
3215
- )
3216
- internal_features_info.append(
3217
- {
3218
- self.bundle.get("features_info_name"): internal_feature_name,
3219
- "feature_link": feature_meta.doc_link,
3220
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3221
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3222
- self.bundle.get("features_info_value_preview"): feature_sample,
3223
- self.bundle.get("features_info_provider"): internal_provider,
3224
- "provider_link": feature_meta.data_provider_link,
3225
- self.bundle.get("features_info_source"): internal_source,
3226
- "source_link": feature_meta.data_source_link,
3227
- self.bundle.get("features_info_commercial_schema"): feature_meta.commercial_schema or "",
3228
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3229
- }
3230
- )
3570
+ feature_info = FeatureInfo.from_metadata(feature_meta, features_df, is_client_feature)
3571
+ features_info.append(feature_info.to_row(self.bundle))
3572
+ features_info_without_links.append(feature_info.to_row_without_links(self.bundle))
3573
+ internal_features_info.append(feature_info.to_internal_row(self.bundle))
3231
3574
 
3232
3575
  if len(features_info) > 0:
3233
3576
  self.features_info = pd.DataFrame(features_info)
@@ -3252,7 +3595,22 @@ class FeaturesEnricher(TransformerMixin):
3252
3595
  autofe_meta = self._search_task.get_autofe_metadata()
3253
3596
  if autofe_meta is None:
3254
3597
  return None
3255
- features_meta = self._search_task.get_all_features_metadata_v2()
3598
+ if len(self._internal_features_info) != 0:
3599
+
3600
+ def to_feature_meta(row):
3601
+ fm = FeaturesMetadataV2(
3602
+ name=row[bundle.get("features_info_name")],
3603
+ type="",
3604
+ source="",
3605
+ hit_rate=row[bundle.get("features_info_hitrate")],
3606
+ shap_value=row[bundle.get("features_info_shap")],
3607
+ data_source=row[bundle.get("features_info_source")],
3608
+ )
3609
+ return fm
3610
+
3611
+ features_meta = self._internal_features_info.apply(to_feature_meta, axis=1).to_list()
3612
+ else:
3613
+ features_meta = self._search_task.get_all_features_metadata_v2()
3256
3614
 
3257
3615
  def get_feature_by_name(name: str):
3258
3616
  for m in features_meta:
@@ -3261,41 +3619,52 @@ class FeaturesEnricher(TransformerMixin):
3261
3619
 
3262
3620
  descriptions = []
3263
3621
  for m in autofe_meta:
3264
- autofe_feature = Feature.from_formula(m.formula)
3265
3622
  orig_to_hashed = {base_column.original_name: base_column.hashed_name for base_column in m.base_columns}
3266
- autofe_feature.rename_columns(orig_to_hashed)
3267
- autofe_feature.set_display_index(m.display_index)
3623
+
3624
+ autofe_feature = (
3625
+ Feature.from_formula(m.formula)
3626
+ .set_display_index(m.display_index)
3627
+ .set_alias(m.alias)
3628
+ .set_op_params(m.operator_params or {})
3629
+ .rename_columns(orig_to_hashed)
3630
+ )
3631
+
3268
3632
  if autofe_feature.op.is_vector:
3269
3633
  continue
3270
3634
 
3271
- description = dict()
3635
+ description = {}
3272
3636
 
3273
3637
  feature_meta = get_feature_by_name(autofe_feature.get_display_name(shorten=True))
3274
3638
  if feature_meta is None:
3275
3639
  self.logger.warning(f"Feature meta for display index {m.display_index} not found")
3276
3640
  continue
3277
3641
  description["shap"] = feature_meta.shap_value
3278
- description["Sources"] = feature_meta.data_source.replace("AutoFE: features from ", "").replace(
3279
- "AutoFE: feature from ", ""
3280
- )
3281
- description["Feature name"] = feature_meta.name
3642
+ description[self.bundle.get("autofe_descriptions_sources")] = feature_meta.data_source.replace(
3643
+ "AutoFE: features from ", ""
3644
+ ).replace("AutoFE: feature from ", "")
3645
+ description[self.bundle.get("autofe_descriptions_feature_name")] = feature_meta.name
3282
3646
 
3283
3647
  feature_idx = 1
3284
3648
  for bc in m.base_columns:
3285
- description[f"Feature {feature_idx}"] = bc.hashed_name
3649
+ description[self.bundle.get("autofe_descriptions_feature").format(feature_idx)] = bc.hashed_name
3286
3650
  feature_idx += 1
3287
3651
 
3288
- description["Function"] = autofe_feature.op.name
3652
+ description[self.bundle.get("autofe_descriptions_function")] = ",".join(
3653
+ sorted(autofe_feature.get_all_operand_names())
3654
+ )
3289
3655
 
3290
3656
  descriptions.append(description)
3291
3657
 
3292
3658
  if len(descriptions) == 0:
3293
3659
  return None
3294
3660
 
3295
- descriptions_df = pd.DataFrame(descriptions)
3296
- descriptions_df.fillna("", inplace=True)
3297
- descriptions_df.sort_values(by="shap", ascending=False, inplace=True)
3298
- descriptions_df.drop(columns="shap", inplace=True)
3661
+ descriptions_df = (
3662
+ pd.DataFrame(descriptions)
3663
+ .fillna("")
3664
+ .sort_values(by="shap", ascending=False)
3665
+ .drop(columns="shap")
3666
+ .reset_index(drop=True)
3667
+ )
3299
3668
  return descriptions_df
3300
3669
 
3301
3670
  except Exception:
@@ -3348,10 +3717,16 @@ class FeaturesEnricher(TransformerMixin):
3348
3717
  is_transform=False,
3349
3718
  silent_mode=False,
3350
3719
  ):
3720
+ for _, key_type in search_keys.items():
3721
+ if not isinstance(key_type, SearchKey):
3722
+ raise ValidationError(self.bundle.get("unsupported_type_of_search_key").format(key_type))
3723
+
3351
3724
  valid_search_keys = {}
3352
3725
  unsupported_search_keys = {
3353
3726
  SearchKey.IP_RANGE_FROM,
3354
3727
  SearchKey.IP_RANGE_TO,
3728
+ SearchKey.IPV6_RANGE_FROM,
3729
+ SearchKey.IPV6_RANGE_TO,
3355
3730
  SearchKey.MSISDN_RANGE_FROM,
3356
3731
  SearchKey.MSISDN_RANGE_TO,
3357
3732
  # SearchKey.EMAIL_ONE_DOMAIN,
@@ -3360,11 +3735,17 @@ class FeaturesEnricher(TransformerMixin):
3360
3735
  if len(passed_unsupported_search_keys) > 0:
3361
3736
  raise ValidationError(self.bundle.get("unsupported_search_key").format(passed_unsupported_search_keys))
3362
3737
 
3738
+ x_columns = [
3739
+ c
3740
+ for c in x.columns
3741
+ if c not in [TARGET, EVAL_SET_INDEX, SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
3742
+ ]
3743
+
3363
3744
  for column_id, meaning_type in search_keys.items():
3364
3745
  column_name = None
3365
3746
  if isinstance(column_id, str):
3366
3747
  if column_id not in x.columns:
3367
- raise ValidationError(self.bundle.get("search_key_not_found").format(column_id, list(x.columns)))
3748
+ raise ValidationError(self.bundle.get("search_key_not_found").format(column_id, x_columns))
3368
3749
  column_name = column_id
3369
3750
  valid_search_keys[column_name] = meaning_type
3370
3751
  elif isinstance(column_id, int):
@@ -3378,15 +3759,15 @@ class FeaturesEnricher(TransformerMixin):
3378
3759
  if meaning_type == SearchKey.COUNTRY and self.country_code is not None:
3379
3760
  msg = self.bundle.get("search_key_country_and_country_code")
3380
3761
  self.logger.warning(msg)
3381
- print(msg)
3762
+ if not silent_mode:
3763
+ self.__log_warning(msg)
3382
3764
  self.country_code = None
3383
3765
 
3384
3766
  if not self.__is_registered and not is_demo_dataset and meaning_type in SearchKey.personal_keys():
3385
3767
  msg = self.bundle.get("unregistered_with_personal_keys").format(meaning_type)
3386
3768
  self.logger.warning(msg)
3387
3769
  if not silent_mode:
3388
- self.warning_counter.increment()
3389
- print(msg)
3770
+ self.__log_warning(msg)
3390
3771
 
3391
3772
  valid_search_keys[column_name] = SearchKey.CUSTOM_KEY
3392
3773
  else:
@@ -3420,27 +3801,23 @@ class FeaturesEnricher(TransformerMixin):
3420
3801
  and not silent_mode
3421
3802
  ):
3422
3803
  msg = self.bundle.get("date_only_search")
3423
- print(msg)
3424
- self.logger.warning(msg)
3425
- self.warning_counter.increment()
3804
+ self.__log_warning(msg)
3426
3805
 
3427
3806
  maybe_date = [k for k, v in valid_search_keys.items() if v in [SearchKey.DATE, SearchKey.DATETIME]]
3428
3807
  if (self.cv is None or self.cv == CVType.k_fold) and len(maybe_date) > 0 and not silent_mode:
3429
3808
  date_column = next(iter(maybe_date))
3430
3809
  if x[date_column].nunique() > 0.9 * _num_samples(x):
3431
3810
  msg = self.bundle.get("date_search_without_time_series")
3432
- print(msg)
3433
- self.logger.warning(msg)
3434
- self.warning_counter.increment()
3811
+ self.__log_warning(msg)
3435
3812
 
3436
3813
  if len(valid_search_keys) == 1:
3437
- for k, v in valid_search_keys.items():
3438
- # Show warning for country only if country is the only key
3439
- if x[k].nunique() == 1 and (v != SearchKey.COUNTRY or len(valid_search_keys) == 1):
3440
- msg = self.bundle.get("single_constant_search_key").format(v, x[k].values[0])
3441
- print(msg)
3442
- self.logger.warning(msg)
3443
- self.warning_counter.increment()
3814
+ key, value = list(valid_search_keys.items())[0]
3815
+ # Show warning for country only if country is the only key
3816
+ if x[key].nunique() == 1:
3817
+ msg = self.bundle.get("single_constant_search_key").format(value, x[key].values[0])
3818
+ if not silent_mode:
3819
+ self.__log_warning(msg)
3820
+ # TODO maybe raise ValidationError
3444
3821
 
3445
3822
  self.logger.info(f"Prepared search keys: {valid_search_keys}")
3446
3823
 
@@ -3473,7 +3850,10 @@ class FeaturesEnricher(TransformerMixin):
3473
3850
  display_html_dataframe(self.metrics, self.metrics, msg)
3474
3851
 
3475
3852
  def __show_selected_features(self, search_keys: Dict[str, SearchKey]):
3476
- msg = self.bundle.get("features_info_header").format(len(self.feature_names_), list(search_keys.keys()))
3853
+ search_key_names = search_keys.keys()
3854
+ if self.fit_columns_renaming:
3855
+ search_key_names = [self.fit_columns_renaming.get(col, col) for col in search_key_names]
3856
+ msg = self.bundle.get("features_info_header").format(len(self.feature_names_), search_key_names)
3477
3857
 
3478
3858
  try:
3479
3859
  _ = get_ipython() # type: ignore
@@ -3481,27 +3861,29 @@ class FeaturesEnricher(TransformerMixin):
3481
3861
  print(Format.GREEN + Format.BOLD + msg + Format.END)
3482
3862
  self.logger.info(msg)
3483
3863
  if len(self.feature_names_) > 0:
3484
- display_html_dataframe(
3485
- self.features_info, self._features_info_without_links, self.bundle.get("relevant_features_header")
3864
+ self.features_info_display_handle = display_html_dataframe(
3865
+ self.features_info,
3866
+ self._features_info_without_links,
3867
+ self.bundle.get("relevant_features_header"),
3868
+ display_id="features_info",
3486
3869
  )
3487
3870
 
3488
- display_html_dataframe(
3871
+ self.data_sources_display_handle = display_html_dataframe(
3489
3872
  self.relevant_data_sources,
3490
3873
  self._relevant_data_sources_wo_links,
3491
3874
  self.bundle.get("relevant_data_sources_header"),
3875
+ display_id="data_sources",
3492
3876
  )
3493
3877
  else:
3494
3878
  msg = self.bundle.get("features_info_zero_important_features")
3495
- self.logger.warning(msg)
3496
- self.__display_support_link(msg)
3497
- self.warning_counter.increment()
3879
+ self.__log_warning(msg, show_support_link=True)
3498
3880
  except (ImportError, NameError):
3499
3881
  print(msg)
3500
3882
  print(self._internal_features_info)
3501
3883
 
3502
- def __show_report_button(self):
3884
+ def __show_report_button(self, display_id: Optional[str] = None, display_handle=None):
3503
3885
  try:
3504
- prepare_and_show_report(
3886
+ return prepare_and_show_report(
3505
3887
  relevant_features_df=self._features_info_without_links,
3506
3888
  relevant_datasources_df=self.relevant_data_sources,
3507
3889
  metrics_df=self.metrics,
@@ -3509,6 +3891,8 @@ class FeaturesEnricher(TransformerMixin):
3509
3891
  search_id=self._search_task.search_task_id,
3510
3892
  email=self.rest_client.get_current_email(),
3511
3893
  search_keys=[str(sk) for sk in self.search_keys.values()],
3894
+ display_id=display_id,
3895
+ display_handle=display_handle,
3512
3896
  )
3513
3897
  except Exception:
3514
3898
  pass
@@ -3550,65 +3934,70 @@ class FeaturesEnricher(TransformerMixin):
3550
3934
  def check_need_detect(search_key: SearchKey):
3551
3935
  return not is_transform or search_key in self.fit_search_keys.values()
3552
3936
 
3553
- if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
3554
- maybe_key = PostalCodeSearchKeyDetector().get_search_key_column(sample)
3555
- if maybe_key is not None:
3556
- search_keys[maybe_key] = SearchKey.POSTAL_CODE
3557
- self.autodetected_search_keys[maybe_key] = SearchKey.POSTAL_CODE
3558
- self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_key}")
3937
+ # if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
3938
+ if check_need_detect(SearchKey.POSTAL_CODE):
3939
+ maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
3940
+ if maybe_keys:
3941
+ new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
3942
+ search_keys.update(new_keys)
3943
+ self.autodetected_search_keys.update(new_keys)
3944
+ self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
3559
3945
  if not silent_mode:
3560
- print(self.bundle.get("postal_code_detected").format(maybe_key))
3946
+ print(self.bundle.get("postal_code_detected").format(maybe_keys))
3561
3947
 
3562
3948
  if (
3563
3949
  SearchKey.COUNTRY not in search_keys.values()
3564
3950
  and self.country_code is None
3565
3951
  and check_need_detect(SearchKey.COUNTRY)
3566
3952
  ):
3567
- maybe_key = CountrySearchKeyDetector().get_search_key_column(sample)
3568
- if maybe_key is not None:
3569
- search_keys[maybe_key] = SearchKey.COUNTRY
3570
- self.autodetected_search_keys[maybe_key] = SearchKey.COUNTRY
3953
+ maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
3954
+ if maybe_key:
3955
+ search_keys[maybe_key[0]] = SearchKey.COUNTRY
3956
+ self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
3571
3957
  self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
3572
3958
  if not silent_mode:
3573
3959
  print(self.bundle.get("country_detected").format(maybe_key))
3574
3960
 
3575
3961
  if (
3576
- SearchKey.EMAIL not in search_keys.values()
3577
- and SearchKey.HEM not in search_keys.values()
3962
+ # SearchKey.EMAIL not in search_keys.values()
3963
+ SearchKey.HEM not in search_keys.values()
3578
3964
  and check_need_detect(SearchKey.HEM)
3579
3965
  ):
3580
- maybe_key = EmailSearchKeyDetector().get_search_key_column(sample)
3581
- if maybe_key is not None and maybe_key not in search_keys.keys():
3966
+ maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
3967
+ if maybe_keys:
3582
3968
  if self.__is_registered or is_demo_dataset:
3583
- search_keys[maybe_key] = SearchKey.EMAIL
3584
- self.autodetected_search_keys[maybe_key] = SearchKey.EMAIL
3585
- self.logger.info(f"Autodetected search key EMAIL in column {maybe_key}")
3969
+ new_keys = {key: SearchKey.EMAIL for key in maybe_keys}
3970
+ search_keys.update(new_keys)
3971
+ self.autodetected_search_keys.update(new_keys)
3972
+ self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
3586
3973
  if not silent_mode:
3587
- print(self.bundle.get("email_detected").format(maybe_key))
3974
+ print(self.bundle.get("email_detected").format(maybe_keys))
3588
3975
  else:
3589
3976
  self.logger.warning(
3590
- f"Autodetected search key EMAIL in column {maybe_key}. But not used because not registered user"
3977
+ f"Autodetected search key EMAIL in column {maybe_keys}."
3978
+ " But not used because not registered user"
3591
3979
  )
3592
3980
  if not silent_mode:
3593
- print(self.bundle.get("email_detected_not_registered").format(maybe_key))
3594
- self.warning_counter.increment()
3981
+ self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
3595
3982
 
3596
- if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
3597
- maybe_key = PhoneSearchKeyDetector().get_search_key_column(sample)
3598
- if maybe_key is not None and maybe_key not in search_keys.keys():
3983
+ # if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
3984
+ if check_need_detect(SearchKey.PHONE):
3985
+ maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
3986
+ if maybe_keys:
3599
3987
  if self.__is_registered or is_demo_dataset:
3600
- search_keys[maybe_key] = SearchKey.PHONE
3601
- self.autodetected_search_keys[maybe_key] = SearchKey.PHONE
3602
- self.logger.info(f"Autodetected search key PHONE in column {maybe_key}")
3988
+ new_keys = {key: SearchKey.PHONE for key in maybe_keys}
3989
+ search_keys.update(new_keys)
3990
+ self.autodetected_search_keys.update(new_keys)
3991
+ self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
3603
3992
  if not silent_mode:
3604
- print(self.bundle.get("phone_detected").format(maybe_key))
3993
+ print(self.bundle.get("phone_detected").format(maybe_keys))
3605
3994
  else:
3606
3995
  self.logger.warning(
3607
- f"Autodetected search key PHONE in column {maybe_key}. But not used because not registered user"
3996
+ f"Autodetected search key PHONE in column {maybe_keys}. "
3997
+ "But not used because not registered user"
3608
3998
  )
3609
3999
  if not silent_mode:
3610
- print(self.bundle.get("phone_detected_not_registered"))
3611
- self.warning_counter.increment()
4000
+ self.__log_warning(self.bundle.get("phone_detected_not_registered"))
3612
4001
 
3613
4002
  return search_keys
3614
4003
 
@@ -3630,21 +4019,19 @@ class FeaturesEnricher(TransformerMixin):
3630
4019
  half_train = round(len(train) / 2)
3631
4020
  part1 = train[:half_train]
3632
4021
  part2 = train[half_train:]
3633
- train_psi = calculate_psi(part1[self.TARGET_NAME], part2[self.TARGET_NAME])
3634
- if train_psi > 0.2:
3635
- self.warning_counter.increment()
3636
- msg = self.bundle.get("train_unstable_target").format(train_psi)
3637
- print(msg)
3638
- self.logger.warning(msg)
4022
+ train_psi_result = calculate_psi(part1[self.TARGET_NAME], part2[self.TARGET_NAME])
4023
+ if isinstance(train_psi_result, Exception):
4024
+ self.logger.exception("Failed to calculate train PSI", train_psi_result)
4025
+ elif train_psi_result > 0.2:
4026
+ self.__log_warning(self.bundle.get("train_unstable_target").format(train_psi_result))
3639
4027
 
3640
4028
  # 2. Check train-test PSI
3641
4029
  if eval1 is not None:
3642
- train_test_psi = calculate_psi(train[self.TARGET_NAME], eval1[self.TARGET_NAME])
3643
- if train_test_psi > 0.2:
3644
- self.warning_counter.increment()
3645
- msg = self.bundle.get("eval_unstable_target").format(train_test_psi)
3646
- print(msg)
3647
- self.logger.warning(msg)
4030
+ train_test_psi_result = calculate_psi(train[self.TARGET_NAME], eval1[self.TARGET_NAME])
4031
+ if isinstance(train_test_psi_result, Exception):
4032
+ self.logger.exception("Failed to calculate test PSI", train_test_psi_result)
4033
+ elif train_test_psi_result > 0.2:
4034
+ self.__log_warning(self.bundle.get("eval_unstable_target").format(train_test_psi_result))
3648
4035
 
3649
4036
  def _dump_python_libs(self):
3650
4037
  try:
@@ -3666,8 +4053,8 @@ class FeaturesEnricher(TransformerMixin):
3666
4053
  self.logger.warning(f"Showing support link: {link_text}")
3667
4054
  display(
3668
4055
  HTML(
3669
- f"""<br/>{link_text} <a href='{support_link}' target='_blank' rel='noopener noreferrer'>
3670
- here</a>"""
4056
+ f"""{link_text} <a href='{support_link}' target='_blank' rel='noopener noreferrer'>
4057
+ here</a><br/>"""
3671
4058
  )
3672
4059
  )
3673
4060
  except (ImportError, NameError):
@@ -3712,7 +4099,7 @@ class FeaturesEnricher(TransformerMixin):
3712
4099
  if y is not None:
3713
4100
  with open(f"{tmp_dir}/y.pickle", "wb") as y_file:
3714
4101
  pickle.dump(sample(y, xy_sample_index), y_file)
3715
- if eval_set:
4102
+ if eval_set and _num_samples(eval_set[0][0]) > 0:
3716
4103
  eval_xy_sample_index = rnd.randint(0, _num_samples(eval_set[0][0]), size=1000)
3717
4104
  with open(f"{tmp_dir}/eval_x.pickle", "wb") as eval_x_file:
3718
4105
  pickle.dump(sample(eval_set[0][0], eval_xy_sample_index), eval_x_file)
@@ -3803,6 +4190,8 @@ def hash_input(X: pd.DataFrame, y: Optional[pd.Series] = None, eval_set: Optiona
3803
4190
  if y is not None:
3804
4191
  hashed_objects.append(pd.util.hash_pandas_object(y, index=False).values)
3805
4192
  if eval_set is not None:
4193
+ if isinstance(eval_set, tuple):
4194
+ eval_set = [eval_set]
3806
4195
  for eval_X, eval_y in eval_set:
3807
4196
  hashed_objects.append(pd.util.hash_pandas_object(eval_X, index=False).values)
3808
4197
  hashed_objects.append(pd.util.hash_pandas_object(eval_y, index=False).values)