upgini 1.1.280a3418.post2__py3-none-any.whl → 1.2.31a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

Files changed (43) hide show
  1. upgini/__about__.py +1 -1
  2. upgini/__init__.py +4 -20
  3. upgini/autofe/all_operands.py +39 -10
  4. upgini/autofe/binary.py +148 -45
  5. upgini/autofe/date.py +197 -26
  6. upgini/autofe/feature.py +102 -19
  7. upgini/autofe/groupby.py +22 -22
  8. upgini/autofe/operand.py +9 -6
  9. upgini/autofe/unary.py +78 -54
  10. upgini/autofe/vector.py +8 -8
  11. upgini/data_source/data_source_publisher.py +128 -5
  12. upgini/dataset.py +50 -386
  13. upgini/features_enricher.py +936 -541
  14. upgini/http.py +27 -16
  15. upgini/lazy_import.py +35 -0
  16. upgini/metadata.py +84 -59
  17. upgini/metrics.py +164 -34
  18. upgini/normalizer/normalize_utils.py +197 -0
  19. upgini/resource_bundle/strings.properties +66 -51
  20. upgini/search_task.py +10 -4
  21. upgini/utils/Roboto-Regular.ttf +0 -0
  22. upgini/utils/base_search_key_detector.py +14 -12
  23. upgini/utils/country_utils.py +16 -0
  24. upgini/utils/custom_loss_utils.py +39 -36
  25. upgini/utils/datetime_utils.py +98 -45
  26. upgini/utils/deduplicate_utils.py +135 -112
  27. upgini/utils/display_utils.py +46 -15
  28. upgini/utils/email_utils.py +54 -16
  29. upgini/utils/feature_info.py +172 -0
  30. upgini/utils/features_validator.py +34 -20
  31. upgini/utils/ip_utils.py +100 -1
  32. upgini/utils/phone_utils.py +343 -0
  33. upgini/utils/postal_code_utils.py +34 -0
  34. upgini/utils/sklearn_ext.py +28 -19
  35. upgini/utils/target_utils.py +113 -57
  36. upgini/utils/warning_counter.py +1 -0
  37. upgini/version_validator.py +8 -4
  38. {upgini-1.1.280a3418.post2.dist-info → upgini-1.2.31a1.dist-info}/METADATA +31 -16
  39. upgini-1.2.31a1.dist-info/RECORD +65 -0
  40. upgini/normalizer/phone_normalizer.py +0 -340
  41. upgini-1.1.280a3418.post2.dist-info/RECORD +0 -62
  42. {upgini-1.1.280a3418.post2.dist-info → upgini-1.2.31a1.dist-info}/WHEEL +0 -0
  43. {upgini-1.1.280a3418.post2.dist-info → upgini-1.2.31a1.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@ import sys
11
11
  import tempfile
12
12
  import time
13
13
  import uuid
14
+ from collections import Counter
14
15
  from dataclasses import dataclass
15
16
  from threading import Thread
16
17
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -22,7 +23,6 @@ from pandas.api.types import (
22
23
  is_datetime64_any_dtype,
23
24
  is_numeric_dtype,
24
25
  is_object_dtype,
25
- is_period_dtype,
26
26
  is_string_dtype,
27
27
  )
28
28
  from scipy.stats import ks_2samp
@@ -45,24 +45,31 @@ from upgini.mdc import MDC
45
45
  from upgini.metadata import (
46
46
  COUNTRY,
47
47
  DEFAULT_INDEX,
48
+ ENTITY_SYSTEM_RECORD_ID,
48
49
  EVAL_SET_INDEX,
49
50
  ORIGINAL_INDEX,
50
51
  RENAMED_INDEX,
52
+ SEARCH_KEY_UNNEST,
51
53
  SORT_ID,
52
54
  SYSTEM_RECORD_ID,
53
55
  TARGET,
54
56
  CVType,
57
+ FeaturesMetadataV2,
55
58
  FileColumnMeaningType,
56
59
  ModelTaskType,
57
60
  RuntimeParameters,
58
61
  SearchKey,
59
62
  )
60
63
  from upgini.metrics import EstimatorWrapper, validate_scoring_argument
64
+ from upgini.normalizer.normalize_utils import Normalizer
61
65
  from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
62
66
  from upgini.search_task import SearchTask
63
67
  from upgini.spinner import Spinner
64
68
  from upgini.utils import combine_search_keys, find_numbers_with_decimal_comma
65
- from upgini.utils.country_utils import CountrySearchKeyDetector
69
+ from upgini.utils.country_utils import (
70
+ CountrySearchKeyConverter,
71
+ CountrySearchKeyDetector,
72
+ )
66
73
  from upgini.utils.custom_loss_utils import (
67
74
  get_additional_params_custom_loss,
68
75
  get_runtime_params_custom_loss,
@@ -71,8 +78,8 @@ from upgini.utils.cv_utils import CVConfig, get_groups
71
78
  from upgini.utils.datetime_utils import (
72
79
  DateTimeSearchKeyConverter,
73
80
  is_blocked_time_series,
81
+ is_dates_distribution_valid,
74
82
  is_time_series,
75
- validate_dates_distribution,
76
83
  )
77
84
  from upgini.utils.deduplicate_utils import (
78
85
  clean_full_duplicates,
@@ -84,12 +91,20 @@ from upgini.utils.display_utils import (
84
91
  prepare_and_show_report,
85
92
  show_request_quote_button,
86
93
  )
87
- from upgini.utils.email_utils import EmailSearchKeyConverter, EmailSearchKeyDetector
94
+ from upgini.utils.email_utils import (
95
+ EmailDomainGenerator,
96
+ EmailSearchKeyConverter,
97
+ EmailSearchKeyDetector,
98
+ )
99
+ from upgini.utils.feature_info import FeatureInfo, _round_shap_value
88
100
  from upgini.utils.features_validator import FeaturesValidator
89
101
  from upgini.utils.format import Format
90
- from upgini.utils.ip_utils import IpToCountrySearchKeyConverter
91
- from upgini.utils.phone_utils import PhoneSearchKeyDetector
92
- from upgini.utils.postal_code_utils import PostalCodeSearchKeyDetector
102
+ from upgini.utils.ip_utils import IpSearchKeyConverter
103
+ from upgini.utils.phone_utils import PhoneSearchKeyConverter, PhoneSearchKeyDetector
104
+ from upgini.utils.postal_code_utils import (
105
+ PostalCodeSearchKeyConverter,
106
+ PostalCodeSearchKeyDetector,
107
+ )
93
108
 
94
109
  try:
95
110
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
@@ -145,6 +160,10 @@ class FeaturesEnricher(TransformerMixin):
145
160
 
146
161
  shared_datasets: list of str, optional (default=None)
147
162
  List of private shared dataset ids for custom search
163
+
164
+ select_features: bool, optional (default=False)
165
+ If True, return only selected features both from input and data sources.
166
+ Otherwise, return all features from input and only selected features from data sources.
148
167
  """
149
168
 
150
169
  TARGET_NAME = "target"
@@ -211,11 +230,12 @@ class FeaturesEnricher(TransformerMixin):
211
230
  client_visitorid: Optional[str] = None,
212
231
  custom_bundle_config: Optional[str] = None,
213
232
  add_date_if_missing: bool = True,
233
+ select_features: bool = False,
214
234
  **kwargs,
215
235
  ):
216
236
  self.bundle = get_custom_bundle(custom_bundle_config)
217
237
  self._api_key = api_key or os.environ.get(UPGINI_API_KEY)
218
- if api_key is not None and not isinstance(api_key, str):
238
+ if self._api_key is not None and not isinstance(self._api_key, str):
219
239
  raise ValidationError(f"api_key should be `string`, but passed: `{api_key}`")
220
240
  self.rest_client = get_rest_client(endpoint, self._api_key, client_ip, client_visitorid)
221
241
  self.client_ip = client_ip
@@ -235,6 +255,7 @@ class FeaturesEnricher(TransformerMixin):
235
255
 
236
256
  self.passed_features: List[str] = []
237
257
  self.df_with_original_index: Optional[pd.DataFrame] = None
258
+ self.fit_columns_renaming: Optional[Dict[str, str]] = None
238
259
  self.country_added = False
239
260
  self.fit_generated_features: List[str] = []
240
261
  self.fit_dropped_features: Set[str] = set()
@@ -245,10 +266,12 @@ class FeaturesEnricher(TransformerMixin):
245
266
  self.eval_set: Optional[List[Tuple]] = None
246
267
  self.autodetected_search_keys: Dict[str, SearchKey] = {}
247
268
  self.imbalanced = False
248
- self.__cached_sampled_datasets: Optional[Tuple[pd.DataFrame, pd.DataFrame, pd.Series, Dict, Dict]] = None
269
+ self.__cached_sampled_datasets: Dict[str, Tuple[pd.DataFrame, pd.DataFrame, pd.Series, Dict, Dict, Dict]] = (
270
+ dict()
271
+ )
249
272
 
250
- validate_version(self.logger)
251
- self.search_keys = search_keys or dict()
273
+ validate_version(self.logger, self.__log_warning)
274
+ self.search_keys = search_keys or {}
252
275
  self.country_code = country_code
253
276
  self.__validate_search_keys(search_keys, search_id)
254
277
  self.model_task_type = model_task_type
@@ -261,8 +284,11 @@ class FeaturesEnricher(TransformerMixin):
261
284
  self._relevant_data_sources_wo_links: pd.DataFrame = self.EMPTY_DATA_SOURCES
262
285
  self.metrics: Optional[pd.DataFrame] = None
263
286
  self.feature_names_ = []
287
+ self.dropped_client_feature_names_ = []
264
288
  self.feature_importances_ = []
265
289
  self.search_id = search_id
290
+ self.select_features = select_features
291
+
266
292
  if search_id:
267
293
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
268
294
 
@@ -322,6 +348,10 @@ class FeaturesEnricher(TransformerMixin):
322
348
  self.exclude_columns = exclude_columns
323
349
  self.baseline_score_column = baseline_score_column
324
350
  self.add_date_if_missing = add_date_if_missing
351
+ self.features_info_display_handle = None
352
+ self.data_sources_display_handle = None
353
+ self.autofe_features_display_handle = None
354
+ self.report_button_handle = None
325
355
 
326
356
  def _get_api_key(self):
327
357
  return self._api_key
@@ -423,7 +453,7 @@ class FeaturesEnricher(TransformerMixin):
423
453
 
424
454
  self.logger.info("Start fit")
425
455
 
426
- self.__validate_search_keys(self.search_keys, self.search_id)
456
+ self.__validate_search_keys(self.search_keys)
427
457
 
428
458
  # Validate client estimator params
429
459
  self._get_client_cat_features(estimator, X, self.search_keys)
@@ -557,7 +587,7 @@ class FeaturesEnricher(TransformerMixin):
557
587
 
558
588
  self.logger.info("Start fit_transform")
559
589
 
560
- self.__validate_search_keys(self.search_keys, self.search_id)
590
+ self.__validate_search_keys(self.search_keys)
561
591
 
562
592
  search_progress = SearchProgress(0.0, ProgressStage.START_FIT)
563
593
  if progress_callback is not None:
@@ -704,7 +734,7 @@ class FeaturesEnricher(TransformerMixin):
704
734
 
705
735
  start_time = time.time()
706
736
  try:
707
- result = self.__inner_transform(
737
+ result, _, _ = self.__inner_transform(
708
738
  trace_id,
709
739
  X,
710
740
  exclude_features_sources=exclude_features_sources,
@@ -831,17 +861,44 @@ class FeaturesEnricher(TransformerMixin):
831
861
  self.logger.warning(msg)
832
862
  print(msg)
833
863
 
864
+ if X is not None and y is None:
865
+ raise ValidationError("X passed without y")
866
+
834
867
  self.__validate_search_keys(self.search_keys, self.search_id)
835
868
  effective_X = X if X is not None else self.X
836
869
  effective_y = y if y is not None else self.y
837
870
  effective_eval_set = eval_set if eval_set is not None else self.eval_set
838
871
  effective_eval_set = self._check_eval_set(effective_eval_set, effective_X, self.bundle)
839
872
 
873
+ if (
874
+ self._search_task is None
875
+ or self._search_task.provider_metadata_v2 is None
876
+ or len(self._search_task.provider_metadata_v2) == 0
877
+ or effective_X is None
878
+ or effective_y is None
879
+ ):
880
+ raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
881
+
882
+ validated_X = self._validate_X(effective_X)
883
+ validated_y = self._validate_y(validated_X, effective_y)
884
+ validated_eval_set = (
885
+ [self._validate_eval_set_pair(validated_X, eval_pair) for eval_pair in effective_eval_set]
886
+ if effective_eval_set is not None
887
+ else None
888
+ )
889
+
890
+ if self.X is None:
891
+ self.X = X
892
+ if self.y is None:
893
+ self.y = y
894
+ if self.eval_set is None:
895
+ self.eval_set = effective_eval_set
896
+
840
897
  try:
841
898
  self.__log_debug_information(
842
- effective_X,
843
- effective_y,
844
- effective_eval_set,
899
+ validated_X,
900
+ validated_y,
901
+ validated_eval_set,
845
902
  exclude_features_sources=exclude_features_sources,
846
903
  cv=cv if cv is not None else self.cv,
847
904
  importance_threshold=importance_threshold,
@@ -851,21 +908,9 @@ class FeaturesEnricher(TransformerMixin):
851
908
  remove_outliers_calc_metrics=remove_outliers_calc_metrics,
852
909
  )
853
910
 
854
- if (
855
- self._search_task is None
856
- or self._search_task.provider_metadata_v2 is None
857
- or len(self._search_task.provider_metadata_v2) == 0
858
- or effective_X is None
859
- or effective_y is None
860
- ):
861
- raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
862
-
863
- if X is not None and y is None:
864
- raise ValidationError("X passed without y")
865
-
866
911
  validate_scoring_argument(scoring)
867
912
 
868
- self._validate_baseline_score(effective_X, effective_eval_set)
913
+ self._validate_baseline_score(validated_X, validated_eval_set)
869
914
 
870
915
  if self._has_paid_features(exclude_features_sources):
871
916
  msg = self.bundle.get("metrics_with_paid_features")
@@ -874,14 +919,14 @@ class FeaturesEnricher(TransformerMixin):
874
919
  return None
875
920
 
876
921
  cat_features, search_keys_for_metrics = self._get_client_cat_features(
877
- estimator, effective_X, self.search_keys
922
+ estimator, validated_X, self.search_keys
878
923
  )
879
924
 
880
925
  prepared_data = self._prepare_data_for_metrics(
881
926
  trace_id=trace_id,
882
- X=effective_X,
883
- y=effective_y,
884
- eval_set=effective_eval_set,
927
+ X=X,
928
+ y=y,
929
+ eval_set=eval_set,
885
930
  exclude_features_sources=exclude_features_sources,
886
931
  importance_threshold=importance_threshold,
887
932
  max_features=max_features,
@@ -904,21 +949,27 @@ class FeaturesEnricher(TransformerMixin):
904
949
  search_keys,
905
950
  groups,
906
951
  _cv,
952
+ columns_renaming,
907
953
  ) = prepared_data
908
954
 
955
+ # rename cat_features
956
+ if cat_features:
957
+ for new_c, old_c in columns_renaming.items():
958
+ if old_c in cat_features:
959
+ cat_features.remove(old_c)
960
+ cat_features.append(new_c)
961
+
909
962
  gc.collect()
910
963
 
964
+ if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
965
+ self.__log_warning(self.bundle.get("metrics_no_important_free_features"))
966
+ return None
967
+
911
968
  print(self.bundle.get("metrics_start"))
912
969
  with Spinner():
913
- if fitting_X.shape[1] == 0 and fitting_enriched_X.shape[1] == 0:
914
- print(self.bundle.get("metrics_no_important_free_features"))
915
- self.logger.warning("No client or free relevant ADS features found to calculate metrics")
916
- self.warning_counter.increment()
917
- return None
918
-
919
970
  self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
920
971
 
921
- has_date = self._get_date_column(search_keys) is not None
972
+ has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
922
973
  model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
923
974
 
924
975
  wrapper = EstimatorWrapper.create(
@@ -930,11 +981,12 @@ class FeaturesEnricher(TransformerMixin):
930
981
  scoring,
931
982
  groups=groups,
932
983
  text_features=self.generate_features,
984
+ has_date=has_date,
933
985
  )
934
986
  metric = wrapper.metric_name
935
987
  multiplier = wrapper.multiplier
936
988
 
937
- # 1 If client features are presented - fit and predict with KFold CatBoost model
989
+ # 1 If client features are presented - fit and predict with KFold estimator
938
990
  # on etalon features and calculate baseline metric
939
991
  etalon_metric = None
940
992
  baseline_estimator = None
@@ -956,14 +1008,24 @@ class FeaturesEnricher(TransformerMixin):
956
1008
  add_params=custom_loss_add_params,
957
1009
  groups=groups,
958
1010
  text_features=self.generate_features,
1011
+ has_date=has_date,
959
1012
  )
960
- etalon_metric = baseline_estimator.cross_val_predict(
1013
+ etalon_cv_result = baseline_estimator.cross_val_predict(
961
1014
  fitting_X, y_sorted, self.baseline_score_column
962
1015
  )
963
- self.logger.info(f"Baseline {metric} on train client features: {etalon_metric}")
1016
+ etalon_metric = etalon_cv_result.get_display_metric()
1017
+ if etalon_metric is None:
1018
+ self.logger.info(
1019
+ f"Baseline {metric} on train client features is None (maybe all features was removed)"
1020
+ )
1021
+ baseline_estimator = None
1022
+ else:
1023
+ self.logger.info(f"Baseline {metric} on train client features: {etalon_metric}")
964
1024
 
965
- # 2 Fit and predict with KFold Catboost model on enriched tds
1025
+ # 2 Fit and predict with KFold estimator on enriched tds
966
1026
  # and calculate final metric (and uplift)
1027
+ enriched_metric = None
1028
+ uplift = None
967
1029
  enriched_estimator = None
968
1030
  if set(fitting_X.columns) != set(fitting_enriched_X.columns):
969
1031
  self.logger.info(
@@ -981,16 +1043,24 @@ class FeaturesEnricher(TransformerMixin):
981
1043
  add_params=custom_loss_add_params,
982
1044
  groups=groups,
983
1045
  text_features=self.generate_features,
1046
+ has_date=has_date,
984
1047
  )
985
- enriched_metric = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
986
- self.logger.info(f"Enriched {metric} on train combined features: {enriched_metric}")
987
- if etalon_metric is not None:
988
- uplift = (enriched_metric - etalon_metric) * multiplier
1048
+ enriched_cv_result = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
1049
+ enriched_metric = enriched_cv_result.get_display_metric()
1050
+ enriched_shaps = enriched_cv_result.shap_values
1051
+
1052
+ if enriched_shaps is not None:
1053
+ self._update_shap_values(trace_id, validated_X.columns.to_list(), enriched_shaps)
1054
+
1055
+ if enriched_metric is None:
1056
+ self.logger.warning(
1057
+ f"Enriched {metric} on train combined features is None (maybe all features was removed)"
1058
+ )
1059
+ enriched_estimator = None
989
1060
  else:
990
- uplift = None
991
- else:
992
- enriched_metric = None
993
- uplift = None
1061
+ self.logger.info(f"Enriched {metric} on train combined features: {enriched_metric}")
1062
+ if etalon_metric is not None and enriched_metric is not None:
1063
+ uplift = (enriched_cv_result.metric - etalon_cv_result.metric) * multiplier
994
1064
 
995
1065
  train_metrics = {
996
1066
  self.bundle.get("quality_metrics_segment_header"): self.bundle.get(
@@ -999,10 +1069,10 @@ class FeaturesEnricher(TransformerMixin):
999
1069
  self.bundle.get("quality_metrics_rows_header"): _num_samples(effective_X),
1000
1070
  }
1001
1071
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1002
- y_sorted
1072
+ validated_y
1003
1073
  ):
1004
1074
  train_metrics[self.bundle.get("quality_metrics_mean_target_header")] = round(
1005
- np.mean(effective_y), 4
1075
+ np.mean(validated_y), 4
1006
1076
  )
1007
1077
  if etalon_metric is not None:
1008
1078
  train_metrics[self.bundle.get("quality_metrics_baseline_header").format(metric)] = etalon_metric
@@ -1033,9 +1103,10 @@ class FeaturesEnricher(TransformerMixin):
1033
1103
  f"Calculate baseline {metric} on eval set {idx + 1} "
1034
1104
  f"on client features: {eval_X_sorted.columns.to_list()}"
1035
1105
  )
1036
- etalon_eval_metric = baseline_estimator.calculate_metric(
1106
+ etalon_eval_results = baseline_estimator.calculate_metric(
1037
1107
  eval_X_sorted, eval_y_sorted, self.baseline_score_column
1038
1108
  )
1109
+ etalon_eval_metric = etalon_eval_results.get_display_metric()
1039
1110
  self.logger.info(
1040
1111
  f"Baseline {metric} on eval set {idx + 1} client features: {etalon_eval_metric}"
1041
1112
  )
@@ -1047,9 +1118,10 @@ class FeaturesEnricher(TransformerMixin):
1047
1118
  f"Calculate enriched {metric} on eval set {idx + 1} "
1048
1119
  f"on combined features: {enriched_eval_X_sorted.columns.to_list()}"
1049
1120
  )
1050
- enriched_eval_metric = enriched_estimator.calculate_metric(
1121
+ enriched_eval_results = enriched_estimator.calculate_metric(
1051
1122
  enriched_eval_X_sorted, enriched_eval_y_sorted
1052
1123
  )
1124
+ enriched_eval_metric = enriched_eval_results.get_display_metric()
1053
1125
  self.logger.info(
1054
1126
  f"Enriched {metric} on eval set {idx + 1} combined features: {enriched_eval_metric}"
1055
1127
  )
@@ -1057,11 +1129,11 @@ class FeaturesEnricher(TransformerMixin):
1057
1129
  enriched_eval_metric = None
1058
1130
 
1059
1131
  if etalon_eval_metric is not None and enriched_eval_metric is not None:
1060
- eval_uplift = (enriched_eval_metric - etalon_eval_metric) * multiplier
1132
+ eval_uplift = (enriched_eval_results.metric - etalon_eval_results.metric) * multiplier
1061
1133
  else:
1062
1134
  eval_uplift = None
1063
1135
 
1064
- effective_eval_set = eval_set if eval_set is not None else self.eval_set
1136
+ # effective_eval_set = eval_set if eval_set is not None else self.eval_set
1065
1137
  eval_metrics = {
1066
1138
  self.bundle.get("quality_metrics_segment_header"): self.bundle.get(
1067
1139
  "quality_metrics_eval_segment"
@@ -1072,10 +1144,10 @@ class FeaturesEnricher(TransformerMixin):
1072
1144
  # self.bundle.get("quality_metrics_match_rate_header"): eval_hit_rate,
1073
1145
  }
1074
1146
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1075
- eval_y_sorted
1147
+ validated_eval_set[idx][1]
1076
1148
  ):
1077
1149
  eval_metrics[self.bundle.get("quality_metrics_mean_target_header")] = round(
1078
- np.mean(effective_eval_set[idx][1]), 4
1150
+ np.mean(validated_eval_set[idx][1]), 4
1079
1151
  )
1080
1152
  if etalon_eval_metric is not None:
1081
1153
  eval_metrics[self.bundle.get("quality_metrics_baseline_header").format(metric)] = (
@@ -1099,7 +1171,7 @@ class FeaturesEnricher(TransformerMixin):
1099
1171
  )
1100
1172
 
1101
1173
  uplift_col = self.bundle.get("quality_metrics_uplift_header")
1102
- date_column = self._get_date_column(search_keys)
1174
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1103
1175
  if (
1104
1176
  uplift_col in metrics_df.columns
1105
1177
  and (metrics_df[uplift_col] < 0).any()
@@ -1138,6 +1210,57 @@ class FeaturesEnricher(TransformerMixin):
1138
1210
  finally:
1139
1211
  self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
1140
1212
 
1213
+ def _update_shap_values(self, trace_id: str, x_columns: List[str], new_shaps: Dict[str, float]):
1214
+ new_shaps = {
1215
+ feature: _round_shap_value(shap) for feature, shap in new_shaps.items() if feature in self.feature_names_
1216
+ }
1217
+ self.__prepare_feature_importances(trace_id, x_columns, new_shaps, silent=True)
1218
+
1219
+ if self.features_info_display_handle is not None:
1220
+ try:
1221
+ _ = get_ipython() # type: ignore
1222
+
1223
+ display_html_dataframe(
1224
+ self.features_info,
1225
+ self._features_info_without_links,
1226
+ self.bundle.get("relevant_features_header"),
1227
+ display_handle=self.features_info_display_handle,
1228
+ )
1229
+ except (ImportError, NameError):
1230
+ pass
1231
+ if self.data_sources_display_handle is not None:
1232
+ try:
1233
+ _ = get_ipython() # type: ignore
1234
+
1235
+ display_html_dataframe(
1236
+ self.relevant_data_sources,
1237
+ self._relevant_data_sources_wo_links,
1238
+ self.bundle.get("relevant_data_sources_header"),
1239
+ display_handle=self.data_sources_display_handle,
1240
+ )
1241
+ except (ImportError, NameError):
1242
+ pass
1243
+ if self.autofe_features_display_handle is not None:
1244
+ try:
1245
+ _ = get_ipython() # type: ignore
1246
+ autofe_descriptions_df = self.get_autofe_features_description()
1247
+ if autofe_descriptions_df is not None:
1248
+ display_html_dataframe(
1249
+ df=autofe_descriptions_df,
1250
+ internal_df=autofe_descriptions_df,
1251
+ header=self.bundle.get("autofe_descriptions_header"),
1252
+ display_handle=self.autofe_features_display_handle,
1253
+ )
1254
+ except (ImportError, NameError):
1255
+ pass
1256
+ if self.report_button_handle is not None:
1257
+ try:
1258
+ _ = get_ipython() # type: ignore
1259
+
1260
+ self.__show_report_button(display_handle=self.report_button_handle)
1261
+ except (ImportError, NameError):
1262
+ pass
1263
+
1141
1264
  def _check_train_and_eval_target_distribution(self, y, eval_set_dict):
1142
1265
  uneven_distribution = False
1143
1266
  for eval_set in eval_set_dict.values():
@@ -1174,34 +1297,6 @@ class FeaturesEnricher(TransformerMixin):
1174
1297
  def _has_paid_features(self, exclude_features_sources: Optional[List[str]]) -> bool:
1175
1298
  return self._has_features_with_commercial_schema(CommercialSchema.PAID.value, exclude_features_sources)
1176
1299
 
1177
- def _extend_x(self, x: pd.DataFrame, is_demo_dataset: bool) -> Tuple[pd.DataFrame, Dict[str, SearchKey]]:
1178
- search_keys = self.search_keys.copy()
1179
- search_keys = self.__prepare_search_keys(x, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
1180
-
1181
- extended_X = x.copy()
1182
- generated_features = []
1183
- date_column = self._get_date_column(search_keys)
1184
- if date_column is not None:
1185
- converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
1186
- extended_X = converter.convert(extended_X, keep_time=True)
1187
- generated_features.extend(converter.generated_features)
1188
- email_column = self._get_email_column(search_keys)
1189
- hem_column = self._get_hem_column(search_keys)
1190
- if email_column:
1191
- converter = EmailSearchKeyConverter(email_column, hem_column, search_keys, self.logger)
1192
- extended_X = converter.convert(extended_X)
1193
- generated_features.extend(converter.generated_features)
1194
- if (
1195
- self.detect_missing_search_keys
1196
- and list(search_keys.values()) == [SearchKey.DATE]
1197
- and self.country_code is None
1198
- ):
1199
- converter = IpToCountrySearchKeyConverter(search_keys, self.logger)
1200
- extended_X = converter.convert(extended_X)
1201
- generated_features = [f for f in generated_features if f in self.fit_generated_features]
1202
-
1203
- return extended_X, search_keys
1204
-
1205
1300
  def _is_input_same_as_fit(
1206
1301
  self,
1207
1302
  X: Union[pd.DataFrame, pd.Series, np.ndarray, None] = None,
@@ -1245,7 +1340,7 @@ class FeaturesEnricher(TransformerMixin):
1245
1340
  groups = None
1246
1341
 
1247
1342
  if not isinstance(_cv, BaseCrossValidator):
1248
- date_column = self._get_date_column(search_keys)
1343
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1249
1344
  date_series = X[date_column] if date_column is not None else None
1250
1345
  _cv, groups = CVConfig(
1251
1346
  _cv, date_series, self.random_state, self._search_task.get_shuffle_kfold(), group_columns=group_columns
@@ -1268,7 +1363,7 @@ class FeaturesEnricher(TransformerMixin):
1268
1363
 
1269
1364
  def _get_client_cat_features(
1270
1365
  self, estimator: Optional[Any], X: pd.DataFrame, search_keys: Dict[str, SearchKey]
1271
- ) -> Optional[List[str]]:
1366
+ ) -> Tuple[Optional[List[str]], List[str]]:
1272
1367
  cat_features = None
1273
1368
  search_keys_for_metrics = []
1274
1369
  if (
@@ -1328,26 +1423,38 @@ class FeaturesEnricher(TransformerMixin):
1328
1423
  progress_bar,
1329
1424
  progress_callback,
1330
1425
  )
1331
- X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys = dataclasses.astuple(sampled_data)
1426
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming = dataclasses.astuple(
1427
+ sampled_data
1428
+ )
1332
1429
 
1333
1430
  excluding_search_keys = list(search_keys.keys())
1334
1431
  if search_keys_for_metrics is not None and len(search_keys_for_metrics) > 0:
1335
- excluding_search_keys = [sk for sk in excluding_search_keys if sk not in search_keys_for_metrics]
1432
+ for sk in excluding_search_keys:
1433
+ if columns_renaming.get(sk) in search_keys_for_metrics:
1434
+ excluding_search_keys.remove(sk)
1435
+
1336
1436
  client_features = [
1337
1437
  c
1338
1438
  for c in X_sampled.columns.to_list()
1339
- if c
1439
+ if (
1440
+ not self.select_features
1441
+ or c in self.feature_names_
1442
+ or (self.fit_columns_renaming is not None and self.fit_columns_renaming.get(c) in self.feature_names_)
1443
+ )
1444
+ and c
1340
1445
  not in (
1341
1446
  excluding_search_keys
1342
1447
  + list(self.fit_dropped_features)
1343
- + [DateTimeSearchKeyConverter.DATETIME_COL, SYSTEM_RECORD_ID]
1448
+ + [DateTimeSearchKeyConverter.DATETIME_COL, SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID]
1344
1449
  )
1345
1450
  ]
1451
+ self.logger.info(f"Client features column on prepare data for metrics: {client_features}")
1346
1452
 
1347
1453
  filtered_enriched_features = self.__filtered_enriched_features(
1348
1454
  importance_threshold,
1349
1455
  max_features,
1350
1456
  )
1457
+ filtered_enriched_features = [c for c in filtered_enriched_features if c not in client_features]
1351
1458
 
1352
1459
  X_sorted, y_sorted = self._sort_by_system_record_id(X_sampled, y_sampled, self.cv)
1353
1460
  enriched_X_sorted, enriched_y_sorted = self._sort_by_system_record_id(enriched_X, y_sampled, self.cv)
@@ -1377,9 +1484,12 @@ class FeaturesEnricher(TransformerMixin):
1377
1484
  fitting_X = fitting_X.drop(columns=constant_columns, errors="ignore")
1378
1485
  fitting_enriched_X = fitting_enriched_X.drop(columns=constant_columns, errors="ignore")
1379
1486
 
1487
+ # TODO maybe there is no more need for these convertions
1380
1488
  # Remove datetime features
1381
1489
  datetime_features = [
1382
- f for f in fitting_X.columns if is_datetime64_any_dtype(fitting_X[f]) or is_period_dtype(fitting_X[f])
1490
+ f
1491
+ for f in fitting_X.columns
1492
+ if is_datetime64_any_dtype(fitting_X[f]) or isinstance(fitting_X[f].dtype, pd.PeriodDtype)
1383
1493
  ]
1384
1494
  if len(datetime_features) > 0:
1385
1495
  self.logger.warning(self.bundle.get("dataset_date_features").format(datetime_features))
@@ -1399,37 +1509,25 @@ class FeaturesEnricher(TransformerMixin):
1399
1509
  if len(decimal_columns_to_fix) > 0:
1400
1510
  self.logger.warning(f"Convert strings with decimal comma to float: {decimal_columns_to_fix}")
1401
1511
  for col in decimal_columns_to_fix:
1402
- fitting_X[col] = fitting_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1512
+ fitting_X[col] = fitting_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1403
1513
  fitting_enriched_X[col] = (
1404
- fitting_enriched_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1514
+ fitting_enriched_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1405
1515
  )
1406
1516
 
1407
- fitting_eval_set_dict = dict()
1517
+ fitting_eval_set_dict = {}
1518
+ fitting_x_columns = fitting_X.columns.to_list()
1519
+ self.logger.info(f"Final list of fitting X columns: {fitting_x_columns}")
1520
+ fitting_enriched_x_columns = fitting_enriched_X.columns.to_list()
1521
+ self.logger.info(f"Final list of fitting enriched X columns: {fitting_enriched_x_columns}")
1408
1522
  for idx, eval_tuple in eval_set_sampled_dict.items():
1409
1523
  eval_X_sampled, enriched_eval_X, eval_y_sampled = eval_tuple
1410
1524
  eval_X_sorted, eval_y_sorted = self._sort_by_system_record_id(eval_X_sampled, eval_y_sampled, self.cv)
1411
1525
  enriched_eval_X_sorted, enriched_eval_y_sorted = self._sort_by_system_record_id(
1412
1526
  enriched_eval_X, eval_y_sampled, self.cv
1413
1527
  )
1414
- fitting_eval_X = eval_X_sorted[client_features].copy()
1415
- fitting_enriched_eval_X = enriched_eval_X_sorted[
1416
- client_features + existing_filtered_enriched_features
1417
- ].copy()
1418
-
1419
- # # Drop high cardinality features in eval set
1420
- if len(columns_with_high_cardinality) > 0:
1421
- fitting_eval_X = fitting_eval_X.drop(columns=columns_with_high_cardinality, errors="ignore")
1422
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(
1423
- columns=columns_with_high_cardinality, errors="ignore"
1424
- )
1425
- # Drop constant features in eval_set
1426
- if len(constant_columns) > 0:
1427
- fitting_eval_X = fitting_eval_X.drop(columns=constant_columns, errors="ignore")
1428
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(columns=constant_columns, errors="ignore")
1429
- # Drop datetime features in eval_set
1430
- if len(datetime_features) > 0:
1431
- fitting_eval_X = fitting_eval_X.drop(columns=datetime_features, errors="ignore")
1432
- fitting_enriched_eval_X = fitting_enriched_eval_X.drop(columns=datetime_features, errors="ignore")
1528
+ fitting_eval_X = eval_X_sorted[fitting_x_columns].copy()
1529
+ fitting_enriched_eval_X = enriched_eval_X_sorted[fitting_enriched_x_columns].copy()
1530
+
1433
1531
  # Convert bool to string in eval_set
1434
1532
  if len(bool_columns) > 0:
1435
1533
  fitting_eval_X[col] = fitting_eval_X[col].astype(str)
@@ -1437,9 +1535,14 @@ class FeaturesEnricher(TransformerMixin):
1437
1535
  # Correct string features with decimal commas
1438
1536
  if len(decimal_columns_to_fix) > 0:
1439
1537
  for col in decimal_columns_to_fix:
1440
- fitting_eval_X[col] = fitting_eval_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1538
+ fitting_eval_X[col] = (
1539
+ fitting_eval_X[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
1540
+ )
1441
1541
  fitting_enriched_eval_X[col] = (
1442
- fitting_enriched_eval_X[col].astype("string").str.replace(",", ".").astype(np.float64)
1542
+ fitting_enriched_eval_X[col]
1543
+ .astype("string")
1544
+ .str.replace(",", ".", regex=False)
1545
+ .astype(np.float64)
1443
1546
  )
1444
1547
 
1445
1548
  fitting_eval_set_dict[idx] = (
@@ -1459,6 +1562,7 @@ class FeaturesEnricher(TransformerMixin):
1459
1562
  search_keys,
1460
1563
  groups,
1461
1564
  cv,
1565
+ columns_renaming,
1462
1566
  )
1463
1567
 
1464
1568
  @dataclass
@@ -1468,6 +1572,7 @@ class FeaturesEnricher(TransformerMixin):
1468
1572
  enriched_X: pd.DataFrame
1469
1573
  eval_set_sampled_dict: Dict[int, Tuple[pd.DataFrame, pd.Series]]
1470
1574
  search_keys: Dict[str, SearchKey]
1575
+ columns_renaming: Dict[str, str]
1471
1576
 
1472
1577
  def _sample_data_for_metrics(
1473
1578
  self,
@@ -1482,18 +1587,28 @@ class FeaturesEnricher(TransformerMixin):
1482
1587
  progress_bar: Optional[ProgressBar],
1483
1588
  progress_callback: Optional[Callable[[SearchProgress], Any]],
1484
1589
  ) -> _SampledDataForMetrics:
1485
- if self.__cached_sampled_datasets is not None and is_input_same_as_fit and remove_outliers_calc_metrics is None:
1590
+ datasets_hash = hash_input(validated_X, validated_y, eval_set)
1591
+ cached_sampled_datasets = self.__cached_sampled_datasets.get(datasets_hash)
1592
+ if cached_sampled_datasets is not None and is_input_same_as_fit and remove_outliers_calc_metrics is None:
1486
1593
  self.logger.info("Cached enriched dataset found - use it")
1487
- return self.__get_sampled_cached_enriched(exclude_features_sources)
1594
+ return self.__get_sampled_cached_enriched(datasets_hash, exclude_features_sources)
1488
1595
  elif len(self.feature_importances_) == 0:
1489
1596
  self.logger.info("No external features selected. So use only input datasets for metrics calculation")
1490
1597
  return self.__sample_only_input(validated_X, validated_y, eval_set, is_demo_dataset)
1491
1598
  # TODO save and check if dataset was deduplicated - use imbalance branch for such case
1492
- elif not self.imbalanced and not exclude_features_sources and is_input_same_as_fit:
1599
+ elif (
1600
+ not self.imbalanced
1601
+ and not exclude_features_sources
1602
+ and is_input_same_as_fit
1603
+ and self.df_with_original_index is not None
1604
+ ):
1493
1605
  self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
1494
1606
  return self.__sample_balanced(eval_set, trace_id, remove_outliers_calc_metrics)
1495
1607
  else:
1496
- self.logger.info("Dataset is imbalanced or exclude_features_sources or X was passed. Run transform")
1608
+ self.logger.info(
1609
+ "Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
1610
+ " Run transform"
1611
+ )
1497
1612
  print(self.bundle.get("prepare_data_for_metrics"))
1498
1613
  return self.__sample_imbalanced(
1499
1614
  validated_X,
@@ -1506,17 +1621,23 @@ class FeaturesEnricher(TransformerMixin):
1506
1621
  progress_callback,
1507
1622
  )
1508
1623
 
1509
- def __get_sampled_cached_enriched(self, exclude_features_sources: Optional[List[str]]) -> _SampledDataForMetrics:
1510
- X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys = self.__cached_sampled_datasets
1624
+ def __get_sampled_cached_enriched(
1625
+ self, datasets_hash: str, exclude_features_sources: Optional[List[str]]
1626
+ ) -> _SampledDataForMetrics:
1627
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming = (
1628
+ self.__cached_sampled_datasets[datasets_hash]
1629
+ )
1511
1630
  if exclude_features_sources:
1512
1631
  enriched_X = enriched_X.drop(columns=exclude_features_sources, errors="ignore")
1513
1632
 
1514
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1633
+ return self.__mk_sampled_data_tuple(
1634
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming
1635
+ )
1515
1636
 
1516
1637
  def __sample_only_input(
1517
1638
  self, validated_X: pd.DataFrame, validated_y: pd.Series, eval_set: Optional[List[tuple]], is_demo_dataset: bool
1518
1639
  ) -> _SampledDataForMetrics:
1519
- eval_set_sampled_dict = dict()
1640
+ eval_set_sampled_dict = {}
1520
1641
 
1521
1642
  df = validated_X.copy()
1522
1643
  df[TARGET] = validated_y
@@ -1529,7 +1650,31 @@ class FeaturesEnricher(TransformerMixin):
1529
1650
  eval_xy[EVAL_SET_INDEX] = idx + 1
1530
1651
  df = pd.concat([df, eval_xy])
1531
1652
 
1532
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1653
+ search_keys = self.search_keys.copy()
1654
+ search_keys = self.__prepare_search_keys(df, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
1655
+
1656
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1657
+ generated_features = []
1658
+ if date_column is not None:
1659
+ converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
1660
+ # Leave original date column values
1661
+ df_with_date_features = converter.convert(df, keep_time=True)
1662
+ df_with_date_features[date_column] = df[date_column]
1663
+ df = df_with_date_features
1664
+ generated_features = converter.generated_features
1665
+
1666
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
1667
+ if email_columns:
1668
+ generator = EmailDomainGenerator(email_columns)
1669
+ df = generator.generate(df)
1670
+ generated_features.extend(generator.generated_features)
1671
+
1672
+ # normalizer = Normalizer(self.bundle, self.logger)
1673
+ # df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
1674
+ # columns_renaming = normalizer.columns_renaming
1675
+ columns_renaming = {c: c for c in df.columns}
1676
+
1677
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1533
1678
 
1534
1679
  num_samples = _num_samples(df)
1535
1680
  sample_threshold, sample_rows = (
@@ -1541,24 +1686,36 @@ class FeaturesEnricher(TransformerMixin):
1541
1686
  self.logger.info(f"Downsampling from {num_samples} to {sample_rows}")
1542
1687
  df = df.sample(n=sample_rows, random_state=self.random_state)
1543
1688
 
1544
- df_extended, search_keys = self._extend_x(df, is_demo_dataset)
1545
- df_extended = self.__add_fit_system_record_id(df_extended, dict(), search_keys)
1689
+ df = self.__add_fit_system_record_id(df, search_keys, SYSTEM_RECORD_ID)
1690
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
1691
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
1546
1692
 
1547
- train_df = df_extended.query(f"{EVAL_SET_INDEX} == 0") if eval_set is not None else df_extended
1693
+ train_df = df.query(f"{EVAL_SET_INDEX} == 0") if eval_set is not None else df
1548
1694
  X_sampled = train_df.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1549
1695
  y_sampled = train_df[TARGET].copy()
1550
1696
  enriched_X = X_sampled
1551
1697
 
1552
1698
  if eval_set is not None:
1553
1699
  for idx in range(len(eval_set)):
1554
- eval_xy_sampled = df_extended.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1700
+ eval_xy_sampled = df.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1555
1701
  eval_X_sampled = eval_xy_sampled.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1556
1702
  eval_y_sampled = eval_xy_sampled[TARGET].copy()
1557
1703
  enriched_eval_X = eval_X_sampled
1558
1704
  eval_set_sampled_dict[idx] = (eval_X_sampled, enriched_eval_X, eval_y_sampled)
1559
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1560
1705
 
1561
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1706
+ datasets_hash = hash_input(X_sampled, y_sampled, eval_set_sampled_dict)
1707
+ self.__cached_sampled_datasets[datasets_hash] = (
1708
+ X_sampled,
1709
+ y_sampled,
1710
+ enriched_X,
1711
+ eval_set_sampled_dict,
1712
+ search_keys,
1713
+ columns_renaming,
1714
+ )
1715
+
1716
+ return self.__mk_sampled_data_tuple(
1717
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, columns_renaming
1718
+ )
1562
1719
 
1563
1720
  def __sample_balanced(
1564
1721
  self,
@@ -1566,22 +1723,21 @@ class FeaturesEnricher(TransformerMixin):
1566
1723
  trace_id: str,
1567
1724
  remove_outliers_calc_metrics: Optional[bool],
1568
1725
  ) -> _SampledDataForMetrics:
1569
- eval_set_sampled_dict = dict()
1726
+ eval_set_sampled_dict = {}
1570
1727
  search_keys = self.fit_search_keys
1571
1728
 
1572
1729
  rows_to_drop = None
1573
- has_date = self._get_date_column(search_keys) is not None
1574
- task_type = self.model_task_type or define_task(
1730
+ has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
1731
+ self.model_task_type = self.model_task_type or define_task(
1575
1732
  self.df_with_original_index[TARGET], has_date, self.logger, silent=True
1576
1733
  )
1577
- if task_type == ModelTaskType.REGRESSION:
1734
+ if self.model_task_type == ModelTaskType.REGRESSION:
1578
1735
  target_outliers_df = self._search_task.get_target_outliers(trace_id)
1579
1736
  if target_outliers_df is not None and len(target_outliers_df) > 0:
1580
1737
  outliers = pd.merge(
1581
1738
  self.df_with_original_index,
1582
1739
  target_outliers_df,
1583
- left_on=SYSTEM_RECORD_ID,
1584
- right_on=SYSTEM_RECORD_ID,
1740
+ on=ENTITY_SYSTEM_RECORD_ID,
1585
1741
  how="inner",
1586
1742
  )
1587
1743
  top_outliers = outliers.sort_values(by=TARGET, ascending=False)[TARGET].head(3)
@@ -1608,6 +1764,7 @@ class FeaturesEnricher(TransformerMixin):
1608
1764
  X_sampled = enriched_Xy[x_columns].copy()
1609
1765
  y_sampled = enriched_Xy[TARGET].copy()
1610
1766
  enriched_X = enriched_Xy.drop(columns=[TARGET, EVAL_SET_INDEX], errors="ignore")
1767
+ enriched_X_columns = enriched_X.columns.to_list()
1611
1768
 
1612
1769
  self.logger.info(f"Shape of enriched_X: {enriched_X.shape}")
1613
1770
  self.logger.info(f"Shape of X after sampling: {X_sampled.shape}")
@@ -1622,12 +1779,22 @@ class FeaturesEnricher(TransformerMixin):
1622
1779
  for idx in range(len(eval_set)):
1623
1780
  eval_X_sampled = enriched_eval_sets[idx + 1][x_columns].copy()
1624
1781
  eval_y_sampled = enriched_eval_sets[idx + 1][TARGET].copy()
1625
- enriched_eval_X = enriched_eval_sets[idx + 1].drop(columns=[TARGET, EVAL_SET_INDEX])
1782
+ enriched_eval_X = enriched_eval_sets[idx + 1][enriched_X_columns].copy()
1626
1783
  eval_set_sampled_dict[idx] = (eval_X_sampled, enriched_eval_X, eval_y_sampled)
1627
1784
 
1628
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1785
+ datasets_hash = hash_input(self.X, self.y, self.eval_set)
1786
+ self.__cached_sampled_datasets[datasets_hash] = (
1787
+ X_sampled,
1788
+ y_sampled,
1789
+ enriched_X,
1790
+ eval_set_sampled_dict,
1791
+ search_keys,
1792
+ self.fit_columns_renaming,
1793
+ )
1629
1794
 
1630
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys)
1795
+ return self.__mk_sampled_data_tuple(
1796
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, search_keys, self.fit_columns_renaming
1797
+ )
1631
1798
 
1632
1799
  def __sample_imbalanced(
1633
1800
  self,
@@ -1640,7 +1807,7 @@ class FeaturesEnricher(TransformerMixin):
1640
1807
  progress_bar: Optional[ProgressBar],
1641
1808
  progress_callback: Optional[Callable[[SearchProgress], Any]],
1642
1809
  ) -> _SampledDataForMetrics:
1643
- eval_set_sampled_dict = dict()
1810
+ eval_set_sampled_dict = {}
1644
1811
  if eval_set is not None:
1645
1812
  self.logger.info("Transform with eval_set")
1646
1813
  # concatenate X and eval_set with eval_set_index
@@ -1654,7 +1821,7 @@ class FeaturesEnricher(TransformerMixin):
1654
1821
  eval_df_with_index[EVAL_SET_INDEX] = idx + 1
1655
1822
  df = pd.concat([df, eval_df_with_index])
1656
1823
 
1657
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1824
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1658
1825
 
1659
1826
  # downsample if need to eval_set threshold
1660
1827
  num_samples = _num_samples(df)
@@ -1662,12 +1829,12 @@ class FeaturesEnricher(TransformerMixin):
1662
1829
  self.logger.info(f"Downsampling from {num_samples} to {Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS}")
1663
1830
  df = df.sample(n=Dataset.FIT_SAMPLE_WITH_EVAL_SET_ROWS, random_state=self.random_state)
1664
1831
 
1665
- eval_set_sampled_dict = dict()
1832
+ eval_set_sampled_dict = {}
1666
1833
 
1667
1834
  tmp_target_name = "__target"
1668
1835
  df = df.rename(columns={TARGET: tmp_target_name})
1669
1836
 
1670
- enriched_df = self.__inner_transform(
1837
+ enriched_df, columns_renaming, generated_features = self.__inner_transform(
1671
1838
  trace_id,
1672
1839
  df,
1673
1840
  exclude_features_sources=exclude_features_sources,
@@ -1684,7 +1851,7 @@ class FeaturesEnricher(TransformerMixin):
1684
1851
 
1685
1852
  x_columns = [
1686
1853
  c
1687
- for c in (validated_X.columns.tolist() + self.fit_generated_features + [SYSTEM_RECORD_ID])
1854
+ for c in (validated_X.columns.tolist() + generated_features + [SYSTEM_RECORD_ID])
1688
1855
  if c in enriched_df.columns
1689
1856
  ]
1690
1857
 
@@ -1692,12 +1859,13 @@ class FeaturesEnricher(TransformerMixin):
1692
1859
  X_sampled = enriched_Xy[x_columns].copy()
1693
1860
  y_sampled = enriched_Xy[TARGET].copy()
1694
1861
  enriched_X = enriched_Xy.drop(columns=[TARGET, EVAL_SET_INDEX])
1862
+ enriched_X_columns = enriched_X.columns.tolist()
1695
1863
 
1696
1864
  for idx in range(len(eval_set)):
1697
1865
  enriched_eval_xy = enriched_df.query(f"{EVAL_SET_INDEX} == {idx + 1}")
1698
1866
  eval_x_sampled = enriched_eval_xy[x_columns].copy()
1699
1867
  eval_y_sampled = enriched_eval_xy[TARGET].copy()
1700
- enriched_eval_x = enriched_eval_xy.drop(columns=[TARGET, EVAL_SET_INDEX])
1868
+ enriched_eval_x = enriched_eval_xy[enriched_X_columns].copy()
1701
1869
  eval_set_sampled_dict[idx] = (eval_x_sampled, enriched_eval_x, eval_y_sampled)
1702
1870
  else:
1703
1871
  self.logger.info("Transform without eval_set")
@@ -1705,7 +1873,7 @@ class FeaturesEnricher(TransformerMixin):
1705
1873
 
1706
1874
  df[TARGET] = validated_y
1707
1875
 
1708
- df = clean_full_duplicates(df, logger=self.logger, silent=True, bundle=self.bundle)
1876
+ df, _ = clean_full_duplicates(df, logger=self.logger, bundle=self.bundle)
1709
1877
 
1710
1878
  num_samples = _num_samples(df)
1711
1879
  if num_samples > Dataset.FIT_SAMPLE_THRESHOLD:
@@ -1715,7 +1883,7 @@ class FeaturesEnricher(TransformerMixin):
1715
1883
  tmp_target_name = "__target"
1716
1884
  df = df.rename(columns={TARGET: tmp_target_name})
1717
1885
 
1718
- enriched_Xy = self.__inner_transform(
1886
+ enriched_Xy, columns_renaming, generated_features = self.__inner_transform(
1719
1887
  trace_id,
1720
1888
  df,
1721
1889
  exclude_features_sources=exclude_features_sources,
@@ -1732,7 +1900,7 @@ class FeaturesEnricher(TransformerMixin):
1732
1900
 
1733
1901
  x_columns = [
1734
1902
  c
1735
- for c in (validated_X.columns.tolist() + self.fit_generated_features + [SYSTEM_RECORD_ID])
1903
+ for c in (validated_X.columns.tolist() + generated_features + [SYSTEM_RECORD_ID])
1736
1904
  if c in enriched_Xy.columns
1737
1905
  ]
1738
1906
 
@@ -1740,9 +1908,19 @@ class FeaturesEnricher(TransformerMixin):
1740
1908
  y_sampled = enriched_Xy[TARGET].copy()
1741
1909
  enriched_X = enriched_Xy.drop(columns=TARGET)
1742
1910
 
1743
- self.__cached_sampled_datasets = (X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys)
1911
+ datasets_hash = hash_input(validated_X, validated_y, eval_set)
1912
+ self.__cached_sampled_datasets[datasets_hash] = (
1913
+ X_sampled,
1914
+ y_sampled,
1915
+ enriched_X,
1916
+ eval_set_sampled_dict,
1917
+ self.search_keys,
1918
+ columns_renaming,
1919
+ )
1744
1920
 
1745
- return self.__mk_sampled_data_tuple(X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys)
1921
+ return self.__mk_sampled_data_tuple(
1922
+ X_sampled, y_sampled, enriched_X, eval_set_sampled_dict, self.search_keys, columns_renaming
1923
+ )
1746
1924
 
1747
1925
  def __mk_sampled_data_tuple(
1748
1926
  self,
@@ -1751,6 +1929,7 @@ class FeaturesEnricher(TransformerMixin):
1751
1929
  enriched_X: pd.DataFrame,
1752
1930
  eval_set_sampled_dict: Dict,
1753
1931
  search_keys: Dict,
1932
+ columns_renaming: Dict[str, str],
1754
1933
  ):
1755
1934
  search_keys = {k: v for k, v in search_keys.items() if k in X_sampled.columns.to_list()}
1756
1935
  return FeaturesEnricher._SampledDataForMetrics(
@@ -1759,6 +1938,7 @@ class FeaturesEnricher(TransformerMixin):
1759
1938
  enriched_X=enriched_X,
1760
1939
  eval_set_sampled_dict=eval_set_sampled_dict,
1761
1940
  search_keys=search_keys,
1941
+ columns_renaming=columns_renaming,
1762
1942
  )
1763
1943
 
1764
1944
  def get_search_id(self) -> Optional[str]:
@@ -1808,9 +1988,19 @@ class FeaturesEnricher(TransformerMixin):
1808
1988
  file_metadata = self._search_task.get_file_metadata(str(uuid.uuid4()))
1809
1989
  search_keys = file_metadata.search_types()
1810
1990
  if SearchKey.IPV6_ADDRESS in search_keys:
1811
- search_keys.remove(SearchKey.IPV6_ADDRESS)
1991
+ # search_keys.remove(SearchKey.IPV6_ADDRESS)
1992
+ search_keys.pop(SearchKey.IPV6_ADDRESS, None)
1812
1993
 
1813
- keys = "{" + ", ".join([f'"{key.name}": "{key_example(key)}"' for key in search_keys]) + "}"
1994
+ keys = (
1995
+ "{"
1996
+ + ", ".join(
1997
+ [
1998
+ f'"{key.name}": {{"name": "{name}", "value": "{key_example(key)}"}}'
1999
+ for key, name in search_keys.items()
2000
+ ]
2001
+ )
2002
+ + "}"
2003
+ )
1814
2004
  features_for_transform = self._search_task.get_features_for_transform()
1815
2005
  if features_for_transform:
1816
2006
  original_features_for_transform = [
@@ -1847,37 +2037,41 @@ class FeaturesEnricher(TransformerMixin):
1847
2037
  progress_bar: Optional[ProgressBar] = None,
1848
2038
  progress_callback: Optional[Callable[[SearchProgress], Any]] = None,
1849
2039
  add_fit_system_record_id: bool = False,
1850
- ) -> pd.DataFrame:
2040
+ ) -> Tuple[pd.DataFrame, Dict[str, str], List[str]]:
1851
2041
  if self._search_task is None:
1852
2042
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
1853
2043
 
1854
2044
  start_time = time.time()
1855
2045
  with MDC(trace_id=trace_id):
1856
2046
  self.logger.info("Start transform")
1857
- self.__log_debug_information(X, exclude_features_sources=exclude_features_sources)
2047
+
2048
+ validated_X = self._validate_X(X, is_transform=True)
2049
+
2050
+ self.__log_debug_information(validated_X, exclude_features_sources=exclude_features_sources)
1858
2051
 
1859
2052
  self.__validate_search_keys(self.search_keys, self.search_id)
1860
2053
 
1861
2054
  if len(self.feature_names_) == 0:
1862
2055
  self.logger.warning(self.bundle.get("no_important_features_for_transform"))
1863
- return X
2056
+ return X, {c: c for c in X.columns}, []
1864
2057
 
1865
2058
  if self._has_paid_features(exclude_features_sources):
1866
2059
  msg = self.bundle.get("transform_with_paid_features")
1867
2060
  self.logger.warning(msg)
1868
2061
  self.__display_support_link(msg)
1869
- return None
2062
+ return None, {c: c for c in X.columns}, []
1870
2063
 
1871
2064
  if not metrics_calculation:
1872
2065
  transform_usage = self.rest_client.get_current_transform_usage(trace_id)
1873
2066
  self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
1874
2067
  if transform_usage.has_limit:
1875
2068
  if len(X) > transform_usage.rest_rows:
1876
- msg = self.bundle.get("transform_usage_warning").format(len(X), transform_usage.rest_rows)
2069
+ rest_rows = max(transform_usage.rest_rows, 0)
2070
+ msg = self.bundle.get("transform_usage_warning").format(len(X), rest_rows)
1877
2071
  self.logger.warning(msg)
1878
2072
  print(msg)
1879
2073
  show_request_quote_button()
1880
- return None
2074
+ return None, {c: c for c in X.columns}, []
1881
2075
  else:
1882
2076
  msg = self.bundle.get("transform_usage_info").format(
1883
2077
  transform_usage.limit, transform_usage.transformed_rows
@@ -1885,11 +2079,11 @@ class FeaturesEnricher(TransformerMixin):
1885
2079
  self.logger.info(msg)
1886
2080
  print(msg)
1887
2081
 
1888
- validated_X = self._validate_X(X, is_transform=True)
1889
-
1890
2082
  is_demo_dataset = hash_input(validated_X) in DEMO_DATASET_HASHES
1891
2083
 
1892
- columns_to_drop = [c for c in validated_X.columns if c in self.feature_names_]
2084
+ columns_to_drop = [
2085
+ c for c in validated_X.columns if c in self.feature_names_ and c in self.dropped_client_feature_names_
2086
+ ]
1893
2087
  if len(columns_to_drop) > 0:
1894
2088
  msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
1895
2089
  self.logger.warning(msg)
@@ -1915,79 +2109,135 @@ class FeaturesEnricher(TransformerMixin):
1915
2109
  df = self.__add_country_code(df, search_keys)
1916
2110
 
1917
2111
  generated_features = []
1918
- date_column = self._get_date_column(search_keys)
2112
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1919
2113
  if date_column is not None:
1920
2114
  converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
1921
- df = converter.convert(df)
2115
+ df = converter.convert(df, keep_time=True)
1922
2116
  self.logger.info(f"Date column after convertion: {df[date_column]}")
1923
2117
  generated_features.extend(converter.generated_features)
1924
2118
  else:
1925
2119
  self.logger.info("Input dataset hasn't date column")
1926
2120
  if self.add_date_if_missing:
1927
2121
  df = self._add_current_date_as_key(df, search_keys, self.logger, self.bundle)
2122
+
2123
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2124
+ if email_columns:
2125
+ generator = EmailDomainGenerator(email_columns)
2126
+ df = generator.generate(df)
2127
+ generated_features.extend(generator.generated_features)
2128
+
2129
+ normalizer = Normalizer(self.bundle, self.logger)
2130
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2131
+ columns_renaming = normalizer.columns_renaming
2132
+
2133
+ # Don't pass all features in backend on transform
2134
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2135
+ features_for_transform = self._search_task.get_features_for_transform() or []
2136
+ if len(features_for_transform) > 0:
2137
+ missing_features_for_transform = [
2138
+ columns_renaming.get(f) for f in features_for_transform if f not in df.columns
2139
+ ]
2140
+ if len(missing_features_for_transform) > 0:
2141
+ raise ValidationError(
2142
+ self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2143
+ )
2144
+ runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_transform)
2145
+
2146
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2147
+
2148
+ df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(
2149
+ df[columns_for_system_record_id], index=False
2150
+ ).astype("float64")
2151
+
2152
+ # Explode multiple search keys
2153
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2154
+
1928
2155
  email_column = self._get_email_column(search_keys)
1929
2156
  hem_column = self._get_hem_column(search_keys)
1930
- email_converted_to_hem = False
1931
2157
  if email_column:
1932
- converter = EmailSearchKeyConverter(email_column, hem_column, search_keys, self.logger)
2158
+ converter = EmailSearchKeyConverter(
2159
+ email_column,
2160
+ hem_column,
2161
+ search_keys,
2162
+ columns_renaming,
2163
+ list(unnest_search_keys.keys()),
2164
+ self.logger,
2165
+ )
1933
2166
  df = converter.convert(df)
1934
- generated_features.extend(converter.generated_features)
1935
- email_converted_to_hem = converter.email_converted_to_hem
1936
- if (
1937
- self.detect_missing_search_keys
1938
- and list(search_keys.values()) == [SearchKey.DATE]
1939
- and self.country_code is None
1940
- ):
1941
- converter = IpToCountrySearchKeyConverter(search_keys, self.logger)
2167
+
2168
+ ip_column = self._get_ip_column(search_keys)
2169
+ if ip_column:
2170
+ converter = IpSearchKeyConverter(
2171
+ ip_column,
2172
+ search_keys,
2173
+ columns_renaming,
2174
+ list(unnest_search_keys.keys()),
2175
+ self.bundle,
2176
+ self.logger,
2177
+ )
1942
2178
  df = converter.convert(df)
1943
- generated_features = [f for f in generated_features if f in self.fit_generated_features]
1944
2179
 
1945
- meaning_types = {col: key.value for col, key in search_keys.items()}
1946
- non_keys_columns = [column for column in df.columns if column not in search_keys.keys()]
2180
+ phone_column = self._get_phone_column(search_keys)
2181
+ country_column = self._get_country_column(search_keys)
2182
+ if phone_column:
2183
+ converter = PhoneSearchKeyConverter(phone_column, country_column)
2184
+ df = converter.convert(df)
1947
2185
 
1948
- if email_converted_to_hem:
1949
- non_keys_columns.append(email_column)
2186
+ if country_column:
2187
+ converter = CountrySearchKeyConverter(country_column)
2188
+ df = converter.convert(df)
1950
2189
 
1951
- # Don't pass features in backend on transform
1952
- original_features_for_transform = None
1953
- runtime_parameters = self._get_copy_of_runtime_parameters()
1954
- if len(non_keys_columns) > 0:
1955
- # Pass only features that need for transform
1956
- features_for_transform = self._search_task.get_features_for_transform()
1957
- if features_for_transform is not None and len(features_for_transform) > 0:
1958
- file_metadata = self._search_task.get_file_metadata(trace_id)
1959
- original_features_for_transform = [
1960
- c.originalName or c.name for c in file_metadata.columns if c.name in features_for_transform
1961
- ]
1962
- non_keys_columns = [c for c in non_keys_columns if c not in original_features_for_transform]
2190
+ postal_code = self._get_postal_column(search_keys)
2191
+ if postal_code:
2192
+ converter = PostalCodeSearchKeyConverter(postal_code)
2193
+ df = converter.convert(df)
2194
+
2195
+ # generated_features = [f for f in generated_features if f in self.fit_generated_features]
1963
2196
 
1964
- runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_transform)
2197
+ meaning_types = {col: key.value for col, key in search_keys.items()}
2198
+ for col in features_for_transform:
2199
+ meaning_types[col] = FileColumnMeaningType.FEATURE
2200
+ features_not_to_pass = [
2201
+ c
2202
+ for c in df.columns
2203
+ if c not in search_keys.keys()
2204
+ and c not in features_for_transform
2205
+ and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2206
+ ]
1965
2207
 
1966
2208
  if add_fit_system_record_id:
1967
- df = self.__add_fit_system_record_id(df, dict(), search_keys)
2209
+ df = self.__add_fit_system_record_id(df, search_keys, SYSTEM_RECORD_ID)
1968
2210
  df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
1969
- non_keys_columns.append(SORT_ID)
2211
+ features_not_to_pass.append(SORT_ID)
1970
2212
 
1971
- columns_for_system_record_id = sorted(list(search_keys.keys()) + (original_features_for_transform or []))
2213
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2214
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
1972
2215
 
2216
+ # search keys might be changed after explode
2217
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
1973
2218
  df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
1974
- "Float64"
2219
+ "float64"
1975
2220
  )
1976
2221
  meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2222
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2223
+ if SEARCH_KEY_UNNEST in df.columns:
2224
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
1977
2225
 
1978
2226
  df = df.reset_index(drop=True)
1979
- system_columns_with_original_index = [SYSTEM_RECORD_ID] + generated_features
2227
+ system_columns_with_original_index = [SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID] + generated_features
1980
2228
  if add_fit_system_record_id:
1981
2229
  system_columns_with_original_index.append(SORT_ID)
1982
2230
  df_with_original_index = df[system_columns_with_original_index].copy()
1983
2231
 
1984
2232
  combined_search_keys = combine_search_keys(search_keys.keys())
1985
2233
 
1986
- df_without_features = df.drop(columns=non_keys_columns)
2234
+ df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
1987
2235
 
1988
- df_without_features = clean_full_duplicates(
1989
- df_without_features, self.logger, silent=silent_mode, bundle=self.bundle
2236
+ df_without_features, full_duplicates_warning = clean_full_duplicates(
2237
+ df_without_features, self.logger, bundle=self.bundle
1990
2238
  )
2239
+ if not silent_mode and full_duplicates_warning:
2240
+ self.__log_warning(full_duplicates_warning)
1991
2241
 
1992
2242
  del df
1993
2243
  gc.collect()
@@ -1995,14 +2245,14 @@ class FeaturesEnricher(TransformerMixin):
1995
2245
  dataset = Dataset(
1996
2246
  "sample_" + str(uuid.uuid4()),
1997
2247
  df=df_without_features,
2248
+ meaning_types=meaning_types,
2249
+ search_keys=combined_search_keys,
2250
+ unnest_search_keys=unnest_search_keys,
1998
2251
  date_format=self.date_format,
1999
2252
  rest_client=self.rest_client,
2000
2253
  logger=self.logger,
2001
2254
  )
2002
- dataset.meaning_types = meaning_types
2003
- dataset.search_keys = combined_search_keys
2004
- if email_converted_to_hem:
2005
- dataset.ignore_columns = [email_column]
2255
+ dataset.columns_renaming = columns_renaming
2006
2256
 
2007
2257
  if max_features is not None or importance_threshold is not None:
2008
2258
  exclude_features_sources = list(
@@ -2090,9 +2340,15 @@ class FeaturesEnricher(TransformerMixin):
2090
2340
  else:
2091
2341
  result = enrich()
2092
2342
 
2343
+ selecting_columns = [
2344
+ c
2345
+ for c in itertools.chain(validated_X.columns.tolist(), generated_features)
2346
+ if c not in self.dropped_client_feature_names_
2347
+ ]
2093
2348
  filtered_columns = self.__filtered_enriched_features(importance_threshold, max_features)
2094
- existing_filtered_columns = [c for c in filtered_columns if c in result.columns]
2095
- selecting_columns = validated_X.columns.tolist() + generated_features + existing_filtered_columns
2349
+ selecting_columns.extend(
2350
+ c for c in filtered_columns if c in result.columns and c not in validated_X.columns
2351
+ )
2096
2352
  if add_fit_system_record_id:
2097
2353
  selecting_columns.append(SORT_ID)
2098
2354
 
@@ -2104,7 +2360,7 @@ class FeaturesEnricher(TransformerMixin):
2104
2360
  if add_fit_system_record_id:
2105
2361
  result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2106
2362
 
2107
- return result
2363
+ return result, columns_renaming, generated_features
2108
2364
 
2109
2365
  def _get_excluded_features(self, max_features: Optional[int], importance_threshold: Optional[float]) -> List[str]:
2110
2366
  features_info = self._internal_features_info
@@ -2128,7 +2384,7 @@ class FeaturesEnricher(TransformerMixin):
2128
2384
  ]
2129
2385
  return excluded_features[feature_name_header].values.tolist()
2130
2386
 
2131
- def __validate_search_keys(self, search_keys: Dict[str, SearchKey], search_id: Optional[str]):
2387
+ def __validate_search_keys(self, search_keys: Dict[str, SearchKey], search_id: Optional[str] = None):
2132
2388
  if (search_keys is None or len(search_keys) == 0) and self.country_code is None:
2133
2389
  if search_id:
2134
2390
  self.logger.debug(f"search_id {search_id} provided without search_keys")
@@ -2139,6 +2395,14 @@ class FeaturesEnricher(TransformerMixin):
2139
2395
 
2140
2396
  key_types = search_keys.values()
2141
2397
 
2398
+ # Multiple search keys allowed only for PHONE, IP, POSTAL_CODE, EMAIL, HEM
2399
+ multi_keys = [key for key, count in Counter(key_types).items() if count > 1]
2400
+ for multi_key in multi_keys:
2401
+ if multi_key not in [SearchKey.PHONE, SearchKey.IP, SearchKey.POSTAL_CODE, SearchKey.EMAIL, SearchKey.HEM]:
2402
+ msg = self.bundle.get("unsupported_multi_key").format(multi_key)
2403
+ self.logger.warning(msg)
2404
+ raise ValidationError(msg)
2405
+
2142
2406
  if SearchKey.DATE in key_types and SearchKey.DATETIME in key_types:
2143
2407
  msg = self.bundle.get("date_and_datetime_simultanious")
2144
2408
  self.logger.warning(msg)
@@ -2154,11 +2418,11 @@ class FeaturesEnricher(TransformerMixin):
2154
2418
  self.logger.warning(msg)
2155
2419
  raise ValidationError(msg)
2156
2420
 
2157
- for key_type in SearchKey.__members__.values():
2158
- if key_type != SearchKey.CUSTOM_KEY and list(key_types).count(key_type) > 1:
2159
- msg = self.bundle.get("multiple_search_key").format(key_type)
2160
- self.logger.warning(msg)
2161
- raise ValidationError(msg)
2421
+ # for key_type in SearchKey.__members__.values():
2422
+ # if key_type != SearchKey.CUSTOM_KEY and list(key_types).count(key_type) > 1:
2423
+ # msg = self.bundle.get("multiple_search_key").format(key_type)
2424
+ # self.logger.warning(msg)
2425
+ # raise ValidationError(msg)
2162
2426
 
2163
2427
  # non_personal_keys = set(SearchKey.__members__.values()) - set(SearchKey.personal_keys())
2164
2428
  # if (
@@ -2174,6 +2438,15 @@ class FeaturesEnricher(TransformerMixin):
2174
2438
  def __is_registered(self) -> bool:
2175
2439
  return self.api_key is not None and self.api_key != ""
2176
2440
 
2441
+ def __log_warning(self, message: str, show_support_link: bool = False):
2442
+ warning_num = self.warning_counter.increment()
2443
+ formatted_message = f"WARNING #{warning_num}: {message}\n"
2444
+ if show_support_link:
2445
+ self.__display_support_link(formatted_message)
2446
+ else:
2447
+ print(formatted_message)
2448
+ self.logger.warning(message)
2449
+
2177
2450
  def __inner_fit(
2178
2451
  self,
2179
2452
  trace_id: str,
@@ -2195,8 +2468,11 @@ class FeaturesEnricher(TransformerMixin):
2195
2468
  ):
2196
2469
  self.warning_counter.reset()
2197
2470
  self.df_with_original_index = None
2198
- self.__cached_sampled_datasets = None
2471
+ self.__cached_sampled_datasets = dict()
2199
2472
  self.metrics = None
2473
+ self.fit_columns_renaming = None
2474
+ self.fit_dropped_features = set()
2475
+ self.fit_generated_features = []
2200
2476
 
2201
2477
  validated_X = self._validate_X(X)
2202
2478
  validated_y = self._validate_y(validated_X, y)
@@ -2217,9 +2493,7 @@ class FeaturesEnricher(TransformerMixin):
2217
2493
  checked_generate_features = []
2218
2494
  for gen_feature in self.generate_features:
2219
2495
  if gen_feature not in x_columns:
2220
- msg = self.bundle.get("missing_generate_feature").format(gen_feature, x_columns)
2221
- print(msg)
2222
- self.logger.warning(msg)
2496
+ self.__log_warning(self.bundle.get("missing_generate_feature").format(gen_feature, x_columns))
2223
2497
  else:
2224
2498
  checked_generate_features.append(gen_feature)
2225
2499
  self.generate_features = checked_generate_features
@@ -2228,9 +2502,9 @@ class FeaturesEnricher(TransformerMixin):
2228
2502
  validate_scoring_argument(scoring)
2229
2503
 
2230
2504
  self.__log_debug_information(
2231
- X,
2232
- y,
2233
- eval_set,
2505
+ validated_X,
2506
+ validated_y,
2507
+ validated_eval_set,
2234
2508
  exclude_features_sources=exclude_features_sources,
2235
2509
  calculate_metrics=calculate_metrics,
2236
2510
  scoring=scoring,
@@ -2240,20 +2514,6 @@ class FeaturesEnricher(TransformerMixin):
2240
2514
 
2241
2515
  df = pd.concat([validated_X, validated_y], axis=1)
2242
2516
 
2243
- self.fit_search_keys = self.search_keys.copy()
2244
- self.fit_search_keys = self.__prepare_search_keys(validated_X, self.fit_search_keys, is_demo_dataset)
2245
-
2246
- validate_dates_distribution(validated_X, self.fit_search_keys, self.logger, self.bundle, self.warning_counter)
2247
-
2248
- maybe_date_column = self._get_date_column(self.fit_search_keys)
2249
- has_date = maybe_date_column is not None
2250
- model_task_type = self.model_task_type or define_task(validated_y, has_date, self.logger)
2251
- self._validate_binary_observations(validated_y, model_task_type)
2252
-
2253
- self.runtime_parameters = get_runtime_params_custom_loss(
2254
- self.loss, model_task_type, self.runtime_parameters, self.logger
2255
- )
2256
-
2257
2517
  if validated_eval_set is not None and len(validated_eval_set) > 0:
2258
2518
  df[EVAL_SET_INDEX] = 0
2259
2519
  for idx, (eval_X, eval_y) in enumerate(validated_eval_set):
@@ -2261,12 +2521,21 @@ class FeaturesEnricher(TransformerMixin):
2261
2521
  eval_df[EVAL_SET_INDEX] = idx + 1
2262
2522
  df = pd.concat([df, eval_df])
2263
2523
 
2264
- df = self.__correct_target(df)
2265
-
2524
+ self.fit_search_keys = self.search_keys.copy()
2266
2525
  df = self.__handle_index_search_keys(df, self.fit_search_keys)
2526
+ self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
2267
2527
 
2268
- if is_numeric_dtype(df[self.TARGET_NAME]) and has_date:
2269
- self._validate_PSI(df.sort_values(by=maybe_date_column))
2528
+ maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2529
+ has_date = maybe_date_column is not None
2530
+ self.model_task_type = self.model_task_type or define_task(validated_y, has_date, self.logger)
2531
+
2532
+ self._validate_binary_observations(validated_y, self.model_task_type)
2533
+
2534
+ self.runtime_parameters = get_runtime_params_custom_loss(
2535
+ self.loss, self.model_task_type, self.runtime_parameters, self.logger
2536
+ )
2537
+
2538
+ df = self.__correct_target(df)
2270
2539
 
2271
2540
  if DEFAULT_INDEX in df.columns:
2272
2541
  msg = self.bundle.get("unsupported_index_column")
@@ -2277,58 +2546,132 @@ class FeaturesEnricher(TransformerMixin):
2277
2546
 
2278
2547
  df = self.__add_country_code(df, self.fit_search_keys)
2279
2548
 
2280
- df = remove_fintech_duplicates(
2281
- df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
2282
- )
2283
- df = clean_full_duplicates(df, self.logger, bundle=self.bundle)
2284
-
2285
- date_column = self._get_date_column(self.fit_search_keys)
2286
- self.__adjust_cv(df, date_column, model_task_type)
2287
-
2288
2549
  self.fit_generated_features = []
2289
2550
 
2290
- if date_column is not None:
2291
- converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
2551
+ if has_date:
2552
+ converter = DateTimeSearchKeyConverter(
2553
+ maybe_date_column,
2554
+ self.date_format,
2555
+ self.logger,
2556
+ bundle=self.bundle,
2557
+ )
2292
2558
  df = converter.convert(df, keep_time=True)
2293
- self.logger.info(f"Date column after convertion: {df[date_column]}")
2559
+ if converter.has_old_dates:
2560
+ self.__log_warning(self.bundle.get("dataset_drop_old_dates"))
2561
+ self.logger.info(f"Date column after convertion: {df[maybe_date_column]}")
2294
2562
  self.fit_generated_features.extend(converter.generated_features)
2295
2563
  else:
2296
2564
  self.logger.info("Input dataset hasn't date column")
2297
2565
  if self.add_date_if_missing:
2298
2566
  df = self._add_current_date_as_key(df, self.fit_search_keys, self.logger, self.bundle)
2567
+
2568
+ email_columns = SearchKey.find_all_keys(self.fit_search_keys, SearchKey.EMAIL)
2569
+ if email_columns:
2570
+ generator = EmailDomainGenerator(email_columns)
2571
+ df = generator.generate(df)
2572
+ self.fit_generated_features.extend(generator.generated_features)
2573
+
2574
+ # Checks that need validated date
2575
+ try:
2576
+ if not is_dates_distribution_valid(df, self.fit_search_keys):
2577
+ self.__log_warning(bundle.get("x_unstable_by_date"))
2578
+ except Exception:
2579
+ self.logger.exception("Failed to check dates distribution validity")
2580
+
2581
+ if (
2582
+ is_numeric_dtype(df[self.TARGET_NAME])
2583
+ and self.model_task_type in [ModelTaskType.BINARY, ModelTaskType.MULTICLASS]
2584
+ and has_date
2585
+ ):
2586
+ self._validate_PSI(df.sort_values(by=maybe_date_column))
2587
+
2588
+ normalizer = Normalizer(self.bundle, self.logger)
2589
+ df, self.fit_search_keys, self.fit_generated_features = normalizer.normalize(
2590
+ df, self.fit_search_keys, self.fit_generated_features
2591
+ )
2592
+ self.fit_columns_renaming = normalizer.columns_renaming
2593
+ if normalizer.removed_features:
2594
+ self.__log_warning(self.bundle.get("dataset_date_features").format(normalizer.removed_features))
2595
+
2596
+ self.__adjust_cv(df)
2597
+
2598
+ df, fintech_warnings = remove_fintech_duplicates(
2599
+ df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
2600
+ )
2601
+ if fintech_warnings:
2602
+ for fintech_warning in fintech_warnings:
2603
+ self.__log_warning(fintech_warning)
2604
+ df, full_duplicates_warning = clean_full_duplicates(df, self.logger, bundle=self.bundle)
2605
+ if full_duplicates_warning:
2606
+ self.__log_warning(full_duplicates_warning)
2607
+
2608
+ # Explode multiple search keys
2609
+ df = self.__add_fit_system_record_id(df, self.fit_search_keys, ENTITY_SYSTEM_RECORD_ID)
2610
+
2611
+ # TODO check that this is correct for enrichment
2612
+ self.df_with_original_index = df.copy()
2613
+ # TODO check maybe need to drop _time column from df_with_original_index
2614
+
2615
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, self.fit_search_keys, self.fit_columns_renaming)
2616
+
2617
+ # Convert EMAIL to HEM after unnesting to do it only with one column
2299
2618
  email_column = self._get_email_column(self.fit_search_keys)
2300
2619
  hem_column = self._get_hem_column(self.fit_search_keys)
2301
- email_converted_to_hem = False
2302
2620
  if email_column:
2303
- converter = EmailSearchKeyConverter(email_column, hem_column, self.fit_search_keys, self.logger)
2621
+ converter = EmailSearchKeyConverter(
2622
+ email_column,
2623
+ hem_column,
2624
+ self.fit_search_keys,
2625
+ self.fit_columns_renaming,
2626
+ list(unnest_search_keys.keys()),
2627
+ self.logger,
2628
+ )
2304
2629
  df = converter.convert(df)
2305
- self.fit_generated_features.extend(converter.generated_features)
2306
- email_converted_to_hem = converter.email_converted_to_hem
2307
- if (
2308
- self.detect_missing_search_keys
2309
- and list(self.fit_search_keys.values()) == [SearchKey.DATE]
2310
- and self.country_code is None
2311
- ):
2312
- converter = IpToCountrySearchKeyConverter(self.fit_search_keys, self.logger)
2630
+
2631
+ ip_column = self._get_ip_column(self.fit_search_keys)
2632
+ if ip_column:
2633
+ converter = IpSearchKeyConverter(
2634
+ ip_column,
2635
+ self.fit_search_keys,
2636
+ self.fit_columns_renaming,
2637
+ list(unnest_search_keys.keys()),
2638
+ self.bundle,
2639
+ self.logger,
2640
+ )
2641
+ df = converter.convert(df)
2642
+
2643
+ phone_column = self._get_phone_column(self.fit_search_keys)
2644
+ country_column = self._get_country_column(self.fit_search_keys)
2645
+ if phone_column:
2646
+ converter = PhoneSearchKeyConverter(phone_column, country_column)
2647
+ df = converter.convert(df)
2648
+
2649
+ if country_column:
2650
+ converter = CountrySearchKeyConverter(country_column)
2651
+ df = converter.convert(df)
2652
+
2653
+ postal_code = self._get_postal_column(self.fit_search_keys)
2654
+ if postal_code:
2655
+ converter = PostalCodeSearchKeyConverter(postal_code)
2313
2656
  df = converter.convert(df)
2314
2657
 
2315
- non_feature_columns = [self.TARGET_NAME, EVAL_SET_INDEX] + list(self.fit_search_keys.keys())
2316
- if email_converted_to_hem:
2317
- non_feature_columns.append(email_column)
2658
+ non_feature_columns = [self.TARGET_NAME, EVAL_SET_INDEX, ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST] + list(
2659
+ self.fit_search_keys.keys()
2660
+ )
2318
2661
  if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2319
2662
  non_feature_columns.append(DateTimeSearchKeyConverter.DATETIME_COL)
2320
2663
 
2321
2664
  features_columns = [c for c in df.columns if c not in non_feature_columns]
2322
2665
 
2323
- features_to_drop = FeaturesValidator(self.logger).validate(
2324
- df, features_columns, self.generate_features, self.warning_counter
2666
+ features_to_drop, feature_validator_warnings = FeaturesValidator(self.logger).validate(
2667
+ df, features_columns, self.generate_features, self.fit_columns_renaming
2325
2668
  )
2669
+ if feature_validator_warnings:
2670
+ for warning in feature_validator_warnings:
2671
+ self.__log_warning(warning)
2326
2672
  self.fit_dropped_features.update(features_to_drop)
2327
2673
  df = df.drop(columns=features_to_drop)
2328
2674
 
2329
- if email_converted_to_hem:
2330
- self.fit_dropped_features.add(email_column)
2331
-
2332
2675
  self.fit_generated_features = [f for f in self.fit_generated_features if f not in self.fit_dropped_features]
2333
2676
 
2334
2677
  meaning_types = {
@@ -2336,12 +2679,19 @@ class FeaturesEnricher(TransformerMixin):
2336
2679
  **{str(c): FileColumnMeaningType.FEATURE for c in df.columns if c not in non_feature_columns},
2337
2680
  }
2338
2681
  meaning_types[self.TARGET_NAME] = FileColumnMeaningType.TARGET
2682
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2683
+ if SEARCH_KEY_UNNEST in df.columns:
2684
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2339
2685
  if eval_set is not None and len(eval_set) > 0:
2340
2686
  meaning_types[EVAL_SET_INDEX] = FileColumnMeaningType.EVAL_SET_INDEX
2341
2687
 
2342
- df = self.__add_fit_system_record_id(df, meaning_types, self.fit_search_keys)
2688
+ df = self.__add_fit_system_record_id(df, self.fit_search_keys, SYSTEM_RECORD_ID)
2689
+
2690
+ if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2691
+ df = df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL)
2692
+
2693
+ meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2343
2694
 
2344
- self.df_with_original_index = df.copy()
2345
2695
  df = df.reset_index(drop=True).sort_values(by=SYSTEM_RECORD_ID).reset_index(drop=True)
2346
2696
 
2347
2697
  combined_search_keys = combine_search_keys(self.fit_search_keys.keys())
@@ -2349,16 +2699,16 @@ class FeaturesEnricher(TransformerMixin):
2349
2699
  dataset = Dataset(
2350
2700
  "tds_" + str(uuid.uuid4()),
2351
2701
  df=df,
2352
- model_task_type=model_task_type,
2702
+ meaning_types=meaning_types,
2703
+ search_keys=combined_search_keys,
2704
+ unnest_search_keys=unnest_search_keys,
2705
+ model_task_type=self.model_task_type,
2353
2706
  date_format=self.date_format,
2354
2707
  random_state=self.random_state,
2355
2708
  rest_client=self.rest_client,
2356
2709
  logger=self.logger,
2357
2710
  )
2358
- dataset.meaning_types = meaning_types
2359
- dataset.search_keys = combined_search_keys
2360
- if email_converted_to_hem:
2361
- dataset.ignore_columns = [email_column]
2711
+ dataset.columns_renaming = self.fit_columns_renaming
2362
2712
 
2363
2713
  self.passed_features = [
2364
2714
  column for column, meaning_type in meaning_types.items() if meaning_type == FileColumnMeaningType.FEATURE
@@ -2434,9 +2784,7 @@ class FeaturesEnricher(TransformerMixin):
2434
2784
  zero_hit_columns = self.get_columns_by_search_keys(zero_hit_search_keys)
2435
2785
  if zero_hit_columns:
2436
2786
  msg = self.bundle.get("features_info_zero_hit_rate_search_keys").format(zero_hit_columns)
2437
- self.logger.warning(msg)
2438
- self.__display_support_link(msg)
2439
- self.warning_counter.increment()
2787
+ self.__log_warning(msg, show_support_link=True)
2440
2788
 
2441
2789
  if (
2442
2790
  self._search_task.unused_features_for_generation is not None
@@ -2446,9 +2794,7 @@ class FeaturesEnricher(TransformerMixin):
2446
2794
  dataset.columns_renaming.get(col) or col for col in self._search_task.unused_features_for_generation
2447
2795
  ]
2448
2796
  msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
2449
- self.logger.warning(msg)
2450
- print(msg)
2451
- self.warning_counter.increment()
2797
+ self.__log_warning(msg)
2452
2798
 
2453
2799
  self.__prepare_feature_importances(trace_id, validated_X.columns.to_list() + self.fit_generated_features)
2454
2800
 
@@ -2456,7 +2802,13 @@ class FeaturesEnricher(TransformerMixin):
2456
2802
 
2457
2803
  autofe_description = self.get_autofe_features_description()
2458
2804
  if autofe_description is not None:
2459
- display_html_dataframe(autofe_description, autofe_description, "*Description of AutoFE feature names")
2805
+ self.logger.info(f"AutoFE descriptions: {autofe_description}")
2806
+ self.autofe_features_display_handle = display_html_dataframe(
2807
+ df=autofe_description,
2808
+ internal_df=autofe_description,
2809
+ header=self.bundle.get("autofe_descriptions_header"),
2810
+ display_id="autofe_descriptions",
2811
+ )
2460
2812
 
2461
2813
  if self._has_paid_features(exclude_features_sources):
2462
2814
  if calculate_metrics is not None and calculate_metrics:
@@ -2496,32 +2848,32 @@ class FeaturesEnricher(TransformerMixin):
2496
2848
  progress_callback,
2497
2849
  )
2498
2850
  except Exception:
2499
- self.__show_report_button()
2851
+ self.report_button_handle = self.__show_report_button(display_id="report_button")
2500
2852
  raise
2501
2853
 
2502
- self.__show_report_button()
2854
+ self.report_button_handle = self.__show_report_button(display_id="report_button")
2503
2855
 
2504
2856
  if not self.warning_counter.has_warnings():
2505
2857
  self.__display_support_link(self.bundle.get("all_ok_community_invite"))
2506
2858
 
2507
- def __adjust_cv(self, df: pd.DataFrame, date_column: pd.Series, model_task_type: ModelTaskType):
2859
+ def __adjust_cv(self, df: pd.DataFrame):
2860
+ date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2508
2861
  # Check Multivariate time series
2509
2862
  if (
2510
2863
  self.cv is None
2511
2864
  and date_column
2512
- and model_task_type == ModelTaskType.REGRESSION
2865
+ and self.model_task_type == ModelTaskType.REGRESSION
2513
2866
  and len({SearchKey.PHONE, SearchKey.EMAIL, SearchKey.HEM}.intersection(self.fit_search_keys.keys())) == 0
2514
2867
  and is_blocked_time_series(df, date_column, list(self.fit_search_keys.keys()) + [TARGET])
2515
2868
  ):
2516
2869
  msg = self.bundle.get("multivariate_timeseries_detected")
2517
2870
  self.__override_cv(CVType.blocked_time_series, msg, print_warning=False)
2518
- elif (
2519
- self.cv is None
2520
- and model_task_type != ModelTaskType.REGRESSION
2521
- and self._get_group_columns(df, self.fit_search_keys)
2522
- ):
2871
+ elif self.cv is None and self.model_task_type != ModelTaskType.REGRESSION:
2523
2872
  msg = self.bundle.get("group_k_fold_in_classification")
2524
2873
  self.__override_cv(CVType.group_k_fold, msg, print_warning=self.cv is not None)
2874
+ group_columns = self._get_group_columns(df, self.fit_search_keys)
2875
+ self.runtime_parameters.properties["cv_params.group_columns"] = ",".join(group_columns)
2876
+ self.runtime_parameters.properties["cv_params.shuffle_kfold"] = "True"
2525
2877
 
2526
2878
  def __override_cv(self, cv: CVType, msg: str, print_warning: bool = True):
2527
2879
  if print_warning:
@@ -2539,9 +2891,6 @@ class FeaturesEnricher(TransformerMixin):
2539
2891
  return [c for c, v in search_keys_with_autodetection.items() if v.value.value in keys]
2540
2892
 
2541
2893
  def _validate_X(self, X, is_transform=False) -> pd.DataFrame:
2542
- if _num_samples(X) == 0:
2543
- raise ValidationError(self.bundle.get("x_is_empty"))
2544
-
2545
2894
  if isinstance(X, pd.DataFrame):
2546
2895
  if isinstance(X.columns, pd.MultiIndex) or isinstance(X.index, pd.MultiIndex):
2547
2896
  raise ValidationError(self.bundle.get("x_multiindex_unsupported"))
@@ -2555,6 +2904,9 @@ class FeaturesEnricher(TransformerMixin):
2555
2904
  else:
2556
2905
  raise ValidationError(self.bundle.get("unsupported_x_type").format(type(X)))
2557
2906
 
2907
+ if _num_samples(X) == 0:
2908
+ raise ValidationError(self.bundle.get("x_is_empty"))
2909
+
2558
2910
  if len(set(validated_X.columns)) != len(validated_X.columns):
2559
2911
  raise ValidationError(self.bundle.get("x_contains_dup_columns"))
2560
2912
  if not is_transform and not validated_X.index.is_unique:
@@ -2574,13 +2926,12 @@ class FeaturesEnricher(TransformerMixin):
2574
2926
  raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(EVAL_SET_INDEX))
2575
2927
  if SYSTEM_RECORD_ID in validated_X.columns:
2576
2928
  raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(SYSTEM_RECORD_ID))
2929
+ if ENTITY_SYSTEM_RECORD_ID in validated_X.columns:
2930
+ raise ValidationError(self.bundle.get("x_contains_reserved_column_name").format(ENTITY_SYSTEM_RECORD_ID))
2577
2931
 
2578
2932
  return validated_X
2579
2933
 
2580
2934
  def _validate_y(self, X: pd.DataFrame, y) -> pd.Series:
2581
- if _num_samples(y) == 0:
2582
- raise ValidationError(self.bundle.get("y_is_empty"))
2583
-
2584
2935
  if (
2585
2936
  not isinstance(y, pd.Series)
2586
2937
  and not isinstance(y, pd.DataFrame)
@@ -2589,6 +2940,9 @@ class FeaturesEnricher(TransformerMixin):
2589
2940
  ):
2590
2941
  raise ValidationError(self.bundle.get("unsupported_y_type").format(type(y)))
2591
2942
 
2943
+ if _num_samples(y) == 0:
2944
+ raise ValidationError(self.bundle.get("y_is_empty"))
2945
+
2592
2946
  if _num_samples(X) != _num_samples(y):
2593
2947
  raise ValidationError(self.bundle.get("x_and_y_diff_size").format(_num_samples(X), _num_samples(y)))
2594
2948
 
@@ -2726,9 +3080,10 @@ class FeaturesEnricher(TransformerMixin):
2726
3080
  X: pd.DataFrame, y: pd.Series, cv: Optional[CVType]
2727
3081
  ) -> Tuple[pd.DataFrame, pd.Series]:
2728
3082
  if cv not in [CVType.time_series, CVType.blocked_time_series]:
3083
+ record_id_column = ENTITY_SYSTEM_RECORD_ID if ENTITY_SYSTEM_RECORD_ID in X else SYSTEM_RECORD_ID
2729
3084
  Xy = X.copy()
2730
3085
  Xy[TARGET] = y
2731
- Xy = Xy.sort_values(by=SYSTEM_RECORD_ID).reset_index(drop=True)
3086
+ Xy = Xy.sort_values(by=record_id_column).reset_index(drop=True)
2732
3087
  X = Xy.drop(columns=TARGET)
2733
3088
  y = Xy[TARGET].copy()
2734
3089
 
@@ -2746,7 +3101,7 @@ class FeaturesEnricher(TransformerMixin):
2746
3101
  if DateTimeSearchKeyConverter.DATETIME_COL in X.columns:
2747
3102
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
2748
3103
  else:
2749
- date_column = FeaturesEnricher._get_date_column(search_keys)
3104
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2750
3105
  sort_columns = [date_column] if date_column is not None else []
2751
3106
 
2752
3107
  # Xy = pd.concat([X, y], axis=1)
@@ -2842,15 +3197,17 @@ class FeaturesEnricher(TransformerMixin):
2842
3197
 
2843
3198
  do_without_pandas_limits(print_datasets_sample)
2844
3199
 
2845
- maybe_date_col = self._get_date_column(self.search_keys)
3200
+ maybe_date_col = SearchKey.find_key(self.search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2846
3201
  if X is not None and maybe_date_col is not None and maybe_date_col in X.columns:
2847
3202
  # TODO cast date column to single dtype
2848
- min_date = X[maybe_date_col].min()
2849
- max_date = X[maybe_date_col].max()
3203
+ date_converter = DateTimeSearchKeyConverter(maybe_date_col, self.date_format)
3204
+ converted_X = date_converter.convert(X)
3205
+ min_date = converted_X[maybe_date_col].min()
3206
+ max_date = converted_X[maybe_date_col].max()
2850
3207
  self.logger.info(f"Dates interval is ({min_date}, {max_date})")
2851
3208
 
2852
3209
  except Exception:
2853
- self.logger.exception("Failed to log debug information")
3210
+ self.logger.warning("Failed to log debug information", exc_info=True)
2854
3211
 
2855
3212
  def __handle_index_search_keys(self, df: pd.DataFrame, search_keys: Dict[str, SearchKey]) -> pd.DataFrame:
2856
3213
  index_names = df.index.names if df.index.names != [None] else [DEFAULT_INDEX]
@@ -2870,15 +3227,8 @@ class FeaturesEnricher(TransformerMixin):
2870
3227
 
2871
3228
  return df
2872
3229
 
2873
- @staticmethod
2874
- def _get_date_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2875
- for col, t in search_keys.items():
2876
- if t in [SearchKey.DATE, SearchKey.DATETIME]:
2877
- return col
2878
-
2879
- @staticmethod
2880
3230
  def _add_current_date_as_key(
2881
- df: pd.DataFrame, search_keys: Dict[str, SearchKey], logger: logging.Logger, bundle: ResourceBundle
3231
+ self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], logger: logging.Logger, bundle: ResourceBundle
2882
3232
  ) -> pd.DataFrame:
2883
3233
  if (
2884
3234
  set(search_keys.values()) == {SearchKey.PHONE}
@@ -2886,12 +3236,10 @@ class FeaturesEnricher(TransformerMixin):
2886
3236
  or set(search_keys.values()) == {SearchKey.HEM}
2887
3237
  or set(search_keys.values()) == {SearchKey.COUNTRY, SearchKey.POSTAL_CODE}
2888
3238
  ):
2889
- msg = bundle.get("current_date_added")
2890
- print(msg)
2891
- logger.warning(msg)
3239
+ self.__log_warning(bundle.get("current_date_added"))
2892
3240
  df[FeaturesEnricher.CURRENT_DATE] = datetime.date.today()
2893
3241
  search_keys[FeaturesEnricher.CURRENT_DATE] = SearchKey.DATE
2894
- converter = DateTimeSearchKeyConverter(FeaturesEnricher.CURRENT_DATE, None, logger, bundle)
3242
+ converter = DateTimeSearchKeyConverter(FeaturesEnricher.CURRENT_DATE)
2895
3243
  df = converter.convert(df)
2896
3244
  return df
2897
3245
 
@@ -2905,24 +3253,87 @@ class FeaturesEnricher(TransformerMixin):
2905
3253
 
2906
3254
  @staticmethod
2907
3255
  def _get_email_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3256
+ cols = [col for col, t in search_keys.items() if t == SearchKey.EMAIL]
3257
+ if len(cols) > 1:
3258
+ raise Exception("More than one email column found after unnest")
3259
+ if len(cols) == 1:
3260
+ return cols[0]
3261
+
3262
+ @staticmethod
3263
+ def _get_hem_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3264
+ cols = [col for col, t in search_keys.items() if t == SearchKey.HEM]
3265
+ if len(cols) > 1:
3266
+ raise Exception("More than one hem column found after unnest")
3267
+ if len(cols) == 1:
3268
+ return cols[0]
3269
+
3270
+ @staticmethod
3271
+ def _get_ip_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3272
+ cols = [col for col, t in search_keys.items() if t == SearchKey.IP]
3273
+ if len(cols) > 1:
3274
+ raise Exception("More than one ip column found after unnest")
3275
+ if len(cols) == 1:
3276
+ return cols[0]
3277
+
3278
+ @staticmethod
3279
+ def _get_phone_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2908
3280
  for col, t in search_keys.items():
2909
- if t == SearchKey.EMAIL:
3281
+ if t == SearchKey.PHONE:
2910
3282
  return col
2911
3283
 
2912
3284
  @staticmethod
2913
- def _get_hem_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3285
+ def _get_country_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2914
3286
  for col, t in search_keys.items():
2915
- if t == SearchKey.HEM:
3287
+ if t == SearchKey.COUNTRY:
2916
3288
  return col
2917
3289
 
2918
3290
  @staticmethod
2919
- def _get_phone_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3291
+ def _get_postal_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
2920
3292
  for col, t in search_keys.items():
2921
- if t == SearchKey.PHONE:
3293
+ if t == SearchKey.POSTAL_CODE:
2922
3294
  return col
2923
3295
 
3296
+ def _explode_multiple_search_keys(
3297
+ self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], columns_renaming: Dict[str, str]
3298
+ ) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
3299
+ # find groups of multiple search keys
3300
+ search_key_names_by_type: Dict[SearchKey, List[str]] = {}
3301
+ for key_name, key_type in search_keys.items():
3302
+ search_key_names_by_type[key_type] = search_key_names_by_type.get(key_type, []) + [key_name]
3303
+ search_key_names_by_type = {
3304
+ key_type: key_names for key_type, key_names in search_key_names_by_type.items() if len(key_names) > 1
3305
+ }
3306
+ if len(search_key_names_by_type) == 0:
3307
+ return df, {}
3308
+
3309
+ self.logger.info(f"Start exploding dataset by {search_key_names_by_type}. Size before: {len(df)}")
3310
+ multiple_keys_columns = [col for cols in search_key_names_by_type.values() for col in cols]
3311
+ other_columns = [col for col in df.columns if col not in multiple_keys_columns]
3312
+ exploded_dfs = []
3313
+ unnest_search_keys = {}
3314
+
3315
+ for key_type, key_names in search_key_names_by_type.items():
3316
+ new_search_key = f"upgini_{key_type.name.lower()}_unnest"
3317
+ exploded_df = pd.melt(
3318
+ df, id_vars=other_columns, value_vars=key_names, var_name=SEARCH_KEY_UNNEST, value_name=new_search_key
3319
+ )
3320
+ exploded_dfs.append(exploded_df)
3321
+ for old_key in key_names:
3322
+ del search_keys[old_key]
3323
+ search_keys[new_search_key] = key_type
3324
+ unnest_search_keys[new_search_key] = key_names
3325
+ columns_renaming[new_search_key] = new_search_key
3326
+
3327
+ df = pd.concat(exploded_dfs, ignore_index=True)
3328
+ self.logger.info(f"Finished explosion. Size after: {len(df)}")
3329
+ return df, unnest_search_keys
3330
+
2924
3331
  def __add_fit_system_record_id(
2925
- self, df: pd.DataFrame, meaning_types: Dict[str, FileColumnMeaningType], search_keys: Dict[str, SearchKey]
3332
+ self,
3333
+ df: pd.DataFrame,
3334
+ # meaning_types: Dict[str, FileColumnMeaningType],
3335
+ search_keys: Dict[str, SearchKey],
3336
+ id_name: str,
2926
3337
  ) -> pd.DataFrame:
2927
3338
  # save original order or rows
2928
3339
  original_index_name = df.index.name
@@ -2933,52 +3344,61 @@ class FeaturesEnricher(TransformerMixin):
2933
3344
 
2934
3345
  # order by date and idempotent order by other keys
2935
3346
  if self.cv not in [CVType.time_series, CVType.blocked_time_series]:
2936
- sort_exclude_columns = [original_order_name, ORIGINAL_INDEX, EVAL_SET_INDEX, TARGET, "__target"]
3347
+ sort_exclude_columns = [
3348
+ original_order_name,
3349
+ ORIGINAL_INDEX,
3350
+ EVAL_SET_INDEX,
3351
+ TARGET,
3352
+ "__target",
3353
+ ENTITY_SYSTEM_RECORD_ID,
3354
+ ]
2937
3355
  if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2938
3356
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
2939
- sort_exclude_columns.append(self._get_date_column(search_keys))
3357
+ sort_exclude_columns.append(SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]))
2940
3358
  else:
2941
- date_column = self._get_date_column(search_keys)
3359
+ date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2942
3360
  sort_columns = [date_column] if date_column is not None else []
2943
3361
 
3362
+ sorted_other_keys = sorted(search_keys, key=lambda x: str(search_keys.get(x)))
3363
+ sorted_other_keys = [k for k in sorted_other_keys if k not in sort_exclude_columns]
3364
+
2944
3365
  other_columns = sorted(
2945
3366
  [
2946
3367
  c
2947
3368
  for c in df.columns
2948
- if c not in sort_columns and c not in sort_exclude_columns and df[c].nunique() > 1
3369
+ if c not in sort_columns
3370
+ and c not in sorted_other_keys
3371
+ and c not in sort_exclude_columns
3372
+ and df[c].nunique() > 1
2949
3373
  ]
2950
- # [
2951
- # sk
2952
- # for sk, key_type in search_keys.items()
2953
- # if key_type not in [SearchKey.DATE, SearchKey.DATETIME]
2954
- # and sk in df.columns
2955
- # and df[sk].nunique() > 1 # don't use constant keys for hash
2956
- # ]
2957
3374
  )
2958
3375
 
3376
+ all_other_columns = sorted_other_keys + other_columns
3377
+
2959
3378
  search_keys_hash = "search_keys_hash"
2960
- if len(other_columns) > 0:
3379
+ if len(all_other_columns) > 0:
2961
3380
  sort_columns.append(search_keys_hash)
2962
- df[search_keys_hash] = pd.util.hash_pandas_object(df[other_columns], index=False)
3381
+ df[search_keys_hash] = pd.util.hash_pandas_object(df[all_other_columns], index=False)
2963
3382
 
2964
3383
  df = df.sort_values(by=sort_columns)
2965
3384
 
2966
3385
  if search_keys_hash in df.columns:
2967
3386
  df.drop(columns=search_keys_hash, inplace=True)
2968
3387
 
2969
- if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
2970
- df.drop(columns=DateTimeSearchKeyConverter.DATETIME_COL, inplace=True)
2971
-
2972
3388
  df = df.reset_index(drop=True).reset_index()
2973
3389
  # system_record_id saves correct order for fit
2974
- df = df.rename(columns={DEFAULT_INDEX: SYSTEM_RECORD_ID})
3390
+ df = df.rename(columns={DEFAULT_INDEX: id_name})
2975
3391
 
2976
3392
  # return original order
2977
3393
  df = df.set_index(ORIGINAL_INDEX)
2978
3394
  df.index.name = original_index_name
2979
3395
  df = df.sort_values(by=original_order_name).drop(columns=original_order_name)
2980
3396
 
2981
- meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
3397
+ # meaning_types[id_name] = (
3398
+ # FileColumnMeaningType.SYSTEM_RECORD_ID
3399
+ # if id_name == SYSTEM_RECORD_ID
3400
+ # else FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
3401
+ # )
2982
3402
  return df
2983
3403
 
2984
3404
  def __correct_target(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -3033,7 +3453,11 @@ class FeaturesEnricher(TransformerMixin):
3033
3453
  )
3034
3454
 
3035
3455
  comparing_columns = X.columns if is_transform else df_with_original_index.columns
3036
- dup_features = [c for c in comparing_columns if c in result_features.columns and c != SYSTEM_RECORD_ID]
3456
+ dup_features = [
3457
+ c
3458
+ for c in comparing_columns
3459
+ if c in result_features.columns and c not in [SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID]
3460
+ ]
3037
3461
  if len(dup_features) > 0:
3038
3462
  self.logger.warning(f"X contain columns with same name as returned from backend: {dup_features}")
3039
3463
  raise ValidationError(self.bundle.get("returned_features_same_as_passed").format(dup_features))
@@ -3041,11 +3465,11 @@ class FeaturesEnricher(TransformerMixin):
3041
3465
  # index overrites from result_features
3042
3466
  original_index_name = df_with_original_index.index.name
3043
3467
  df_with_original_index = df_with_original_index.reset_index()
3468
+ # TODO drop system_record_id before merge
3044
3469
  result_features = pd.merge(
3045
3470
  df_with_original_index,
3046
3471
  result_features,
3047
- left_on=SYSTEM_RECORD_ID,
3048
- right_on=SYSTEM_RECORD_ID,
3472
+ on=ENTITY_SYSTEM_RECORD_ID,
3049
3473
  how="left" if is_transform else "inner",
3050
3474
  )
3051
3475
  result_features = result_features.set_index(original_index_name or DEFAULT_INDEX)
@@ -3053,10 +3477,12 @@ class FeaturesEnricher(TransformerMixin):
3053
3477
 
3054
3478
  if rows_to_drop is not None:
3055
3479
  self.logger.info(f"Before dropping target outliers size: {len(result_features)}")
3056
- result_features = result_features[~result_features[SYSTEM_RECORD_ID].isin(rows_to_drop[SYSTEM_RECORD_ID])]
3480
+ result_features = result_features[
3481
+ ~result_features[ENTITY_SYSTEM_RECORD_ID].isin(rows_to_drop[ENTITY_SYSTEM_RECORD_ID])
3482
+ ]
3057
3483
  self.logger.info(f"After dropping target outliers size: {len(result_features)}")
3058
3484
 
3059
- result_eval_sets = dict()
3485
+ result_eval_sets = {}
3060
3486
  if not is_transform and EVAL_SET_INDEX in result_features.columns:
3061
3487
  result_train_features = result_features.loc[result_features[EVAL_SET_INDEX] == 0].copy()
3062
3488
  eval_set_indices = list(result_features[EVAL_SET_INDEX].unique())
@@ -3086,16 +3512,17 @@ class FeaturesEnricher(TransformerMixin):
3086
3512
  result_train = result_train_features
3087
3513
 
3088
3514
  if drop_system_record_id:
3089
- if SYSTEM_RECORD_ID in result_train.columns:
3090
- result_train = result_train.drop(columns=SYSTEM_RECORD_ID)
3515
+ result_train = result_train.drop(columns=[SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID], errors="ignore")
3091
3516
  for eval_set_index in result_eval_sets.keys():
3092
- if SYSTEM_RECORD_ID in result_eval_sets[eval_set_index].columns:
3093
- result_eval_sets[eval_set_index] = result_eval_sets[eval_set_index].drop(columns=SYSTEM_RECORD_ID)
3517
+ result_eval_sets[eval_set_index] = result_eval_sets[eval_set_index].drop(
3518
+ columns=[SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID], errors="ignore"
3519
+ )
3094
3520
 
3095
3521
  return result_train, result_eval_sets
3096
3522
 
3097
- def __prepare_feature_importances(self, trace_id: str, x_columns: List[str], silent=False):
3098
- llm_source = "LLM with external data augmentation"
3523
+ def __prepare_feature_importances(
3524
+ self, trace_id: str, x_columns: List[str], updated_shaps: Optional[Dict[str, float]] = None, silent=False
3525
+ ):
3099
3526
  if self._search_task is None:
3100
3527
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
3101
3528
  features_meta = self._search_task.get_all_features_metadata_v2()
@@ -3106,122 +3533,44 @@ class FeaturesEnricher(TransformerMixin):
3106
3533
  features_df = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
3107
3534
 
3108
3535
  self.feature_names_ = []
3536
+ self.dropped_client_feature_names_ = []
3109
3537
  self.feature_importances_ = []
3110
3538
  features_info = []
3111
3539
  features_info_without_links = []
3112
3540
  internal_features_info = []
3113
3541
 
3114
- def round_shap_value(shap: float) -> float:
3115
- if shap > 0.0 and shap < 0.0001:
3116
- return 0.0001
3117
- else:
3118
- return round(shap, 4)
3119
-
3120
- def list_or_single(lst: List[str], single: str):
3121
- return lst or ([single] if single else [])
3122
-
3123
- def to_anchor(link: str, value: str) -> str:
3124
- if not value:
3125
- return ""
3126
- elif not link:
3127
- return value
3128
- elif value == llm_source:
3129
- return value
3130
- else:
3131
- return f"<a href='{link}' target='_blank' rel='noopener noreferrer'>{value}</a>"
3132
-
3133
- def make_links(names: List[str], links: List[str]):
3134
- all_links = [to_anchor(link, name) for name, link in itertools.zip_longest(names, links)]
3135
- return ",".join(all_links)
3542
+ if updated_shaps is not None:
3543
+ for fm in features_meta:
3544
+ fm.shap_value = updated_shaps.get(fm.name, 0.0)
3136
3545
 
3137
3546
  features_meta.sort(key=lambda m: (-m.shap_value, m.name))
3138
3547
  for feature_meta in features_meta:
3139
3548
  if feature_meta.name in original_names_dict.keys():
3140
3549
  feature_meta.name = original_names_dict[feature_meta.name]
3141
- # Use only enriched features
3550
+
3551
+ is_client_feature = feature_meta.name in x_columns
3552
+
3553
+ if feature_meta.shap_value == 0.0:
3554
+ if self.select_features:
3555
+ self.dropped_client_feature_names_.append(feature_meta.name)
3556
+ continue
3557
+
3558
+ # Use only important features
3142
3559
  if (
3143
- feature_meta.name in x_columns
3560
+ feature_meta.name in self.fit_generated_features
3144
3561
  or feature_meta.name == COUNTRY
3145
- or feature_meta.shap_value == 0.0
3146
- or feature_meta.name in self.fit_generated_features
3562
+ # In select_features mode we select also from etalon features and need to show them
3563
+ or (not self.select_features and is_client_feature)
3147
3564
  ):
3148
3565
  continue
3149
3566
 
3150
- feature_sample = []
3151
3567
  self.feature_names_.append(feature_meta.name)
3152
- self.feature_importances_.append(round_shap_value(feature_meta.shap_value))
3153
- if feature_meta.name in features_df.columns:
3154
- feature_sample = np.random.choice(features_df[feature_meta.name].dropna().unique(), 3).tolist()
3155
- if len(feature_sample) > 0 and isinstance(feature_sample[0], float):
3156
- feature_sample = [round(f, 4) for f in feature_sample]
3157
- feature_sample = [str(f) for f in feature_sample]
3158
- feature_sample = ", ".join(feature_sample)
3159
- if len(feature_sample) > 30:
3160
- feature_sample = feature_sample[:30] + "..."
3161
-
3162
- internal_provider = feature_meta.data_provider or "Upgini"
3163
- providers = list_or_single(feature_meta.data_providers, feature_meta.data_provider)
3164
- provider_links = list_or_single(feature_meta.data_provider_links, feature_meta.data_provider_link)
3165
- if providers:
3166
- provider = make_links(providers, provider_links)
3167
- else:
3168
- provider = to_anchor("https://upgini.com", "Upgini")
3568
+ self.feature_importances_.append(_round_shap_value(feature_meta.shap_value))
3169
3569
 
3170
- internal_source = feature_meta.data_source or (
3171
- llm_source
3172
- if not feature_meta.name.endswith("_country") and not feature_meta.name.endswith("_postal_code")
3173
- else ""
3174
- )
3175
- sources = list_or_single(feature_meta.data_sources, feature_meta.data_source)
3176
- source_links = list_or_single(feature_meta.data_source_links, feature_meta.data_source_link)
3177
- if sources:
3178
- source = make_links(sources, source_links)
3179
- else:
3180
- source = internal_source
3181
-
3182
- internal_feature_name = feature_meta.name
3183
- if feature_meta.doc_link:
3184
- feature_name = to_anchor(feature_meta.doc_link, feature_meta.name)
3185
- else:
3186
- feature_name = internal_feature_name
3187
-
3188
- features_info.append(
3189
- {
3190
- self.bundle.get("features_info_name"): feature_name,
3191
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3192
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3193
- self.bundle.get("features_info_value_preview"): feature_sample,
3194
- self.bundle.get("features_info_provider"): provider,
3195
- self.bundle.get("features_info_source"): source,
3196
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3197
- }
3198
- )
3199
- features_info_without_links.append(
3200
- {
3201
- self.bundle.get("features_info_name"): internal_feature_name,
3202
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3203
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3204
- self.bundle.get("features_info_value_preview"): feature_sample,
3205
- self.bundle.get("features_info_provider"): internal_provider,
3206
- self.bundle.get("features_info_source"): internal_source,
3207
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3208
- }
3209
- )
3210
- internal_features_info.append(
3211
- {
3212
- self.bundle.get("features_info_name"): internal_feature_name,
3213
- "feature_link": feature_meta.doc_link,
3214
- self.bundle.get("features_info_shap"): round_shap_value(feature_meta.shap_value),
3215
- self.bundle.get("features_info_hitrate"): feature_meta.hit_rate,
3216
- self.bundle.get("features_info_value_preview"): feature_sample,
3217
- self.bundle.get("features_info_provider"): internal_provider,
3218
- "provider_link": feature_meta.data_provider_link,
3219
- self.bundle.get("features_info_source"): internal_source,
3220
- "source_link": feature_meta.data_source_link,
3221
- self.bundle.get("features_info_commercial_schema"): feature_meta.commercial_schema or "",
3222
- self.bundle.get("features_info_update_frequency"): feature_meta.update_frequency,
3223
- }
3224
- )
3570
+ feature_info = FeatureInfo.from_metadata(feature_meta, features_df, is_client_feature)
3571
+ features_info.append(feature_info.to_row(self.bundle))
3572
+ features_info_without_links.append(feature_info.to_row_without_links(self.bundle))
3573
+ internal_features_info.append(feature_info.to_internal_row(self.bundle))
3225
3574
 
3226
3575
  if len(features_info) > 0:
3227
3576
  self.features_info = pd.DataFrame(features_info)
@@ -3246,7 +3595,22 @@ class FeaturesEnricher(TransformerMixin):
3246
3595
  autofe_meta = self._search_task.get_autofe_metadata()
3247
3596
  if autofe_meta is None:
3248
3597
  return None
3249
- features_meta = self._search_task.get_all_features_metadata_v2()
3598
+ if len(self._internal_features_info) != 0:
3599
+
3600
+ def to_feature_meta(row):
3601
+ fm = FeaturesMetadataV2(
3602
+ name=row[bundle.get("features_info_name")],
3603
+ type="",
3604
+ source="",
3605
+ hit_rate=row[bundle.get("features_info_hitrate")],
3606
+ shap_value=row[bundle.get("features_info_shap")],
3607
+ data_source=row[bundle.get("features_info_source")],
3608
+ )
3609
+ return fm
3610
+
3611
+ features_meta = self._internal_features_info.apply(to_feature_meta, axis=1).to_list()
3612
+ else:
3613
+ features_meta = self._search_task.get_all_features_metadata_v2()
3250
3614
 
3251
3615
  def get_feature_by_name(name: str):
3252
3616
  for m in features_meta:
@@ -3255,41 +3619,52 @@ class FeaturesEnricher(TransformerMixin):
3255
3619
 
3256
3620
  descriptions = []
3257
3621
  for m in autofe_meta:
3258
- autofe_feature = Feature.from_formula(m.formula)
3259
3622
  orig_to_hashed = {base_column.original_name: base_column.hashed_name for base_column in m.base_columns}
3260
- autofe_feature.rename_columns(orig_to_hashed)
3261
- autofe_feature.set_display_index(m.display_index)
3623
+
3624
+ autofe_feature = (
3625
+ Feature.from_formula(m.formula)
3626
+ .set_display_index(m.display_index)
3627
+ .set_alias(m.alias)
3628
+ .set_op_params(m.operator_params or {})
3629
+ .rename_columns(orig_to_hashed)
3630
+ )
3631
+
3262
3632
  if autofe_feature.op.is_vector:
3263
3633
  continue
3264
3634
 
3265
- description = dict()
3635
+ description = {}
3266
3636
 
3267
3637
  feature_meta = get_feature_by_name(autofe_feature.get_display_name(shorten=True))
3268
3638
  if feature_meta is None:
3269
3639
  self.logger.warning(f"Feature meta for display index {m.display_index} not found")
3270
3640
  continue
3271
3641
  description["shap"] = feature_meta.shap_value
3272
- description["Sources"] = feature_meta.data_source.replace("AutoFE: features from ", "").replace(
3273
- "AutoFE: feature from ", ""
3274
- )
3275
- description["Feature name"] = feature_meta.name
3642
+ description[self.bundle.get("autofe_descriptions_sources")] = feature_meta.data_source.replace(
3643
+ "AutoFE: features from ", ""
3644
+ ).replace("AutoFE: feature from ", "")
3645
+ description[self.bundle.get("autofe_descriptions_feature_name")] = feature_meta.name
3276
3646
 
3277
3647
  feature_idx = 1
3278
3648
  for bc in m.base_columns:
3279
- description[f"Feature {feature_idx}"] = bc.hashed_name
3649
+ description[self.bundle.get("autofe_descriptions_feature").format(feature_idx)] = bc.hashed_name
3280
3650
  feature_idx += 1
3281
3651
 
3282
- description["Function"] = autofe_feature.op.name
3652
+ description[self.bundle.get("autofe_descriptions_function")] = ",".join(
3653
+ sorted(autofe_feature.get_all_operand_names())
3654
+ )
3283
3655
 
3284
3656
  descriptions.append(description)
3285
3657
 
3286
3658
  if len(descriptions) == 0:
3287
3659
  return None
3288
3660
 
3289
- descriptions_df = pd.DataFrame(descriptions)
3290
- descriptions_df.fillna("", inplace=True)
3291
- descriptions_df.sort_values(by="shap", ascending=False, inplace=True)
3292
- descriptions_df.drop(columns="shap", inplace=True)
3661
+ descriptions_df = (
3662
+ pd.DataFrame(descriptions)
3663
+ .fillna("")
3664
+ .sort_values(by="shap", ascending=False)
3665
+ .drop(columns="shap")
3666
+ .reset_index(drop=True)
3667
+ )
3293
3668
  return descriptions_df
3294
3669
 
3295
3670
  except Exception:
@@ -3342,10 +3717,16 @@ class FeaturesEnricher(TransformerMixin):
3342
3717
  is_transform=False,
3343
3718
  silent_mode=False,
3344
3719
  ):
3720
+ for _, key_type in search_keys.items():
3721
+ if not isinstance(key_type, SearchKey):
3722
+ raise ValidationError(self.bundle.get("unsupported_type_of_search_key").format(key_type))
3723
+
3345
3724
  valid_search_keys = {}
3346
3725
  unsupported_search_keys = {
3347
3726
  SearchKey.IP_RANGE_FROM,
3348
3727
  SearchKey.IP_RANGE_TO,
3728
+ SearchKey.IPV6_RANGE_FROM,
3729
+ SearchKey.IPV6_RANGE_TO,
3349
3730
  SearchKey.MSISDN_RANGE_FROM,
3350
3731
  SearchKey.MSISDN_RANGE_TO,
3351
3732
  # SearchKey.EMAIL_ONE_DOMAIN,
@@ -3354,11 +3735,17 @@ class FeaturesEnricher(TransformerMixin):
3354
3735
  if len(passed_unsupported_search_keys) > 0:
3355
3736
  raise ValidationError(self.bundle.get("unsupported_search_key").format(passed_unsupported_search_keys))
3356
3737
 
3738
+ x_columns = [
3739
+ c
3740
+ for c in x.columns
3741
+ if c not in [TARGET, EVAL_SET_INDEX, SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
3742
+ ]
3743
+
3357
3744
  for column_id, meaning_type in search_keys.items():
3358
3745
  column_name = None
3359
3746
  if isinstance(column_id, str):
3360
3747
  if column_id not in x.columns:
3361
- raise ValidationError(self.bundle.get("search_key_not_found").format(column_id, list(x.columns)))
3748
+ raise ValidationError(self.bundle.get("search_key_not_found").format(column_id, x_columns))
3362
3749
  column_name = column_id
3363
3750
  valid_search_keys[column_name] = meaning_type
3364
3751
  elif isinstance(column_id, int):
@@ -3372,15 +3759,15 @@ class FeaturesEnricher(TransformerMixin):
3372
3759
  if meaning_type == SearchKey.COUNTRY and self.country_code is not None:
3373
3760
  msg = self.bundle.get("search_key_country_and_country_code")
3374
3761
  self.logger.warning(msg)
3375
- print(msg)
3762
+ if not silent_mode:
3763
+ self.__log_warning(msg)
3376
3764
  self.country_code = None
3377
3765
 
3378
3766
  if not self.__is_registered and not is_demo_dataset and meaning_type in SearchKey.personal_keys():
3379
3767
  msg = self.bundle.get("unregistered_with_personal_keys").format(meaning_type)
3380
3768
  self.logger.warning(msg)
3381
3769
  if not silent_mode:
3382
- self.warning_counter.increment()
3383
- print(msg)
3770
+ self.__log_warning(msg)
3384
3771
 
3385
3772
  valid_search_keys[column_name] = SearchKey.CUSTOM_KEY
3386
3773
  else:
@@ -3414,27 +3801,23 @@ class FeaturesEnricher(TransformerMixin):
3414
3801
  and not silent_mode
3415
3802
  ):
3416
3803
  msg = self.bundle.get("date_only_search")
3417
- print(msg)
3418
- self.logger.warning(msg)
3419
- self.warning_counter.increment()
3804
+ self.__log_warning(msg)
3420
3805
 
3421
3806
  maybe_date = [k for k, v in valid_search_keys.items() if v in [SearchKey.DATE, SearchKey.DATETIME]]
3422
3807
  if (self.cv is None or self.cv == CVType.k_fold) and len(maybe_date) > 0 and not silent_mode:
3423
3808
  date_column = next(iter(maybe_date))
3424
3809
  if x[date_column].nunique() > 0.9 * _num_samples(x):
3425
3810
  msg = self.bundle.get("date_search_without_time_series")
3426
- print(msg)
3427
- self.logger.warning(msg)
3428
- self.warning_counter.increment()
3811
+ self.__log_warning(msg)
3429
3812
 
3430
3813
  if len(valid_search_keys) == 1:
3431
- for k, v in valid_search_keys.items():
3432
- # Show warning for country only if country is the only key
3433
- if x[k].nunique() == 1 and (v != SearchKey.COUNTRY or len(valid_search_keys) == 1):
3434
- msg = self.bundle.get("single_constant_search_key").format(v, x[k].values[0])
3435
- print(msg)
3436
- self.logger.warning(msg)
3437
- self.warning_counter.increment()
3814
+ key, value = list(valid_search_keys.items())[0]
3815
+ # Show warning for country only if country is the only key
3816
+ if x[key].nunique() == 1:
3817
+ msg = self.bundle.get("single_constant_search_key").format(value, x[key].values[0])
3818
+ if not silent_mode:
3819
+ self.__log_warning(msg)
3820
+ # TODO maybe raise ValidationError
3438
3821
 
3439
3822
  self.logger.info(f"Prepared search keys: {valid_search_keys}")
3440
3823
 
@@ -3467,7 +3850,10 @@ class FeaturesEnricher(TransformerMixin):
3467
3850
  display_html_dataframe(self.metrics, self.metrics, msg)
3468
3851
 
3469
3852
  def __show_selected_features(self, search_keys: Dict[str, SearchKey]):
3470
- msg = self.bundle.get("features_info_header").format(len(self.feature_names_), list(search_keys.keys()))
3853
+ search_key_names = search_keys.keys()
3854
+ if self.fit_columns_renaming:
3855
+ search_key_names = [self.fit_columns_renaming.get(col, col) for col in search_key_names]
3856
+ msg = self.bundle.get("features_info_header").format(len(self.feature_names_), search_key_names)
3471
3857
 
3472
3858
  try:
3473
3859
  _ = get_ipython() # type: ignore
@@ -3475,27 +3861,29 @@ class FeaturesEnricher(TransformerMixin):
3475
3861
  print(Format.GREEN + Format.BOLD + msg + Format.END)
3476
3862
  self.logger.info(msg)
3477
3863
  if len(self.feature_names_) > 0:
3478
- display_html_dataframe(
3479
- self.features_info, self._features_info_without_links, self.bundle.get("relevant_features_header")
3864
+ self.features_info_display_handle = display_html_dataframe(
3865
+ self.features_info,
3866
+ self._features_info_without_links,
3867
+ self.bundle.get("relevant_features_header"),
3868
+ display_id="features_info",
3480
3869
  )
3481
3870
 
3482
- display_html_dataframe(
3871
+ self.data_sources_display_handle = display_html_dataframe(
3483
3872
  self.relevant_data_sources,
3484
3873
  self._relevant_data_sources_wo_links,
3485
3874
  self.bundle.get("relevant_data_sources_header"),
3875
+ display_id="data_sources",
3486
3876
  )
3487
3877
  else:
3488
3878
  msg = self.bundle.get("features_info_zero_important_features")
3489
- self.logger.warning(msg)
3490
- self.__display_support_link(msg)
3491
- self.warning_counter.increment()
3879
+ self.__log_warning(msg, show_support_link=True)
3492
3880
  except (ImportError, NameError):
3493
3881
  print(msg)
3494
3882
  print(self._internal_features_info)
3495
3883
 
3496
- def __show_report_button(self):
3884
+ def __show_report_button(self, display_id: Optional[str] = None, display_handle=None):
3497
3885
  try:
3498
- prepare_and_show_report(
3886
+ return prepare_and_show_report(
3499
3887
  relevant_features_df=self._features_info_without_links,
3500
3888
  relevant_datasources_df=self.relevant_data_sources,
3501
3889
  metrics_df=self.metrics,
@@ -3503,6 +3891,8 @@ class FeaturesEnricher(TransformerMixin):
3503
3891
  search_id=self._search_task.search_task_id,
3504
3892
  email=self.rest_client.get_current_email(),
3505
3893
  search_keys=[str(sk) for sk in self.search_keys.values()],
3894
+ display_id=display_id,
3895
+ display_handle=display_handle,
3506
3896
  )
3507
3897
  except Exception:
3508
3898
  pass
@@ -3544,65 +3934,70 @@ class FeaturesEnricher(TransformerMixin):
3544
3934
  def check_need_detect(search_key: SearchKey):
3545
3935
  return not is_transform or search_key in self.fit_search_keys.values()
3546
3936
 
3547
- if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
3548
- maybe_key = PostalCodeSearchKeyDetector().get_search_key_column(sample)
3549
- if maybe_key is not None:
3550
- search_keys[maybe_key] = SearchKey.POSTAL_CODE
3551
- self.autodetected_search_keys[maybe_key] = SearchKey.POSTAL_CODE
3552
- self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_key}")
3937
+ # if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
3938
+ if check_need_detect(SearchKey.POSTAL_CODE):
3939
+ maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
3940
+ if maybe_keys:
3941
+ new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
3942
+ search_keys.update(new_keys)
3943
+ self.autodetected_search_keys.update(new_keys)
3944
+ self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
3553
3945
  if not silent_mode:
3554
- print(self.bundle.get("postal_code_detected").format(maybe_key))
3946
+ print(self.bundle.get("postal_code_detected").format(maybe_keys))
3555
3947
 
3556
3948
  if (
3557
3949
  SearchKey.COUNTRY not in search_keys.values()
3558
3950
  and self.country_code is None
3559
3951
  and check_need_detect(SearchKey.COUNTRY)
3560
3952
  ):
3561
- maybe_key = CountrySearchKeyDetector().get_search_key_column(sample)
3562
- if maybe_key is not None:
3563
- search_keys[maybe_key] = SearchKey.COUNTRY
3564
- self.autodetected_search_keys[maybe_key] = SearchKey.COUNTRY
3953
+ maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
3954
+ if maybe_key:
3955
+ search_keys[maybe_key[0]] = SearchKey.COUNTRY
3956
+ self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
3565
3957
  self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
3566
3958
  if not silent_mode:
3567
3959
  print(self.bundle.get("country_detected").format(maybe_key))
3568
3960
 
3569
3961
  if (
3570
- SearchKey.EMAIL not in search_keys.values()
3571
- and SearchKey.HEM not in search_keys.values()
3962
+ # SearchKey.EMAIL not in search_keys.values()
3963
+ SearchKey.HEM not in search_keys.values()
3572
3964
  and check_need_detect(SearchKey.HEM)
3573
3965
  ):
3574
- maybe_key = EmailSearchKeyDetector().get_search_key_column(sample)
3575
- if maybe_key is not None and maybe_key not in search_keys.keys():
3966
+ maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
3967
+ if maybe_keys:
3576
3968
  if self.__is_registered or is_demo_dataset:
3577
- search_keys[maybe_key] = SearchKey.EMAIL
3578
- self.autodetected_search_keys[maybe_key] = SearchKey.EMAIL
3579
- self.logger.info(f"Autodetected search key EMAIL in column {maybe_key}")
3969
+ new_keys = {key: SearchKey.EMAIL for key in maybe_keys}
3970
+ search_keys.update(new_keys)
3971
+ self.autodetected_search_keys.update(new_keys)
3972
+ self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
3580
3973
  if not silent_mode:
3581
- print(self.bundle.get("email_detected").format(maybe_key))
3974
+ print(self.bundle.get("email_detected").format(maybe_keys))
3582
3975
  else:
3583
3976
  self.logger.warning(
3584
- f"Autodetected search key EMAIL in column {maybe_key}. But not used because not registered user"
3977
+ f"Autodetected search key EMAIL in column {maybe_keys}."
3978
+ " But not used because not registered user"
3585
3979
  )
3586
3980
  if not silent_mode:
3587
- print(self.bundle.get("email_detected_not_registered").format(maybe_key))
3588
- self.warning_counter.increment()
3981
+ self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
3589
3982
 
3590
- if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
3591
- maybe_key = PhoneSearchKeyDetector().get_search_key_column(sample)
3592
- if maybe_key is not None and maybe_key not in search_keys.keys():
3983
+ # if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
3984
+ if check_need_detect(SearchKey.PHONE):
3985
+ maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
3986
+ if maybe_keys:
3593
3987
  if self.__is_registered or is_demo_dataset:
3594
- search_keys[maybe_key] = SearchKey.PHONE
3595
- self.autodetected_search_keys[maybe_key] = SearchKey.PHONE
3596
- self.logger.info(f"Autodetected search key PHONE in column {maybe_key}")
3988
+ new_keys = {key: SearchKey.PHONE for key in maybe_keys}
3989
+ search_keys.update(new_keys)
3990
+ self.autodetected_search_keys.update(new_keys)
3991
+ self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
3597
3992
  if not silent_mode:
3598
- print(self.bundle.get("phone_detected").format(maybe_key))
3993
+ print(self.bundle.get("phone_detected").format(maybe_keys))
3599
3994
  else:
3600
3995
  self.logger.warning(
3601
- f"Autodetected search key PHONE in column {maybe_key}. But not used because not registered user"
3996
+ f"Autodetected search key PHONE in column {maybe_keys}. "
3997
+ "But not used because not registered user"
3602
3998
  )
3603
3999
  if not silent_mode:
3604
- print(self.bundle.get("phone_detected_not_registered"))
3605
- self.warning_counter.increment()
4000
+ self.__log_warning(self.bundle.get("phone_detected_not_registered"))
3606
4001
 
3607
4002
  return search_keys
3608
4003
 
@@ -3624,21 +4019,19 @@ class FeaturesEnricher(TransformerMixin):
3624
4019
  half_train = round(len(train) / 2)
3625
4020
  part1 = train[:half_train]
3626
4021
  part2 = train[half_train:]
3627
- train_psi = calculate_psi(part1[self.TARGET_NAME], part2[self.TARGET_NAME])
3628
- if train_psi > 0.2:
3629
- self.warning_counter.increment()
3630
- msg = self.bundle.get("train_unstable_target").format(train_psi)
3631
- print(msg)
3632
- self.logger.warning(msg)
4022
+ train_psi_result = calculate_psi(part1[self.TARGET_NAME], part2[self.TARGET_NAME])
4023
+ if isinstance(train_psi_result, Exception):
4024
+ self.logger.exception("Failed to calculate train PSI", train_psi_result)
4025
+ elif train_psi_result > 0.2:
4026
+ self.__log_warning(self.bundle.get("train_unstable_target").format(train_psi_result))
3633
4027
 
3634
4028
  # 2. Check train-test PSI
3635
4029
  if eval1 is not None:
3636
- train_test_psi = calculate_psi(train[self.TARGET_NAME], eval1[self.TARGET_NAME])
3637
- if train_test_psi > 0.2:
3638
- self.warning_counter.increment()
3639
- msg = self.bundle.get("eval_unstable_target").format(train_test_psi)
3640
- print(msg)
3641
- self.logger.warning(msg)
4030
+ train_test_psi_result = calculate_psi(train[self.TARGET_NAME], eval1[self.TARGET_NAME])
4031
+ if isinstance(train_test_psi_result, Exception):
4032
+ self.logger.exception("Failed to calculate test PSI", train_test_psi_result)
4033
+ elif train_test_psi_result > 0.2:
4034
+ self.__log_warning(self.bundle.get("eval_unstable_target").format(train_test_psi_result))
3642
4035
 
3643
4036
  def _dump_python_libs(self):
3644
4037
  try:
@@ -3660,8 +4053,8 @@ class FeaturesEnricher(TransformerMixin):
3660
4053
  self.logger.warning(f"Showing support link: {link_text}")
3661
4054
  display(
3662
4055
  HTML(
3663
- f"""<br/>{link_text} <a href='{support_link}' target='_blank' rel='noopener noreferrer'>
3664
- here</a>"""
4056
+ f"""{link_text} <a href='{support_link}' target='_blank' rel='noopener noreferrer'>
4057
+ here</a><br/>"""
3665
4058
  )
3666
4059
  )
3667
4060
  except (ImportError, NameError):
@@ -3706,7 +4099,7 @@ class FeaturesEnricher(TransformerMixin):
3706
4099
  if y is not None:
3707
4100
  with open(f"{tmp_dir}/y.pickle", "wb") as y_file:
3708
4101
  pickle.dump(sample(y, xy_sample_index), y_file)
3709
- if eval_set:
4102
+ if eval_set and _num_samples(eval_set[0][0]) > 0:
3710
4103
  eval_xy_sample_index = rnd.randint(0, _num_samples(eval_set[0][0]), size=1000)
3711
4104
  with open(f"{tmp_dir}/eval_x.pickle", "wb") as eval_x_file:
3712
4105
  pickle.dump(sample(eval_set[0][0], eval_xy_sample_index), eval_x_file)
@@ -3797,6 +4190,8 @@ def hash_input(X: pd.DataFrame, y: Optional[pd.Series] = None, eval_set: Optiona
3797
4190
  if y is not None:
3798
4191
  hashed_objects.append(pd.util.hash_pandas_object(y, index=False).values)
3799
4192
  if eval_set is not None:
4193
+ if isinstance(eval_set, tuple):
4194
+ eval_set = [eval_set]
3800
4195
  for eval_X, eval_y in eval_set:
3801
4196
  hashed_objects.append(pd.util.hash_pandas_object(eval_X, index=False).values)
3802
4197
  hashed_objects.append(pd.util.hash_pandas_object(eval_y, index=False).values)