upgini 1.2.141__py3-none-any.whl → 1.2.142__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

@@ -42,6 +42,7 @@ from upgini.http import (
42
42
  get_rest_client,
43
43
  )
44
44
  from upgini.mdc import MDC
45
+ from upgini.mdc.context import get_mdc_fields
45
46
  from upgini.metadata import (
46
47
  COUNTRY,
47
48
  CURRENT_DATE_COL,
@@ -273,7 +274,7 @@ class FeaturesEnricher(TransformerMixin):
273
274
  self.X: pd.DataFrame | None = None
274
275
  self.y: pd.Series | None = None
275
276
  self.eval_set: list[tuple] | None = None
276
- self.autodetected_search_keys: dict[str, SearchKey] = {}
277
+ self.autodetected_search_keys: dict[str, SearchKey] = dict()
277
278
  self.imbalanced = False
278
279
  self.fit_select_features = True
279
280
  self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
@@ -309,8 +310,8 @@ class FeaturesEnricher(TransformerMixin):
309
310
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
310
311
 
311
312
  print(self.bundle.get("search_by_task_id_start"))
312
- trace_id = time.time_ns()
313
- with MDC(correlation_id=trace_id):
313
+ trace_id = self._get_trace_id()
314
+ with MDC(correlation_id=trace_id, search_task_id=search_id):
314
315
  try:
315
316
  self.logger.debug(f"FeaturesEnricher created from existing search: {search_id}")
316
317
  self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
@@ -318,7 +319,7 @@ class FeaturesEnricher(TransformerMixin):
318
319
  x_columns = [c.name for c in file_metadata.columns]
319
320
  self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
320
321
  df = pd.DataFrame(columns=x_columns)
321
- self.__prepare_feature_importances(trace_id, df, silent=True, update_selected_features=False)
322
+ self.__prepare_feature_importances(df, silent=True, update_selected_features=False)
322
323
  if print_loaded_report:
323
324
  self.__show_selected_features()
324
325
  # TODO validate search_keys with search_keys from file_metadata
@@ -487,7 +488,7 @@ class FeaturesEnricher(TransformerMixin):
487
488
  stability_agg_func: str, optional (default="max")
488
489
  Function to aggregate stability values. Can be "max", "min", "mean".
489
490
  """
490
- trace_id = time.time_ns()
491
+ trace_id = self._get_trace_id()
491
492
  if self.print_trace_id:
492
493
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
493
494
  start_time = time.time()
@@ -522,10 +523,9 @@ class FeaturesEnricher(TransformerMixin):
522
523
  self.X = X
523
524
  self.y = y
524
525
  self.eval_set = self._check_eval_set(eval_set, X)
525
- self.dump_input(trace_id, X, y, self.eval_set)
526
+ self.dump_input(X, y, self.eval_set)
526
527
  self.__set_select_features(select_features)
527
528
  self.__inner_fit(
528
- trace_id,
529
529
  X,
530
530
  y,
531
531
  self.eval_set,
@@ -646,7 +646,7 @@ class FeaturesEnricher(TransformerMixin):
646
646
 
647
647
  self.warning_counter.reset()
648
648
  auto_fe_parameters = AutoFEParameters() if auto_fe_parameters is None else auto_fe_parameters
649
- trace_id = time.time_ns()
649
+ trace_id = self._get_trace_id()
650
650
  if self.print_trace_id:
651
651
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
652
652
  start_time = time.time()
@@ -677,13 +677,12 @@ class FeaturesEnricher(TransformerMixin):
677
677
  self.y = y
678
678
  self.eval_set = self._check_eval_set(eval_set, X)
679
679
  self.__set_select_features(select_features)
680
- self.dump_input(trace_id, X, y, self.eval_set)
680
+ self.dump_input(X, y, self.eval_set)
681
681
 
682
682
  if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
683
683
  raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
684
684
 
685
685
  self.__inner_fit(
686
- trace_id,
687
686
  X,
688
687
  y,
689
688
  self.eval_set,
@@ -738,7 +737,6 @@ class FeaturesEnricher(TransformerMixin):
738
737
  y,
739
738
  exclude_features_sources=exclude_features_sources,
740
739
  keep_input=keep_input,
741
- trace_id=trace_id,
742
740
  silent_mode=True,
743
741
  progress_bar=progress_bar,
744
742
  progress_callback=progress_callback,
@@ -753,7 +751,6 @@ class FeaturesEnricher(TransformerMixin):
753
751
  *args,
754
752
  exclude_features_sources: list[str] | None = None,
755
753
  keep_input: bool = True,
756
- trace_id: str | None = None,
757
754
  silent_mode=False,
758
755
  progress_bar: ProgressBar | None = None,
759
756
  progress_callback: Callable[[SearchProgress], Any] | None = None,
@@ -790,12 +787,12 @@ class FeaturesEnricher(TransformerMixin):
790
787
  progress_bar.progress = search_progress.to_progress_bar()
791
788
  if new_progress:
792
789
  progress_bar.display()
793
- trace_id = trace_id or time.time_ns()
790
+ trace_id = self._get_trace_id()
794
791
  if self.print_trace_id:
795
792
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
796
793
  search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
797
794
  with MDC(correlation_id=trace_id, search_id=search_id):
798
- self.dump_input(trace_id, X)
795
+ self.dump_input(X)
799
796
  if len(args) > 0:
800
797
  msg = f"WARNING: Unsupported positional arguments for transform: {args}"
801
798
  self.logger.warning(msg)
@@ -808,7 +805,6 @@ class FeaturesEnricher(TransformerMixin):
808
805
  start_time = time.time()
809
806
  try:
810
807
  result, _, _, _ = self.__inner_transform(
811
- trace_id,
812
808
  X,
813
809
  y=y,
814
810
  exclude_features_sources=exclude_features_sources,
@@ -872,7 +868,6 @@ class FeaturesEnricher(TransformerMixin):
872
868
  estimator=None,
873
869
  exclude_features_sources: list[str] | None = None,
874
870
  remove_outliers_calc_metrics: bool | None = None,
875
- trace_id: str | None = None,
876
871
  internal_call: bool = False,
877
872
  progress_bar: ProgressBar | None = None,
878
873
  progress_callback: Callable[[SearchProgress], Any] | None = None,
@@ -910,7 +905,7 @@ class FeaturesEnricher(TransformerMixin):
910
905
  Dataframe with metrics calculated on train and validation datasets.
911
906
  """
912
907
 
913
- trace_id = trace_id or time.time_ns()
908
+ trace_id = self._get_trace_id()
914
909
  start_time = time.time()
915
910
  search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
916
911
  with MDC(correlation_id=trace_id, search_id=search_id):
@@ -943,7 +938,7 @@ class FeaturesEnricher(TransformerMixin):
943
938
  raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
944
939
 
945
940
  validated_X, validated_y, validated_eval_set = self._validate_train_eval(
946
- effective_X, effective_y, effective_eval_set, silent=internal_call
941
+ effective_X, effective_y, effective_eval_set
947
942
  )
948
943
 
949
944
  if self.X is None:
@@ -978,11 +973,13 @@ class FeaturesEnricher(TransformerMixin):
978
973
  self.__display_support_link(msg)
979
974
  return None
980
975
 
976
+ search_keys = self._get_fit_search_keys_with_original_names()
977
+
981
978
  cat_features_from_backend = self.__get_categorical_features()
982
979
  # Convert to original names
983
980
  cat_features_from_backend = [self.fit_columns_renaming.get(c, c) for c in cat_features_from_backend]
984
981
  client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
985
- estimator, validated_X, self.search_keys
982
+ estimator, validated_X, search_keys
986
983
  )
987
984
  # Exclude id columns from cat_features
988
985
  if self.id_columns and self.id_columns_encoder is not None:
@@ -1004,7 +1001,6 @@ class FeaturesEnricher(TransformerMixin):
1004
1001
  self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
1005
1002
 
1006
1003
  prepared_data = self._get_cached_enriched_data(
1007
- trace_id=trace_id,
1008
1004
  X=X,
1009
1005
  y=y,
1010
1006
  eval_set=eval_set,
@@ -1257,7 +1253,7 @@ class FeaturesEnricher(TransformerMixin):
1257
1253
 
1258
1254
  if updating_shaps is not None:
1259
1255
  decoded_X = self._decode_id_columns(fitting_X)
1260
- self._update_shap_values(trace_id, decoded_X, updating_shaps, silent=not internal_call)
1256
+ self._update_shap_values(decoded_X, updating_shaps, silent=not internal_call)
1261
1257
 
1262
1258
  metrics_df = pd.DataFrame(metrics)
1263
1259
  mean_target_hdr = self.bundle.get("quality_metrics_mean_target_header")
@@ -1307,9 +1303,33 @@ class FeaturesEnricher(TransformerMixin):
1307
1303
  finally:
1308
1304
  self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
1309
1305
 
1306
+ def _get_trace_id(self):
1307
+ if get_mdc_fields().get("correlation_id") is not None:
1308
+ return get_mdc_fields().get("correlation_id")
1309
+ return int(time.time() * 1000)
1310
+
1311
+ def _get_autodetected_search_keys(self):
1312
+ if self.autodetected_search_keys is None and self._search_task is not None:
1313
+ meta = self._search_task.get_file_metadata(self._get_trace_id())
1314
+ self.autodetected_search_keys = {k: SearchKey[v] for k, v in meta.autodetectedSearchKeys.items()}
1315
+
1316
+ return self.autodetected_search_keys
1317
+
1318
+ def _get_fit_search_keys_with_original_names(self):
1319
+ if self.fit_search_keys is None and self._search_task is not None:
1320
+ fit_search_keys = dict()
1321
+ meta = self._search_task.get_file_metadata(self._get_trace_id())
1322
+ for column in meta.columns:
1323
+ # TODO check for EMAIL->HEM and multikeys
1324
+ search_key_type = SearchKey.from_meaning_type(column.meaningType)
1325
+ if search_key_type is not None:
1326
+ fit_search_keys[column.originalName] = search_key_type
1327
+ else:
1328
+ fit_search_keys = {self.fit_columns_renaming.get(k, k): v for k, v in self.fit_search_keys.items()}
1329
+ return fit_search_keys
1330
+
1310
1331
  def _select_features_by_psi(
1311
1332
  self,
1312
- trace_id: str,
1313
1333
  X: pd.DataFrame | pd.Series | np.ndarray,
1314
1334
  y: pd.DataFrame | pd.Series | np.ndarray | list,
1315
1335
  eval_set: list[tuple] | tuple | None,
@@ -1322,7 +1342,8 @@ class FeaturesEnricher(TransformerMixin):
1322
1342
  progress_callback: Callable | None = None,
1323
1343
  ):
1324
1344
  search_keys = self.search_keys.copy()
1325
- validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set, silent=True)
1345
+ search_keys.update(self._get_autodetected_search_keys())
1346
+ validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set)
1326
1347
  if isinstance(X, np.ndarray):
1327
1348
  search_keys = {str(k): v for k, v in search_keys.items()}
1328
1349
 
@@ -1357,7 +1378,6 @@ class FeaturesEnricher(TransformerMixin):
1357
1378
  ]
1358
1379
 
1359
1380
  prepared_data = self._get_cached_enriched_data(
1360
- trace_id=trace_id,
1361
1381
  X=X,
1362
1382
  y=y,
1363
1383
  eval_set=eval_set,
@@ -1506,7 +1526,7 @@ class FeaturesEnricher(TransformerMixin):
1506
1526
 
1507
1527
  return total_unstable_features
1508
1528
 
1509
- def _update_shap_values(self, trace_id: str, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
1529
+ def _update_shap_values(self, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
1510
1530
  renaming = self.fit_columns_renaming or {}
1511
1531
  self.logger.info(f"Updating SHAP values: {new_shaps}")
1512
1532
  new_shaps = {
@@ -1514,7 +1534,7 @@ class FeaturesEnricher(TransformerMixin):
1514
1534
  for feature, shap in new_shaps.items()
1515
1535
  if feature in self.feature_names_ or renaming.get(feature, feature) in self.feature_names_
1516
1536
  }
1517
- self.__prepare_feature_importances(trace_id, df, new_shaps)
1537
+ self.__prepare_feature_importances(df, new_shaps)
1518
1538
 
1519
1539
  if not silent and self.features_info_display_handle is not None:
1520
1540
  try:
@@ -1694,7 +1714,6 @@ class FeaturesEnricher(TransformerMixin):
1694
1714
 
1695
1715
  def _get_cached_enriched_data(
1696
1716
  self,
1697
- trace_id: str,
1698
1717
  X: pd.DataFrame | pd.Series | np.ndarray | None = None,
1699
1718
  y: pd.DataFrame | pd.Series | np.ndarray | list | None = None,
1700
1719
  eval_set: list[tuple] | tuple | None = None,
@@ -1710,10 +1729,9 @@ class FeaturesEnricher(TransformerMixin):
1710
1729
  is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
1711
1730
  is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
1712
1731
  checked_eval_set = self._check_eval_set(eval_set, X)
1713
- validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set, silent=True)
1732
+ validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set)
1714
1733
 
1715
1734
  sampled_data = self._get_enriched_datasets(
1716
- trace_id=trace_id,
1717
1735
  validated_X=validated_X,
1718
1736
  validated_y=validated_y,
1719
1737
  eval_set=validated_eval_set,
@@ -1740,7 +1758,7 @@ class FeaturesEnricher(TransformerMixin):
1740
1758
 
1741
1759
  self.logger.info(f"Excluding search keys: {excluding_search_keys}")
1742
1760
 
1743
- file_meta = self._search_task.get_file_metadata(trace_id)
1761
+ file_meta = self._search_task.get_file_metadata(self._get_trace_id())
1744
1762
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
1745
1763
  original_dropped_features = [columns_renaming.get(f, f) for f in fit_dropped_features]
1746
1764
 
@@ -1917,7 +1935,6 @@ class FeaturesEnricher(TransformerMixin):
1917
1935
 
1918
1936
  def _get_enriched_datasets(
1919
1937
  self,
1920
- trace_id: str,
1921
1938
  validated_X: pd.DataFrame | pd.Series | np.ndarray | None,
1922
1939
  validated_y: pd.DataFrame | pd.Series | np.ndarray | list | None,
1923
1940
  eval_set: list[tuple] | None,
@@ -1945,9 +1962,7 @@ class FeaturesEnricher(TransformerMixin):
1945
1962
  and self.df_with_original_index is not None
1946
1963
  ):
1947
1964
  self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
1948
- return self.__get_enriched_from_fit(
1949
- validated_X, validated_y, eval_set, trace_id, remove_outliers_calc_metrics
1950
- )
1965
+ return self.__get_enriched_from_fit(validated_X, validated_y, eval_set, remove_outliers_calc_metrics)
1951
1966
  else:
1952
1967
  self.logger.info(
1953
1968
  "Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
@@ -1959,7 +1974,6 @@ class FeaturesEnricher(TransformerMixin):
1959
1974
  validated_y,
1960
1975
  eval_set,
1961
1976
  exclude_features_sources,
1962
- trace_id,
1963
1977
  progress_bar,
1964
1978
  progress_callback,
1965
1979
  is_for_metrics=is_for_metrics,
@@ -2088,7 +2102,6 @@ class FeaturesEnricher(TransformerMixin):
2088
2102
  validated_X: pd.DataFrame,
2089
2103
  validated_y: pd.Series,
2090
2104
  eval_set: list[tuple] | None,
2091
- trace_id: str,
2092
2105
  remove_outliers_calc_metrics: bool | None,
2093
2106
  ) -> _EnrichedDataForMetrics:
2094
2107
  eval_set_sampled_dict = {}
@@ -2103,7 +2116,7 @@ class FeaturesEnricher(TransformerMixin):
2103
2116
  if remove_outliers_calc_metrics is None:
2104
2117
  remove_outliers_calc_metrics = True
2105
2118
  if self.model_task_type == ModelTaskType.REGRESSION and remove_outliers_calc_metrics:
2106
- target_outliers_df = self._search_task.get_target_outliers(trace_id)
2119
+ target_outliers_df = self._search_task.get_target_outliers(self._get_trace_id())
2107
2120
  if target_outliers_df is not None and len(target_outliers_df) > 0:
2108
2121
  outliers = pd.merge(
2109
2122
  self.df_with_original_index,
@@ -2120,7 +2133,7 @@ class FeaturesEnricher(TransformerMixin):
2120
2133
 
2121
2134
  # index in each dataset (X, eval set) may be reordered and non unique, but index in validated datasets
2122
2135
  # can differs from it
2123
- fit_features = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
2136
+ fit_features = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
2124
2137
 
2125
2138
  # Pre-process features if we need to drop outliers
2126
2139
  if rows_to_drop is not None:
@@ -2146,7 +2159,7 @@ class FeaturesEnricher(TransformerMixin):
2146
2159
  validated_Xy[TARGET] = validated_y
2147
2160
 
2148
2161
  selecting_columns = self._selecting_input_and_generated_columns(
2149
- validated_Xy, self.fit_generated_features, keep_input=True, trace_id=trace_id
2162
+ validated_Xy, self.fit_generated_features, keep_input=True
2150
2163
  )
2151
2164
  selecting_columns.extend(
2152
2165
  c
@@ -2211,7 +2224,6 @@ class FeaturesEnricher(TransformerMixin):
2211
2224
  validated_y: pd.Series,
2212
2225
  eval_set: list[tuple] | None,
2213
2226
  exclude_features_sources: list[str] | None,
2214
- trace_id: str,
2215
2227
  progress_bar: ProgressBar | None,
2216
2228
  progress_callback: Callable[[SearchProgress], Any] | None,
2217
2229
  is_for_metrics: bool = False,
@@ -2237,7 +2249,6 @@ class FeaturesEnricher(TransformerMixin):
2237
2249
 
2238
2250
  # Transform
2239
2251
  enriched_df, columns_renaming, generated_features, search_keys = self.__inner_transform(
2240
- trace_id,
2241
2252
  X=df.drop(columns=[TARGET]),
2242
2253
  y=df[TARGET],
2243
2254
  exclude_features_sources=exclude_features_sources,
@@ -2414,11 +2425,10 @@ class FeaturesEnricher(TransformerMixin):
2414
2425
 
2415
2426
  return self.features_info
2416
2427
 
2417
- def get_progress(self, trace_id: str | None = None, search_task: SearchTask | None = None) -> SearchProgress:
2428
+ def get_progress(self, search_task: SearchTask | None = None) -> SearchProgress:
2418
2429
  search_task = search_task or self._search_task
2419
2430
  if search_task is not None:
2420
- trace_id = trace_id or time.time_ns()
2421
- return search_task.get_progress(trace_id)
2431
+ return search_task.get_progress(self._get_trace_id())
2422
2432
 
2423
2433
  def display_transactional_transform_api(self, only_online_sources=False):
2424
2434
  if self.api_key is None:
@@ -2519,7 +2529,6 @@ if response.status_code == 200:
2519
2529
 
2520
2530
  def __inner_transform(
2521
2531
  self,
2522
- trace_id: str,
2523
2532
  X: pd.DataFrame,
2524
2533
  *,
2525
2534
  y: pd.Series | None = None,
@@ -2538,182 +2547,135 @@ if response.status_code == 200:
2538
2547
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
2539
2548
 
2540
2549
  start_time = time.time()
2541
- search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
2542
- with MDC(correlation_id=trace_id, search_id=search_id):
2543
- self.logger.info("Start transform")
2550
+ self.logger.info("Start transform")
2544
2551
 
2545
- validated_X, validated_y, validated_eval_set = self._validate_train_eval(
2546
- X, y, eval_set=None, is_transform=True, silent=True
2547
- )
2548
- df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
2552
+ search_keys = self.search_keys.copy()
2549
2553
 
2550
- validated_Xy = df.copy()
2554
+ self.__validate_search_keys(search_keys, self.search_id)
2551
2555
 
2552
- self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
2556
+ validated_X, validated_y, validated_eval_set = self._validate_train_eval(
2557
+ X, y, eval_set=None, is_transform=True
2558
+ )
2559
+ df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
2553
2560
 
2554
- # If there are no important features, return original dataframe
2555
- if len(self.feature_names_) == 0:
2556
- msg = self.bundle.get("no_important_features_for_transform")
2557
- self.__log_warning(msg, show_support_link=True)
2558
- return None, {}, [], self.search_keys
2561
+ validated_Xy = df.copy()
2559
2562
 
2560
- self.__validate_search_keys(self.search_keys, self.search_id)
2563
+ self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
2561
2564
 
2562
- if self._has_paid_features(exclude_features_sources):
2563
- msg = self.bundle.get("transform_with_paid_features")
2564
- self.logger.warning(msg)
2565
- self.__display_support_link(msg)
2566
- return None, {}, [], self.search_keys
2565
+ # If there are no important features, return original dataframe
2566
+ if len(self.feature_names_) == 0:
2567
+ msg = self.bundle.get("no_important_features_for_transform")
2568
+ self.__log_warning(msg, show_support_link=True)
2569
+ return None, {}, [], search_keys
2567
2570
 
2568
- online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
2569
- if len(online_api_features) > 0:
2570
- self.logger.warning(
2571
- f"There are important features for transform, that generated by online API: {online_api_features}"
2572
- )
2573
- msg = self.bundle.get("online_api_features_transform").format(online_api_features)
2574
- self.logger.warning(msg)
2575
- print(msg)
2576
- self.display_transactional_transform_api(only_online_sources=True)
2577
-
2578
- if not metrics_calculation:
2579
- transform_usage = self.rest_client.get_current_transform_usage(trace_id)
2580
- self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
2581
- if transform_usage.has_limit:
2582
- if len(X) > transform_usage.rest_rows:
2583
- rest_rows = max(transform_usage.rest_rows, 0)
2584
- bundle_msg = (
2585
- "transform_usage_warning_registered"
2586
- if self.__is_registered
2587
- else "transform_usage_warning_demo"
2588
- )
2589
- msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
2590
- self.logger.warning(msg)
2591
- print(msg)
2592
- show_request_quote_button(is_registered=self.__is_registered)
2593
- return None, {}, [], {}
2594
- else:
2595
- msg = self.bundle.get("transform_usage_info").format(
2596
- transform_usage.limit, transform_usage.transformed_rows
2597
- )
2598
- self.logger.info(msg)
2599
- print(msg)
2571
+ if self._has_paid_features(exclude_features_sources):
2572
+ msg = self.bundle.get("transform_with_paid_features")
2573
+ self.logger.warning(msg)
2574
+ self.__display_support_link(msg)
2575
+ return None, {}, [], search_keys
2600
2576
 
2601
- is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
2577
+ online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
2578
+ if len(online_api_features) > 0:
2579
+ self.logger.warning(
2580
+ f"There are important features for transform, that generated by online API: {online_api_features}"
2581
+ )
2582
+ msg = self.bundle.get("online_api_features_transform").format(online_api_features)
2583
+ self.logger.warning(msg)
2584
+ print(msg)
2585
+ self.display_transactional_transform_api(only_online_sources=True)
2586
+
2587
+ if not metrics_calculation:
2588
+ transform_usage = self.rest_client.get_current_transform_usage(self._get_trace_id())
2589
+ self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
2590
+ if transform_usage.has_limit:
2591
+ if len(X) > transform_usage.rest_rows:
2592
+ rest_rows = max(transform_usage.rest_rows, 0)
2593
+ bundle_msg = (
2594
+ "transform_usage_warning_registered" if self.__is_registered else "transform_usage_warning_demo"
2595
+ )
2596
+ msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
2597
+ self.logger.warning(msg)
2598
+ print(msg)
2599
+ show_request_quote_button(is_registered=self.__is_registered)
2600
+ return None, {}, [], {}
2601
+ else:
2602
+ msg = self.bundle.get("transform_usage_info").format(
2603
+ transform_usage.limit, transform_usage.transformed_rows
2604
+ )
2605
+ self.logger.info(msg)
2606
+ print(msg)
2602
2607
 
2603
- columns_to_drop = [
2604
- c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
2605
- ]
2606
- if len(columns_to_drop) > 0:
2607
- msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
2608
- self.logger.warning(msg)
2609
- print(msg)
2610
- df = df.drop(columns=columns_to_drop)
2608
+ is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
2611
2609
 
2612
- search_keys = self.search_keys.copy()
2613
- if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2614
- search_keys.update(
2615
- {col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in self.search_keys}
2616
- )
2610
+ columns_to_drop = [
2611
+ c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
2612
+ ]
2613
+ if len(columns_to_drop) > 0:
2614
+ msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
2615
+ self.logger.warning(msg)
2616
+ print(msg)
2617
+ df = df.drop(columns=columns_to_drop)
2617
2618
 
2618
- search_keys = self.__prepare_search_keys(
2619
- df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
2620
- )
2619
+ if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2620
+ search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in search_keys})
2621
2621
 
2622
- df = self.__handle_index_search_keys(df, search_keys)
2622
+ search_keys = self.__prepare_search_keys(
2623
+ df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
2624
+ )
2623
2625
 
2624
- if DEFAULT_INDEX in df.columns:
2625
- msg = self.bundle.get("unsupported_index_column")
2626
- self.logger.info(msg)
2627
- print(msg)
2628
- df.drop(columns=DEFAULT_INDEX, inplace=True)
2629
- validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
2626
+ df = self.__handle_index_search_keys(df, search_keys)
2630
2627
 
2631
- df = self.__add_country_code(df, search_keys)
2628
+ if DEFAULT_INDEX in df.columns:
2629
+ msg = self.bundle.get("unsupported_index_column")
2630
+ self.logger.info(msg)
2631
+ print(msg)
2632
+ df.drop(columns=DEFAULT_INDEX, inplace=True)
2633
+ validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
2632
2634
 
2633
- generated_features = []
2634
- date_column = self._get_date_column(search_keys)
2635
- if date_column is not None:
2636
- converter = DateTimeConverter(
2637
- date_column,
2638
- self.date_format,
2639
- self.logger,
2640
- bundle=self.bundle,
2641
- generate_cyclical_features=self.generate_search_key_features,
2642
- )
2643
- df = converter.convert(df, keep_time=True)
2644
- self.logger.info(f"Date column after convertion: {df[date_column]}")
2645
- generated_features.extend(converter.generated_features)
2646
- else:
2647
- self.logger.info("Input dataset hasn't date column")
2648
- if self.__should_add_date_column():
2649
- df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
2650
-
2651
- email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2652
- if email_columns and self.generate_search_key_features:
2653
- generator = EmailDomainGenerator(email_columns)
2654
- df = generator.generate(df)
2655
- generated_features.extend(generator.generated_features)
2656
-
2657
- normalizer = Normalizer(self.bundle, self.logger)
2658
- df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2659
- columns_renaming = normalizer.columns_renaming
2660
-
2661
- # If there are no external features, we don't call backend on transform
2662
- external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
2663
- if len(external_features) == 0:
2664
- self.logger.warning(
2665
- "No external features found, returning original dataframe"
2666
- f" with generated important features: {self.feature_names_}"
2667
- )
2668
- df = df.rename(columns=columns_renaming)
2669
- generated_features = [columns_renaming.get(c, c) for c in generated_features]
2670
- search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
2671
- selecting_columns = self._selecting_input_and_generated_columns(
2672
- validated_Xy, generated_features, keep_input, trace_id, is_transform=True
2673
- )
2674
- self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
2675
- if add_fit_system_record_id:
2676
- df = self._add_fit_system_record_id(
2677
- df,
2678
- search_keys,
2679
- SYSTEM_RECORD_ID,
2680
- TARGET,
2681
- columns_renaming,
2682
- self.id_columns,
2683
- self.cv,
2684
- self.model_task_type,
2685
- self.logger,
2686
- self.bundle,
2687
- )
2688
- selecting_columns.append(SYSTEM_RECORD_ID)
2689
- return df[selecting_columns], columns_renaming, generated_features, search_keys
2690
-
2691
- # Don't pass all features in backend on transform
2692
- runtime_parameters = self._get_copy_of_runtime_parameters()
2693
- features_for_transform = self._search_task.get_features_for_transform()
2694
- if features_for_transform:
2695
- missing_features_for_transform = [
2696
- columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
2697
- ]
2698
- if TARGET in missing_features_for_transform:
2699
- raise ValidationError(self.bundle.get("missing_target_for_transform"))
2635
+ df = self.__add_country_code(df, search_keys)
2700
2636
 
2701
- if len(missing_features_for_transform) > 0:
2702
- raise ValidationError(
2703
- self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2704
- )
2705
- features_for_embeddings = self._search_task.get_features_for_embeddings()
2706
- if features_for_embeddings:
2707
- runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
2708
- features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
2637
+ generated_features = []
2638
+ date_column = self._get_date_column(search_keys)
2639
+ if date_column is not None:
2640
+ converter = DateTimeConverter(
2641
+ date_column,
2642
+ self.date_format,
2643
+ self.logger,
2644
+ bundle=self.bundle,
2645
+ generate_cyclical_features=self.generate_search_key_features,
2646
+ )
2647
+ df = converter.convert(df, keep_time=True)
2648
+ self.logger.info(f"Date column after convertion: {df[date_column]}")
2649
+ generated_features.extend(converter.generated_features)
2650
+ else:
2651
+ self.logger.info("Input dataset hasn't date column")
2652
+ if self.__should_add_date_column():
2653
+ df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
2709
2654
 
2710
- columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2655
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2656
+ if email_columns and self.generate_search_key_features:
2657
+ generator = EmailDomainGenerator(email_columns)
2658
+ df = generator.generate(df)
2659
+ generated_features.extend(generator.generated_features)
2711
2660
 
2712
- df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(
2713
- df[columns_for_system_record_id], index=False
2714
- ).astype("float64")
2661
+ normalizer = Normalizer(self.bundle, self.logger)
2662
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2663
+ columns_renaming = normalizer.columns_renaming
2715
2664
 
2716
- features_not_to_pass = []
2665
+ # If there are no external features, we don't call backend on transform
2666
+ external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
2667
+ if len(external_features) == 0:
2668
+ self.logger.warning(
2669
+ "No external features found, returning original dataframe"
2670
+ f" with generated important features: {self.feature_names_}"
2671
+ )
2672
+ df = df.rename(columns=columns_renaming)
2673
+ generated_features = [columns_renaming.get(c, c) for c in generated_features]
2674
+ search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
2675
+ selecting_columns = self._selecting_input_and_generated_columns(
2676
+ validated_Xy, generated_features, keep_input, is_transform=True
2677
+ )
2678
+ self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
2717
2679
  if add_fit_system_record_id:
2718
2680
  df = self._add_fit_system_record_id(
2719
2681
  df,
@@ -2727,252 +2689,294 @@ if response.status_code == 200:
2727
2689
  self.logger,
2728
2690
  self.bundle,
2729
2691
  )
2730
- df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
2731
- features_not_to_pass.append(SORT_ID)
2692
+ selecting_columns.append(SYSTEM_RECORD_ID)
2693
+ return df[selecting_columns], columns_renaming, generated_features, search_keys
2732
2694
 
2733
- system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
2734
- if add_fit_system_record_id:
2735
- system_columns_with_original_index.append(SORT_ID)
2695
+ # Don't pass all features in backend on transform
2696
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2697
+ features_for_transform = self._search_task.get_features_for_transform()
2698
+ if features_for_transform:
2699
+ missing_features_for_transform = [
2700
+ columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
2701
+ ]
2702
+ if TARGET in missing_features_for_transform:
2703
+ raise ValidationError(self.bundle.get("missing_target_for_transform"))
2736
2704
 
2737
- df_before_explode = df[system_columns_with_original_index].copy()
2705
+ if len(missing_features_for_transform) > 0:
2706
+ raise ValidationError(
2707
+ self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2708
+ )
2709
+ features_for_embeddings = self._search_task.get_features_for_embeddings()
2710
+ if features_for_embeddings:
2711
+ runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
2712
+ features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
2738
2713
 
2739
- # Explode multiple search keys
2740
- df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2714
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2741
2715
 
2742
- # Convert search keys and generate features on them
2716
+ df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2717
+ "float64"
2718
+ )
2743
2719
 
2744
- email_column = self._get_email_column(search_keys)
2745
- hem_column = self._get_hem_column(search_keys)
2746
- if email_column:
2747
- converter = EmailSearchKeyConverter(
2748
- email_column,
2749
- hem_column,
2750
- search_keys,
2751
- columns_renaming,
2752
- list(unnest_search_keys.keys()),
2753
- self.logger,
2754
- )
2755
- df = converter.convert(df)
2720
+ features_not_to_pass = []
2721
+ if add_fit_system_record_id:
2722
+ df = self._add_fit_system_record_id(
2723
+ df,
2724
+ search_keys,
2725
+ SYSTEM_RECORD_ID,
2726
+ TARGET,
2727
+ columns_renaming,
2728
+ self.id_columns,
2729
+ self.cv,
2730
+ self.model_task_type,
2731
+ self.logger,
2732
+ self.bundle,
2733
+ )
2734
+ df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
2735
+ features_not_to_pass.append(SORT_ID)
2756
2736
 
2757
- ip_column = self._get_ip_column(search_keys)
2758
- if ip_column:
2759
- converter = IpSearchKeyConverter(
2760
- ip_column,
2761
- search_keys,
2762
- columns_renaming,
2763
- list(unnest_search_keys.keys()),
2764
- self.bundle,
2765
- self.logger,
2766
- )
2767
- df = converter.convert(df)
2768
-
2769
- date_features = []
2770
- for col in features_for_transform:
2771
- if DateTimeConverter(col).is_datetime(df):
2772
- df[col] = DateTimeConverter(col).to_date_string(df)
2773
- date_features.append(col)
2774
-
2775
- meaning_types = {}
2776
- meaning_types.update(
2777
- {
2778
- col: FileColumnMeaningType.FEATURE
2779
- for col in features_for_transform
2780
- if col not in date_features and col not in generated_features
2781
- }
2737
+ system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
2738
+ if add_fit_system_record_id:
2739
+ system_columns_with_original_index.append(SORT_ID)
2740
+
2741
+ df_before_explode = df[system_columns_with_original_index].copy()
2742
+
2743
+ # Explode multiple search keys
2744
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2745
+
2746
+ # Convert search keys and generate features on them
2747
+
2748
+ email_column = self._get_email_column(search_keys)
2749
+ hem_column = self._get_hem_column(search_keys)
2750
+ if email_column:
2751
+ converter = EmailSearchKeyConverter(
2752
+ email_column,
2753
+ hem_column,
2754
+ search_keys,
2755
+ columns_renaming,
2756
+ list(unnest_search_keys.keys()),
2757
+ self.logger,
2782
2758
  )
2783
- meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
2784
- meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
2785
- meaning_types.update({col: key.value for col, key in search_keys.items()})
2759
+ df = converter.convert(df)
2786
2760
 
2787
- features_not_to_pass.extend(
2788
- [
2789
- c
2790
- for c in df.columns
2791
- if c not in search_keys.keys()
2792
- and c not in features_for_transform
2793
- and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2794
- ]
2761
+ ip_column = self._get_ip_column(search_keys)
2762
+ if ip_column:
2763
+ converter = IpSearchKeyConverter(
2764
+ ip_column,
2765
+ search_keys,
2766
+ columns_renaming,
2767
+ list(unnest_search_keys.keys()),
2768
+ self.bundle,
2769
+ self.logger,
2795
2770
  )
2771
+ df = converter.convert(df)
2796
2772
 
2797
- if DateTimeConverter.DATETIME_COL in df.columns:
2798
- df = df.drop(columns=DateTimeConverter.DATETIME_COL)
2773
+ date_features = []
2774
+ for col in features_for_transform:
2775
+ if DateTimeConverter(col).is_datetime(df):
2776
+ df[col] = DateTimeConverter(col).to_date_string(df)
2777
+ date_features.append(col)
2799
2778
 
2800
- # search keys might be changed after explode
2801
- columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2802
- df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2803
- "float64"
2804
- )
2805
- meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2806
- meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2807
- if SEARCH_KEY_UNNEST in df.columns:
2808
- meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2779
+ meaning_types = {}
2780
+ meaning_types.update(
2781
+ {
2782
+ col: FileColumnMeaningType.FEATURE
2783
+ for col in features_for_transform
2784
+ if col not in date_features and col not in generated_features
2785
+ }
2786
+ )
2787
+ meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
2788
+ meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
2789
+ meaning_types.update({col: key.value for col, key in search_keys.items()})
2809
2790
 
2810
- df = df.reset_index(drop=True)
2791
+ features_not_to_pass.extend(
2792
+ [
2793
+ c
2794
+ for c in df.columns
2795
+ if c not in search_keys.keys()
2796
+ and c not in features_for_transform
2797
+ and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2798
+ ]
2799
+ )
2811
2800
 
2812
- combined_search_keys = combine_search_keys(search_keys.keys())
2801
+ if DateTimeConverter.DATETIME_COL in df.columns:
2802
+ df = df.drop(columns=DateTimeConverter.DATETIME_COL)
2813
2803
 
2814
- df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
2804
+ # search keys might be changed after explode
2805
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2806
+ df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2807
+ "float64"
2808
+ )
2809
+ meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2810
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2811
+ if SEARCH_KEY_UNNEST in df.columns:
2812
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2815
2813
 
2816
- df_without_features, full_duplicates_warning = clean_full_duplicates(
2817
- df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
2818
- )
2819
- if not silent_mode and full_duplicates_warning:
2820
- self.__log_warning(full_duplicates_warning)
2814
+ df = df.reset_index(drop=True)
2821
2815
 
2822
- del df
2823
- gc.collect()
2824
-
2825
- dataset = Dataset(
2826
- "sample_" + str(uuid.uuid4()),
2827
- df=df_without_features,
2828
- meaning_types=meaning_types,
2829
- search_keys=combined_search_keys,
2830
- unnest_search_keys=unnest_search_keys,
2831
- id_columns=self.__get_renamed_id_columns(columns_renaming),
2832
- date_column=self._get_date_column(search_keys),
2833
- date_format=self.date_format,
2834
- sample_config=self.sample_config,
2835
- rest_client=self.rest_client,
2836
- logger=self.logger,
2837
- bundle=self.bundle,
2838
- warning_callback=self.__log_warning,
2839
- )
2840
- dataset.columns_renaming = columns_renaming
2841
-
2842
- validation_task = self._search_task.validation(
2843
- trace_id,
2844
- dataset,
2845
- start_time=start_time,
2846
- extract_features=True,
2847
- runtime_parameters=runtime_parameters,
2848
- exclude_features_sources=exclude_features_sources,
2849
- metrics_calculation=metrics_calculation,
2850
- silent_mode=silent_mode,
2851
- progress_bar=progress_bar,
2852
- progress_callback=progress_callback,
2853
- )
2816
+ combined_search_keys = combine_search_keys(search_keys.keys())
2854
2817
 
2855
- del df_without_features, dataset
2856
- gc.collect()
2818
+ df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
2857
2819
 
2858
- if not silent_mode:
2859
- print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2860
- if not self.__is_registered:
2861
- print(self.bundle.get("polling_unregister_information"))
2820
+ df_without_features, full_duplicates_warning = clean_full_duplicates(
2821
+ df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
2822
+ )
2823
+ if not silent_mode and full_duplicates_warning:
2824
+ self.__log_warning(full_duplicates_warning)
2862
2825
 
2863
- progress = self.get_progress(trace_id, validation_task)
2864
- progress.recalculate_eta(time.time() - start_time)
2865
- if progress_bar is not None:
2866
- progress_bar.progress = progress.to_progress_bar()
2867
- if progress_callback is not None:
2868
- progress_callback(progress)
2869
- prev_progress: SearchProgress | None = None
2870
- polling_period_seconds = 1
2871
- try:
2872
- while progress.stage != ProgressStage.DOWNLOADING.value:
2873
- if prev_progress is None or prev_progress.percent != progress.percent:
2874
- progress.recalculate_eta(time.time() - start_time)
2875
- else:
2876
- progress.update_eta(prev_progress.eta - polling_period_seconds)
2877
- prev_progress = progress
2878
- if progress_bar is not None:
2879
- progress_bar.progress = progress.to_progress_bar()
2880
- if progress_callback is not None:
2881
- progress_callback(progress)
2882
- if progress.stage == ProgressStage.FAILED.value:
2883
- raise Exception(progress.error_message)
2884
- time.sleep(polling_period_seconds)
2885
- progress = self.get_progress(trace_id, validation_task)
2886
- except KeyboardInterrupt as e:
2887
- print(self.bundle.get("search_stopping"))
2888
- self.rest_client.stop_search_task_v2(trace_id, validation_task.search_task_id)
2889
- self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2890
- print(self.bundle.get("search_stopped"))
2891
- raise e
2826
+ del df
2827
+ gc.collect()
2892
2828
 
2893
- validation_task.poll_result(trace_id, quiet=True)
2829
+ dataset = Dataset(
2830
+ "sample_" + str(uuid.uuid4()),
2831
+ df=df_without_features,
2832
+ meaning_types=meaning_types,
2833
+ search_keys=combined_search_keys,
2834
+ unnest_search_keys=unnest_search_keys,
2835
+ id_columns=self.__get_renamed_id_columns(columns_renaming),
2836
+ date_column=self._get_date_column(search_keys),
2837
+ date_format=self.date_format,
2838
+ sample_config=self.sample_config,
2839
+ rest_client=self.rest_client,
2840
+ logger=self.logger,
2841
+ bundle=self.bundle,
2842
+ warning_callback=self.__log_warning,
2843
+ )
2844
+ dataset.columns_renaming = columns_renaming
2894
2845
 
2895
- seconds_left = time.time() - start_time
2896
- progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2897
- if progress_bar is not None:
2898
- progress_bar.progress = progress.to_progress_bar()
2899
- if progress_callback is not None:
2900
- progress_callback(progress)
2846
+ validation_task = self._search_task.validation(
2847
+ self._get_trace_id(),
2848
+ dataset,
2849
+ start_time=start_time,
2850
+ extract_features=True,
2851
+ runtime_parameters=runtime_parameters,
2852
+ exclude_features_sources=exclude_features_sources,
2853
+ metrics_calculation=metrics_calculation,
2854
+ silent_mode=silent_mode,
2855
+ progress_bar=progress_bar,
2856
+ progress_callback=progress_callback,
2857
+ )
2901
2858
 
2902
- if not silent_mode:
2903
- print(self.bundle.get("transform_start"))
2859
+ del df_without_features, dataset
2860
+ gc.collect()
2904
2861
 
2905
- # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2906
- df_before_explode = df_before_explode.rename(columns=columns_renaming)
2907
- generated_features = [columns_renaming.get(c, c) for c in generated_features]
2908
- combined_df = pd.concat(
2909
- [
2910
- validated_Xy.reset_index(drop=True),
2911
- df_before_explode.reset_index(drop=True),
2912
- ],
2913
- axis=1,
2914
- ).set_index(validated_Xy.index)
2915
-
2916
- result_features = validation_task.get_all_validation_raw_features(trace_id, metrics_calculation)
2917
-
2918
- result = self.__enrich(
2919
- combined_df,
2920
- result_features,
2921
- how="left",
2922
- )
2862
+ if not silent_mode:
2863
+ print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2864
+ if not self.__is_registered:
2865
+ print(self.bundle.get("polling_unregister_information"))
2923
2866
 
2924
- selecting_columns = self._selecting_input_and_generated_columns(
2925
- validated_Xy, generated_features, keep_input, trace_id, is_transform=True
2926
- )
2927
- selecting_columns.extend(
2928
- c
2929
- for c in result.columns
2930
- if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
2931
- )
2932
- if add_fit_system_record_id:
2933
- selecting_columns.append(SORT_ID)
2867
+ progress = self.get_progress(validation_task)
2868
+ progress.recalculate_eta(time.time() - start_time)
2869
+ if progress_bar is not None:
2870
+ progress_bar.progress = progress.to_progress_bar()
2871
+ if progress_callback is not None:
2872
+ progress_callback(progress)
2873
+ prev_progress: SearchProgress | None = None
2874
+ polling_period_seconds = 1
2875
+ try:
2876
+ while progress.stage != ProgressStage.DOWNLOADING.value:
2877
+ if prev_progress is None or prev_progress.percent != progress.percent:
2878
+ progress.recalculate_eta(time.time() - start_time)
2879
+ else:
2880
+ progress.update_eta(prev_progress.eta - polling_period_seconds)
2881
+ prev_progress = progress
2882
+ if progress_bar is not None:
2883
+ progress_bar.progress = progress.to_progress_bar()
2884
+ if progress_callback is not None:
2885
+ progress_callback(progress)
2886
+ if progress.stage == ProgressStage.FAILED.value:
2887
+ raise Exception(progress.error_message)
2888
+ time.sleep(polling_period_seconds)
2889
+ progress = self.get_progress(validation_task)
2890
+ except KeyboardInterrupt as e:
2891
+ print(self.bundle.get("search_stopping"))
2892
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2893
+ self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2894
+ print(self.bundle.get("search_stopped"))
2895
+ raise e
2896
+
2897
+ validation_task.poll_result(self._get_trace_id(), quiet=True)
2898
+
2899
+ seconds_left = time.time() - start_time
2900
+ progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2901
+ if progress_bar is not None:
2902
+ progress_bar.progress = progress.to_progress_bar()
2903
+ if progress_callback is not None:
2904
+ progress_callback(progress)
2934
2905
 
2935
- selecting_columns = list(set(selecting_columns))
2936
- # sorting: first columns from X, then generated features, then enriched features
2937
- sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
2938
- for c in generated_features:
2939
- if c in selecting_columns and c not in sorted_selecting_columns:
2940
- sorted_selecting_columns.append(c)
2941
- for c in result.columns:
2942
- if c in selecting_columns and c not in sorted_selecting_columns:
2943
- sorted_selecting_columns.append(c)
2906
+ if not silent_mode:
2907
+ print(self.bundle.get("transform_start"))
2944
2908
 
2945
- self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
2909
+ # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2910
+ df_before_explode = df_before_explode.rename(columns=columns_renaming)
2911
+ generated_features = [columns_renaming.get(c, c) for c in generated_features]
2912
+ combined_df = pd.concat(
2913
+ [
2914
+ validated_Xy.reset_index(drop=True),
2915
+ df_before_explode.reset_index(drop=True),
2916
+ ],
2917
+ axis=1,
2918
+ ).set_index(validated_Xy.index)
2919
+
2920
+ result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2921
+
2922
+ result = self.__enrich(
2923
+ combined_df,
2924
+ result_features,
2925
+ how="left",
2926
+ )
2927
+
2928
+ selecting_columns = self._selecting_input_and_generated_columns(
2929
+ validated_Xy, generated_features, keep_input, is_transform=True
2930
+ )
2931
+ selecting_columns.extend(
2932
+ c
2933
+ for c in result.columns
2934
+ if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
2935
+ )
2936
+ if add_fit_system_record_id:
2937
+ selecting_columns.append(SORT_ID)
2946
2938
 
2947
- result = result[sorted_selecting_columns]
2939
+ selecting_columns = list(set(selecting_columns))
2940
+ # sorting: first columns from X, then generated features, then enriched features
2941
+ sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
2942
+ for c in generated_features:
2943
+ if c in selecting_columns and c not in sorted_selecting_columns:
2944
+ sorted_selecting_columns.append(c)
2945
+ for c in result.columns:
2946
+ if c in selecting_columns and c not in sorted_selecting_columns:
2947
+ sorted_selecting_columns.append(c)
2948
2948
 
2949
- if self.country_added:
2950
- result = result.drop(columns=COUNTRY, errors="ignore")
2949
+ self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
2951
2950
 
2952
- if add_fit_system_record_id:
2953
- result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2951
+ result = result[sorted_selecting_columns]
2952
+
2953
+ if self.country_added:
2954
+ result = result.drop(columns=COUNTRY, errors="ignore")
2954
2955
 
2955
- for c in result.columns:
2956
- if result[c].dtype == "category":
2957
- result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
2956
+ if add_fit_system_record_id:
2957
+ result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2958
2958
 
2959
- return result, columns_renaming, generated_features, search_keys
2959
+ for c in result.columns:
2960
+ if result[c].dtype == "category":
2961
+ result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
2962
+
2963
+ return result, columns_renaming, generated_features, search_keys
2960
2964
 
2961
2965
  def _selecting_input_and_generated_columns(
2962
2966
  self,
2963
2967
  validated_Xy: pd.DataFrame,
2964
2968
  generated_features: list[str],
2965
2969
  keep_input: bool,
2966
- trace_id: str,
2967
2970
  is_transform: bool = False,
2968
2971
  ):
2969
- file_meta = self._search_task.get_file_metadata(trace_id)
2972
+ file_meta = self._search_task.get_file_metadata(self._get_trace_id())
2970
2973
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
2971
2974
  fit_input_columns = [c.originalName for c in file_meta.columns]
2972
2975
  original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
2973
2976
  new_columns_on_transform = [
2974
2977
  c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
2975
2978
  ]
2979
+ fit_original_search_keys = self._get_fit_search_keys_with_original_names()
2976
2980
 
2977
2981
  selected_generated_features = [c for c in generated_features if c in self.feature_names_]
2978
2982
  if keep_input is True:
@@ -2982,7 +2986,7 @@ if response.status_code == 200:
2982
2986
  if not self.fit_select_features
2983
2987
  or c in self.feature_names_
2984
2988
  or (c in new_columns_on_transform and is_transform)
2985
- or c in self.search_keys
2989
+ or c in fit_original_search_keys
2986
2990
  or c in (self.id_columns or [])
2987
2991
  or c in [EVAL_SET_INDEX, TARGET] # transform for metrics calculation
2988
2992
  or c == self.baseline_score_column
@@ -3067,7 +3071,6 @@ if response.status_code == 200:
3067
3071
 
3068
3072
  def __inner_fit(
3069
3073
  self,
3070
- trace_id: str,
3071
3074
  X: pd.DataFrame | pd.Series | np.ndarray,
3072
3075
  y: pd.DataFrame | pd.Series | np.ndarray | list | None,
3073
3076
  eval_set: list[tuple] | None,
@@ -3149,6 +3152,8 @@ if response.status_code == 200:
3149
3152
  df = self.__handle_index_search_keys(df, self.fit_search_keys)
3150
3153
  self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
3151
3154
 
3155
+ df = self._validate_OOT(df, self.fit_search_keys)
3156
+
3152
3157
  maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
3153
3158
  has_date = maybe_date_column is not None and maybe_date_column in validated_X.columns
3154
3159
 
@@ -3234,6 +3239,7 @@ if response.status_code == 200:
3234
3239
  )
3235
3240
  self.fit_columns_renaming = normalizer.columns_renaming
3236
3241
  if normalizer.removed_datetime_features:
3242
+ self.fit_dropped_features.update(normalizer.removed_datetime_features)
3237
3243
  original_removed_datetime_features = [
3238
3244
  self.fit_columns_renaming.get(f, f) for f in normalizer.removed_datetime_features
3239
3245
  ]
@@ -3400,6 +3406,7 @@ if response.status_code == 200:
3400
3406
  id_columns=self.__get_renamed_id_columns(),
3401
3407
  is_imbalanced=self.imbalanced,
3402
3408
  dropped_columns=[self.fit_columns_renaming.get(f, f) for f in self.fit_dropped_features],
3409
+ autodetected_search_keys=self.autodetected_search_keys,
3403
3410
  date_column=self._get_date_column(self.fit_search_keys),
3404
3411
  date_format=self.date_format,
3405
3412
  random_state=self.random_state,
@@ -3423,7 +3430,7 @@ if response.status_code == 200:
3423
3430
  ]
3424
3431
 
3425
3432
  self._search_task = dataset.search(
3426
- trace_id=trace_id,
3433
+ trace_id=self._get_trace_id(),
3427
3434
  progress_bar=progress_bar,
3428
3435
  start_time=start_time,
3429
3436
  progress_callback=progress_callback,
@@ -3443,7 +3450,7 @@ if response.status_code == 200:
3443
3450
  if not self.__is_registered:
3444
3451
  print(self.bundle.get("polling_unregister_information"))
3445
3452
 
3446
- progress = self.get_progress(trace_id)
3453
+ progress = self.get_progress()
3447
3454
  prev_progress = None
3448
3455
  progress.recalculate_eta(time.time() - start_time)
3449
3456
  if progress_bar is not None:
@@ -3469,16 +3476,16 @@ if response.status_code == 200:
3469
3476
  )
3470
3477
  raise RuntimeError(self.bundle.get("search_task_failed_status"))
3471
3478
  time.sleep(poll_period_seconds)
3472
- progress = self.get_progress(trace_id)
3479
+ progress = self.get_progress()
3473
3480
  except KeyboardInterrupt as e:
3474
3481
  print(self.bundle.get("search_stopping"))
3475
- self.rest_client.stop_search_task_v2(trace_id, self._search_task.search_task_id)
3482
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), self._search_task.search_task_id)
3476
3483
  self.logger.warning(f"Search {self._search_task.search_task_id} stopped by user")
3477
3484
  self._search_task = None
3478
3485
  print(self.bundle.get("search_stopped"))
3479
3486
  raise e
3480
3487
 
3481
- self._search_task.poll_result(trace_id, quiet=True)
3488
+ self._search_task.poll_result(self._get_trace_id(), quiet=True)
3482
3489
 
3483
3490
  seconds_left = time.time() - start_time
3484
3491
  progress = SearchProgress(97.0, ProgressStage.GENERATING_REPORT, seconds_left)
@@ -3507,10 +3514,9 @@ if response.status_code == 200:
3507
3514
  msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
3508
3515
  self.__log_warning(msg)
3509
3516
 
3510
- self.__prepare_feature_importances(trace_id, df)
3517
+ self.__prepare_feature_importances(df)
3511
3518
 
3512
3519
  self._select_features_by_psi(
3513
- trace_id=trace_id,
3514
3520
  X=X,
3515
3521
  y=y,
3516
3522
  eval_set=eval_set,
@@ -3523,7 +3529,7 @@ if response.status_code == 200:
3523
3529
  progress_callback=progress_callback,
3524
3530
  )
3525
3531
 
3526
- self.__prepare_feature_importances(trace_id, df)
3532
+ self.__prepare_feature_importances(df)
3527
3533
 
3528
3534
  self.__show_selected_features()
3529
3535
 
@@ -3558,7 +3564,6 @@ if response.status_code == 200:
3558
3564
  scoring,
3559
3565
  estimator,
3560
3566
  remove_outliers_calc_metrics,
3561
- trace_id,
3562
3567
  progress_bar,
3563
3568
  progress_callback,
3564
3569
  )
@@ -3653,11 +3658,10 @@ if response.status_code == 200:
3653
3658
  y: pd.Series | None = None,
3654
3659
  eval_set: list[tuple[pd.DataFrame, pd.Series]] | None = None,
3655
3660
  is_transform: bool = False,
3656
- silent: bool = False,
3657
3661
  ) -> tuple[pd.DataFrame, pd.Series, list[tuple[pd.DataFrame, pd.Series]]] | None:
3658
3662
  validated_X = self._validate_X(X, is_transform)
3659
3663
  validated_y = self._validate_y(validated_X, y, enforce_y=not is_transform)
3660
- validated_eval_set = self._validate_eval_set(validated_X, eval_set, silent=silent)
3664
+ validated_eval_set = self._validate_eval_set(validated_X, eval_set)
3661
3665
  return validated_X, validated_y, validated_eval_set
3662
3666
 
3663
3667
  def _encode_id_columns(
@@ -3783,31 +3787,41 @@ if response.status_code == 200:
3783
3787
  return validated_y
3784
3788
 
3785
3789
  def _validate_eval_set(
3786
- self, X: pd.DataFrame, eval_set: list[tuple[pd.DataFrame, pd.Series]] | None, silent: bool = False
3787
- ):
3790
+ self,
3791
+ X: pd.DataFrame,
3792
+ eval_set: list[tuple[pd.DataFrame, pd.Series]] | None,
3793
+ ) -> list[tuple[pd.DataFrame, pd.Series]] | None:
3788
3794
  if eval_set is None:
3789
3795
  return None
3790
3796
  validated_eval_set = []
3791
- date_col = self._get_date_column(self.search_keys)
3792
- has_date = date_col is not None and date_col in X.columns
3793
- for idx, eval_pair in enumerate(eval_set):
3797
+ for _, eval_pair in enumerate(eval_set):
3794
3798
  validated_pair = self._validate_eval_set_pair(X, eval_pair)
3795
- if validated_pair[1].isna().all():
3796
- if not has_date:
3797
- msg = self.bundle.get("oot_without_date_not_supported").format(idx + 1)
3798
- elif self.columns_for_online_api:
3799
- msg = self.bundle.get("oot_with_online_sources_not_supported").format(idx + 1)
3800
- else:
3801
- msg = None
3802
- if msg:
3803
- if not silent:
3804
- print(msg)
3805
- self.logger.warning(msg)
3806
- continue
3807
3799
  validated_eval_set.append(validated_pair)
3808
3800
 
3809
3801
  return validated_eval_set
3810
3802
 
3803
+ def _validate_OOT(self, df: pd.DataFrame, search_keys: dict[str, SearchKey]) -> pd.DataFrame:
3804
+ if EVAL_SET_INDEX not in df.columns:
3805
+ return df
3806
+
3807
+ for eval_set_index in df[EVAL_SET_INDEX].unique():
3808
+ if eval_set_index == 0:
3809
+ continue
3810
+ eval_df = df[df[EVAL_SET_INDEX] == eval_set_index]
3811
+ date_col = self._get_date_column(search_keys)
3812
+ has_date = date_col is not None and date_col in eval_df.columns
3813
+ if eval_df[TARGET].isna().all():
3814
+ msg = None
3815
+ if not has_date:
3816
+ msg = self.bundle.get("oot_without_date_not_supported").format(eval_set_index)
3817
+ elif self.columns_for_online_api:
3818
+ msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
3819
+ if msg:
3820
+ print(msg)
3821
+ self.logger.warning(msg)
3822
+ df = df[df[EVAL_SET_INDEX] != eval_set_index]
3823
+ return df
3824
+
3811
3825
  def _validate_eval_set_pair(self, X: pd.DataFrame, eval_pair: tuple) -> tuple[pd.DataFrame, pd.Series]:
3812
3826
  if len(eval_pair) != 2:
3813
3827
  raise ValidationError(self.bundle.get("eval_set_invalid_tuple_size").format(len(eval_pair)))
@@ -4423,47 +4437,6 @@ if response.status_code == 200:
4423
4437
 
4424
4438
  return result_features
4425
4439
 
4426
- def __get_features_importance_from_server(self, trace_id: str, df: pd.DataFrame):
4427
- if self._search_task is None:
4428
- raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
4429
- features_meta = self._search_task.get_all_features_metadata_v2()
4430
- if features_meta is None:
4431
- raise Exception(self.bundle.get("missing_features_meta"))
4432
- features_meta = deepcopy(features_meta)
4433
-
4434
- original_names_dict = {c.name: c.originalName for c in self._search_task.get_file_metadata(trace_id).columns}
4435
- df = df.rename(columns=original_names_dict)
4436
-
4437
- features_meta.sort(key=lambda m: (-m.shap_value, m.name))
4438
-
4439
- importances = {}
4440
-
4441
- for feature_meta in features_meta:
4442
- if feature_meta.name in original_names_dict.keys():
4443
- feature_meta.name = original_names_dict[feature_meta.name]
4444
-
4445
- is_client_feature = feature_meta.name in df.columns
4446
-
4447
- if feature_meta.shap_value == 0.0:
4448
- continue
4449
-
4450
- # Use only important features
4451
- if (
4452
- feature_meta.name == COUNTRY
4453
- # In select_features mode we select also from etalon features and need to show them
4454
- or (not self.fit_select_features and is_client_feature)
4455
- ):
4456
- continue
4457
-
4458
- # Temporary workaround for duplicate features metadata
4459
- if feature_meta.name in importances:
4460
- self.logger.warning(f"WARNING: Duplicate feature metadata: {feature_meta}")
4461
- continue
4462
-
4463
- importances[feature_meta.name] = feature_meta.shap_value
4464
-
4465
- return importances
4466
-
4467
4440
  def __get_categorical_features(self) -> list[str]:
4468
4441
  features_meta = self._search_task.get_all_features_metadata_v2()
4469
4442
  if features_meta is None:
@@ -4473,7 +4446,6 @@ if response.status_code == 200:
4473
4446
 
4474
4447
  def __prepare_feature_importances(
4475
4448
  self,
4476
- trace_id: str,
4477
4449
  clients_features_df: pd.DataFrame,
4478
4450
  updated_shaps: dict[str, float] | None = None,
4479
4451
  update_selected_features: bool = True,
@@ -4481,16 +4453,16 @@ if response.status_code == 200:
4481
4453
  ):
4482
4454
  if self._search_task is None:
4483
4455
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
4484
- selected_features = self._search_task.get_selected_features(trace_id)
4456
+ selected_features = self._search_task.get_selected_features(self._get_trace_id())
4485
4457
  features_meta = self._search_task.get_all_features_metadata_v2()
4486
4458
  if features_meta is None:
4487
4459
  raise Exception(self.bundle.get("missing_features_meta"))
4488
4460
  features_meta = deepcopy(features_meta)
4489
4461
 
4490
- file_metadata_columns = self._search_task.get_file_metadata(trace_id).columns
4462
+ file_metadata_columns = self._search_task.get_file_metadata(self._get_trace_id()).columns
4491
4463
  file_meta_by_orig_name = {c.originalName: c for c in file_metadata_columns}
4492
4464
  original_names_dict = {c.name: c.originalName for c in file_metadata_columns}
4493
- features_df = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
4465
+ features_df = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
4494
4466
 
4495
4467
  # To be sure that names with hash suffixes
4496
4468
  clients_features_df = clients_features_df.rename(columns=original_names_dict)
@@ -4581,7 +4553,7 @@ if response.status_code == 200:
4581
4553
  internal_features_info.append(feature_info.to_internal_row(self.bundle))
4582
4554
 
4583
4555
  if update_selected_features:
4584
- self._search_task.update_selected_features(trace_id, self.feature_names_)
4556
+ self._search_task.update_selected_features(self._get_trace_id(), self.feature_names_)
4585
4557
 
4586
4558
  if len(features_info) > 0:
4587
4559
  self.features_info = pd.DataFrame(features_info)
@@ -4779,12 +4751,17 @@ if response.status_code == 200:
4779
4751
  ):
4780
4752
  raise ValidationError(self.bundle.get("empty_search_key").format(column_name))
4781
4753
 
4782
- if self.autodetect_search_keys and (
4783
- not is_transform or set(valid_search_keys.values()) != set(self.fit_search_keys.values())
4784
- ):
4785
- valid_search_keys = self.__detect_missing_search_keys(
4786
- x, valid_search_keys, is_demo_dataset, silent_mode, is_transform
4787
- )
4754
+ if is_transform:
4755
+ fit_autodetected_search_keys = self._get_autodetected_search_keys()
4756
+ if fit_autodetected_search_keys is not None:
4757
+ for key in fit_autodetected_search_keys.keys():
4758
+ if key not in x.columns:
4759
+ raise ValidationError(
4760
+ self.bundle.get("autodetected_search_key_not_found").format(key, x.columns)
4761
+ )
4762
+ valid_search_keys.update(fit_autodetected_search_keys)
4763
+ elif self.autodetect_search_keys:
4764
+ valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
4788
4765
 
4789
4766
  if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4790
4767
  if self.__is_registered:
@@ -4829,7 +4806,6 @@ if response.status_code == 200:
4829
4806
  scoring: Callable | str | None,
4830
4807
  estimator: Any | None,
4831
4808
  remove_outliers_calc_metrics: bool | None,
4832
- trace_id: str,
4833
4809
  progress_bar: ProgressBar | None = None,
4834
4810
  progress_callback: Callable[[SearchProgress], Any] | None = None,
4835
4811
  ):
@@ -4837,7 +4813,6 @@ if response.status_code == 200:
4837
4813
  scoring=scoring,
4838
4814
  estimator=estimator,
4839
4815
  remove_outliers_calc_metrics=remove_outliers_calc_metrics,
4840
- trace_id=trace_id,
4841
4816
  internal_call=True,
4842
4817
  progress_bar=progress_bar,
4843
4818
  progress_callback=progress_callback,
@@ -4902,60 +4877,36 @@ if response.status_code == 200:
4902
4877
  df: pd.DataFrame,
4903
4878
  search_keys: dict[str, SearchKey],
4904
4879
  is_demo_dataset: bool,
4905
- silent_mode=False,
4906
- is_transform=False,
4907
4880
  ) -> dict[str, SearchKey]:
4908
4881
  sample = df.head(100)
4909
4882
 
4910
- def check_need_detect(search_key: SearchKey):
4911
- return not is_transform or (
4912
- search_key in self.fit_search_keys.values() and search_key not in search_keys.values()
4913
- )
4914
-
4915
- if (
4916
- SearchKey.DATE not in search_keys.values()
4917
- and SearchKey.DATETIME not in search_keys.values()
4918
- and check_need_detect(SearchKey.DATE)
4919
- and check_need_detect(SearchKey.DATETIME)
4920
- ):
4883
+ if SearchKey.DATE not in search_keys.values() and SearchKey.DATETIME not in search_keys.values():
4921
4884
  maybe_keys = DateSearchKeyDetector().get_search_key_columns(sample, search_keys)
4922
4885
  if len(maybe_keys) > 0:
4923
4886
  datetime_key = maybe_keys[0]
4924
4887
  search_keys[datetime_key] = SearchKey.DATETIME
4925
4888
  self.autodetected_search_keys[datetime_key] = SearchKey.DATETIME
4926
4889
  self.logger.info(f"Autodetected search key DATETIME in column {datetime_key}")
4927
- if not silent_mode:
4928
- print(self.bundle.get("datetime_detected").format(datetime_key))
4890
+ print(self.bundle.get("datetime_detected").format(datetime_key))
4929
4891
 
4930
4892
  # if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
4931
- if check_need_detect(SearchKey.POSTAL_CODE):
4932
- maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
4933
- if maybe_keys:
4934
- new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
4935
- search_keys.update(new_keys)
4936
- self.autodetected_search_keys.update(new_keys)
4937
- self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
4938
- if not silent_mode:
4939
- print(self.bundle.get("postal_code_detected").format(maybe_keys))
4940
-
4941
- if (
4942
- SearchKey.COUNTRY not in search_keys.values()
4943
- and self.country_code is None
4944
- and check_need_detect(SearchKey.COUNTRY)
4945
- ):
4893
+ maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
4894
+ if maybe_keys:
4895
+ new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
4896
+ search_keys.update(new_keys)
4897
+ self.autodetected_search_keys.update(new_keys)
4898
+ self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
4899
+ print(self.bundle.get("postal_code_detected").format(maybe_keys))
4900
+
4901
+ if SearchKey.COUNTRY not in search_keys.values() and self.country_code is None:
4946
4902
  maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
4947
4903
  if maybe_key:
4948
4904
  search_keys[maybe_key[0]] = SearchKey.COUNTRY
4949
4905
  self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
4950
4906
  self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
4951
- if not silent_mode:
4952
- print(self.bundle.get("country_detected").format(maybe_key))
4907
+ print(self.bundle.get("country_detected").format(maybe_key))
4953
4908
 
4954
- if (
4955
- # SearchKey.EMAIL not in search_keys.values()
4956
- SearchKey.HEM not in search_keys.values()
4957
- and check_need_detect(SearchKey.HEM)
4958
- ):
4909
+ if SearchKey.EMAIL not in search_keys.values() and SearchKey.HEM not in search_keys.values():
4959
4910
  maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
4960
4911
  if maybe_keys:
4961
4912
  if self.__is_registered or is_demo_dataset:
@@ -4963,34 +4914,28 @@ if response.status_code == 200:
4963
4914
  search_keys.update(new_keys)
4964
4915
  self.autodetected_search_keys.update(new_keys)
4965
4916
  self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
4966
- if not silent_mode:
4967
- print(self.bundle.get("email_detected").format(maybe_keys))
4917
+ print(self.bundle.get("email_detected").format(maybe_keys))
4968
4918
  else:
4969
4919
  self.logger.warning(
4970
4920
  f"Autodetected search key EMAIL in column {maybe_keys}."
4971
4921
  " But not used because not registered user"
4972
4922
  )
4973
- if not silent_mode:
4974
- self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
4923
+ self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
4975
4924
 
4976
4925
  # if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
4977
- if check_need_detect(SearchKey.PHONE):
4978
- maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
4979
- if maybe_keys:
4980
- if self.__is_registered or is_demo_dataset:
4981
- new_keys = {key: SearchKey.PHONE for key in maybe_keys}
4982
- search_keys.update(new_keys)
4983
- self.autodetected_search_keys.update(new_keys)
4984
- self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
4985
- if not silent_mode:
4986
- print(self.bundle.get("phone_detected").format(maybe_keys))
4987
- else:
4988
- self.logger.warning(
4989
- f"Autodetected search key PHONE in column {maybe_keys}. "
4990
- "But not used because not registered user"
4991
- )
4992
- if not silent_mode:
4993
- self.__log_warning(self.bundle.get("phone_detected_not_registered"))
4926
+ maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
4927
+ if maybe_keys:
4928
+ if self.__is_registered or is_demo_dataset:
4929
+ new_keys = {key: SearchKey.PHONE for key in maybe_keys}
4930
+ search_keys.update(new_keys)
4931
+ self.autodetected_search_keys.update(new_keys)
4932
+ self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
4933
+ print(self.bundle.get("phone_detected").format(maybe_keys))
4934
+ else:
4935
+ self.logger.warning(
4936
+ f"Autodetected search key PHONE in column {maybe_keys}. " "But not used because not registered user"
4937
+ )
4938
+ self.__log_warning(self.bundle.get("phone_detected_not_registered"))
4994
4939
 
4995
4940
  return search_keys
4996
4941
 
@@ -5062,13 +5007,12 @@ if response.status_code == 200:
5062
5007
 
5063
5008
  def dump_input(
5064
5009
  self,
5065
- trace_id: str,
5066
5010
  X: pd.DataFrame | pd.Series,
5067
5011
  y: pd.DataFrame | pd.Series | None = None,
5068
5012
  eval_set: tuple | None = None,
5069
5013
  ):
5070
- def dump_task(X_, y_, eval_set_):
5071
- with MDC(correlation_id=trace_id):
5014
+ def dump_task(X_, y_, eval_set_, trace_id_):
5015
+ with MDC(correlation_id=trace_id_):
5072
5016
  try:
5073
5017
  if isinstance(X_, pd.Series):
5074
5018
  X_ = X_.to_frame()
@@ -5076,27 +5020,25 @@ if response.status_code == 200:
5076
5020
  with tempfile.TemporaryDirectory() as tmp_dir:
5077
5021
  X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
5078
5022
  x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
5079
- if self.rest_client.is_file_uploaded(trace_id, x_digest_sha256):
5023
+ if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
5080
5024
  self.logger.info(
5081
5025
  f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
5082
5026
  )
5083
5027
  else:
5084
- self.rest_client.dump_input_file(
5085
- trace_id, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
5086
- )
5028
+ self.rest_client.dump_input_file(f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256)
5087
5029
 
5088
5030
  if y_ is not None:
5089
5031
  if isinstance(y_, pd.Series):
5090
5032
  y_ = y_.to_frame()
5091
5033
  y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
5092
5034
  y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
5093
- if self.rest_client.is_file_uploaded(trace_id, y_digest_sha256):
5035
+ if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
5094
5036
  self.logger.info(
5095
5037
  f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
5096
5038
  )
5097
5039
  else:
5098
5040
  self.rest_client.dump_input_file(
5099
- trace_id, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5041
+ trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5100
5042
  )
5101
5043
 
5102
5044
  if eval_set_ is not None and len(eval_set_) > 0:
@@ -5105,14 +5047,14 @@ if response.status_code == 200:
5105
5047
  eval_x_ = eval_x_.to_frame()
5106
5048
  eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
5107
5049
  eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
5108
- if self.rest_client.is_file_uploaded(trace_id, eval_x_digest_sha256):
5050
+ if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
5109
5051
  self.logger.info(
5110
5052
  f"File eval_x_{idx}.parquet was already uploaded with"
5111
5053
  f" digest {eval_x_digest_sha256}, skipping"
5112
5054
  )
5113
5055
  else:
5114
5056
  self.rest_client.dump_input_file(
5115
- trace_id,
5057
+ trace_id_,
5116
5058
  f"{tmp_dir}/eval_x_{idx}.parquet",
5117
5059
  f"eval_x_{idx}.parquet",
5118
5060
  eval_x_digest_sha256,
@@ -5122,14 +5064,14 @@ if response.status_code == 200:
5122
5064
  eval_y_ = eval_y_.to_frame()
5123
5065
  eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
5124
5066
  eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
5125
- if self.rest_client.is_file_uploaded(trace_id, eval_y_digest_sha256):
5067
+ if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
5126
5068
  self.logger.info(
5127
5069
  f"File eval_y_{idx}.parquet was already uploaded"
5128
5070
  f" with digest {eval_y_digest_sha256}, skipping"
5129
5071
  )
5130
5072
  else:
5131
5073
  self.rest_client.dump_input_file(
5132
- trace_id,
5074
+ trace_id_,
5133
5075
  f"{tmp_dir}/eval_y_{idx}.parquet",
5134
5076
  f"eval_y_{idx}.parquet",
5135
5077
  eval_y_digest_sha256,
@@ -5138,7 +5080,8 @@ if response.status_code == 200:
5138
5080
  self.logger.warning("Failed to dump input files", exc_info=True)
5139
5081
 
5140
5082
  try:
5141
- Thread(target=dump_task, args=(X, y, eval_set), daemon=True).start()
5083
+ trace_id = self._get_trace_id()
5084
+ Thread(target=dump_task, args=(X, y, eval_set, trace_id), daemon=True).start()
5142
5085
  except Exception:
5143
5086
  self.logger.warning("Failed to dump input files", exc_info=True)
5144
5087