upgini 1.2.141__py3-none-any.whl → 1.2.142a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

@@ -42,6 +42,7 @@ from upgini.http import (
42
42
  get_rest_client,
43
43
  )
44
44
  from upgini.mdc import MDC
45
+ from upgini.mdc.context import get_mdc_fields
45
46
  from upgini.metadata import (
46
47
  COUNTRY,
47
48
  CURRENT_DATE_COL,
@@ -273,7 +274,7 @@ class FeaturesEnricher(TransformerMixin):
273
274
  self.X: pd.DataFrame | None = None
274
275
  self.y: pd.Series | None = None
275
276
  self.eval_set: list[tuple] | None = None
276
- self.autodetected_search_keys: dict[str, SearchKey] = {}
277
+ self.autodetected_search_keys: dict[str, SearchKey] = dict()
277
278
  self.imbalanced = False
278
279
  self.fit_select_features = True
279
280
  self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
@@ -309,8 +310,8 @@ class FeaturesEnricher(TransformerMixin):
309
310
  search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
310
311
 
311
312
  print(self.bundle.get("search_by_task_id_start"))
312
- trace_id = time.time_ns()
313
- with MDC(correlation_id=trace_id):
313
+ trace_id = self._get_trace_id()
314
+ with MDC(correlation_id=trace_id, search_task_id=search_id):
314
315
  try:
315
316
  self.logger.debug(f"FeaturesEnricher created from existing search: {search_id}")
316
317
  self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
@@ -318,7 +319,7 @@ class FeaturesEnricher(TransformerMixin):
318
319
  x_columns = [c.name for c in file_metadata.columns]
319
320
  self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
320
321
  df = pd.DataFrame(columns=x_columns)
321
- self.__prepare_feature_importances(trace_id, df, silent=True, update_selected_features=False)
322
+ self.__prepare_feature_importances(df, silent=True, update_selected_features=False)
322
323
  if print_loaded_report:
323
324
  self.__show_selected_features()
324
325
  # TODO validate search_keys with search_keys from file_metadata
@@ -487,7 +488,7 @@ class FeaturesEnricher(TransformerMixin):
487
488
  stability_agg_func: str, optional (default="max")
488
489
  Function to aggregate stability values. Can be "max", "min", "mean".
489
490
  """
490
- trace_id = time.time_ns()
491
+ trace_id = self._get_trace_id()
491
492
  if self.print_trace_id:
492
493
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
493
494
  start_time = time.time()
@@ -522,10 +523,9 @@ class FeaturesEnricher(TransformerMixin):
522
523
  self.X = X
523
524
  self.y = y
524
525
  self.eval_set = self._check_eval_set(eval_set, X)
525
- self.dump_input(trace_id, X, y, self.eval_set)
526
+ self.dump_input(X, y, self.eval_set)
526
527
  self.__set_select_features(select_features)
527
528
  self.__inner_fit(
528
- trace_id,
529
529
  X,
530
530
  y,
531
531
  self.eval_set,
@@ -646,7 +646,7 @@ class FeaturesEnricher(TransformerMixin):
646
646
 
647
647
  self.warning_counter.reset()
648
648
  auto_fe_parameters = AutoFEParameters() if auto_fe_parameters is None else auto_fe_parameters
649
- trace_id = time.time_ns()
649
+ trace_id = self._get_trace_id()
650
650
  if self.print_trace_id:
651
651
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
652
652
  start_time = time.time()
@@ -677,13 +677,12 @@ class FeaturesEnricher(TransformerMixin):
677
677
  self.y = y
678
678
  self.eval_set = self._check_eval_set(eval_set, X)
679
679
  self.__set_select_features(select_features)
680
- self.dump_input(trace_id, X, y, self.eval_set)
680
+ self.dump_input(X, y, self.eval_set)
681
681
 
682
682
  if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
683
683
  raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
684
684
 
685
685
  self.__inner_fit(
686
- trace_id,
687
686
  X,
688
687
  y,
689
688
  self.eval_set,
@@ -738,7 +737,6 @@ class FeaturesEnricher(TransformerMixin):
738
737
  y,
739
738
  exclude_features_sources=exclude_features_sources,
740
739
  keep_input=keep_input,
741
- trace_id=trace_id,
742
740
  silent_mode=True,
743
741
  progress_bar=progress_bar,
744
742
  progress_callback=progress_callback,
@@ -753,7 +751,6 @@ class FeaturesEnricher(TransformerMixin):
753
751
  *args,
754
752
  exclude_features_sources: list[str] | None = None,
755
753
  keep_input: bool = True,
756
- trace_id: str | None = None,
757
754
  silent_mode=False,
758
755
  progress_bar: ProgressBar | None = None,
759
756
  progress_callback: Callable[[SearchProgress], Any] | None = None,
@@ -790,12 +787,12 @@ class FeaturesEnricher(TransformerMixin):
790
787
  progress_bar.progress = search_progress.to_progress_bar()
791
788
  if new_progress:
792
789
  progress_bar.display()
793
- trace_id = trace_id or time.time_ns()
790
+ trace_id = self._get_trace_id()
794
791
  if self.print_trace_id:
795
792
  print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
796
793
  search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
797
794
  with MDC(correlation_id=trace_id, search_id=search_id):
798
- self.dump_input(trace_id, X)
795
+ self.dump_input(X)
799
796
  if len(args) > 0:
800
797
  msg = f"WARNING: Unsupported positional arguments for transform: {args}"
801
798
  self.logger.warning(msg)
@@ -808,7 +805,6 @@ class FeaturesEnricher(TransformerMixin):
808
805
  start_time = time.time()
809
806
  try:
810
807
  result, _, _, _ = self.__inner_transform(
811
- trace_id,
812
808
  X,
813
809
  y=y,
814
810
  exclude_features_sources=exclude_features_sources,
@@ -872,7 +868,6 @@ class FeaturesEnricher(TransformerMixin):
872
868
  estimator=None,
873
869
  exclude_features_sources: list[str] | None = None,
874
870
  remove_outliers_calc_metrics: bool | None = None,
875
- trace_id: str | None = None,
876
871
  internal_call: bool = False,
877
872
  progress_bar: ProgressBar | None = None,
878
873
  progress_callback: Callable[[SearchProgress], Any] | None = None,
@@ -910,7 +905,7 @@ class FeaturesEnricher(TransformerMixin):
910
905
  Dataframe with metrics calculated on train and validation datasets.
911
906
  """
912
907
 
913
- trace_id = trace_id or time.time_ns()
908
+ trace_id = self._get_trace_id()
914
909
  start_time = time.time()
915
910
  search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
916
911
  with MDC(correlation_id=trace_id, search_id=search_id):
@@ -943,7 +938,7 @@ class FeaturesEnricher(TransformerMixin):
943
938
  raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
944
939
 
945
940
  validated_X, validated_y, validated_eval_set = self._validate_train_eval(
946
- effective_X, effective_y, effective_eval_set, silent=internal_call
941
+ effective_X, effective_y, effective_eval_set
947
942
  )
948
943
 
949
944
  if self.X is None:
@@ -978,11 +973,13 @@ class FeaturesEnricher(TransformerMixin):
978
973
  self.__display_support_link(msg)
979
974
  return None
980
975
 
976
+ search_keys = self._get_fit_search_keys_with_original_names()
977
+
981
978
  cat_features_from_backend = self.__get_categorical_features()
982
979
  # Convert to original names
983
980
  cat_features_from_backend = [self.fit_columns_renaming.get(c, c) for c in cat_features_from_backend]
984
981
  client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
985
- estimator, validated_X, self.search_keys
982
+ estimator, validated_X, search_keys
986
983
  )
987
984
  # Exclude id columns from cat_features
988
985
  if self.id_columns and self.id_columns_encoder is not None:
@@ -1004,7 +1001,6 @@ class FeaturesEnricher(TransformerMixin):
1004
1001
  self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
1005
1002
 
1006
1003
  prepared_data = self._get_cached_enriched_data(
1007
- trace_id=trace_id,
1008
1004
  X=X,
1009
1005
  y=y,
1010
1006
  eval_set=eval_set,
@@ -1257,7 +1253,7 @@ class FeaturesEnricher(TransformerMixin):
1257
1253
 
1258
1254
  if updating_shaps is not None:
1259
1255
  decoded_X = self._decode_id_columns(fitting_X)
1260
- self._update_shap_values(trace_id, decoded_X, updating_shaps, silent=not internal_call)
1256
+ self._update_shap_values(decoded_X, updating_shaps, silent=not internal_call)
1261
1257
 
1262
1258
  metrics_df = pd.DataFrame(metrics)
1263
1259
  mean_target_hdr = self.bundle.get("quality_metrics_mean_target_header")
@@ -1307,9 +1303,33 @@ class FeaturesEnricher(TransformerMixin):
1307
1303
  finally:
1308
1304
  self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
1309
1305
 
1306
+ def _get_trace_id(self):
1307
+ if get_mdc_fields().get("correlation_id") is not None:
1308
+ return get_mdc_fields().get("correlation_id")
1309
+ return int(time.time() * 1000)
1310
+
1311
+ def _get_autodetected_search_keys(self):
1312
+ if self.autodetected_search_keys is None and self._search_task is not None:
1313
+ meta = self._search_task.get_file_metadata(self._get_trace_id())
1314
+ self.autodetected_search_keys = {k: SearchKey[v] for k, v in meta.autodetectedSearchKeys.items()}
1315
+
1316
+ return self.autodetected_search_keys
1317
+
1318
+ def _get_fit_search_keys_with_original_names(self):
1319
+ if self.fit_search_keys is None and self._search_task is not None:
1320
+ fit_search_keys = dict()
1321
+ meta = self._search_task.get_file_metadata(self._get_trace_id())
1322
+ for column in meta.columns:
1323
+ # TODO check for EMAIL->HEM and multikeys
1324
+ search_key_type = SearchKey.from_meaning_type(column.meaningType)
1325
+ if search_key_type is not None:
1326
+ fit_search_keys[column.originalName] = search_key_type
1327
+ else:
1328
+ fit_search_keys = {self.fit_columns_renaming.get(k, k): v for k, v in self.fit_search_keys.items()}
1329
+ return fit_search_keys
1330
+
1310
1331
  def _select_features_by_psi(
1311
1332
  self,
1312
- trace_id: str,
1313
1333
  X: pd.DataFrame | pd.Series | np.ndarray,
1314
1334
  y: pd.DataFrame | pd.Series | np.ndarray | list,
1315
1335
  eval_set: list[tuple] | tuple | None,
@@ -1322,7 +1342,8 @@ class FeaturesEnricher(TransformerMixin):
1322
1342
  progress_callback: Callable | None = None,
1323
1343
  ):
1324
1344
  search_keys = self.search_keys.copy()
1325
- validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set, silent=True)
1345
+ search_keys.update(self._get_autodetected_search_keys())
1346
+ validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set)
1326
1347
  if isinstance(X, np.ndarray):
1327
1348
  search_keys = {str(k): v for k, v in search_keys.items()}
1328
1349
 
@@ -1357,7 +1378,6 @@ class FeaturesEnricher(TransformerMixin):
1357
1378
  ]
1358
1379
 
1359
1380
  prepared_data = self._get_cached_enriched_data(
1360
- trace_id=trace_id,
1361
1381
  X=X,
1362
1382
  y=y,
1363
1383
  eval_set=eval_set,
@@ -1506,7 +1526,7 @@ class FeaturesEnricher(TransformerMixin):
1506
1526
 
1507
1527
  return total_unstable_features
1508
1528
 
1509
- def _update_shap_values(self, trace_id: str, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
1529
+ def _update_shap_values(self, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
1510
1530
  renaming = self.fit_columns_renaming or {}
1511
1531
  self.logger.info(f"Updating SHAP values: {new_shaps}")
1512
1532
  new_shaps = {
@@ -1514,7 +1534,7 @@ class FeaturesEnricher(TransformerMixin):
1514
1534
  for feature, shap in new_shaps.items()
1515
1535
  if feature in self.feature_names_ or renaming.get(feature, feature) in self.feature_names_
1516
1536
  }
1517
- self.__prepare_feature_importances(trace_id, df, new_shaps)
1537
+ self.__prepare_feature_importances(df, new_shaps)
1518
1538
 
1519
1539
  if not silent and self.features_info_display_handle is not None:
1520
1540
  try:
@@ -1694,7 +1714,6 @@ class FeaturesEnricher(TransformerMixin):
1694
1714
 
1695
1715
  def _get_cached_enriched_data(
1696
1716
  self,
1697
- trace_id: str,
1698
1717
  X: pd.DataFrame | pd.Series | np.ndarray | None = None,
1699
1718
  y: pd.DataFrame | pd.Series | np.ndarray | list | None = None,
1700
1719
  eval_set: list[tuple] | tuple | None = None,
@@ -1710,10 +1729,9 @@ class FeaturesEnricher(TransformerMixin):
1710
1729
  is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
1711
1730
  is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
1712
1731
  checked_eval_set = self._check_eval_set(eval_set, X)
1713
- validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set, silent=True)
1732
+ validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set)
1714
1733
 
1715
1734
  sampled_data = self._get_enriched_datasets(
1716
- trace_id=trace_id,
1717
1735
  validated_X=validated_X,
1718
1736
  validated_y=validated_y,
1719
1737
  eval_set=validated_eval_set,
@@ -1740,7 +1758,7 @@ class FeaturesEnricher(TransformerMixin):
1740
1758
 
1741
1759
  self.logger.info(f"Excluding search keys: {excluding_search_keys}")
1742
1760
 
1743
- file_meta = self._search_task.get_file_metadata(trace_id)
1761
+ file_meta = self._search_task.get_file_metadata(self._get_trace_id())
1744
1762
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
1745
1763
  original_dropped_features = [columns_renaming.get(f, f) for f in fit_dropped_features]
1746
1764
 
@@ -1917,7 +1935,6 @@ class FeaturesEnricher(TransformerMixin):
1917
1935
 
1918
1936
  def _get_enriched_datasets(
1919
1937
  self,
1920
- trace_id: str,
1921
1938
  validated_X: pd.DataFrame | pd.Series | np.ndarray | None,
1922
1939
  validated_y: pd.DataFrame | pd.Series | np.ndarray | list | None,
1923
1940
  eval_set: list[tuple] | None,
@@ -1945,9 +1962,7 @@ class FeaturesEnricher(TransformerMixin):
1945
1962
  and self.df_with_original_index is not None
1946
1963
  ):
1947
1964
  self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
1948
- return self.__get_enriched_from_fit(
1949
- validated_X, validated_y, eval_set, trace_id, remove_outliers_calc_metrics
1950
- )
1965
+ return self.__get_enriched_from_fit(validated_X, validated_y, eval_set, remove_outliers_calc_metrics)
1951
1966
  else:
1952
1967
  self.logger.info(
1953
1968
  "Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
@@ -1959,7 +1974,6 @@ class FeaturesEnricher(TransformerMixin):
1959
1974
  validated_y,
1960
1975
  eval_set,
1961
1976
  exclude_features_sources,
1962
- trace_id,
1963
1977
  progress_bar,
1964
1978
  progress_callback,
1965
1979
  is_for_metrics=is_for_metrics,
@@ -2088,7 +2102,6 @@ class FeaturesEnricher(TransformerMixin):
2088
2102
  validated_X: pd.DataFrame,
2089
2103
  validated_y: pd.Series,
2090
2104
  eval_set: list[tuple] | None,
2091
- trace_id: str,
2092
2105
  remove_outliers_calc_metrics: bool | None,
2093
2106
  ) -> _EnrichedDataForMetrics:
2094
2107
  eval_set_sampled_dict = {}
@@ -2103,7 +2116,7 @@ class FeaturesEnricher(TransformerMixin):
2103
2116
  if remove_outliers_calc_metrics is None:
2104
2117
  remove_outliers_calc_metrics = True
2105
2118
  if self.model_task_type == ModelTaskType.REGRESSION and remove_outliers_calc_metrics:
2106
- target_outliers_df = self._search_task.get_target_outliers(trace_id)
2119
+ target_outliers_df = self._search_task.get_target_outliers(self._get_trace_id())
2107
2120
  if target_outliers_df is not None and len(target_outliers_df) > 0:
2108
2121
  outliers = pd.merge(
2109
2122
  self.df_with_original_index,
@@ -2120,7 +2133,7 @@ class FeaturesEnricher(TransformerMixin):
2120
2133
 
2121
2134
  # index in each dataset (X, eval set) may be reordered and non unique, but index in validated datasets
2122
2135
  # can differs from it
2123
- fit_features = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
2136
+ fit_features = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
2124
2137
 
2125
2138
  # Pre-process features if we need to drop outliers
2126
2139
  if rows_to_drop is not None:
@@ -2146,7 +2159,7 @@ class FeaturesEnricher(TransformerMixin):
2146
2159
  validated_Xy[TARGET] = validated_y
2147
2160
 
2148
2161
  selecting_columns = self._selecting_input_and_generated_columns(
2149
- validated_Xy, self.fit_generated_features, keep_input=True, trace_id=trace_id
2162
+ validated_Xy, self.fit_generated_features, keep_input=True
2150
2163
  )
2151
2164
  selecting_columns.extend(
2152
2165
  c
@@ -2211,7 +2224,6 @@ class FeaturesEnricher(TransformerMixin):
2211
2224
  validated_y: pd.Series,
2212
2225
  eval_set: list[tuple] | None,
2213
2226
  exclude_features_sources: list[str] | None,
2214
- trace_id: str,
2215
2227
  progress_bar: ProgressBar | None,
2216
2228
  progress_callback: Callable[[SearchProgress], Any] | None,
2217
2229
  is_for_metrics: bool = False,
@@ -2237,7 +2249,6 @@ class FeaturesEnricher(TransformerMixin):
2237
2249
 
2238
2250
  # Transform
2239
2251
  enriched_df, columns_renaming, generated_features, search_keys = self.__inner_transform(
2240
- trace_id,
2241
2252
  X=df.drop(columns=[TARGET]),
2242
2253
  y=df[TARGET],
2243
2254
  exclude_features_sources=exclude_features_sources,
@@ -2414,11 +2425,10 @@ class FeaturesEnricher(TransformerMixin):
2414
2425
 
2415
2426
  return self.features_info
2416
2427
 
2417
- def get_progress(self, trace_id: str | None = None, search_task: SearchTask | None = None) -> SearchProgress:
2428
+ def get_progress(self, search_task: SearchTask | None = None) -> SearchProgress:
2418
2429
  search_task = search_task or self._search_task
2419
2430
  if search_task is not None:
2420
- trace_id = trace_id or time.time_ns()
2421
- return search_task.get_progress(trace_id)
2431
+ return search_task.get_progress(self._get_trace_id())
2422
2432
 
2423
2433
  def display_transactional_transform_api(self, only_online_sources=False):
2424
2434
  if self.api_key is None:
@@ -2519,7 +2529,6 @@ if response.status_code == 200:
2519
2529
 
2520
2530
  def __inner_transform(
2521
2531
  self,
2522
- trace_id: str,
2523
2532
  X: pd.DataFrame,
2524
2533
  *,
2525
2534
  y: pd.Series | None = None,
@@ -2538,182 +2547,135 @@ if response.status_code == 200:
2538
2547
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
2539
2548
 
2540
2549
  start_time = time.time()
2541
- search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
2542
- with MDC(correlation_id=trace_id, search_id=search_id):
2543
- self.logger.info("Start transform")
2550
+ self.logger.info("Start transform")
2544
2551
 
2545
- validated_X, validated_y, validated_eval_set = self._validate_train_eval(
2546
- X, y, eval_set=None, is_transform=True, silent=True
2547
- )
2548
- df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
2552
+ search_keys = self.search_keys.copy()
2549
2553
 
2550
- validated_Xy = df.copy()
2554
+ self.__validate_search_keys(search_keys, self.search_id)
2551
2555
 
2552
- self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
2556
+ validated_X, validated_y, validated_eval_set = self._validate_train_eval(
2557
+ X, y, eval_set=None, is_transform=True
2558
+ )
2559
+ df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
2553
2560
 
2554
- # If there are no important features, return original dataframe
2555
- if len(self.feature_names_) == 0:
2556
- msg = self.bundle.get("no_important_features_for_transform")
2557
- self.__log_warning(msg, show_support_link=True)
2558
- return None, {}, [], self.search_keys
2561
+ validated_Xy = df.copy()
2559
2562
 
2560
- self.__validate_search_keys(self.search_keys, self.search_id)
2563
+ self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
2561
2564
 
2562
- if self._has_paid_features(exclude_features_sources):
2563
- msg = self.bundle.get("transform_with_paid_features")
2564
- self.logger.warning(msg)
2565
- self.__display_support_link(msg)
2566
- return None, {}, [], self.search_keys
2565
+ # If there are no important features, return original dataframe
2566
+ if len(self.feature_names_) == 0:
2567
+ msg = self.bundle.get("no_important_features_for_transform")
2568
+ self.__log_warning(msg, show_support_link=True)
2569
+ return None, {}, [], search_keys
2567
2570
 
2568
- online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
2569
- if len(online_api_features) > 0:
2570
- self.logger.warning(
2571
- f"There are important features for transform, that generated by online API: {online_api_features}"
2572
- )
2573
- msg = self.bundle.get("online_api_features_transform").format(online_api_features)
2574
- self.logger.warning(msg)
2575
- print(msg)
2576
- self.display_transactional_transform_api(only_online_sources=True)
2577
-
2578
- if not metrics_calculation:
2579
- transform_usage = self.rest_client.get_current_transform_usage(trace_id)
2580
- self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
2581
- if transform_usage.has_limit:
2582
- if len(X) > transform_usage.rest_rows:
2583
- rest_rows = max(transform_usage.rest_rows, 0)
2584
- bundle_msg = (
2585
- "transform_usage_warning_registered"
2586
- if self.__is_registered
2587
- else "transform_usage_warning_demo"
2588
- )
2589
- msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
2590
- self.logger.warning(msg)
2591
- print(msg)
2592
- show_request_quote_button(is_registered=self.__is_registered)
2593
- return None, {}, [], {}
2594
- else:
2595
- msg = self.bundle.get("transform_usage_info").format(
2596
- transform_usage.limit, transform_usage.transformed_rows
2597
- )
2598
- self.logger.info(msg)
2599
- print(msg)
2571
+ if self._has_paid_features(exclude_features_sources):
2572
+ msg = self.bundle.get("transform_with_paid_features")
2573
+ self.logger.warning(msg)
2574
+ self.__display_support_link(msg)
2575
+ return None, {}, [], search_keys
2600
2576
 
2601
- is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
2577
+ online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
2578
+ if len(online_api_features) > 0:
2579
+ self.logger.warning(
2580
+ f"There are important features for transform, that generated by online API: {online_api_features}"
2581
+ )
2582
+ msg = self.bundle.get("online_api_features_transform").format(online_api_features)
2583
+ self.logger.warning(msg)
2584
+ print(msg)
2585
+ self.display_transactional_transform_api(only_online_sources=True)
2586
+
2587
+ if not metrics_calculation:
2588
+ transform_usage = self.rest_client.get_current_transform_usage(self._get_trace_id())
2589
+ self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
2590
+ if transform_usage.has_limit:
2591
+ if len(X) > transform_usage.rest_rows:
2592
+ rest_rows = max(transform_usage.rest_rows, 0)
2593
+ bundle_msg = (
2594
+ "transform_usage_warning_registered" if self.__is_registered else "transform_usage_warning_demo"
2595
+ )
2596
+ msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
2597
+ self.logger.warning(msg)
2598
+ print(msg)
2599
+ show_request_quote_button(is_registered=self.__is_registered)
2600
+ return None, {}, [], {}
2601
+ else:
2602
+ msg = self.bundle.get("transform_usage_info").format(
2603
+ transform_usage.limit, transform_usage.transformed_rows
2604
+ )
2605
+ self.logger.info(msg)
2606
+ print(msg)
2602
2607
 
2603
- columns_to_drop = [
2604
- c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
2605
- ]
2606
- if len(columns_to_drop) > 0:
2607
- msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
2608
- self.logger.warning(msg)
2609
- print(msg)
2610
- df = df.drop(columns=columns_to_drop)
2608
+ is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
2611
2609
 
2612
- search_keys = self.search_keys.copy()
2613
- if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2614
- search_keys.update(
2615
- {col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in self.search_keys}
2616
- )
2610
+ columns_to_drop = [
2611
+ c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
2612
+ ]
2613
+ if len(columns_to_drop) > 0:
2614
+ msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
2615
+ self.logger.warning(msg)
2616
+ print(msg)
2617
+ df = df.drop(columns=columns_to_drop)
2617
2618
 
2618
- search_keys = self.__prepare_search_keys(
2619
- df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
2620
- )
2619
+ if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2620
+ search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in search_keys})
2621
2621
 
2622
- df = self.__handle_index_search_keys(df, search_keys)
2622
+ search_keys = self.__prepare_search_keys(
2623
+ df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
2624
+ )
2623
2625
 
2624
- if DEFAULT_INDEX in df.columns:
2625
- msg = self.bundle.get("unsupported_index_column")
2626
- self.logger.info(msg)
2627
- print(msg)
2628
- df.drop(columns=DEFAULT_INDEX, inplace=True)
2629
- validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
2626
+ df = self.__handle_index_search_keys(df, search_keys)
2630
2627
 
2631
- df = self.__add_country_code(df, search_keys)
2628
+ if DEFAULT_INDEX in df.columns:
2629
+ msg = self.bundle.get("unsupported_index_column")
2630
+ self.logger.info(msg)
2631
+ print(msg)
2632
+ df.drop(columns=DEFAULT_INDEX, inplace=True)
2633
+ validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
2632
2634
 
2633
- generated_features = []
2634
- date_column = self._get_date_column(search_keys)
2635
- if date_column is not None:
2636
- converter = DateTimeConverter(
2637
- date_column,
2638
- self.date_format,
2639
- self.logger,
2640
- bundle=self.bundle,
2641
- generate_cyclical_features=self.generate_search_key_features,
2642
- )
2643
- df = converter.convert(df, keep_time=True)
2644
- self.logger.info(f"Date column after convertion: {df[date_column]}")
2645
- generated_features.extend(converter.generated_features)
2646
- else:
2647
- self.logger.info("Input dataset hasn't date column")
2648
- if self.__should_add_date_column():
2649
- df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
2650
-
2651
- email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2652
- if email_columns and self.generate_search_key_features:
2653
- generator = EmailDomainGenerator(email_columns)
2654
- df = generator.generate(df)
2655
- generated_features.extend(generator.generated_features)
2656
-
2657
- normalizer = Normalizer(self.bundle, self.logger)
2658
- df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2659
- columns_renaming = normalizer.columns_renaming
2660
-
2661
- # If there are no external features, we don't call backend on transform
2662
- external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
2663
- if len(external_features) == 0:
2664
- self.logger.warning(
2665
- "No external features found, returning original dataframe"
2666
- f" with generated important features: {self.feature_names_}"
2667
- )
2668
- df = df.rename(columns=columns_renaming)
2669
- generated_features = [columns_renaming.get(c, c) for c in generated_features]
2670
- search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
2671
- selecting_columns = self._selecting_input_and_generated_columns(
2672
- validated_Xy, generated_features, keep_input, trace_id, is_transform=True
2673
- )
2674
- self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
2675
- if add_fit_system_record_id:
2676
- df = self._add_fit_system_record_id(
2677
- df,
2678
- search_keys,
2679
- SYSTEM_RECORD_ID,
2680
- TARGET,
2681
- columns_renaming,
2682
- self.id_columns,
2683
- self.cv,
2684
- self.model_task_type,
2685
- self.logger,
2686
- self.bundle,
2687
- )
2688
- selecting_columns.append(SYSTEM_RECORD_ID)
2689
- return df[selecting_columns], columns_renaming, generated_features, search_keys
2690
-
2691
- # Don't pass all features in backend on transform
2692
- runtime_parameters = self._get_copy_of_runtime_parameters()
2693
- features_for_transform = self._search_task.get_features_for_transform()
2694
- if features_for_transform:
2695
- missing_features_for_transform = [
2696
- columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
2697
- ]
2698
- if TARGET in missing_features_for_transform:
2699
- raise ValidationError(self.bundle.get("missing_target_for_transform"))
2635
+ df = self.__add_country_code(df, search_keys)
2700
2636
 
2701
- if len(missing_features_for_transform) > 0:
2702
- raise ValidationError(
2703
- self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2704
- )
2705
- features_for_embeddings = self._search_task.get_features_for_embeddings()
2706
- if features_for_embeddings:
2707
- runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
2708
- features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
2637
+ generated_features = []
2638
+ date_column = self._get_date_column(search_keys)
2639
+ if date_column is not None:
2640
+ converter = DateTimeConverter(
2641
+ date_column,
2642
+ self.date_format,
2643
+ self.logger,
2644
+ bundle=self.bundle,
2645
+ generate_cyclical_features=self.generate_search_key_features,
2646
+ )
2647
+ df = converter.convert(df, keep_time=True)
2648
+ self.logger.info(f"Date column after convertion: {df[date_column]}")
2649
+ generated_features.extend(converter.generated_features)
2650
+ else:
2651
+ self.logger.info("Input dataset hasn't date column")
2652
+ if self.__should_add_date_column():
2653
+ df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
2709
2654
 
2710
- columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2655
+ email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
2656
+ if email_columns and self.generate_search_key_features:
2657
+ generator = EmailDomainGenerator(email_columns)
2658
+ df = generator.generate(df)
2659
+ generated_features.extend(generator.generated_features)
2711
2660
 
2712
- df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(
2713
- df[columns_for_system_record_id], index=False
2714
- ).astype("float64")
2661
+ normalizer = Normalizer(self.bundle, self.logger)
2662
+ df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
2663
+ columns_renaming = normalizer.columns_renaming
2715
2664
 
2716
- features_not_to_pass = []
2665
+ # If there are no external features, we don't call backend on transform
2666
+ external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
2667
+ if len(external_features) == 0:
2668
+ self.logger.warning(
2669
+ "No external features found, returning original dataframe"
2670
+ f" with generated important features: {self.feature_names_}"
2671
+ )
2672
+ df = df.rename(columns=columns_renaming)
2673
+ generated_features = [columns_renaming.get(c, c) for c in generated_features]
2674
+ search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
2675
+ selecting_columns = self._selecting_input_and_generated_columns(
2676
+ validated_Xy, generated_features, keep_input, is_transform=True
2677
+ )
2678
+ self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
2717
2679
  if add_fit_system_record_id:
2718
2680
  df = self._add_fit_system_record_id(
2719
2681
  df,
@@ -2727,252 +2689,294 @@ if response.status_code == 200:
2727
2689
  self.logger,
2728
2690
  self.bundle,
2729
2691
  )
2730
- df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
2731
- features_not_to_pass.append(SORT_ID)
2692
+ selecting_columns.append(SYSTEM_RECORD_ID)
2693
+ return df[selecting_columns], columns_renaming, generated_features, search_keys
2732
2694
 
2733
- system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
2734
- if add_fit_system_record_id:
2735
- system_columns_with_original_index.append(SORT_ID)
2695
+ # Don't pass all features in backend on transform
2696
+ runtime_parameters = self._get_copy_of_runtime_parameters()
2697
+ features_for_transform = self._search_task.get_features_for_transform()
2698
+ if features_for_transform:
2699
+ missing_features_for_transform = [
2700
+ columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
2701
+ ]
2702
+ if TARGET in missing_features_for_transform:
2703
+ raise ValidationError(self.bundle.get("missing_target_for_transform"))
2736
2704
 
2737
- df_before_explode = df[system_columns_with_original_index].copy()
2705
+ if len(missing_features_for_transform) > 0:
2706
+ raise ValidationError(
2707
+ self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
2708
+ )
2709
+ features_for_embeddings = self._search_task.get_features_for_embeddings()
2710
+ if features_for_embeddings:
2711
+ runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
2712
+ features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
2738
2713
 
2739
- # Explode multiple search keys
2740
- df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2714
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2741
2715
 
2742
- # Convert search keys and generate features on them
2716
+ df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2717
+ "float64"
2718
+ )
2743
2719
 
2744
- email_column = self._get_email_column(search_keys)
2745
- hem_column = self._get_hem_column(search_keys)
2746
- if email_column:
2747
- converter = EmailSearchKeyConverter(
2748
- email_column,
2749
- hem_column,
2750
- search_keys,
2751
- columns_renaming,
2752
- list(unnest_search_keys.keys()),
2753
- self.logger,
2754
- )
2755
- df = converter.convert(df)
2720
+ features_not_to_pass = []
2721
+ if add_fit_system_record_id:
2722
+ df = self._add_fit_system_record_id(
2723
+ df,
2724
+ search_keys,
2725
+ SYSTEM_RECORD_ID,
2726
+ TARGET,
2727
+ columns_renaming,
2728
+ self.id_columns,
2729
+ self.cv,
2730
+ self.model_task_type,
2731
+ self.logger,
2732
+ self.bundle,
2733
+ )
2734
+ df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
2735
+ features_not_to_pass.append(SORT_ID)
2756
2736
 
2757
- ip_column = self._get_ip_column(search_keys)
2758
- if ip_column:
2759
- converter = IpSearchKeyConverter(
2760
- ip_column,
2761
- search_keys,
2762
- columns_renaming,
2763
- list(unnest_search_keys.keys()),
2764
- self.bundle,
2765
- self.logger,
2766
- )
2767
- df = converter.convert(df)
2768
-
2769
- date_features = []
2770
- for col in features_for_transform:
2771
- if DateTimeConverter(col).is_datetime(df):
2772
- df[col] = DateTimeConverter(col).to_date_string(df)
2773
- date_features.append(col)
2774
-
2775
- meaning_types = {}
2776
- meaning_types.update(
2777
- {
2778
- col: FileColumnMeaningType.FEATURE
2779
- for col in features_for_transform
2780
- if col not in date_features and col not in generated_features
2781
- }
2737
+ system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
2738
+ if add_fit_system_record_id:
2739
+ system_columns_with_original_index.append(SORT_ID)
2740
+
2741
+ df_before_explode = df[system_columns_with_original_index].copy()
2742
+
2743
+ # Explode multiple search keys
2744
+ df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
2745
+
2746
+ # Convert search keys and generate features on them
2747
+
2748
+ email_column = self._get_email_column(search_keys)
2749
+ hem_column = self._get_hem_column(search_keys)
2750
+ if email_column:
2751
+ converter = EmailSearchKeyConverter(
2752
+ email_column,
2753
+ hem_column,
2754
+ search_keys,
2755
+ columns_renaming,
2756
+ list(unnest_search_keys.keys()),
2757
+ self.logger,
2782
2758
  )
2783
- meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
2784
- meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
2785
- meaning_types.update({col: key.value for col, key in search_keys.items()})
2759
+ df = converter.convert(df)
2786
2760
 
2787
- features_not_to_pass.extend(
2788
- [
2789
- c
2790
- for c in df.columns
2791
- if c not in search_keys.keys()
2792
- and c not in features_for_transform
2793
- and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2794
- ]
2761
+ ip_column = self._get_ip_column(search_keys)
2762
+ if ip_column:
2763
+ converter = IpSearchKeyConverter(
2764
+ ip_column,
2765
+ search_keys,
2766
+ columns_renaming,
2767
+ list(unnest_search_keys.keys()),
2768
+ self.bundle,
2769
+ self.logger,
2795
2770
  )
2771
+ df = converter.convert(df)
2796
2772
 
2797
- if DateTimeConverter.DATETIME_COL in df.columns:
2798
- df = df.drop(columns=DateTimeConverter.DATETIME_COL)
2773
+ date_features = []
2774
+ for col in features_for_transform:
2775
+ if DateTimeConverter(col).is_datetime(df):
2776
+ df[col] = DateTimeConverter(col).to_date_string(df)
2777
+ date_features.append(col)
2799
2778
 
2800
- # search keys might be changed after explode
2801
- columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2802
- df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2803
- "float64"
2804
- )
2805
- meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2806
- meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2807
- if SEARCH_KEY_UNNEST in df.columns:
2808
- meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2779
+ meaning_types = {}
2780
+ meaning_types.update(
2781
+ {
2782
+ col: FileColumnMeaningType.FEATURE
2783
+ for col in features_for_transform
2784
+ if col not in date_features and col not in generated_features
2785
+ }
2786
+ )
2787
+ meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
2788
+ meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
2789
+ meaning_types.update({col: key.value for col, key in search_keys.items()})
2809
2790
 
2810
- df = df.reset_index(drop=True)
2791
+ features_not_to_pass.extend(
2792
+ [
2793
+ c
2794
+ for c in df.columns
2795
+ if c not in search_keys.keys()
2796
+ and c not in features_for_transform
2797
+ and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
2798
+ ]
2799
+ )
2811
2800
 
2812
- combined_search_keys = combine_search_keys(search_keys.keys())
2801
+ if DateTimeConverter.DATETIME_COL in df.columns:
2802
+ df = df.drop(columns=DateTimeConverter.DATETIME_COL)
2813
2803
 
2814
- df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
2804
+ # search keys might be changed after explode
2805
+ columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
2806
+ df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
2807
+ "float64"
2808
+ )
2809
+ meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
2810
+ meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
2811
+ if SEARCH_KEY_UNNEST in df.columns:
2812
+ meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
2815
2813
 
2816
- df_without_features, full_duplicates_warning = clean_full_duplicates(
2817
- df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
2818
- )
2819
- if not silent_mode and full_duplicates_warning:
2820
- self.__log_warning(full_duplicates_warning)
2814
+ df = df.reset_index(drop=True)
2821
2815
 
2822
- del df
2823
- gc.collect()
2824
-
2825
- dataset = Dataset(
2826
- "sample_" + str(uuid.uuid4()),
2827
- df=df_without_features,
2828
- meaning_types=meaning_types,
2829
- search_keys=combined_search_keys,
2830
- unnest_search_keys=unnest_search_keys,
2831
- id_columns=self.__get_renamed_id_columns(columns_renaming),
2832
- date_column=self._get_date_column(search_keys),
2833
- date_format=self.date_format,
2834
- sample_config=self.sample_config,
2835
- rest_client=self.rest_client,
2836
- logger=self.logger,
2837
- bundle=self.bundle,
2838
- warning_callback=self.__log_warning,
2839
- )
2840
- dataset.columns_renaming = columns_renaming
2841
-
2842
- validation_task = self._search_task.validation(
2843
- trace_id,
2844
- dataset,
2845
- start_time=start_time,
2846
- extract_features=True,
2847
- runtime_parameters=runtime_parameters,
2848
- exclude_features_sources=exclude_features_sources,
2849
- metrics_calculation=metrics_calculation,
2850
- silent_mode=silent_mode,
2851
- progress_bar=progress_bar,
2852
- progress_callback=progress_callback,
2853
- )
2816
+ combined_search_keys = combine_search_keys(search_keys.keys())
2854
2817
 
2855
- del df_without_features, dataset
2856
- gc.collect()
2818
+ df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
2857
2819
 
2858
- if not silent_mode:
2859
- print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2860
- if not self.__is_registered:
2861
- print(self.bundle.get("polling_unregister_information"))
2820
+ df_without_features, full_duplicates_warning = clean_full_duplicates(
2821
+ df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
2822
+ )
2823
+ if not silent_mode and full_duplicates_warning:
2824
+ self.__log_warning(full_duplicates_warning)
2862
2825
 
2863
- progress = self.get_progress(trace_id, validation_task)
2864
- progress.recalculate_eta(time.time() - start_time)
2865
- if progress_bar is not None:
2866
- progress_bar.progress = progress.to_progress_bar()
2867
- if progress_callback is not None:
2868
- progress_callback(progress)
2869
- prev_progress: SearchProgress | None = None
2870
- polling_period_seconds = 1
2871
- try:
2872
- while progress.stage != ProgressStage.DOWNLOADING.value:
2873
- if prev_progress is None or prev_progress.percent != progress.percent:
2874
- progress.recalculate_eta(time.time() - start_time)
2875
- else:
2876
- progress.update_eta(prev_progress.eta - polling_period_seconds)
2877
- prev_progress = progress
2878
- if progress_bar is not None:
2879
- progress_bar.progress = progress.to_progress_bar()
2880
- if progress_callback is not None:
2881
- progress_callback(progress)
2882
- if progress.stage == ProgressStage.FAILED.value:
2883
- raise Exception(progress.error_message)
2884
- time.sleep(polling_period_seconds)
2885
- progress = self.get_progress(trace_id, validation_task)
2886
- except KeyboardInterrupt as e:
2887
- print(self.bundle.get("search_stopping"))
2888
- self.rest_client.stop_search_task_v2(trace_id, validation_task.search_task_id)
2889
- self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2890
- print(self.bundle.get("search_stopped"))
2891
- raise e
2826
+ del df
2827
+ gc.collect()
2892
2828
 
2893
- validation_task.poll_result(trace_id, quiet=True)
2829
+ dataset = Dataset(
2830
+ "sample_" + str(uuid.uuid4()),
2831
+ df=df_without_features,
2832
+ meaning_types=meaning_types,
2833
+ search_keys=combined_search_keys,
2834
+ unnest_search_keys=unnest_search_keys,
2835
+ id_columns=self.__get_renamed_id_columns(columns_renaming),
2836
+ date_column=self._get_date_column(search_keys),
2837
+ date_format=self.date_format,
2838
+ sample_config=self.sample_config,
2839
+ rest_client=self.rest_client,
2840
+ logger=self.logger,
2841
+ bundle=self.bundle,
2842
+ warning_callback=self.__log_warning,
2843
+ )
2844
+ dataset.columns_renaming = columns_renaming
2894
2845
 
2895
- seconds_left = time.time() - start_time
2896
- progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2897
- if progress_bar is not None:
2898
- progress_bar.progress = progress.to_progress_bar()
2899
- if progress_callback is not None:
2900
- progress_callback(progress)
2846
+ validation_task = self._search_task.validation(
2847
+ self._get_trace_id(),
2848
+ dataset,
2849
+ start_time=start_time,
2850
+ extract_features=True,
2851
+ runtime_parameters=runtime_parameters,
2852
+ exclude_features_sources=exclude_features_sources,
2853
+ metrics_calculation=metrics_calculation,
2854
+ silent_mode=silent_mode,
2855
+ progress_bar=progress_bar,
2856
+ progress_callback=progress_callback,
2857
+ )
2901
2858
 
2902
- if not silent_mode:
2903
- print(self.bundle.get("transform_start"))
2859
+ del df_without_features, dataset
2860
+ gc.collect()
2904
2861
 
2905
- # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2906
- df_before_explode = df_before_explode.rename(columns=columns_renaming)
2907
- generated_features = [columns_renaming.get(c, c) for c in generated_features]
2908
- combined_df = pd.concat(
2909
- [
2910
- validated_Xy.reset_index(drop=True),
2911
- df_before_explode.reset_index(drop=True),
2912
- ],
2913
- axis=1,
2914
- ).set_index(validated_Xy.index)
2915
-
2916
- result_features = validation_task.get_all_validation_raw_features(trace_id, metrics_calculation)
2917
-
2918
- result = self.__enrich(
2919
- combined_df,
2920
- result_features,
2921
- how="left",
2922
- )
2862
+ if not silent_mode:
2863
+ print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
2864
+ if not self.__is_registered:
2865
+ print(self.bundle.get("polling_unregister_information"))
2923
2866
 
2924
- selecting_columns = self._selecting_input_and_generated_columns(
2925
- validated_Xy, generated_features, keep_input, trace_id, is_transform=True
2926
- )
2927
- selecting_columns.extend(
2928
- c
2929
- for c in result.columns
2930
- if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
2931
- )
2932
- if add_fit_system_record_id:
2933
- selecting_columns.append(SORT_ID)
2867
+ progress = self.get_progress(validation_task)
2868
+ progress.recalculate_eta(time.time() - start_time)
2869
+ if progress_bar is not None:
2870
+ progress_bar.progress = progress.to_progress_bar()
2871
+ if progress_callback is not None:
2872
+ progress_callback(progress)
2873
+ prev_progress: SearchProgress | None = None
2874
+ polling_period_seconds = 1
2875
+ try:
2876
+ while progress.stage != ProgressStage.DOWNLOADING.value:
2877
+ if prev_progress is None or prev_progress.percent != progress.percent:
2878
+ progress.recalculate_eta(time.time() - start_time)
2879
+ else:
2880
+ progress.update_eta(prev_progress.eta - polling_period_seconds)
2881
+ prev_progress = progress
2882
+ if progress_bar is not None:
2883
+ progress_bar.progress = progress.to_progress_bar()
2884
+ if progress_callback is not None:
2885
+ progress_callback(progress)
2886
+ if progress.stage == ProgressStage.FAILED.value:
2887
+ raise Exception(progress.error_message)
2888
+ time.sleep(polling_period_seconds)
2889
+ progress = self.get_progress(validation_task)
2890
+ except KeyboardInterrupt as e:
2891
+ print(self.bundle.get("search_stopping"))
2892
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
2893
+ self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
2894
+ print(self.bundle.get("search_stopped"))
2895
+ raise e
2896
+
2897
+ validation_task.poll_result(self._get_trace_id(), quiet=True)
2898
+
2899
+ seconds_left = time.time() - start_time
2900
+ progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
2901
+ if progress_bar is not None:
2902
+ progress_bar.progress = progress.to_progress_bar()
2903
+ if progress_callback is not None:
2904
+ progress_callback(progress)
2934
2905
 
2935
- selecting_columns = list(set(selecting_columns))
2936
- # sorting: first columns from X, then generated features, then enriched features
2937
- sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
2938
- for c in generated_features:
2939
- if c in selecting_columns and c not in sorted_selecting_columns:
2940
- sorted_selecting_columns.append(c)
2941
- for c in result.columns:
2942
- if c in selecting_columns and c not in sorted_selecting_columns:
2943
- sorted_selecting_columns.append(c)
2906
+ if not silent_mode:
2907
+ print(self.bundle.get("transform_start"))
2944
2908
 
2945
- self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
2909
+ # Prepare input DataFrame for __enrich by concatenating generated ids and client features
2910
+ df_before_explode = df_before_explode.rename(columns=columns_renaming)
2911
+ generated_features = [columns_renaming.get(c, c) for c in generated_features]
2912
+ combined_df = pd.concat(
2913
+ [
2914
+ validated_Xy.reset_index(drop=True),
2915
+ df_before_explode.reset_index(drop=True),
2916
+ ],
2917
+ axis=1,
2918
+ ).set_index(validated_Xy.index)
2919
+
2920
+ result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
2921
+
2922
+ result = self.__enrich(
2923
+ combined_df,
2924
+ result_features,
2925
+ how="left",
2926
+ )
2927
+
2928
+ selecting_columns = self._selecting_input_and_generated_columns(
2929
+ validated_Xy, generated_features, keep_input, is_transform=True
2930
+ )
2931
+ selecting_columns.extend(
2932
+ c
2933
+ for c in result.columns
2934
+ if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
2935
+ )
2936
+ if add_fit_system_record_id:
2937
+ selecting_columns.append(SORT_ID)
2946
2938
 
2947
- result = result[sorted_selecting_columns]
2939
+ selecting_columns = list(set(selecting_columns))
2940
+ # sorting: first columns from X, then generated features, then enriched features
2941
+ sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
2942
+ for c in generated_features:
2943
+ if c in selecting_columns and c not in sorted_selecting_columns:
2944
+ sorted_selecting_columns.append(c)
2945
+ for c in result.columns:
2946
+ if c in selecting_columns and c not in sorted_selecting_columns:
2947
+ sorted_selecting_columns.append(c)
2948
2948
 
2949
- if self.country_added:
2950
- result = result.drop(columns=COUNTRY, errors="ignore")
2949
+ self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
2951
2950
 
2952
- if add_fit_system_record_id:
2953
- result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2951
+ result = result[sorted_selecting_columns]
2952
+
2953
+ if self.country_added:
2954
+ result = result.drop(columns=COUNTRY, errors="ignore")
2954
2955
 
2955
- for c in result.columns:
2956
- if result[c].dtype == "category":
2957
- result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
2956
+ if add_fit_system_record_id:
2957
+ result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
2958
2958
 
2959
- return result, columns_renaming, generated_features, search_keys
2959
+ for c in result.columns:
2960
+ if result[c].dtype == "category":
2961
+ result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
2962
+
2963
+ return result, columns_renaming, generated_features, search_keys
2960
2964
 
2961
2965
  def _selecting_input_and_generated_columns(
2962
2966
  self,
2963
2967
  validated_Xy: pd.DataFrame,
2964
2968
  generated_features: list[str],
2965
2969
  keep_input: bool,
2966
- trace_id: str,
2967
2970
  is_transform: bool = False,
2968
2971
  ):
2969
- file_meta = self._search_task.get_file_metadata(trace_id)
2972
+ file_meta = self._search_task.get_file_metadata(self._get_trace_id())
2970
2973
  fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
2971
2974
  fit_input_columns = [c.originalName for c in file_meta.columns]
2972
2975
  original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
2973
2976
  new_columns_on_transform = [
2974
2977
  c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
2975
2978
  ]
2979
+ fit_original_search_keys = self._get_fit_search_keys_with_original_names()
2976
2980
 
2977
2981
  selected_generated_features = [c for c in generated_features if c in self.feature_names_]
2978
2982
  if keep_input is True:
@@ -2982,7 +2986,7 @@ if response.status_code == 200:
2982
2986
  if not self.fit_select_features
2983
2987
  or c in self.feature_names_
2984
2988
  or (c in new_columns_on_transform and is_transform)
2985
- or c in self.search_keys
2989
+ or c in fit_original_search_keys
2986
2990
  or c in (self.id_columns or [])
2987
2991
  or c in [EVAL_SET_INDEX, TARGET] # transform for metrics calculation
2988
2992
  or c == self.baseline_score_column
@@ -3067,7 +3071,6 @@ if response.status_code == 200:
3067
3071
 
3068
3072
  def __inner_fit(
3069
3073
  self,
3070
- trace_id: str,
3071
3074
  X: pd.DataFrame | pd.Series | np.ndarray,
3072
3075
  y: pd.DataFrame | pd.Series | np.ndarray | list | None,
3073
3076
  eval_set: list[tuple] | None,
@@ -3149,6 +3152,8 @@ if response.status_code == 200:
3149
3152
  df = self.__handle_index_search_keys(df, self.fit_search_keys)
3150
3153
  self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
3151
3154
 
3155
+ df = self._validate_OOT(df, self.fit_search_keys)
3156
+
3152
3157
  maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
3153
3158
  has_date = maybe_date_column is not None and maybe_date_column in validated_X.columns
3154
3159
 
@@ -3400,6 +3405,7 @@ if response.status_code == 200:
3400
3405
  id_columns=self.__get_renamed_id_columns(),
3401
3406
  is_imbalanced=self.imbalanced,
3402
3407
  dropped_columns=[self.fit_columns_renaming.get(f, f) for f in self.fit_dropped_features],
3408
+ autodetected_search_keys=self.autodetected_search_keys,
3403
3409
  date_column=self._get_date_column(self.fit_search_keys),
3404
3410
  date_format=self.date_format,
3405
3411
  random_state=self.random_state,
@@ -3423,7 +3429,7 @@ if response.status_code == 200:
3423
3429
  ]
3424
3430
 
3425
3431
  self._search_task = dataset.search(
3426
- trace_id=trace_id,
3432
+ trace_id=self._get_trace_id(),
3427
3433
  progress_bar=progress_bar,
3428
3434
  start_time=start_time,
3429
3435
  progress_callback=progress_callback,
@@ -3443,7 +3449,7 @@ if response.status_code == 200:
3443
3449
  if not self.__is_registered:
3444
3450
  print(self.bundle.get("polling_unregister_information"))
3445
3451
 
3446
- progress = self.get_progress(trace_id)
3452
+ progress = self.get_progress()
3447
3453
  prev_progress = None
3448
3454
  progress.recalculate_eta(time.time() - start_time)
3449
3455
  if progress_bar is not None:
@@ -3469,16 +3475,16 @@ if response.status_code == 200:
3469
3475
  )
3470
3476
  raise RuntimeError(self.bundle.get("search_task_failed_status"))
3471
3477
  time.sleep(poll_period_seconds)
3472
- progress = self.get_progress(trace_id)
3478
+ progress = self.get_progress()
3473
3479
  except KeyboardInterrupt as e:
3474
3480
  print(self.bundle.get("search_stopping"))
3475
- self.rest_client.stop_search_task_v2(trace_id, self._search_task.search_task_id)
3481
+ self.rest_client.stop_search_task_v2(self._get_trace_id(), self._search_task.search_task_id)
3476
3482
  self.logger.warning(f"Search {self._search_task.search_task_id} stopped by user")
3477
3483
  self._search_task = None
3478
3484
  print(self.bundle.get("search_stopped"))
3479
3485
  raise e
3480
3486
 
3481
- self._search_task.poll_result(trace_id, quiet=True)
3487
+ self._search_task.poll_result(self._get_trace_id(), quiet=True)
3482
3488
 
3483
3489
  seconds_left = time.time() - start_time
3484
3490
  progress = SearchProgress(97.0, ProgressStage.GENERATING_REPORT, seconds_left)
@@ -3507,10 +3513,9 @@ if response.status_code == 200:
3507
3513
  msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
3508
3514
  self.__log_warning(msg)
3509
3515
 
3510
- self.__prepare_feature_importances(trace_id, df)
3516
+ self.__prepare_feature_importances(df)
3511
3517
 
3512
3518
  self._select_features_by_psi(
3513
- trace_id=trace_id,
3514
3519
  X=X,
3515
3520
  y=y,
3516
3521
  eval_set=eval_set,
@@ -3523,7 +3528,7 @@ if response.status_code == 200:
3523
3528
  progress_callback=progress_callback,
3524
3529
  )
3525
3530
 
3526
- self.__prepare_feature_importances(trace_id, df)
3531
+ self.__prepare_feature_importances(df)
3527
3532
 
3528
3533
  self.__show_selected_features()
3529
3534
 
@@ -3558,7 +3563,6 @@ if response.status_code == 200:
3558
3563
  scoring,
3559
3564
  estimator,
3560
3565
  remove_outliers_calc_metrics,
3561
- trace_id,
3562
3566
  progress_bar,
3563
3567
  progress_callback,
3564
3568
  )
@@ -3653,11 +3657,10 @@ if response.status_code == 200:
3653
3657
  y: pd.Series | None = None,
3654
3658
  eval_set: list[tuple[pd.DataFrame, pd.Series]] | None = None,
3655
3659
  is_transform: bool = False,
3656
- silent: bool = False,
3657
3660
  ) -> tuple[pd.DataFrame, pd.Series, list[tuple[pd.DataFrame, pd.Series]]] | None:
3658
3661
  validated_X = self._validate_X(X, is_transform)
3659
3662
  validated_y = self._validate_y(validated_X, y, enforce_y=not is_transform)
3660
- validated_eval_set = self._validate_eval_set(validated_X, eval_set, silent=silent)
3663
+ validated_eval_set = self._validate_eval_set(validated_X, eval_set)
3661
3664
  return validated_X, validated_y, validated_eval_set
3662
3665
 
3663
3666
  def _encode_id_columns(
@@ -3783,31 +3786,41 @@ if response.status_code == 200:
3783
3786
  return validated_y
3784
3787
 
3785
3788
  def _validate_eval_set(
3786
- self, X: pd.DataFrame, eval_set: list[tuple[pd.DataFrame, pd.Series]] | None, silent: bool = False
3787
- ):
3789
+ self,
3790
+ X: pd.DataFrame,
3791
+ eval_set: list[tuple[pd.DataFrame, pd.Series]] | None,
3792
+ ) -> list[tuple[pd.DataFrame, pd.Series]] | None:
3788
3793
  if eval_set is None:
3789
3794
  return None
3790
3795
  validated_eval_set = []
3791
- date_col = self._get_date_column(self.search_keys)
3792
- has_date = date_col is not None and date_col in X.columns
3793
- for idx, eval_pair in enumerate(eval_set):
3796
+ for _, eval_pair in enumerate(eval_set):
3794
3797
  validated_pair = self._validate_eval_set_pair(X, eval_pair)
3795
- if validated_pair[1].isna().all():
3796
- if not has_date:
3797
- msg = self.bundle.get("oot_without_date_not_supported").format(idx + 1)
3798
- elif self.columns_for_online_api:
3799
- msg = self.bundle.get("oot_with_online_sources_not_supported").format(idx + 1)
3800
- else:
3801
- msg = None
3802
- if msg:
3803
- if not silent:
3804
- print(msg)
3805
- self.logger.warning(msg)
3806
- continue
3807
3798
  validated_eval_set.append(validated_pair)
3808
3799
 
3809
3800
  return validated_eval_set
3810
3801
 
3802
+ def _validate_OOT(self, df: pd.DataFrame, search_keys: dict[str, SearchKey]) -> pd.DataFrame:
3803
+ if EVAL_SET_INDEX not in df.columns:
3804
+ return df
3805
+
3806
+ for eval_set_index in df[EVAL_SET_INDEX].unique():
3807
+ if eval_set_index == 0:
3808
+ continue
3809
+ eval_df = df[df[EVAL_SET_INDEX] == eval_set_index]
3810
+ date_col = self._get_date_column(search_keys)
3811
+ has_date = date_col is not None and date_col in eval_df.columns
3812
+ if eval_df[TARGET].isna().all():
3813
+ msg = None
3814
+ if not has_date:
3815
+ msg = self.bundle.get("oot_without_date_not_supported").format(eval_set_index)
3816
+ elif self.columns_for_online_api:
3817
+ msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
3818
+ if msg:
3819
+ print(msg)
3820
+ self.logger.warning(msg)
3821
+ df = df[df[EVAL_SET_INDEX] != eval_set_index]
3822
+ return df
3823
+
3811
3824
  def _validate_eval_set_pair(self, X: pd.DataFrame, eval_pair: tuple) -> tuple[pd.DataFrame, pd.Series]:
3812
3825
  if len(eval_pair) != 2:
3813
3826
  raise ValidationError(self.bundle.get("eval_set_invalid_tuple_size").format(len(eval_pair)))
@@ -4423,47 +4436,6 @@ if response.status_code == 200:
4423
4436
 
4424
4437
  return result_features
4425
4438
 
4426
- def __get_features_importance_from_server(self, trace_id: str, df: pd.DataFrame):
4427
- if self._search_task is None:
4428
- raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
4429
- features_meta = self._search_task.get_all_features_metadata_v2()
4430
- if features_meta is None:
4431
- raise Exception(self.bundle.get("missing_features_meta"))
4432
- features_meta = deepcopy(features_meta)
4433
-
4434
- original_names_dict = {c.name: c.originalName for c in self._search_task.get_file_metadata(trace_id).columns}
4435
- df = df.rename(columns=original_names_dict)
4436
-
4437
- features_meta.sort(key=lambda m: (-m.shap_value, m.name))
4438
-
4439
- importances = {}
4440
-
4441
- for feature_meta in features_meta:
4442
- if feature_meta.name in original_names_dict.keys():
4443
- feature_meta.name = original_names_dict[feature_meta.name]
4444
-
4445
- is_client_feature = feature_meta.name in df.columns
4446
-
4447
- if feature_meta.shap_value == 0.0:
4448
- continue
4449
-
4450
- # Use only important features
4451
- if (
4452
- feature_meta.name == COUNTRY
4453
- # In select_features mode we select also from etalon features and need to show them
4454
- or (not self.fit_select_features and is_client_feature)
4455
- ):
4456
- continue
4457
-
4458
- # Temporary workaround for duplicate features metadata
4459
- if feature_meta.name in importances:
4460
- self.logger.warning(f"WARNING: Duplicate feature metadata: {feature_meta}")
4461
- continue
4462
-
4463
- importances[feature_meta.name] = feature_meta.shap_value
4464
-
4465
- return importances
4466
-
4467
4439
  def __get_categorical_features(self) -> list[str]:
4468
4440
  features_meta = self._search_task.get_all_features_metadata_v2()
4469
4441
  if features_meta is None:
@@ -4473,7 +4445,6 @@ if response.status_code == 200:
4473
4445
 
4474
4446
  def __prepare_feature_importances(
4475
4447
  self,
4476
- trace_id: str,
4477
4448
  clients_features_df: pd.DataFrame,
4478
4449
  updated_shaps: dict[str, float] | None = None,
4479
4450
  update_selected_features: bool = True,
@@ -4481,16 +4452,16 @@ if response.status_code == 200:
4481
4452
  ):
4482
4453
  if self._search_task is None:
4483
4454
  raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
4484
- selected_features = self._search_task.get_selected_features(trace_id)
4455
+ selected_features = self._search_task.get_selected_features(self._get_trace_id())
4485
4456
  features_meta = self._search_task.get_all_features_metadata_v2()
4486
4457
  if features_meta is None:
4487
4458
  raise Exception(self.bundle.get("missing_features_meta"))
4488
4459
  features_meta = deepcopy(features_meta)
4489
4460
 
4490
- file_metadata_columns = self._search_task.get_file_metadata(trace_id).columns
4461
+ file_metadata_columns = self._search_task.get_file_metadata(self._get_trace_id()).columns
4491
4462
  file_meta_by_orig_name = {c.originalName: c for c in file_metadata_columns}
4492
4463
  original_names_dict = {c.name: c.originalName for c in file_metadata_columns}
4493
- features_df = self._search_task.get_all_initial_raw_features(trace_id, metrics_calculation=True)
4464
+ features_df = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
4494
4465
 
4495
4466
  # To be sure that names with hash suffixes
4496
4467
  clients_features_df = clients_features_df.rename(columns=original_names_dict)
@@ -4581,7 +4552,7 @@ if response.status_code == 200:
4581
4552
  internal_features_info.append(feature_info.to_internal_row(self.bundle))
4582
4553
 
4583
4554
  if update_selected_features:
4584
- self._search_task.update_selected_features(trace_id, self.feature_names_)
4555
+ self._search_task.update_selected_features(self._get_trace_id(), self.feature_names_)
4585
4556
 
4586
4557
  if len(features_info) > 0:
4587
4558
  self.features_info = pd.DataFrame(features_info)
@@ -4779,12 +4750,17 @@ if response.status_code == 200:
4779
4750
  ):
4780
4751
  raise ValidationError(self.bundle.get("empty_search_key").format(column_name))
4781
4752
 
4782
- if self.autodetect_search_keys and (
4783
- not is_transform or set(valid_search_keys.values()) != set(self.fit_search_keys.values())
4784
- ):
4785
- valid_search_keys = self.__detect_missing_search_keys(
4786
- x, valid_search_keys, is_demo_dataset, silent_mode, is_transform
4787
- )
4753
+ if is_transform:
4754
+ fit_autodetected_search_keys = self._get_autodetected_search_keys()
4755
+ if fit_autodetected_search_keys is not None:
4756
+ for key in fit_autodetected_search_keys.keys():
4757
+ if key not in x.columns:
4758
+ raise ValidationError(
4759
+ self.bundle.get("autodetected_search_key_not_found").format(key, x.columns)
4760
+ )
4761
+ valid_search_keys.update(fit_autodetected_search_keys)
4762
+ elif self.autodetect_search_keys:
4763
+ valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
4788
4764
 
4789
4765
  if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
4790
4766
  if self.__is_registered:
@@ -4829,7 +4805,6 @@ if response.status_code == 200:
4829
4805
  scoring: Callable | str | None,
4830
4806
  estimator: Any | None,
4831
4807
  remove_outliers_calc_metrics: bool | None,
4832
- trace_id: str,
4833
4808
  progress_bar: ProgressBar | None = None,
4834
4809
  progress_callback: Callable[[SearchProgress], Any] | None = None,
4835
4810
  ):
@@ -4837,7 +4812,6 @@ if response.status_code == 200:
4837
4812
  scoring=scoring,
4838
4813
  estimator=estimator,
4839
4814
  remove_outliers_calc_metrics=remove_outliers_calc_metrics,
4840
- trace_id=trace_id,
4841
4815
  internal_call=True,
4842
4816
  progress_bar=progress_bar,
4843
4817
  progress_callback=progress_callback,
@@ -4902,60 +4876,36 @@ if response.status_code == 200:
4902
4876
  df: pd.DataFrame,
4903
4877
  search_keys: dict[str, SearchKey],
4904
4878
  is_demo_dataset: bool,
4905
- silent_mode=False,
4906
- is_transform=False,
4907
4879
  ) -> dict[str, SearchKey]:
4908
4880
  sample = df.head(100)
4909
4881
 
4910
- def check_need_detect(search_key: SearchKey):
4911
- return not is_transform or (
4912
- search_key in self.fit_search_keys.values() and search_key not in search_keys.values()
4913
- )
4914
-
4915
- if (
4916
- SearchKey.DATE not in search_keys.values()
4917
- and SearchKey.DATETIME not in search_keys.values()
4918
- and check_need_detect(SearchKey.DATE)
4919
- and check_need_detect(SearchKey.DATETIME)
4920
- ):
4882
+ if SearchKey.DATE not in search_keys.values() and SearchKey.DATETIME not in search_keys.values():
4921
4883
  maybe_keys = DateSearchKeyDetector().get_search_key_columns(sample, search_keys)
4922
4884
  if len(maybe_keys) > 0:
4923
4885
  datetime_key = maybe_keys[0]
4924
4886
  search_keys[datetime_key] = SearchKey.DATETIME
4925
4887
  self.autodetected_search_keys[datetime_key] = SearchKey.DATETIME
4926
4888
  self.logger.info(f"Autodetected search key DATETIME in column {datetime_key}")
4927
- if not silent_mode:
4928
- print(self.bundle.get("datetime_detected").format(datetime_key))
4889
+ print(self.bundle.get("datetime_detected").format(datetime_key))
4929
4890
 
4930
4891
  # if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
4931
- if check_need_detect(SearchKey.POSTAL_CODE):
4932
- maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
4933
- if maybe_keys:
4934
- new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
4935
- search_keys.update(new_keys)
4936
- self.autodetected_search_keys.update(new_keys)
4937
- self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
4938
- if not silent_mode:
4939
- print(self.bundle.get("postal_code_detected").format(maybe_keys))
4940
-
4941
- if (
4942
- SearchKey.COUNTRY not in search_keys.values()
4943
- and self.country_code is None
4944
- and check_need_detect(SearchKey.COUNTRY)
4945
- ):
4892
+ maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
4893
+ if maybe_keys:
4894
+ new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
4895
+ search_keys.update(new_keys)
4896
+ self.autodetected_search_keys.update(new_keys)
4897
+ self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
4898
+ print(self.bundle.get("postal_code_detected").format(maybe_keys))
4899
+
4900
+ if SearchKey.COUNTRY not in search_keys.values() and self.country_code is None:
4946
4901
  maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
4947
4902
  if maybe_key:
4948
4903
  search_keys[maybe_key[0]] = SearchKey.COUNTRY
4949
4904
  self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
4950
4905
  self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
4951
- if not silent_mode:
4952
- print(self.bundle.get("country_detected").format(maybe_key))
4906
+ print(self.bundle.get("country_detected").format(maybe_key))
4953
4907
 
4954
- if (
4955
- # SearchKey.EMAIL not in search_keys.values()
4956
- SearchKey.HEM not in search_keys.values()
4957
- and check_need_detect(SearchKey.HEM)
4958
- ):
4908
+ if SearchKey.EMAIL not in search_keys.values() and SearchKey.HEM not in search_keys.values():
4959
4909
  maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
4960
4910
  if maybe_keys:
4961
4911
  if self.__is_registered or is_demo_dataset:
@@ -4963,34 +4913,28 @@ if response.status_code == 200:
4963
4913
  search_keys.update(new_keys)
4964
4914
  self.autodetected_search_keys.update(new_keys)
4965
4915
  self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
4966
- if not silent_mode:
4967
- print(self.bundle.get("email_detected").format(maybe_keys))
4916
+ print(self.bundle.get("email_detected").format(maybe_keys))
4968
4917
  else:
4969
4918
  self.logger.warning(
4970
4919
  f"Autodetected search key EMAIL in column {maybe_keys}."
4971
4920
  " But not used because not registered user"
4972
4921
  )
4973
- if not silent_mode:
4974
- self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
4922
+ self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
4975
4923
 
4976
4924
  # if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
4977
- if check_need_detect(SearchKey.PHONE):
4978
- maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
4979
- if maybe_keys:
4980
- if self.__is_registered or is_demo_dataset:
4981
- new_keys = {key: SearchKey.PHONE for key in maybe_keys}
4982
- search_keys.update(new_keys)
4983
- self.autodetected_search_keys.update(new_keys)
4984
- self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
4985
- if not silent_mode:
4986
- print(self.bundle.get("phone_detected").format(maybe_keys))
4987
- else:
4988
- self.logger.warning(
4989
- f"Autodetected search key PHONE in column {maybe_keys}. "
4990
- "But not used because not registered user"
4991
- )
4992
- if not silent_mode:
4993
- self.__log_warning(self.bundle.get("phone_detected_not_registered"))
4925
+ maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
4926
+ if maybe_keys:
4927
+ if self.__is_registered or is_demo_dataset:
4928
+ new_keys = {key: SearchKey.PHONE for key in maybe_keys}
4929
+ search_keys.update(new_keys)
4930
+ self.autodetected_search_keys.update(new_keys)
4931
+ self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
4932
+ print(self.bundle.get("phone_detected").format(maybe_keys))
4933
+ else:
4934
+ self.logger.warning(
4935
+ f"Autodetected search key PHONE in column {maybe_keys}. " "But not used because not registered user"
4936
+ )
4937
+ self.__log_warning(self.bundle.get("phone_detected_not_registered"))
4994
4938
 
4995
4939
  return search_keys
4996
4940
 
@@ -5062,13 +5006,12 @@ if response.status_code == 200:
5062
5006
 
5063
5007
  def dump_input(
5064
5008
  self,
5065
- trace_id: str,
5066
5009
  X: pd.DataFrame | pd.Series,
5067
5010
  y: pd.DataFrame | pd.Series | None = None,
5068
5011
  eval_set: tuple | None = None,
5069
5012
  ):
5070
- def dump_task(X_, y_, eval_set_):
5071
- with MDC(correlation_id=trace_id):
5013
+ def dump_task(X_, y_, eval_set_, trace_id_):
5014
+ with MDC(correlation_id=trace_id_):
5072
5015
  try:
5073
5016
  if isinstance(X_, pd.Series):
5074
5017
  X_ = X_.to_frame()
@@ -5076,27 +5019,25 @@ if response.status_code == 200:
5076
5019
  with tempfile.TemporaryDirectory() as tmp_dir:
5077
5020
  X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
5078
5021
  x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
5079
- if self.rest_client.is_file_uploaded(trace_id, x_digest_sha256):
5022
+ if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
5080
5023
  self.logger.info(
5081
5024
  f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
5082
5025
  )
5083
5026
  else:
5084
- self.rest_client.dump_input_file(
5085
- trace_id, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
5086
- )
5027
+ self.rest_client.dump_input_file(f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256)
5087
5028
 
5088
5029
  if y_ is not None:
5089
5030
  if isinstance(y_, pd.Series):
5090
5031
  y_ = y_.to_frame()
5091
5032
  y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
5092
5033
  y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
5093
- if self.rest_client.is_file_uploaded(trace_id, y_digest_sha256):
5034
+ if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
5094
5035
  self.logger.info(
5095
5036
  f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
5096
5037
  )
5097
5038
  else:
5098
5039
  self.rest_client.dump_input_file(
5099
- trace_id, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5040
+ trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
5100
5041
  )
5101
5042
 
5102
5043
  if eval_set_ is not None and len(eval_set_) > 0:
@@ -5105,14 +5046,14 @@ if response.status_code == 200:
5105
5046
  eval_x_ = eval_x_.to_frame()
5106
5047
  eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
5107
5048
  eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
5108
- if self.rest_client.is_file_uploaded(trace_id, eval_x_digest_sha256):
5049
+ if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
5109
5050
  self.logger.info(
5110
5051
  f"File eval_x_{idx}.parquet was already uploaded with"
5111
5052
  f" digest {eval_x_digest_sha256}, skipping"
5112
5053
  )
5113
5054
  else:
5114
5055
  self.rest_client.dump_input_file(
5115
- trace_id,
5056
+ trace_id_,
5116
5057
  f"{tmp_dir}/eval_x_{idx}.parquet",
5117
5058
  f"eval_x_{idx}.parquet",
5118
5059
  eval_x_digest_sha256,
@@ -5122,14 +5063,14 @@ if response.status_code == 200:
5122
5063
  eval_y_ = eval_y_.to_frame()
5123
5064
  eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
5124
5065
  eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
5125
- if self.rest_client.is_file_uploaded(trace_id, eval_y_digest_sha256):
5066
+ if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
5126
5067
  self.logger.info(
5127
5068
  f"File eval_y_{idx}.parquet was already uploaded"
5128
5069
  f" with digest {eval_y_digest_sha256}, skipping"
5129
5070
  )
5130
5071
  else:
5131
5072
  self.rest_client.dump_input_file(
5132
- trace_id,
5073
+ trace_id_,
5133
5074
  f"{tmp_dir}/eval_y_{idx}.parquet",
5134
5075
  f"eval_y_{idx}.parquet",
5135
5076
  eval_y_digest_sha256,
@@ -5138,7 +5079,8 @@ if response.status_code == 200:
5138
5079
  self.logger.warning("Failed to dump input files", exc_info=True)
5139
5080
 
5140
5081
  try:
5141
- Thread(target=dump_task, args=(X, y, eval_set), daemon=True).start()
5082
+ trace_id = self._get_trace_id()
5083
+ Thread(target=dump_task, args=(X, y, eval_set, trace_id), daemon=True).start()
5142
5084
  except Exception:
5143
5085
  self.logger.warning("Failed to dump input files", exc_info=True)
5144
5086