upgini 1.2.141__py3-none-any.whl → 1.2.142a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of upgini might be problematic. Click here for more details.
- upgini/__about__.py +1 -1
- upgini/dataset.py +8 -0
- upgini/features_enricher.py +501 -559
- upgini/metadata.py +2 -1
- upgini/normalizer/normalize_utils.py +1 -1
- upgini/resource_bundle/strings.properties +10 -9
- {upgini-1.2.141.dist-info → upgini-1.2.142a1.dist-info}/METADATA +1 -1
- {upgini-1.2.141.dist-info → upgini-1.2.142a1.dist-info}/RECORD +10 -10
- {upgini-1.2.141.dist-info → upgini-1.2.142a1.dist-info}/WHEEL +0 -0
- {upgini-1.2.141.dist-info → upgini-1.2.142a1.dist-info}/licenses/LICENSE +0 -0
upgini/features_enricher.py
CHANGED
|
@@ -42,6 +42,7 @@ from upgini.http import (
|
|
|
42
42
|
get_rest_client,
|
|
43
43
|
)
|
|
44
44
|
from upgini.mdc import MDC
|
|
45
|
+
from upgini.mdc.context import get_mdc_fields
|
|
45
46
|
from upgini.metadata import (
|
|
46
47
|
COUNTRY,
|
|
47
48
|
CURRENT_DATE_COL,
|
|
@@ -273,7 +274,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
273
274
|
self.X: pd.DataFrame | None = None
|
|
274
275
|
self.y: pd.Series | None = None
|
|
275
276
|
self.eval_set: list[tuple] | None = None
|
|
276
|
-
self.autodetected_search_keys: dict[str, SearchKey] =
|
|
277
|
+
self.autodetected_search_keys: dict[str, SearchKey] = dict()
|
|
277
278
|
self.imbalanced = False
|
|
278
279
|
self.fit_select_features = True
|
|
279
280
|
self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
|
|
@@ -309,8 +310,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
309
310
|
search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
|
|
310
311
|
|
|
311
312
|
print(self.bundle.get("search_by_task_id_start"))
|
|
312
|
-
trace_id =
|
|
313
|
-
with MDC(correlation_id=trace_id):
|
|
313
|
+
trace_id = self._get_trace_id()
|
|
314
|
+
with MDC(correlation_id=trace_id, search_task_id=search_id):
|
|
314
315
|
try:
|
|
315
316
|
self.logger.debug(f"FeaturesEnricher created from existing search: {search_id}")
|
|
316
317
|
self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
|
|
@@ -318,7 +319,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
318
319
|
x_columns = [c.name for c in file_metadata.columns]
|
|
319
320
|
self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
|
|
320
321
|
df = pd.DataFrame(columns=x_columns)
|
|
321
|
-
self.__prepare_feature_importances(
|
|
322
|
+
self.__prepare_feature_importances(df, silent=True, update_selected_features=False)
|
|
322
323
|
if print_loaded_report:
|
|
323
324
|
self.__show_selected_features()
|
|
324
325
|
# TODO validate search_keys with search_keys from file_metadata
|
|
@@ -487,7 +488,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
487
488
|
stability_agg_func: str, optional (default="max")
|
|
488
489
|
Function to aggregate stability values. Can be "max", "min", "mean".
|
|
489
490
|
"""
|
|
490
|
-
trace_id =
|
|
491
|
+
trace_id = self._get_trace_id()
|
|
491
492
|
if self.print_trace_id:
|
|
492
493
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
493
494
|
start_time = time.time()
|
|
@@ -522,10 +523,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
522
523
|
self.X = X
|
|
523
524
|
self.y = y
|
|
524
525
|
self.eval_set = self._check_eval_set(eval_set, X)
|
|
525
|
-
self.dump_input(
|
|
526
|
+
self.dump_input(X, y, self.eval_set)
|
|
526
527
|
self.__set_select_features(select_features)
|
|
527
528
|
self.__inner_fit(
|
|
528
|
-
trace_id,
|
|
529
529
|
X,
|
|
530
530
|
y,
|
|
531
531
|
self.eval_set,
|
|
@@ -646,7 +646,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
646
646
|
|
|
647
647
|
self.warning_counter.reset()
|
|
648
648
|
auto_fe_parameters = AutoFEParameters() if auto_fe_parameters is None else auto_fe_parameters
|
|
649
|
-
trace_id =
|
|
649
|
+
trace_id = self._get_trace_id()
|
|
650
650
|
if self.print_trace_id:
|
|
651
651
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
652
652
|
start_time = time.time()
|
|
@@ -677,13 +677,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
677
677
|
self.y = y
|
|
678
678
|
self.eval_set = self._check_eval_set(eval_set, X)
|
|
679
679
|
self.__set_select_features(select_features)
|
|
680
|
-
self.dump_input(
|
|
680
|
+
self.dump_input(X, y, self.eval_set)
|
|
681
681
|
|
|
682
682
|
if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
|
|
683
683
|
raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
|
|
684
684
|
|
|
685
685
|
self.__inner_fit(
|
|
686
|
-
trace_id,
|
|
687
686
|
X,
|
|
688
687
|
y,
|
|
689
688
|
self.eval_set,
|
|
@@ -738,7 +737,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
738
737
|
y,
|
|
739
738
|
exclude_features_sources=exclude_features_sources,
|
|
740
739
|
keep_input=keep_input,
|
|
741
|
-
trace_id=trace_id,
|
|
742
740
|
silent_mode=True,
|
|
743
741
|
progress_bar=progress_bar,
|
|
744
742
|
progress_callback=progress_callback,
|
|
@@ -753,7 +751,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
753
751
|
*args,
|
|
754
752
|
exclude_features_sources: list[str] | None = None,
|
|
755
753
|
keep_input: bool = True,
|
|
756
|
-
trace_id: str | None = None,
|
|
757
754
|
silent_mode=False,
|
|
758
755
|
progress_bar: ProgressBar | None = None,
|
|
759
756
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
@@ -790,12 +787,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
790
787
|
progress_bar.progress = search_progress.to_progress_bar()
|
|
791
788
|
if new_progress:
|
|
792
789
|
progress_bar.display()
|
|
793
|
-
trace_id =
|
|
790
|
+
trace_id = self._get_trace_id()
|
|
794
791
|
if self.print_trace_id:
|
|
795
792
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
796
793
|
search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
|
|
797
794
|
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
798
|
-
self.dump_input(
|
|
795
|
+
self.dump_input(X)
|
|
799
796
|
if len(args) > 0:
|
|
800
797
|
msg = f"WARNING: Unsupported positional arguments for transform: {args}"
|
|
801
798
|
self.logger.warning(msg)
|
|
@@ -808,7 +805,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
808
805
|
start_time = time.time()
|
|
809
806
|
try:
|
|
810
807
|
result, _, _, _ = self.__inner_transform(
|
|
811
|
-
trace_id,
|
|
812
808
|
X,
|
|
813
809
|
y=y,
|
|
814
810
|
exclude_features_sources=exclude_features_sources,
|
|
@@ -872,7 +868,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
872
868
|
estimator=None,
|
|
873
869
|
exclude_features_sources: list[str] | None = None,
|
|
874
870
|
remove_outliers_calc_metrics: bool | None = None,
|
|
875
|
-
trace_id: str | None = None,
|
|
876
871
|
internal_call: bool = False,
|
|
877
872
|
progress_bar: ProgressBar | None = None,
|
|
878
873
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
@@ -910,7 +905,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
910
905
|
Dataframe with metrics calculated on train and validation datasets.
|
|
911
906
|
"""
|
|
912
907
|
|
|
913
|
-
trace_id =
|
|
908
|
+
trace_id = self._get_trace_id()
|
|
914
909
|
start_time = time.time()
|
|
915
910
|
search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
|
|
916
911
|
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
@@ -943,7 +938,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
943
938
|
raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
|
|
944
939
|
|
|
945
940
|
validated_X, validated_y, validated_eval_set = self._validate_train_eval(
|
|
946
|
-
effective_X, effective_y, effective_eval_set
|
|
941
|
+
effective_X, effective_y, effective_eval_set
|
|
947
942
|
)
|
|
948
943
|
|
|
949
944
|
if self.X is None:
|
|
@@ -978,11 +973,13 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
978
973
|
self.__display_support_link(msg)
|
|
979
974
|
return None
|
|
980
975
|
|
|
976
|
+
search_keys = self._get_fit_search_keys_with_original_names()
|
|
977
|
+
|
|
981
978
|
cat_features_from_backend = self.__get_categorical_features()
|
|
982
979
|
# Convert to original names
|
|
983
980
|
cat_features_from_backend = [self.fit_columns_renaming.get(c, c) for c in cat_features_from_backend]
|
|
984
981
|
client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
|
|
985
|
-
estimator, validated_X,
|
|
982
|
+
estimator, validated_X, search_keys
|
|
986
983
|
)
|
|
987
984
|
# Exclude id columns from cat_features
|
|
988
985
|
if self.id_columns and self.id_columns_encoder is not None:
|
|
@@ -1004,7 +1001,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1004
1001
|
self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
|
|
1005
1002
|
|
|
1006
1003
|
prepared_data = self._get_cached_enriched_data(
|
|
1007
|
-
trace_id=trace_id,
|
|
1008
1004
|
X=X,
|
|
1009
1005
|
y=y,
|
|
1010
1006
|
eval_set=eval_set,
|
|
@@ -1257,7 +1253,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1257
1253
|
|
|
1258
1254
|
if updating_shaps is not None:
|
|
1259
1255
|
decoded_X = self._decode_id_columns(fitting_X)
|
|
1260
|
-
self._update_shap_values(
|
|
1256
|
+
self._update_shap_values(decoded_X, updating_shaps, silent=not internal_call)
|
|
1261
1257
|
|
|
1262
1258
|
metrics_df = pd.DataFrame(metrics)
|
|
1263
1259
|
mean_target_hdr = self.bundle.get("quality_metrics_mean_target_header")
|
|
@@ -1307,9 +1303,33 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1307
1303
|
finally:
|
|
1308
1304
|
self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
|
|
1309
1305
|
|
|
1306
|
+
def _get_trace_id(self):
|
|
1307
|
+
if get_mdc_fields().get("correlation_id") is not None:
|
|
1308
|
+
return get_mdc_fields().get("correlation_id")
|
|
1309
|
+
return int(time.time() * 1000)
|
|
1310
|
+
|
|
1311
|
+
def _get_autodetected_search_keys(self):
|
|
1312
|
+
if self.autodetected_search_keys is None and self._search_task is not None:
|
|
1313
|
+
meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1314
|
+
self.autodetected_search_keys = {k: SearchKey[v] for k, v in meta.autodetectedSearchKeys.items()}
|
|
1315
|
+
|
|
1316
|
+
return self.autodetected_search_keys
|
|
1317
|
+
|
|
1318
|
+
def _get_fit_search_keys_with_original_names(self):
|
|
1319
|
+
if self.fit_search_keys is None and self._search_task is not None:
|
|
1320
|
+
fit_search_keys = dict()
|
|
1321
|
+
meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1322
|
+
for column in meta.columns:
|
|
1323
|
+
# TODO check for EMAIL->HEM and multikeys
|
|
1324
|
+
search_key_type = SearchKey.from_meaning_type(column.meaningType)
|
|
1325
|
+
if search_key_type is not None:
|
|
1326
|
+
fit_search_keys[column.originalName] = search_key_type
|
|
1327
|
+
else:
|
|
1328
|
+
fit_search_keys = {self.fit_columns_renaming.get(k, k): v for k, v in self.fit_search_keys.items()}
|
|
1329
|
+
return fit_search_keys
|
|
1330
|
+
|
|
1310
1331
|
def _select_features_by_psi(
|
|
1311
1332
|
self,
|
|
1312
|
-
trace_id: str,
|
|
1313
1333
|
X: pd.DataFrame | pd.Series | np.ndarray,
|
|
1314
1334
|
y: pd.DataFrame | pd.Series | np.ndarray | list,
|
|
1315
1335
|
eval_set: list[tuple] | tuple | None,
|
|
@@ -1322,7 +1342,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1322
1342
|
progress_callback: Callable | None = None,
|
|
1323
1343
|
):
|
|
1324
1344
|
search_keys = self.search_keys.copy()
|
|
1325
|
-
|
|
1345
|
+
search_keys.update(self._get_autodetected_search_keys())
|
|
1346
|
+
validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set)
|
|
1326
1347
|
if isinstance(X, np.ndarray):
|
|
1327
1348
|
search_keys = {str(k): v for k, v in search_keys.items()}
|
|
1328
1349
|
|
|
@@ -1357,7 +1378,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1357
1378
|
]
|
|
1358
1379
|
|
|
1359
1380
|
prepared_data = self._get_cached_enriched_data(
|
|
1360
|
-
trace_id=trace_id,
|
|
1361
1381
|
X=X,
|
|
1362
1382
|
y=y,
|
|
1363
1383
|
eval_set=eval_set,
|
|
@@ -1506,7 +1526,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1506
1526
|
|
|
1507
1527
|
return total_unstable_features
|
|
1508
1528
|
|
|
1509
|
-
def _update_shap_values(self,
|
|
1529
|
+
def _update_shap_values(self, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
|
|
1510
1530
|
renaming = self.fit_columns_renaming or {}
|
|
1511
1531
|
self.logger.info(f"Updating SHAP values: {new_shaps}")
|
|
1512
1532
|
new_shaps = {
|
|
@@ -1514,7 +1534,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1514
1534
|
for feature, shap in new_shaps.items()
|
|
1515
1535
|
if feature in self.feature_names_ or renaming.get(feature, feature) in self.feature_names_
|
|
1516
1536
|
}
|
|
1517
|
-
self.__prepare_feature_importances(
|
|
1537
|
+
self.__prepare_feature_importances(df, new_shaps)
|
|
1518
1538
|
|
|
1519
1539
|
if not silent and self.features_info_display_handle is not None:
|
|
1520
1540
|
try:
|
|
@@ -1694,7 +1714,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1694
1714
|
|
|
1695
1715
|
def _get_cached_enriched_data(
|
|
1696
1716
|
self,
|
|
1697
|
-
trace_id: str,
|
|
1698
1717
|
X: pd.DataFrame | pd.Series | np.ndarray | None = None,
|
|
1699
1718
|
y: pd.DataFrame | pd.Series | np.ndarray | list | None = None,
|
|
1700
1719
|
eval_set: list[tuple] | tuple | None = None,
|
|
@@ -1710,10 +1729,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1710
1729
|
is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
|
|
1711
1730
|
is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
|
|
1712
1731
|
checked_eval_set = self._check_eval_set(eval_set, X)
|
|
1713
|
-
validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set
|
|
1732
|
+
validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set)
|
|
1714
1733
|
|
|
1715
1734
|
sampled_data = self._get_enriched_datasets(
|
|
1716
|
-
trace_id=trace_id,
|
|
1717
1735
|
validated_X=validated_X,
|
|
1718
1736
|
validated_y=validated_y,
|
|
1719
1737
|
eval_set=validated_eval_set,
|
|
@@ -1740,7 +1758,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1740
1758
|
|
|
1741
1759
|
self.logger.info(f"Excluding search keys: {excluding_search_keys}")
|
|
1742
1760
|
|
|
1743
|
-
file_meta = self._search_task.get_file_metadata(
|
|
1761
|
+
file_meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1744
1762
|
fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
|
|
1745
1763
|
original_dropped_features = [columns_renaming.get(f, f) for f in fit_dropped_features]
|
|
1746
1764
|
|
|
@@ -1917,7 +1935,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1917
1935
|
|
|
1918
1936
|
def _get_enriched_datasets(
|
|
1919
1937
|
self,
|
|
1920
|
-
trace_id: str,
|
|
1921
1938
|
validated_X: pd.DataFrame | pd.Series | np.ndarray | None,
|
|
1922
1939
|
validated_y: pd.DataFrame | pd.Series | np.ndarray | list | None,
|
|
1923
1940
|
eval_set: list[tuple] | None,
|
|
@@ -1945,9 +1962,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1945
1962
|
and self.df_with_original_index is not None
|
|
1946
1963
|
):
|
|
1947
1964
|
self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
|
|
1948
|
-
return self.__get_enriched_from_fit(
|
|
1949
|
-
validated_X, validated_y, eval_set, trace_id, remove_outliers_calc_metrics
|
|
1950
|
-
)
|
|
1965
|
+
return self.__get_enriched_from_fit(validated_X, validated_y, eval_set, remove_outliers_calc_metrics)
|
|
1951
1966
|
else:
|
|
1952
1967
|
self.logger.info(
|
|
1953
1968
|
"Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
|
|
@@ -1959,7 +1974,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1959
1974
|
validated_y,
|
|
1960
1975
|
eval_set,
|
|
1961
1976
|
exclude_features_sources,
|
|
1962
|
-
trace_id,
|
|
1963
1977
|
progress_bar,
|
|
1964
1978
|
progress_callback,
|
|
1965
1979
|
is_for_metrics=is_for_metrics,
|
|
@@ -2088,7 +2102,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2088
2102
|
validated_X: pd.DataFrame,
|
|
2089
2103
|
validated_y: pd.Series,
|
|
2090
2104
|
eval_set: list[tuple] | None,
|
|
2091
|
-
trace_id: str,
|
|
2092
2105
|
remove_outliers_calc_metrics: bool | None,
|
|
2093
2106
|
) -> _EnrichedDataForMetrics:
|
|
2094
2107
|
eval_set_sampled_dict = {}
|
|
@@ -2103,7 +2116,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2103
2116
|
if remove_outliers_calc_metrics is None:
|
|
2104
2117
|
remove_outliers_calc_metrics = True
|
|
2105
2118
|
if self.model_task_type == ModelTaskType.REGRESSION and remove_outliers_calc_metrics:
|
|
2106
|
-
target_outliers_df = self._search_task.get_target_outliers(
|
|
2119
|
+
target_outliers_df = self._search_task.get_target_outliers(self._get_trace_id())
|
|
2107
2120
|
if target_outliers_df is not None and len(target_outliers_df) > 0:
|
|
2108
2121
|
outliers = pd.merge(
|
|
2109
2122
|
self.df_with_original_index,
|
|
@@ -2120,7 +2133,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2120
2133
|
|
|
2121
2134
|
# index in each dataset (X, eval set) may be reordered and non unique, but index in validated datasets
|
|
2122
2135
|
# can differs from it
|
|
2123
|
-
fit_features = self._search_task.get_all_initial_raw_features(
|
|
2136
|
+
fit_features = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
|
|
2124
2137
|
|
|
2125
2138
|
# Pre-process features if we need to drop outliers
|
|
2126
2139
|
if rows_to_drop is not None:
|
|
@@ -2146,7 +2159,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2146
2159
|
validated_Xy[TARGET] = validated_y
|
|
2147
2160
|
|
|
2148
2161
|
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2149
|
-
validated_Xy, self.fit_generated_features, keep_input=True
|
|
2162
|
+
validated_Xy, self.fit_generated_features, keep_input=True
|
|
2150
2163
|
)
|
|
2151
2164
|
selecting_columns.extend(
|
|
2152
2165
|
c
|
|
@@ -2211,7 +2224,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2211
2224
|
validated_y: pd.Series,
|
|
2212
2225
|
eval_set: list[tuple] | None,
|
|
2213
2226
|
exclude_features_sources: list[str] | None,
|
|
2214
|
-
trace_id: str,
|
|
2215
2227
|
progress_bar: ProgressBar | None,
|
|
2216
2228
|
progress_callback: Callable[[SearchProgress], Any] | None,
|
|
2217
2229
|
is_for_metrics: bool = False,
|
|
@@ -2237,7 +2249,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2237
2249
|
|
|
2238
2250
|
# Transform
|
|
2239
2251
|
enriched_df, columns_renaming, generated_features, search_keys = self.__inner_transform(
|
|
2240
|
-
trace_id,
|
|
2241
2252
|
X=df.drop(columns=[TARGET]),
|
|
2242
2253
|
y=df[TARGET],
|
|
2243
2254
|
exclude_features_sources=exclude_features_sources,
|
|
@@ -2414,11 +2425,10 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2414
2425
|
|
|
2415
2426
|
return self.features_info
|
|
2416
2427
|
|
|
2417
|
-
def get_progress(self,
|
|
2428
|
+
def get_progress(self, search_task: SearchTask | None = None) -> SearchProgress:
|
|
2418
2429
|
search_task = search_task or self._search_task
|
|
2419
2430
|
if search_task is not None:
|
|
2420
|
-
|
|
2421
|
-
return search_task.get_progress(trace_id)
|
|
2431
|
+
return search_task.get_progress(self._get_trace_id())
|
|
2422
2432
|
|
|
2423
2433
|
def display_transactional_transform_api(self, only_online_sources=False):
|
|
2424
2434
|
if self.api_key is None:
|
|
@@ -2519,7 +2529,6 @@ if response.status_code == 200:
|
|
|
2519
2529
|
|
|
2520
2530
|
def __inner_transform(
|
|
2521
2531
|
self,
|
|
2522
|
-
trace_id: str,
|
|
2523
2532
|
X: pd.DataFrame,
|
|
2524
2533
|
*,
|
|
2525
2534
|
y: pd.Series | None = None,
|
|
@@ -2538,182 +2547,135 @@ if response.status_code == 200:
|
|
|
2538
2547
|
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
2539
2548
|
|
|
2540
2549
|
start_time = time.time()
|
|
2541
|
-
|
|
2542
|
-
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
2543
|
-
self.logger.info("Start transform")
|
|
2550
|
+
self.logger.info("Start transform")
|
|
2544
2551
|
|
|
2545
|
-
|
|
2546
|
-
X, y, eval_set=None, is_transform=True, silent=True
|
|
2547
|
-
)
|
|
2548
|
-
df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
|
|
2552
|
+
search_keys = self.search_keys.copy()
|
|
2549
2553
|
|
|
2550
|
-
|
|
2554
|
+
self.__validate_search_keys(search_keys, self.search_id)
|
|
2551
2555
|
|
|
2552
|
-
|
|
2556
|
+
validated_X, validated_y, validated_eval_set = self._validate_train_eval(
|
|
2557
|
+
X, y, eval_set=None, is_transform=True
|
|
2558
|
+
)
|
|
2559
|
+
df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
|
|
2553
2560
|
|
|
2554
|
-
|
|
2555
|
-
if len(self.feature_names_) == 0:
|
|
2556
|
-
msg = self.bundle.get("no_important_features_for_transform")
|
|
2557
|
-
self.__log_warning(msg, show_support_link=True)
|
|
2558
|
-
return None, {}, [], self.search_keys
|
|
2561
|
+
validated_Xy = df.copy()
|
|
2559
2562
|
|
|
2560
|
-
|
|
2563
|
+
self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
|
|
2561
2564
|
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2565
|
+
# If there are no important features, return original dataframe
|
|
2566
|
+
if len(self.feature_names_) == 0:
|
|
2567
|
+
msg = self.bundle.get("no_important_features_for_transform")
|
|
2568
|
+
self.__log_warning(msg, show_support_link=True)
|
|
2569
|
+
return None, {}, [], search_keys
|
|
2567
2570
|
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
msg = self.bundle.get("online_api_features_transform").format(online_api_features)
|
|
2574
|
-
self.logger.warning(msg)
|
|
2575
|
-
print(msg)
|
|
2576
|
-
self.display_transactional_transform_api(only_online_sources=True)
|
|
2577
|
-
|
|
2578
|
-
if not metrics_calculation:
|
|
2579
|
-
transform_usage = self.rest_client.get_current_transform_usage(trace_id)
|
|
2580
|
-
self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
|
|
2581
|
-
if transform_usage.has_limit:
|
|
2582
|
-
if len(X) > transform_usage.rest_rows:
|
|
2583
|
-
rest_rows = max(transform_usage.rest_rows, 0)
|
|
2584
|
-
bundle_msg = (
|
|
2585
|
-
"transform_usage_warning_registered"
|
|
2586
|
-
if self.__is_registered
|
|
2587
|
-
else "transform_usage_warning_demo"
|
|
2588
|
-
)
|
|
2589
|
-
msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
|
|
2590
|
-
self.logger.warning(msg)
|
|
2591
|
-
print(msg)
|
|
2592
|
-
show_request_quote_button(is_registered=self.__is_registered)
|
|
2593
|
-
return None, {}, [], {}
|
|
2594
|
-
else:
|
|
2595
|
-
msg = self.bundle.get("transform_usage_info").format(
|
|
2596
|
-
transform_usage.limit, transform_usage.transformed_rows
|
|
2597
|
-
)
|
|
2598
|
-
self.logger.info(msg)
|
|
2599
|
-
print(msg)
|
|
2571
|
+
if self._has_paid_features(exclude_features_sources):
|
|
2572
|
+
msg = self.bundle.get("transform_with_paid_features")
|
|
2573
|
+
self.logger.warning(msg)
|
|
2574
|
+
self.__display_support_link(msg)
|
|
2575
|
+
return None, {}, [], search_keys
|
|
2600
2576
|
|
|
2601
|
-
|
|
2577
|
+
online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
|
|
2578
|
+
if len(online_api_features) > 0:
|
|
2579
|
+
self.logger.warning(
|
|
2580
|
+
f"There are important features for transform, that generated by online API: {online_api_features}"
|
|
2581
|
+
)
|
|
2582
|
+
msg = self.bundle.get("online_api_features_transform").format(online_api_features)
|
|
2583
|
+
self.logger.warning(msg)
|
|
2584
|
+
print(msg)
|
|
2585
|
+
self.display_transactional_transform_api(only_online_sources=True)
|
|
2586
|
+
|
|
2587
|
+
if not metrics_calculation:
|
|
2588
|
+
transform_usage = self.rest_client.get_current_transform_usage(self._get_trace_id())
|
|
2589
|
+
self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
|
|
2590
|
+
if transform_usage.has_limit:
|
|
2591
|
+
if len(X) > transform_usage.rest_rows:
|
|
2592
|
+
rest_rows = max(transform_usage.rest_rows, 0)
|
|
2593
|
+
bundle_msg = (
|
|
2594
|
+
"transform_usage_warning_registered" if self.__is_registered else "transform_usage_warning_demo"
|
|
2595
|
+
)
|
|
2596
|
+
msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
|
|
2597
|
+
self.logger.warning(msg)
|
|
2598
|
+
print(msg)
|
|
2599
|
+
show_request_quote_button(is_registered=self.__is_registered)
|
|
2600
|
+
return None, {}, [], {}
|
|
2601
|
+
else:
|
|
2602
|
+
msg = self.bundle.get("transform_usage_info").format(
|
|
2603
|
+
transform_usage.limit, transform_usage.transformed_rows
|
|
2604
|
+
)
|
|
2605
|
+
self.logger.info(msg)
|
|
2606
|
+
print(msg)
|
|
2602
2607
|
|
|
2603
|
-
|
|
2604
|
-
c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
|
|
2605
|
-
]
|
|
2606
|
-
if len(columns_to_drop) > 0:
|
|
2607
|
-
msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
|
|
2608
|
-
self.logger.warning(msg)
|
|
2609
|
-
print(msg)
|
|
2610
|
-
df = df.drop(columns=columns_to_drop)
|
|
2608
|
+
is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
|
|
2611
2609
|
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2610
|
+
columns_to_drop = [
|
|
2611
|
+
c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
|
|
2612
|
+
]
|
|
2613
|
+
if len(columns_to_drop) > 0:
|
|
2614
|
+
msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
|
|
2615
|
+
self.logger.warning(msg)
|
|
2616
|
+
print(msg)
|
|
2617
|
+
df = df.drop(columns=columns_to_drop)
|
|
2617
2618
|
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
)
|
|
2619
|
+
if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
|
|
2620
|
+
search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in search_keys})
|
|
2621
2621
|
|
|
2622
|
-
|
|
2622
|
+
search_keys = self.__prepare_search_keys(
|
|
2623
|
+
df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
|
|
2624
|
+
)
|
|
2623
2625
|
|
|
2624
|
-
|
|
2625
|
-
msg = self.bundle.get("unsupported_index_column")
|
|
2626
|
-
self.logger.info(msg)
|
|
2627
|
-
print(msg)
|
|
2628
|
-
df.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2629
|
-
validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2626
|
+
df = self.__handle_index_search_keys(df, search_keys)
|
|
2630
2627
|
|
|
2631
|
-
|
|
2628
|
+
if DEFAULT_INDEX in df.columns:
|
|
2629
|
+
msg = self.bundle.get("unsupported_index_column")
|
|
2630
|
+
self.logger.info(msg)
|
|
2631
|
+
print(msg)
|
|
2632
|
+
df.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2633
|
+
validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2632
2634
|
|
|
2633
|
-
|
|
2634
|
-
date_column = self._get_date_column(search_keys)
|
|
2635
|
-
if date_column is not None:
|
|
2636
|
-
converter = DateTimeConverter(
|
|
2637
|
-
date_column,
|
|
2638
|
-
self.date_format,
|
|
2639
|
-
self.logger,
|
|
2640
|
-
bundle=self.bundle,
|
|
2641
|
-
generate_cyclical_features=self.generate_search_key_features,
|
|
2642
|
-
)
|
|
2643
|
-
df = converter.convert(df, keep_time=True)
|
|
2644
|
-
self.logger.info(f"Date column after convertion: {df[date_column]}")
|
|
2645
|
-
generated_features.extend(converter.generated_features)
|
|
2646
|
-
else:
|
|
2647
|
-
self.logger.info("Input dataset hasn't date column")
|
|
2648
|
-
if self.__should_add_date_column():
|
|
2649
|
-
df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
|
|
2650
|
-
|
|
2651
|
-
email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
|
|
2652
|
-
if email_columns and self.generate_search_key_features:
|
|
2653
|
-
generator = EmailDomainGenerator(email_columns)
|
|
2654
|
-
df = generator.generate(df)
|
|
2655
|
-
generated_features.extend(generator.generated_features)
|
|
2656
|
-
|
|
2657
|
-
normalizer = Normalizer(self.bundle, self.logger)
|
|
2658
|
-
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2659
|
-
columns_renaming = normalizer.columns_renaming
|
|
2660
|
-
|
|
2661
|
-
# If there are no external features, we don't call backend on transform
|
|
2662
|
-
external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
|
|
2663
|
-
if len(external_features) == 0:
|
|
2664
|
-
self.logger.warning(
|
|
2665
|
-
"No external features found, returning original dataframe"
|
|
2666
|
-
f" with generated important features: {self.feature_names_}"
|
|
2667
|
-
)
|
|
2668
|
-
df = df.rename(columns=columns_renaming)
|
|
2669
|
-
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2670
|
-
search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
|
|
2671
|
-
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2672
|
-
validated_Xy, generated_features, keep_input, trace_id, is_transform=True
|
|
2673
|
-
)
|
|
2674
|
-
self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
|
|
2675
|
-
if add_fit_system_record_id:
|
|
2676
|
-
df = self._add_fit_system_record_id(
|
|
2677
|
-
df,
|
|
2678
|
-
search_keys,
|
|
2679
|
-
SYSTEM_RECORD_ID,
|
|
2680
|
-
TARGET,
|
|
2681
|
-
columns_renaming,
|
|
2682
|
-
self.id_columns,
|
|
2683
|
-
self.cv,
|
|
2684
|
-
self.model_task_type,
|
|
2685
|
-
self.logger,
|
|
2686
|
-
self.bundle,
|
|
2687
|
-
)
|
|
2688
|
-
selecting_columns.append(SYSTEM_RECORD_ID)
|
|
2689
|
-
return df[selecting_columns], columns_renaming, generated_features, search_keys
|
|
2690
|
-
|
|
2691
|
-
# Don't pass all features in backend on transform
|
|
2692
|
-
runtime_parameters = self._get_copy_of_runtime_parameters()
|
|
2693
|
-
features_for_transform = self._search_task.get_features_for_transform()
|
|
2694
|
-
if features_for_transform:
|
|
2695
|
-
missing_features_for_transform = [
|
|
2696
|
-
columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
|
|
2697
|
-
]
|
|
2698
|
-
if TARGET in missing_features_for_transform:
|
|
2699
|
-
raise ValidationError(self.bundle.get("missing_target_for_transform"))
|
|
2635
|
+
df = self.__add_country_code(df, search_keys)
|
|
2700
2636
|
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
|
|
2637
|
+
generated_features = []
|
|
2638
|
+
date_column = self._get_date_column(search_keys)
|
|
2639
|
+
if date_column is not None:
|
|
2640
|
+
converter = DateTimeConverter(
|
|
2641
|
+
date_column,
|
|
2642
|
+
self.date_format,
|
|
2643
|
+
self.logger,
|
|
2644
|
+
bundle=self.bundle,
|
|
2645
|
+
generate_cyclical_features=self.generate_search_key_features,
|
|
2646
|
+
)
|
|
2647
|
+
df = converter.convert(df, keep_time=True)
|
|
2648
|
+
self.logger.info(f"Date column after convertion: {df[date_column]}")
|
|
2649
|
+
generated_features.extend(converter.generated_features)
|
|
2650
|
+
else:
|
|
2651
|
+
self.logger.info("Input dataset hasn't date column")
|
|
2652
|
+
if self.__should_add_date_column():
|
|
2653
|
+
df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
|
|
2709
2654
|
|
|
2710
|
-
|
|
2655
|
+
email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
|
|
2656
|
+
if email_columns and self.generate_search_key_features:
|
|
2657
|
+
generator = EmailDomainGenerator(email_columns)
|
|
2658
|
+
df = generator.generate(df)
|
|
2659
|
+
generated_features.extend(generator.generated_features)
|
|
2711
2660
|
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2661
|
+
normalizer = Normalizer(self.bundle, self.logger)
|
|
2662
|
+
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2663
|
+
columns_renaming = normalizer.columns_renaming
|
|
2715
2664
|
|
|
2716
|
-
|
|
2665
|
+
# If there are no external features, we don't call backend on transform
|
|
2666
|
+
external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
|
|
2667
|
+
if len(external_features) == 0:
|
|
2668
|
+
self.logger.warning(
|
|
2669
|
+
"No external features found, returning original dataframe"
|
|
2670
|
+
f" with generated important features: {self.feature_names_}"
|
|
2671
|
+
)
|
|
2672
|
+
df = df.rename(columns=columns_renaming)
|
|
2673
|
+
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2674
|
+
search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
|
|
2675
|
+
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2676
|
+
validated_Xy, generated_features, keep_input, is_transform=True
|
|
2677
|
+
)
|
|
2678
|
+
self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
|
|
2717
2679
|
if add_fit_system_record_id:
|
|
2718
2680
|
df = self._add_fit_system_record_id(
|
|
2719
2681
|
df,
|
|
@@ -2727,252 +2689,294 @@ if response.status_code == 200:
|
|
|
2727
2689
|
self.logger,
|
|
2728
2690
|
self.bundle,
|
|
2729
2691
|
)
|
|
2730
|
-
|
|
2731
|
-
|
|
2692
|
+
selecting_columns.append(SYSTEM_RECORD_ID)
|
|
2693
|
+
return df[selecting_columns], columns_renaming, generated_features, search_keys
|
|
2732
2694
|
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2695
|
+
# Don't pass all features in backend on transform
|
|
2696
|
+
runtime_parameters = self._get_copy_of_runtime_parameters()
|
|
2697
|
+
features_for_transform = self._search_task.get_features_for_transform()
|
|
2698
|
+
if features_for_transform:
|
|
2699
|
+
missing_features_for_transform = [
|
|
2700
|
+
columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
|
|
2701
|
+
]
|
|
2702
|
+
if TARGET in missing_features_for_transform:
|
|
2703
|
+
raise ValidationError(self.bundle.get("missing_target_for_transform"))
|
|
2736
2704
|
|
|
2737
|
-
|
|
2705
|
+
if len(missing_features_for_transform) > 0:
|
|
2706
|
+
raise ValidationError(
|
|
2707
|
+
self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
|
|
2708
|
+
)
|
|
2709
|
+
features_for_embeddings = self._search_task.get_features_for_embeddings()
|
|
2710
|
+
if features_for_embeddings:
|
|
2711
|
+
runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
|
|
2712
|
+
features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
|
|
2738
2713
|
|
|
2739
|
-
|
|
2740
|
-
df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
|
|
2714
|
+
columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
|
|
2741
2715
|
|
|
2742
|
-
|
|
2716
|
+
df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
|
|
2717
|
+
"float64"
|
|
2718
|
+
)
|
|
2743
2719
|
|
|
2744
|
-
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2720
|
+
features_not_to_pass = []
|
|
2721
|
+
if add_fit_system_record_id:
|
|
2722
|
+
df = self._add_fit_system_record_id(
|
|
2723
|
+
df,
|
|
2724
|
+
search_keys,
|
|
2725
|
+
SYSTEM_RECORD_ID,
|
|
2726
|
+
TARGET,
|
|
2727
|
+
columns_renaming,
|
|
2728
|
+
self.id_columns,
|
|
2729
|
+
self.cv,
|
|
2730
|
+
self.model_task_type,
|
|
2731
|
+
self.logger,
|
|
2732
|
+
self.bundle,
|
|
2733
|
+
)
|
|
2734
|
+
df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
|
|
2735
|
+
features_not_to_pass.append(SORT_ID)
|
|
2756
2736
|
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
2777
|
-
|
|
2778
|
-
col: FileColumnMeaningType.FEATURE
|
|
2779
|
-
for col in features_for_transform
|
|
2780
|
-
if col not in date_features and col not in generated_features
|
|
2781
|
-
}
|
|
2737
|
+
system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
|
|
2738
|
+
if add_fit_system_record_id:
|
|
2739
|
+
system_columns_with_original_index.append(SORT_ID)
|
|
2740
|
+
|
|
2741
|
+
df_before_explode = df[system_columns_with_original_index].copy()
|
|
2742
|
+
|
|
2743
|
+
# Explode multiple search keys
|
|
2744
|
+
df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
|
|
2745
|
+
|
|
2746
|
+
# Convert search keys and generate features on them
|
|
2747
|
+
|
|
2748
|
+
email_column = self._get_email_column(search_keys)
|
|
2749
|
+
hem_column = self._get_hem_column(search_keys)
|
|
2750
|
+
if email_column:
|
|
2751
|
+
converter = EmailSearchKeyConverter(
|
|
2752
|
+
email_column,
|
|
2753
|
+
hem_column,
|
|
2754
|
+
search_keys,
|
|
2755
|
+
columns_renaming,
|
|
2756
|
+
list(unnest_search_keys.keys()),
|
|
2757
|
+
self.logger,
|
|
2782
2758
|
)
|
|
2783
|
-
|
|
2784
|
-
meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
|
|
2785
|
-
meaning_types.update({col: key.value for col, key in search_keys.items()})
|
|
2759
|
+
df = converter.convert(df)
|
|
2786
2760
|
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
|
|
2794
|
-
|
|
2761
|
+
ip_column = self._get_ip_column(search_keys)
|
|
2762
|
+
if ip_column:
|
|
2763
|
+
converter = IpSearchKeyConverter(
|
|
2764
|
+
ip_column,
|
|
2765
|
+
search_keys,
|
|
2766
|
+
columns_renaming,
|
|
2767
|
+
list(unnest_search_keys.keys()),
|
|
2768
|
+
self.bundle,
|
|
2769
|
+
self.logger,
|
|
2795
2770
|
)
|
|
2771
|
+
df = converter.convert(df)
|
|
2796
2772
|
|
|
2797
|
-
|
|
2798
|
-
|
|
2773
|
+
date_features = []
|
|
2774
|
+
for col in features_for_transform:
|
|
2775
|
+
if DateTimeConverter(col).is_datetime(df):
|
|
2776
|
+
df[col] = DateTimeConverter(col).to_date_string(df)
|
|
2777
|
+
date_features.append(col)
|
|
2799
2778
|
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2803
|
-
|
|
2804
|
-
|
|
2805
|
-
|
|
2806
|
-
|
|
2807
|
-
|
|
2808
|
-
|
|
2779
|
+
meaning_types = {}
|
|
2780
|
+
meaning_types.update(
|
|
2781
|
+
{
|
|
2782
|
+
col: FileColumnMeaningType.FEATURE
|
|
2783
|
+
for col in features_for_transform
|
|
2784
|
+
if col not in date_features and col not in generated_features
|
|
2785
|
+
}
|
|
2786
|
+
)
|
|
2787
|
+
meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
|
|
2788
|
+
meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
|
|
2789
|
+
meaning_types.update({col: key.value for col, key in search_keys.items()})
|
|
2809
2790
|
|
|
2810
|
-
|
|
2791
|
+
features_not_to_pass.extend(
|
|
2792
|
+
[
|
|
2793
|
+
c
|
|
2794
|
+
for c in df.columns
|
|
2795
|
+
if c not in search_keys.keys()
|
|
2796
|
+
and c not in features_for_transform
|
|
2797
|
+
and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
|
|
2798
|
+
]
|
|
2799
|
+
)
|
|
2811
2800
|
|
|
2812
|
-
|
|
2801
|
+
if DateTimeConverter.DATETIME_COL in df.columns:
|
|
2802
|
+
df = df.drop(columns=DateTimeConverter.DATETIME_COL)
|
|
2813
2803
|
|
|
2814
|
-
|
|
2804
|
+
# search keys might be changed after explode
|
|
2805
|
+
columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
|
|
2806
|
+
df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
|
|
2807
|
+
"float64"
|
|
2808
|
+
)
|
|
2809
|
+
meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
|
|
2810
|
+
meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
|
|
2811
|
+
if SEARCH_KEY_UNNEST in df.columns:
|
|
2812
|
+
meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
|
|
2815
2813
|
|
|
2816
|
-
|
|
2817
|
-
df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
|
|
2818
|
-
)
|
|
2819
|
-
if not silent_mode and full_duplicates_warning:
|
|
2820
|
-
self.__log_warning(full_duplicates_warning)
|
|
2814
|
+
df = df.reset_index(drop=True)
|
|
2821
2815
|
|
|
2822
|
-
|
|
2823
|
-
gc.collect()
|
|
2824
|
-
|
|
2825
|
-
dataset = Dataset(
|
|
2826
|
-
"sample_" + str(uuid.uuid4()),
|
|
2827
|
-
df=df_without_features,
|
|
2828
|
-
meaning_types=meaning_types,
|
|
2829
|
-
search_keys=combined_search_keys,
|
|
2830
|
-
unnest_search_keys=unnest_search_keys,
|
|
2831
|
-
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2832
|
-
date_column=self._get_date_column(search_keys),
|
|
2833
|
-
date_format=self.date_format,
|
|
2834
|
-
sample_config=self.sample_config,
|
|
2835
|
-
rest_client=self.rest_client,
|
|
2836
|
-
logger=self.logger,
|
|
2837
|
-
bundle=self.bundle,
|
|
2838
|
-
warning_callback=self.__log_warning,
|
|
2839
|
-
)
|
|
2840
|
-
dataset.columns_renaming = columns_renaming
|
|
2841
|
-
|
|
2842
|
-
validation_task = self._search_task.validation(
|
|
2843
|
-
trace_id,
|
|
2844
|
-
dataset,
|
|
2845
|
-
start_time=start_time,
|
|
2846
|
-
extract_features=True,
|
|
2847
|
-
runtime_parameters=runtime_parameters,
|
|
2848
|
-
exclude_features_sources=exclude_features_sources,
|
|
2849
|
-
metrics_calculation=metrics_calculation,
|
|
2850
|
-
silent_mode=silent_mode,
|
|
2851
|
-
progress_bar=progress_bar,
|
|
2852
|
-
progress_callback=progress_callback,
|
|
2853
|
-
)
|
|
2816
|
+
combined_search_keys = combine_search_keys(search_keys.keys())
|
|
2854
2817
|
|
|
2855
|
-
|
|
2856
|
-
gc.collect()
|
|
2818
|
+
df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
|
|
2857
2819
|
|
|
2858
|
-
|
|
2859
|
-
|
|
2860
|
-
|
|
2861
|
-
|
|
2820
|
+
df_without_features, full_duplicates_warning = clean_full_duplicates(
|
|
2821
|
+
df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
|
|
2822
|
+
)
|
|
2823
|
+
if not silent_mode and full_duplicates_warning:
|
|
2824
|
+
self.__log_warning(full_duplicates_warning)
|
|
2862
2825
|
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
if progress_bar is not None:
|
|
2866
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2867
|
-
if progress_callback is not None:
|
|
2868
|
-
progress_callback(progress)
|
|
2869
|
-
prev_progress: SearchProgress | None = None
|
|
2870
|
-
polling_period_seconds = 1
|
|
2871
|
-
try:
|
|
2872
|
-
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2873
|
-
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2874
|
-
progress.recalculate_eta(time.time() - start_time)
|
|
2875
|
-
else:
|
|
2876
|
-
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2877
|
-
prev_progress = progress
|
|
2878
|
-
if progress_bar is not None:
|
|
2879
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2880
|
-
if progress_callback is not None:
|
|
2881
|
-
progress_callback(progress)
|
|
2882
|
-
if progress.stage == ProgressStage.FAILED.value:
|
|
2883
|
-
raise Exception(progress.error_message)
|
|
2884
|
-
time.sleep(polling_period_seconds)
|
|
2885
|
-
progress = self.get_progress(trace_id, validation_task)
|
|
2886
|
-
except KeyboardInterrupt as e:
|
|
2887
|
-
print(self.bundle.get("search_stopping"))
|
|
2888
|
-
self.rest_client.stop_search_task_v2(trace_id, validation_task.search_task_id)
|
|
2889
|
-
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2890
|
-
print(self.bundle.get("search_stopped"))
|
|
2891
|
-
raise e
|
|
2826
|
+
del df
|
|
2827
|
+
gc.collect()
|
|
2892
2828
|
|
|
2893
|
-
|
|
2829
|
+
dataset = Dataset(
|
|
2830
|
+
"sample_" + str(uuid.uuid4()),
|
|
2831
|
+
df=df_without_features,
|
|
2832
|
+
meaning_types=meaning_types,
|
|
2833
|
+
search_keys=combined_search_keys,
|
|
2834
|
+
unnest_search_keys=unnest_search_keys,
|
|
2835
|
+
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2836
|
+
date_column=self._get_date_column(search_keys),
|
|
2837
|
+
date_format=self.date_format,
|
|
2838
|
+
sample_config=self.sample_config,
|
|
2839
|
+
rest_client=self.rest_client,
|
|
2840
|
+
logger=self.logger,
|
|
2841
|
+
bundle=self.bundle,
|
|
2842
|
+
warning_callback=self.__log_warning,
|
|
2843
|
+
)
|
|
2844
|
+
dataset.columns_renaming = columns_renaming
|
|
2894
2845
|
|
|
2895
|
-
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
|
|
2900
|
-
|
|
2846
|
+
validation_task = self._search_task.validation(
|
|
2847
|
+
self._get_trace_id(),
|
|
2848
|
+
dataset,
|
|
2849
|
+
start_time=start_time,
|
|
2850
|
+
extract_features=True,
|
|
2851
|
+
runtime_parameters=runtime_parameters,
|
|
2852
|
+
exclude_features_sources=exclude_features_sources,
|
|
2853
|
+
metrics_calculation=metrics_calculation,
|
|
2854
|
+
silent_mode=silent_mode,
|
|
2855
|
+
progress_bar=progress_bar,
|
|
2856
|
+
progress_callback=progress_callback,
|
|
2857
|
+
)
|
|
2901
2858
|
|
|
2902
|
-
|
|
2903
|
-
|
|
2859
|
+
del df_without_features, dataset
|
|
2860
|
+
gc.collect()
|
|
2904
2861
|
|
|
2905
|
-
|
|
2906
|
-
|
|
2907
|
-
|
|
2908
|
-
|
|
2909
|
-
[
|
|
2910
|
-
validated_Xy.reset_index(drop=True),
|
|
2911
|
-
df_before_explode.reset_index(drop=True),
|
|
2912
|
-
],
|
|
2913
|
-
axis=1,
|
|
2914
|
-
).set_index(validated_Xy.index)
|
|
2915
|
-
|
|
2916
|
-
result_features = validation_task.get_all_validation_raw_features(trace_id, metrics_calculation)
|
|
2917
|
-
|
|
2918
|
-
result = self.__enrich(
|
|
2919
|
-
combined_df,
|
|
2920
|
-
result_features,
|
|
2921
|
-
how="left",
|
|
2922
|
-
)
|
|
2862
|
+
if not silent_mode:
|
|
2863
|
+
print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
|
|
2864
|
+
if not self.__is_registered:
|
|
2865
|
+
print(self.bundle.get("polling_unregister_information"))
|
|
2923
2866
|
|
|
2924
|
-
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
2928
|
-
|
|
2929
|
-
|
|
2930
|
-
|
|
2931
|
-
|
|
2932
|
-
|
|
2933
|
-
|
|
2867
|
+
progress = self.get_progress(validation_task)
|
|
2868
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2869
|
+
if progress_bar is not None:
|
|
2870
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2871
|
+
if progress_callback is not None:
|
|
2872
|
+
progress_callback(progress)
|
|
2873
|
+
prev_progress: SearchProgress | None = None
|
|
2874
|
+
polling_period_seconds = 1
|
|
2875
|
+
try:
|
|
2876
|
+
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2877
|
+
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2878
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2879
|
+
else:
|
|
2880
|
+
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2881
|
+
prev_progress = progress
|
|
2882
|
+
if progress_bar is not None:
|
|
2883
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2884
|
+
if progress_callback is not None:
|
|
2885
|
+
progress_callback(progress)
|
|
2886
|
+
if progress.stage == ProgressStage.FAILED.value:
|
|
2887
|
+
raise Exception(progress.error_message)
|
|
2888
|
+
time.sleep(polling_period_seconds)
|
|
2889
|
+
progress = self.get_progress(validation_task)
|
|
2890
|
+
except KeyboardInterrupt as e:
|
|
2891
|
+
print(self.bundle.get("search_stopping"))
|
|
2892
|
+
self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
|
|
2893
|
+
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2894
|
+
print(self.bundle.get("search_stopped"))
|
|
2895
|
+
raise e
|
|
2896
|
+
|
|
2897
|
+
validation_task.poll_result(self._get_trace_id(), quiet=True)
|
|
2898
|
+
|
|
2899
|
+
seconds_left = time.time() - start_time
|
|
2900
|
+
progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
|
|
2901
|
+
if progress_bar is not None:
|
|
2902
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2903
|
+
if progress_callback is not None:
|
|
2904
|
+
progress_callback(progress)
|
|
2934
2905
|
|
|
2935
|
-
|
|
2936
|
-
|
|
2937
|
-
sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
|
|
2938
|
-
for c in generated_features:
|
|
2939
|
-
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2940
|
-
sorted_selecting_columns.append(c)
|
|
2941
|
-
for c in result.columns:
|
|
2942
|
-
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2943
|
-
sorted_selecting_columns.append(c)
|
|
2906
|
+
if not silent_mode:
|
|
2907
|
+
print(self.bundle.get("transform_start"))
|
|
2944
2908
|
|
|
2945
|
-
|
|
2909
|
+
# Prepare input DataFrame for __enrich by concatenating generated ids and client features
|
|
2910
|
+
df_before_explode = df_before_explode.rename(columns=columns_renaming)
|
|
2911
|
+
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2912
|
+
combined_df = pd.concat(
|
|
2913
|
+
[
|
|
2914
|
+
validated_Xy.reset_index(drop=True),
|
|
2915
|
+
df_before_explode.reset_index(drop=True),
|
|
2916
|
+
],
|
|
2917
|
+
axis=1,
|
|
2918
|
+
).set_index(validated_Xy.index)
|
|
2919
|
+
|
|
2920
|
+
result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
|
|
2921
|
+
|
|
2922
|
+
result = self.__enrich(
|
|
2923
|
+
combined_df,
|
|
2924
|
+
result_features,
|
|
2925
|
+
how="left",
|
|
2926
|
+
)
|
|
2927
|
+
|
|
2928
|
+
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2929
|
+
validated_Xy, generated_features, keep_input, is_transform=True
|
|
2930
|
+
)
|
|
2931
|
+
selecting_columns.extend(
|
|
2932
|
+
c
|
|
2933
|
+
for c in result.columns
|
|
2934
|
+
if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
|
|
2935
|
+
)
|
|
2936
|
+
if add_fit_system_record_id:
|
|
2937
|
+
selecting_columns.append(SORT_ID)
|
|
2946
2938
|
|
|
2947
|
-
|
|
2939
|
+
selecting_columns = list(set(selecting_columns))
|
|
2940
|
+
# sorting: first columns from X, then generated features, then enriched features
|
|
2941
|
+
sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
|
|
2942
|
+
for c in generated_features:
|
|
2943
|
+
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2944
|
+
sorted_selecting_columns.append(c)
|
|
2945
|
+
for c in result.columns:
|
|
2946
|
+
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2947
|
+
sorted_selecting_columns.append(c)
|
|
2948
2948
|
|
|
2949
|
-
|
|
2950
|
-
result = result.drop(columns=COUNTRY, errors="ignore")
|
|
2949
|
+
self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
|
|
2951
2950
|
|
|
2952
|
-
|
|
2953
|
-
|
|
2951
|
+
result = result[sorted_selecting_columns]
|
|
2952
|
+
|
|
2953
|
+
if self.country_added:
|
|
2954
|
+
result = result.drop(columns=COUNTRY, errors="ignore")
|
|
2954
2955
|
|
|
2955
|
-
|
|
2956
|
-
|
|
2957
|
-
result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
|
|
2956
|
+
if add_fit_system_record_id:
|
|
2957
|
+
result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
|
|
2958
2958
|
|
|
2959
|
-
|
|
2959
|
+
for c in result.columns:
|
|
2960
|
+
if result[c].dtype == "category":
|
|
2961
|
+
result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
|
|
2962
|
+
|
|
2963
|
+
return result, columns_renaming, generated_features, search_keys
|
|
2960
2964
|
|
|
2961
2965
|
def _selecting_input_and_generated_columns(
|
|
2962
2966
|
self,
|
|
2963
2967
|
validated_Xy: pd.DataFrame,
|
|
2964
2968
|
generated_features: list[str],
|
|
2965
2969
|
keep_input: bool,
|
|
2966
|
-
trace_id: str,
|
|
2967
2970
|
is_transform: bool = False,
|
|
2968
2971
|
):
|
|
2969
|
-
file_meta = self._search_task.get_file_metadata(
|
|
2972
|
+
file_meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
2970
2973
|
fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
|
|
2971
2974
|
fit_input_columns = [c.originalName for c in file_meta.columns]
|
|
2972
2975
|
original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
|
|
2973
2976
|
new_columns_on_transform = [
|
|
2974
2977
|
c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
|
|
2975
2978
|
]
|
|
2979
|
+
fit_original_search_keys = self._get_fit_search_keys_with_original_names()
|
|
2976
2980
|
|
|
2977
2981
|
selected_generated_features = [c for c in generated_features if c in self.feature_names_]
|
|
2978
2982
|
if keep_input is True:
|
|
@@ -2982,7 +2986,7 @@ if response.status_code == 200:
|
|
|
2982
2986
|
if not self.fit_select_features
|
|
2983
2987
|
or c in self.feature_names_
|
|
2984
2988
|
or (c in new_columns_on_transform and is_transform)
|
|
2985
|
-
or c in
|
|
2989
|
+
or c in fit_original_search_keys
|
|
2986
2990
|
or c in (self.id_columns or [])
|
|
2987
2991
|
or c in [EVAL_SET_INDEX, TARGET] # transform for metrics calculation
|
|
2988
2992
|
or c == self.baseline_score_column
|
|
@@ -3067,7 +3071,6 @@ if response.status_code == 200:
|
|
|
3067
3071
|
|
|
3068
3072
|
def __inner_fit(
|
|
3069
3073
|
self,
|
|
3070
|
-
trace_id: str,
|
|
3071
3074
|
X: pd.DataFrame | pd.Series | np.ndarray,
|
|
3072
3075
|
y: pd.DataFrame | pd.Series | np.ndarray | list | None,
|
|
3073
3076
|
eval_set: list[tuple] | None,
|
|
@@ -3149,6 +3152,8 @@ if response.status_code == 200:
|
|
|
3149
3152
|
df = self.__handle_index_search_keys(df, self.fit_search_keys)
|
|
3150
3153
|
self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
|
|
3151
3154
|
|
|
3155
|
+
df = self._validate_OOT(df, self.fit_search_keys)
|
|
3156
|
+
|
|
3152
3157
|
maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
|
|
3153
3158
|
has_date = maybe_date_column is not None and maybe_date_column in validated_X.columns
|
|
3154
3159
|
|
|
@@ -3400,6 +3405,7 @@ if response.status_code == 200:
|
|
|
3400
3405
|
id_columns=self.__get_renamed_id_columns(),
|
|
3401
3406
|
is_imbalanced=self.imbalanced,
|
|
3402
3407
|
dropped_columns=[self.fit_columns_renaming.get(f, f) for f in self.fit_dropped_features],
|
|
3408
|
+
autodetected_search_keys=self.autodetected_search_keys,
|
|
3403
3409
|
date_column=self._get_date_column(self.fit_search_keys),
|
|
3404
3410
|
date_format=self.date_format,
|
|
3405
3411
|
random_state=self.random_state,
|
|
@@ -3423,7 +3429,7 @@ if response.status_code == 200:
|
|
|
3423
3429
|
]
|
|
3424
3430
|
|
|
3425
3431
|
self._search_task = dataset.search(
|
|
3426
|
-
trace_id=
|
|
3432
|
+
trace_id=self._get_trace_id(),
|
|
3427
3433
|
progress_bar=progress_bar,
|
|
3428
3434
|
start_time=start_time,
|
|
3429
3435
|
progress_callback=progress_callback,
|
|
@@ -3443,7 +3449,7 @@ if response.status_code == 200:
|
|
|
3443
3449
|
if not self.__is_registered:
|
|
3444
3450
|
print(self.bundle.get("polling_unregister_information"))
|
|
3445
3451
|
|
|
3446
|
-
progress = self.get_progress(
|
|
3452
|
+
progress = self.get_progress()
|
|
3447
3453
|
prev_progress = None
|
|
3448
3454
|
progress.recalculate_eta(time.time() - start_time)
|
|
3449
3455
|
if progress_bar is not None:
|
|
@@ -3469,16 +3475,16 @@ if response.status_code == 200:
|
|
|
3469
3475
|
)
|
|
3470
3476
|
raise RuntimeError(self.bundle.get("search_task_failed_status"))
|
|
3471
3477
|
time.sleep(poll_period_seconds)
|
|
3472
|
-
progress = self.get_progress(
|
|
3478
|
+
progress = self.get_progress()
|
|
3473
3479
|
except KeyboardInterrupt as e:
|
|
3474
3480
|
print(self.bundle.get("search_stopping"))
|
|
3475
|
-
self.rest_client.stop_search_task_v2(
|
|
3481
|
+
self.rest_client.stop_search_task_v2(self._get_trace_id(), self._search_task.search_task_id)
|
|
3476
3482
|
self.logger.warning(f"Search {self._search_task.search_task_id} stopped by user")
|
|
3477
3483
|
self._search_task = None
|
|
3478
3484
|
print(self.bundle.get("search_stopped"))
|
|
3479
3485
|
raise e
|
|
3480
3486
|
|
|
3481
|
-
self._search_task.poll_result(
|
|
3487
|
+
self._search_task.poll_result(self._get_trace_id(), quiet=True)
|
|
3482
3488
|
|
|
3483
3489
|
seconds_left = time.time() - start_time
|
|
3484
3490
|
progress = SearchProgress(97.0, ProgressStage.GENERATING_REPORT, seconds_left)
|
|
@@ -3507,10 +3513,9 @@ if response.status_code == 200:
|
|
|
3507
3513
|
msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
|
|
3508
3514
|
self.__log_warning(msg)
|
|
3509
3515
|
|
|
3510
|
-
self.__prepare_feature_importances(
|
|
3516
|
+
self.__prepare_feature_importances(df)
|
|
3511
3517
|
|
|
3512
3518
|
self._select_features_by_psi(
|
|
3513
|
-
trace_id=trace_id,
|
|
3514
3519
|
X=X,
|
|
3515
3520
|
y=y,
|
|
3516
3521
|
eval_set=eval_set,
|
|
@@ -3523,7 +3528,7 @@ if response.status_code == 200:
|
|
|
3523
3528
|
progress_callback=progress_callback,
|
|
3524
3529
|
)
|
|
3525
3530
|
|
|
3526
|
-
self.__prepare_feature_importances(
|
|
3531
|
+
self.__prepare_feature_importances(df)
|
|
3527
3532
|
|
|
3528
3533
|
self.__show_selected_features()
|
|
3529
3534
|
|
|
@@ -3558,7 +3563,6 @@ if response.status_code == 200:
|
|
|
3558
3563
|
scoring,
|
|
3559
3564
|
estimator,
|
|
3560
3565
|
remove_outliers_calc_metrics,
|
|
3561
|
-
trace_id,
|
|
3562
3566
|
progress_bar,
|
|
3563
3567
|
progress_callback,
|
|
3564
3568
|
)
|
|
@@ -3653,11 +3657,10 @@ if response.status_code == 200:
|
|
|
3653
3657
|
y: pd.Series | None = None,
|
|
3654
3658
|
eval_set: list[tuple[pd.DataFrame, pd.Series]] | None = None,
|
|
3655
3659
|
is_transform: bool = False,
|
|
3656
|
-
silent: bool = False,
|
|
3657
3660
|
) -> tuple[pd.DataFrame, pd.Series, list[tuple[pd.DataFrame, pd.Series]]] | None:
|
|
3658
3661
|
validated_X = self._validate_X(X, is_transform)
|
|
3659
3662
|
validated_y = self._validate_y(validated_X, y, enforce_y=not is_transform)
|
|
3660
|
-
validated_eval_set = self._validate_eval_set(validated_X, eval_set
|
|
3663
|
+
validated_eval_set = self._validate_eval_set(validated_X, eval_set)
|
|
3661
3664
|
return validated_X, validated_y, validated_eval_set
|
|
3662
3665
|
|
|
3663
3666
|
def _encode_id_columns(
|
|
@@ -3783,31 +3786,41 @@ if response.status_code == 200:
|
|
|
3783
3786
|
return validated_y
|
|
3784
3787
|
|
|
3785
3788
|
def _validate_eval_set(
|
|
3786
|
-
self,
|
|
3787
|
-
|
|
3789
|
+
self,
|
|
3790
|
+
X: pd.DataFrame,
|
|
3791
|
+
eval_set: list[tuple[pd.DataFrame, pd.Series]] | None,
|
|
3792
|
+
) -> list[tuple[pd.DataFrame, pd.Series]] | None:
|
|
3788
3793
|
if eval_set is None:
|
|
3789
3794
|
return None
|
|
3790
3795
|
validated_eval_set = []
|
|
3791
|
-
|
|
3792
|
-
has_date = date_col is not None and date_col in X.columns
|
|
3793
|
-
for idx, eval_pair in enumerate(eval_set):
|
|
3796
|
+
for _, eval_pair in enumerate(eval_set):
|
|
3794
3797
|
validated_pair = self._validate_eval_set_pair(X, eval_pair)
|
|
3795
|
-
if validated_pair[1].isna().all():
|
|
3796
|
-
if not has_date:
|
|
3797
|
-
msg = self.bundle.get("oot_without_date_not_supported").format(idx + 1)
|
|
3798
|
-
elif self.columns_for_online_api:
|
|
3799
|
-
msg = self.bundle.get("oot_with_online_sources_not_supported").format(idx + 1)
|
|
3800
|
-
else:
|
|
3801
|
-
msg = None
|
|
3802
|
-
if msg:
|
|
3803
|
-
if not silent:
|
|
3804
|
-
print(msg)
|
|
3805
|
-
self.logger.warning(msg)
|
|
3806
|
-
continue
|
|
3807
3798
|
validated_eval_set.append(validated_pair)
|
|
3808
3799
|
|
|
3809
3800
|
return validated_eval_set
|
|
3810
3801
|
|
|
3802
|
+
def _validate_OOT(self, df: pd.DataFrame, search_keys: dict[str, SearchKey]) -> pd.DataFrame:
|
|
3803
|
+
if EVAL_SET_INDEX not in df.columns:
|
|
3804
|
+
return df
|
|
3805
|
+
|
|
3806
|
+
for eval_set_index in df[EVAL_SET_INDEX].unique():
|
|
3807
|
+
if eval_set_index == 0:
|
|
3808
|
+
continue
|
|
3809
|
+
eval_df = df[df[EVAL_SET_INDEX] == eval_set_index]
|
|
3810
|
+
date_col = self._get_date_column(search_keys)
|
|
3811
|
+
has_date = date_col is not None and date_col in eval_df.columns
|
|
3812
|
+
if eval_df[TARGET].isna().all():
|
|
3813
|
+
msg = None
|
|
3814
|
+
if not has_date:
|
|
3815
|
+
msg = self.bundle.get("oot_without_date_not_supported").format(eval_set_index)
|
|
3816
|
+
elif self.columns_for_online_api:
|
|
3817
|
+
msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
|
|
3818
|
+
if msg:
|
|
3819
|
+
print(msg)
|
|
3820
|
+
self.logger.warning(msg)
|
|
3821
|
+
df = df[df[EVAL_SET_INDEX] != eval_set_index]
|
|
3822
|
+
return df
|
|
3823
|
+
|
|
3811
3824
|
def _validate_eval_set_pair(self, X: pd.DataFrame, eval_pair: tuple) -> tuple[pd.DataFrame, pd.Series]:
|
|
3812
3825
|
if len(eval_pair) != 2:
|
|
3813
3826
|
raise ValidationError(self.bundle.get("eval_set_invalid_tuple_size").format(len(eval_pair)))
|
|
@@ -4423,47 +4436,6 @@ if response.status_code == 200:
|
|
|
4423
4436
|
|
|
4424
4437
|
return result_features
|
|
4425
4438
|
|
|
4426
|
-
def __get_features_importance_from_server(self, trace_id: str, df: pd.DataFrame):
|
|
4427
|
-
if self._search_task is None:
|
|
4428
|
-
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
4429
|
-
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4430
|
-
if features_meta is None:
|
|
4431
|
-
raise Exception(self.bundle.get("missing_features_meta"))
|
|
4432
|
-
features_meta = deepcopy(features_meta)
|
|
4433
|
-
|
|
4434
|
-
original_names_dict = {c.name: c.originalName for c in self._search_task.get_file_metadata(trace_id).columns}
|
|
4435
|
-
df = df.rename(columns=original_names_dict)
|
|
4436
|
-
|
|
4437
|
-
features_meta.sort(key=lambda m: (-m.shap_value, m.name))
|
|
4438
|
-
|
|
4439
|
-
importances = {}
|
|
4440
|
-
|
|
4441
|
-
for feature_meta in features_meta:
|
|
4442
|
-
if feature_meta.name in original_names_dict.keys():
|
|
4443
|
-
feature_meta.name = original_names_dict[feature_meta.name]
|
|
4444
|
-
|
|
4445
|
-
is_client_feature = feature_meta.name in df.columns
|
|
4446
|
-
|
|
4447
|
-
if feature_meta.shap_value == 0.0:
|
|
4448
|
-
continue
|
|
4449
|
-
|
|
4450
|
-
# Use only important features
|
|
4451
|
-
if (
|
|
4452
|
-
feature_meta.name == COUNTRY
|
|
4453
|
-
# In select_features mode we select also from etalon features and need to show them
|
|
4454
|
-
or (not self.fit_select_features and is_client_feature)
|
|
4455
|
-
):
|
|
4456
|
-
continue
|
|
4457
|
-
|
|
4458
|
-
# Temporary workaround for duplicate features metadata
|
|
4459
|
-
if feature_meta.name in importances:
|
|
4460
|
-
self.logger.warning(f"WARNING: Duplicate feature metadata: {feature_meta}")
|
|
4461
|
-
continue
|
|
4462
|
-
|
|
4463
|
-
importances[feature_meta.name] = feature_meta.shap_value
|
|
4464
|
-
|
|
4465
|
-
return importances
|
|
4466
|
-
|
|
4467
4439
|
def __get_categorical_features(self) -> list[str]:
|
|
4468
4440
|
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4469
4441
|
if features_meta is None:
|
|
@@ -4473,7 +4445,6 @@ if response.status_code == 200:
|
|
|
4473
4445
|
|
|
4474
4446
|
def __prepare_feature_importances(
|
|
4475
4447
|
self,
|
|
4476
|
-
trace_id: str,
|
|
4477
4448
|
clients_features_df: pd.DataFrame,
|
|
4478
4449
|
updated_shaps: dict[str, float] | None = None,
|
|
4479
4450
|
update_selected_features: bool = True,
|
|
@@ -4481,16 +4452,16 @@ if response.status_code == 200:
|
|
|
4481
4452
|
):
|
|
4482
4453
|
if self._search_task is None:
|
|
4483
4454
|
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
4484
|
-
selected_features = self._search_task.get_selected_features(
|
|
4455
|
+
selected_features = self._search_task.get_selected_features(self._get_trace_id())
|
|
4485
4456
|
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4486
4457
|
if features_meta is None:
|
|
4487
4458
|
raise Exception(self.bundle.get("missing_features_meta"))
|
|
4488
4459
|
features_meta = deepcopy(features_meta)
|
|
4489
4460
|
|
|
4490
|
-
file_metadata_columns = self._search_task.get_file_metadata(
|
|
4461
|
+
file_metadata_columns = self._search_task.get_file_metadata(self._get_trace_id()).columns
|
|
4491
4462
|
file_meta_by_orig_name = {c.originalName: c for c in file_metadata_columns}
|
|
4492
4463
|
original_names_dict = {c.name: c.originalName for c in file_metadata_columns}
|
|
4493
|
-
features_df = self._search_task.get_all_initial_raw_features(
|
|
4464
|
+
features_df = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
|
|
4494
4465
|
|
|
4495
4466
|
# To be sure that names with hash suffixes
|
|
4496
4467
|
clients_features_df = clients_features_df.rename(columns=original_names_dict)
|
|
@@ -4581,7 +4552,7 @@ if response.status_code == 200:
|
|
|
4581
4552
|
internal_features_info.append(feature_info.to_internal_row(self.bundle))
|
|
4582
4553
|
|
|
4583
4554
|
if update_selected_features:
|
|
4584
|
-
self._search_task.update_selected_features(
|
|
4555
|
+
self._search_task.update_selected_features(self._get_trace_id(), self.feature_names_)
|
|
4585
4556
|
|
|
4586
4557
|
if len(features_info) > 0:
|
|
4587
4558
|
self.features_info = pd.DataFrame(features_info)
|
|
@@ -4779,12 +4750,17 @@ if response.status_code == 200:
|
|
|
4779
4750
|
):
|
|
4780
4751
|
raise ValidationError(self.bundle.get("empty_search_key").format(column_name))
|
|
4781
4752
|
|
|
4782
|
-
if
|
|
4783
|
-
|
|
4784
|
-
|
|
4785
|
-
|
|
4786
|
-
|
|
4787
|
-
|
|
4753
|
+
if is_transform:
|
|
4754
|
+
fit_autodetected_search_keys = self._get_autodetected_search_keys()
|
|
4755
|
+
if fit_autodetected_search_keys is not None:
|
|
4756
|
+
for key in fit_autodetected_search_keys.keys():
|
|
4757
|
+
if key not in x.columns:
|
|
4758
|
+
raise ValidationError(
|
|
4759
|
+
self.bundle.get("autodetected_search_key_not_found").format(key, x.columns)
|
|
4760
|
+
)
|
|
4761
|
+
valid_search_keys.update(fit_autodetected_search_keys)
|
|
4762
|
+
elif self.autodetect_search_keys:
|
|
4763
|
+
valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
|
|
4788
4764
|
|
|
4789
4765
|
if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
|
|
4790
4766
|
if self.__is_registered:
|
|
@@ -4829,7 +4805,6 @@ if response.status_code == 200:
|
|
|
4829
4805
|
scoring: Callable | str | None,
|
|
4830
4806
|
estimator: Any | None,
|
|
4831
4807
|
remove_outliers_calc_metrics: bool | None,
|
|
4832
|
-
trace_id: str,
|
|
4833
4808
|
progress_bar: ProgressBar | None = None,
|
|
4834
4809
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
4835
4810
|
):
|
|
@@ -4837,7 +4812,6 @@ if response.status_code == 200:
|
|
|
4837
4812
|
scoring=scoring,
|
|
4838
4813
|
estimator=estimator,
|
|
4839
4814
|
remove_outliers_calc_metrics=remove_outliers_calc_metrics,
|
|
4840
|
-
trace_id=trace_id,
|
|
4841
4815
|
internal_call=True,
|
|
4842
4816
|
progress_bar=progress_bar,
|
|
4843
4817
|
progress_callback=progress_callback,
|
|
@@ -4902,60 +4876,36 @@ if response.status_code == 200:
|
|
|
4902
4876
|
df: pd.DataFrame,
|
|
4903
4877
|
search_keys: dict[str, SearchKey],
|
|
4904
4878
|
is_demo_dataset: bool,
|
|
4905
|
-
silent_mode=False,
|
|
4906
|
-
is_transform=False,
|
|
4907
4879
|
) -> dict[str, SearchKey]:
|
|
4908
4880
|
sample = df.head(100)
|
|
4909
4881
|
|
|
4910
|
-
|
|
4911
|
-
return not is_transform or (
|
|
4912
|
-
search_key in self.fit_search_keys.values() and search_key not in search_keys.values()
|
|
4913
|
-
)
|
|
4914
|
-
|
|
4915
|
-
if (
|
|
4916
|
-
SearchKey.DATE not in search_keys.values()
|
|
4917
|
-
and SearchKey.DATETIME not in search_keys.values()
|
|
4918
|
-
and check_need_detect(SearchKey.DATE)
|
|
4919
|
-
and check_need_detect(SearchKey.DATETIME)
|
|
4920
|
-
):
|
|
4882
|
+
if SearchKey.DATE not in search_keys.values() and SearchKey.DATETIME not in search_keys.values():
|
|
4921
4883
|
maybe_keys = DateSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4922
4884
|
if len(maybe_keys) > 0:
|
|
4923
4885
|
datetime_key = maybe_keys[0]
|
|
4924
4886
|
search_keys[datetime_key] = SearchKey.DATETIME
|
|
4925
4887
|
self.autodetected_search_keys[datetime_key] = SearchKey.DATETIME
|
|
4926
4888
|
self.logger.info(f"Autodetected search key DATETIME in column {datetime_key}")
|
|
4927
|
-
|
|
4928
|
-
print(self.bundle.get("datetime_detected").format(datetime_key))
|
|
4889
|
+
print(self.bundle.get("datetime_detected").format(datetime_key))
|
|
4929
4890
|
|
|
4930
4891
|
# if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
|
|
4931
|
-
|
|
4932
|
-
|
|
4933
|
-
|
|
4934
|
-
|
|
4935
|
-
|
|
4936
|
-
|
|
4937
|
-
|
|
4938
|
-
|
|
4939
|
-
|
|
4940
|
-
|
|
4941
|
-
if (
|
|
4942
|
-
SearchKey.COUNTRY not in search_keys.values()
|
|
4943
|
-
and self.country_code is None
|
|
4944
|
-
and check_need_detect(SearchKey.COUNTRY)
|
|
4945
|
-
):
|
|
4892
|
+
maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4893
|
+
if maybe_keys:
|
|
4894
|
+
new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
|
|
4895
|
+
search_keys.update(new_keys)
|
|
4896
|
+
self.autodetected_search_keys.update(new_keys)
|
|
4897
|
+
self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
|
|
4898
|
+
print(self.bundle.get("postal_code_detected").format(maybe_keys))
|
|
4899
|
+
|
|
4900
|
+
if SearchKey.COUNTRY not in search_keys.values() and self.country_code is None:
|
|
4946
4901
|
maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4947
4902
|
if maybe_key:
|
|
4948
4903
|
search_keys[maybe_key[0]] = SearchKey.COUNTRY
|
|
4949
4904
|
self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
|
|
4950
4905
|
self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
|
|
4951
|
-
|
|
4952
|
-
print(self.bundle.get("country_detected").format(maybe_key))
|
|
4906
|
+
print(self.bundle.get("country_detected").format(maybe_key))
|
|
4953
4907
|
|
|
4954
|
-
if (
|
|
4955
|
-
# SearchKey.EMAIL not in search_keys.values()
|
|
4956
|
-
SearchKey.HEM not in search_keys.values()
|
|
4957
|
-
and check_need_detect(SearchKey.HEM)
|
|
4958
|
-
):
|
|
4908
|
+
if SearchKey.EMAIL not in search_keys.values() and SearchKey.HEM not in search_keys.values():
|
|
4959
4909
|
maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4960
4910
|
if maybe_keys:
|
|
4961
4911
|
if self.__is_registered or is_demo_dataset:
|
|
@@ -4963,34 +4913,28 @@ if response.status_code == 200:
|
|
|
4963
4913
|
search_keys.update(new_keys)
|
|
4964
4914
|
self.autodetected_search_keys.update(new_keys)
|
|
4965
4915
|
self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
|
|
4966
|
-
|
|
4967
|
-
print(self.bundle.get("email_detected").format(maybe_keys))
|
|
4916
|
+
print(self.bundle.get("email_detected").format(maybe_keys))
|
|
4968
4917
|
else:
|
|
4969
4918
|
self.logger.warning(
|
|
4970
4919
|
f"Autodetected search key EMAIL in column {maybe_keys}."
|
|
4971
4920
|
" But not used because not registered user"
|
|
4972
4921
|
)
|
|
4973
|
-
|
|
4974
|
-
self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
|
|
4922
|
+
self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
|
|
4975
4923
|
|
|
4976
4924
|
# if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
|
|
4977
|
-
|
|
4978
|
-
|
|
4979
|
-
if
|
|
4980
|
-
|
|
4981
|
-
|
|
4982
|
-
|
|
4983
|
-
|
|
4984
|
-
|
|
4985
|
-
|
|
4986
|
-
|
|
4987
|
-
|
|
4988
|
-
|
|
4989
|
-
|
|
4990
|
-
"But not used because not registered user"
|
|
4991
|
-
)
|
|
4992
|
-
if not silent_mode:
|
|
4993
|
-
self.__log_warning(self.bundle.get("phone_detected_not_registered"))
|
|
4925
|
+
maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4926
|
+
if maybe_keys:
|
|
4927
|
+
if self.__is_registered or is_demo_dataset:
|
|
4928
|
+
new_keys = {key: SearchKey.PHONE for key in maybe_keys}
|
|
4929
|
+
search_keys.update(new_keys)
|
|
4930
|
+
self.autodetected_search_keys.update(new_keys)
|
|
4931
|
+
self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
|
|
4932
|
+
print(self.bundle.get("phone_detected").format(maybe_keys))
|
|
4933
|
+
else:
|
|
4934
|
+
self.logger.warning(
|
|
4935
|
+
f"Autodetected search key PHONE in column {maybe_keys}. " "But not used because not registered user"
|
|
4936
|
+
)
|
|
4937
|
+
self.__log_warning(self.bundle.get("phone_detected_not_registered"))
|
|
4994
4938
|
|
|
4995
4939
|
return search_keys
|
|
4996
4940
|
|
|
@@ -5062,13 +5006,12 @@ if response.status_code == 200:
|
|
|
5062
5006
|
|
|
5063
5007
|
def dump_input(
|
|
5064
5008
|
self,
|
|
5065
|
-
trace_id: str,
|
|
5066
5009
|
X: pd.DataFrame | pd.Series,
|
|
5067
5010
|
y: pd.DataFrame | pd.Series | None = None,
|
|
5068
5011
|
eval_set: tuple | None = None,
|
|
5069
5012
|
):
|
|
5070
|
-
def dump_task(X_, y_, eval_set_):
|
|
5071
|
-
with MDC(correlation_id=
|
|
5013
|
+
def dump_task(X_, y_, eval_set_, trace_id_):
|
|
5014
|
+
with MDC(correlation_id=trace_id_):
|
|
5072
5015
|
try:
|
|
5073
5016
|
if isinstance(X_, pd.Series):
|
|
5074
5017
|
X_ = X_.to_frame()
|
|
@@ -5076,27 +5019,25 @@ if response.status_code == 200:
|
|
|
5076
5019
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
5077
5020
|
X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
|
|
5078
5021
|
x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
|
|
5079
|
-
if self.rest_client.is_file_uploaded(
|
|
5022
|
+
if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
|
|
5080
5023
|
self.logger.info(
|
|
5081
5024
|
f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
|
|
5082
5025
|
)
|
|
5083
5026
|
else:
|
|
5084
|
-
self.rest_client.dump_input_file(
|
|
5085
|
-
trace_id, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
|
|
5086
|
-
)
|
|
5027
|
+
self.rest_client.dump_input_file(f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256)
|
|
5087
5028
|
|
|
5088
5029
|
if y_ is not None:
|
|
5089
5030
|
if isinstance(y_, pd.Series):
|
|
5090
5031
|
y_ = y_.to_frame()
|
|
5091
5032
|
y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
|
|
5092
5033
|
y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
|
|
5093
|
-
if self.rest_client.is_file_uploaded(
|
|
5034
|
+
if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
|
|
5094
5035
|
self.logger.info(
|
|
5095
5036
|
f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
|
|
5096
5037
|
)
|
|
5097
5038
|
else:
|
|
5098
5039
|
self.rest_client.dump_input_file(
|
|
5099
|
-
|
|
5040
|
+
trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
|
|
5100
5041
|
)
|
|
5101
5042
|
|
|
5102
5043
|
if eval_set_ is not None and len(eval_set_) > 0:
|
|
@@ -5105,14 +5046,14 @@ if response.status_code == 200:
|
|
|
5105
5046
|
eval_x_ = eval_x_.to_frame()
|
|
5106
5047
|
eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
|
|
5107
5048
|
eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
|
|
5108
|
-
if self.rest_client.is_file_uploaded(
|
|
5049
|
+
if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
|
|
5109
5050
|
self.logger.info(
|
|
5110
5051
|
f"File eval_x_{idx}.parquet was already uploaded with"
|
|
5111
5052
|
f" digest {eval_x_digest_sha256}, skipping"
|
|
5112
5053
|
)
|
|
5113
5054
|
else:
|
|
5114
5055
|
self.rest_client.dump_input_file(
|
|
5115
|
-
|
|
5056
|
+
trace_id_,
|
|
5116
5057
|
f"{tmp_dir}/eval_x_{idx}.parquet",
|
|
5117
5058
|
f"eval_x_{idx}.parquet",
|
|
5118
5059
|
eval_x_digest_sha256,
|
|
@@ -5122,14 +5063,14 @@ if response.status_code == 200:
|
|
|
5122
5063
|
eval_y_ = eval_y_.to_frame()
|
|
5123
5064
|
eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
|
|
5124
5065
|
eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
|
|
5125
|
-
if self.rest_client.is_file_uploaded(
|
|
5066
|
+
if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
|
|
5126
5067
|
self.logger.info(
|
|
5127
5068
|
f"File eval_y_{idx}.parquet was already uploaded"
|
|
5128
5069
|
f" with digest {eval_y_digest_sha256}, skipping"
|
|
5129
5070
|
)
|
|
5130
5071
|
else:
|
|
5131
5072
|
self.rest_client.dump_input_file(
|
|
5132
|
-
|
|
5073
|
+
trace_id_,
|
|
5133
5074
|
f"{tmp_dir}/eval_y_{idx}.parquet",
|
|
5134
5075
|
f"eval_y_{idx}.parquet",
|
|
5135
5076
|
eval_y_digest_sha256,
|
|
@@ -5138,7 +5079,8 @@ if response.status_code == 200:
|
|
|
5138
5079
|
self.logger.warning("Failed to dump input files", exc_info=True)
|
|
5139
5080
|
|
|
5140
5081
|
try:
|
|
5141
|
-
|
|
5082
|
+
trace_id = self._get_trace_id()
|
|
5083
|
+
Thread(target=dump_task, args=(X, y, eval_set, trace_id), daemon=True).start()
|
|
5142
5084
|
except Exception:
|
|
5143
5085
|
self.logger.warning("Failed to dump input files", exc_info=True)
|
|
5144
5086
|
|