upgini 1.2.141__py3-none-any.whl → 1.2.142__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of upgini might be problematic. Click here for more details.
- upgini/__about__.py +1 -1
- upgini/dataset.py +8 -0
- upgini/features_enricher.py +502 -559
- upgini/metadata.py +2 -1
- upgini/normalizer/normalize_utils.py +1 -1
- upgini/resource_bundle/strings.properties +10 -9
- upgini/utils/datetime_utils.py +7 -4
- {upgini-1.2.141.dist-info → upgini-1.2.142.dist-info}/METADATA +1 -1
- {upgini-1.2.141.dist-info → upgini-1.2.142.dist-info}/RECORD +11 -11
- {upgini-1.2.141.dist-info → upgini-1.2.142.dist-info}/WHEEL +0 -0
- {upgini-1.2.141.dist-info → upgini-1.2.142.dist-info}/licenses/LICENSE +0 -0
upgini/features_enricher.py
CHANGED
|
@@ -42,6 +42,7 @@ from upgini.http import (
|
|
|
42
42
|
get_rest_client,
|
|
43
43
|
)
|
|
44
44
|
from upgini.mdc import MDC
|
|
45
|
+
from upgini.mdc.context import get_mdc_fields
|
|
45
46
|
from upgini.metadata import (
|
|
46
47
|
COUNTRY,
|
|
47
48
|
CURRENT_DATE_COL,
|
|
@@ -273,7 +274,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
273
274
|
self.X: pd.DataFrame | None = None
|
|
274
275
|
self.y: pd.Series | None = None
|
|
275
276
|
self.eval_set: list[tuple] | None = None
|
|
276
|
-
self.autodetected_search_keys: dict[str, SearchKey] =
|
|
277
|
+
self.autodetected_search_keys: dict[str, SearchKey] = dict()
|
|
277
278
|
self.imbalanced = False
|
|
278
279
|
self.fit_select_features = True
|
|
279
280
|
self.__cached_sampled_datasets: dict[str, tuple[pd.DataFrame, pd.DataFrame, pd.Series, dict, dict, dict]] = (
|
|
@@ -309,8 +310,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
309
310
|
search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
|
|
310
311
|
|
|
311
312
|
print(self.bundle.get("search_by_task_id_start"))
|
|
312
|
-
trace_id =
|
|
313
|
-
with MDC(correlation_id=trace_id):
|
|
313
|
+
trace_id = self._get_trace_id()
|
|
314
|
+
with MDC(correlation_id=trace_id, search_task_id=search_id):
|
|
314
315
|
try:
|
|
315
316
|
self.logger.debug(f"FeaturesEnricher created from existing search: {search_id}")
|
|
316
317
|
self._search_task = search_task.poll_result(trace_id, quiet=True, check_fit=True)
|
|
@@ -318,7 +319,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
318
319
|
x_columns = [c.name for c in file_metadata.columns]
|
|
319
320
|
self.fit_columns_renaming = {c.name: c.originalName for c in file_metadata.columns}
|
|
320
321
|
df = pd.DataFrame(columns=x_columns)
|
|
321
|
-
self.__prepare_feature_importances(
|
|
322
|
+
self.__prepare_feature_importances(df, silent=True, update_selected_features=False)
|
|
322
323
|
if print_loaded_report:
|
|
323
324
|
self.__show_selected_features()
|
|
324
325
|
# TODO validate search_keys with search_keys from file_metadata
|
|
@@ -487,7 +488,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
487
488
|
stability_agg_func: str, optional (default="max")
|
|
488
489
|
Function to aggregate stability values. Can be "max", "min", "mean".
|
|
489
490
|
"""
|
|
490
|
-
trace_id =
|
|
491
|
+
trace_id = self._get_trace_id()
|
|
491
492
|
if self.print_trace_id:
|
|
492
493
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
493
494
|
start_time = time.time()
|
|
@@ -522,10 +523,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
522
523
|
self.X = X
|
|
523
524
|
self.y = y
|
|
524
525
|
self.eval_set = self._check_eval_set(eval_set, X)
|
|
525
|
-
self.dump_input(
|
|
526
|
+
self.dump_input(X, y, self.eval_set)
|
|
526
527
|
self.__set_select_features(select_features)
|
|
527
528
|
self.__inner_fit(
|
|
528
|
-
trace_id,
|
|
529
529
|
X,
|
|
530
530
|
y,
|
|
531
531
|
self.eval_set,
|
|
@@ -646,7 +646,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
646
646
|
|
|
647
647
|
self.warning_counter.reset()
|
|
648
648
|
auto_fe_parameters = AutoFEParameters() if auto_fe_parameters is None else auto_fe_parameters
|
|
649
|
-
trace_id =
|
|
649
|
+
trace_id = self._get_trace_id()
|
|
650
650
|
if self.print_trace_id:
|
|
651
651
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
652
652
|
start_time = time.time()
|
|
@@ -677,13 +677,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
677
677
|
self.y = y
|
|
678
678
|
self.eval_set = self._check_eval_set(eval_set, X)
|
|
679
679
|
self.__set_select_features(select_features)
|
|
680
|
-
self.dump_input(
|
|
680
|
+
self.dump_input(X, y, self.eval_set)
|
|
681
681
|
|
|
682
682
|
if _num_samples(drop_duplicates(X)) > Dataset.MAX_ROWS:
|
|
683
683
|
raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(Dataset.MAX_ROWS))
|
|
684
684
|
|
|
685
685
|
self.__inner_fit(
|
|
686
|
-
trace_id,
|
|
687
686
|
X,
|
|
688
687
|
y,
|
|
689
688
|
self.eval_set,
|
|
@@ -738,7 +737,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
738
737
|
y,
|
|
739
738
|
exclude_features_sources=exclude_features_sources,
|
|
740
739
|
keep_input=keep_input,
|
|
741
|
-
trace_id=trace_id,
|
|
742
740
|
silent_mode=True,
|
|
743
741
|
progress_bar=progress_bar,
|
|
744
742
|
progress_callback=progress_callback,
|
|
@@ -753,7 +751,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
753
751
|
*args,
|
|
754
752
|
exclude_features_sources: list[str] | None = None,
|
|
755
753
|
keep_input: bool = True,
|
|
756
|
-
trace_id: str | None = None,
|
|
757
754
|
silent_mode=False,
|
|
758
755
|
progress_bar: ProgressBar | None = None,
|
|
759
756
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
@@ -790,12 +787,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
790
787
|
progress_bar.progress = search_progress.to_progress_bar()
|
|
791
788
|
if new_progress:
|
|
792
789
|
progress_bar.display()
|
|
793
|
-
trace_id =
|
|
790
|
+
trace_id = self._get_trace_id()
|
|
794
791
|
if self.print_trace_id:
|
|
795
792
|
print(f"https://app.datadoghq.eu/logs?query=%40correlation_id%3A{trace_id}")
|
|
796
793
|
search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
|
|
797
794
|
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
798
|
-
self.dump_input(
|
|
795
|
+
self.dump_input(X)
|
|
799
796
|
if len(args) > 0:
|
|
800
797
|
msg = f"WARNING: Unsupported positional arguments for transform: {args}"
|
|
801
798
|
self.logger.warning(msg)
|
|
@@ -808,7 +805,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
808
805
|
start_time = time.time()
|
|
809
806
|
try:
|
|
810
807
|
result, _, _, _ = self.__inner_transform(
|
|
811
|
-
trace_id,
|
|
812
808
|
X,
|
|
813
809
|
y=y,
|
|
814
810
|
exclude_features_sources=exclude_features_sources,
|
|
@@ -872,7 +868,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
872
868
|
estimator=None,
|
|
873
869
|
exclude_features_sources: list[str] | None = None,
|
|
874
870
|
remove_outliers_calc_metrics: bool | None = None,
|
|
875
|
-
trace_id: str | None = None,
|
|
876
871
|
internal_call: bool = False,
|
|
877
872
|
progress_bar: ProgressBar | None = None,
|
|
878
873
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
@@ -910,7 +905,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
910
905
|
Dataframe with metrics calculated on train and validation datasets.
|
|
911
906
|
"""
|
|
912
907
|
|
|
913
|
-
trace_id =
|
|
908
|
+
trace_id = self._get_trace_id()
|
|
914
909
|
start_time = time.time()
|
|
915
910
|
search_id = self.search_id or (self._search_task.search_task_id if self._search_task is not None else None)
|
|
916
911
|
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
@@ -943,7 +938,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
943
938
|
raise ValidationError(self.bundle.get("metrics_unfitted_enricher"))
|
|
944
939
|
|
|
945
940
|
validated_X, validated_y, validated_eval_set = self._validate_train_eval(
|
|
946
|
-
effective_X, effective_y, effective_eval_set
|
|
941
|
+
effective_X, effective_y, effective_eval_set
|
|
947
942
|
)
|
|
948
943
|
|
|
949
944
|
if self.X is None:
|
|
@@ -978,11 +973,13 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
978
973
|
self.__display_support_link(msg)
|
|
979
974
|
return None
|
|
980
975
|
|
|
976
|
+
search_keys = self._get_fit_search_keys_with_original_names()
|
|
977
|
+
|
|
981
978
|
cat_features_from_backend = self.__get_categorical_features()
|
|
982
979
|
# Convert to original names
|
|
983
980
|
cat_features_from_backend = [self.fit_columns_renaming.get(c, c) for c in cat_features_from_backend]
|
|
984
981
|
client_cat_features, search_keys_for_metrics = self._get_and_validate_client_cat_features(
|
|
985
|
-
estimator, validated_X,
|
|
982
|
+
estimator, validated_X, search_keys
|
|
986
983
|
)
|
|
987
984
|
# Exclude id columns from cat_features
|
|
988
985
|
if self.id_columns and self.id_columns_encoder is not None:
|
|
@@ -1004,7 +1001,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1004
1001
|
self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
|
|
1005
1002
|
|
|
1006
1003
|
prepared_data = self._get_cached_enriched_data(
|
|
1007
|
-
trace_id=trace_id,
|
|
1008
1004
|
X=X,
|
|
1009
1005
|
y=y,
|
|
1010
1006
|
eval_set=eval_set,
|
|
@@ -1257,7 +1253,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1257
1253
|
|
|
1258
1254
|
if updating_shaps is not None:
|
|
1259
1255
|
decoded_X = self._decode_id_columns(fitting_X)
|
|
1260
|
-
self._update_shap_values(
|
|
1256
|
+
self._update_shap_values(decoded_X, updating_shaps, silent=not internal_call)
|
|
1261
1257
|
|
|
1262
1258
|
metrics_df = pd.DataFrame(metrics)
|
|
1263
1259
|
mean_target_hdr = self.bundle.get("quality_metrics_mean_target_header")
|
|
@@ -1307,9 +1303,33 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1307
1303
|
finally:
|
|
1308
1304
|
self.logger.info(f"Calculating metrics elapsed time: {time.time() - start_time}")
|
|
1309
1305
|
|
|
1306
|
+
def _get_trace_id(self):
|
|
1307
|
+
if get_mdc_fields().get("correlation_id") is not None:
|
|
1308
|
+
return get_mdc_fields().get("correlation_id")
|
|
1309
|
+
return int(time.time() * 1000)
|
|
1310
|
+
|
|
1311
|
+
def _get_autodetected_search_keys(self):
|
|
1312
|
+
if self.autodetected_search_keys is None and self._search_task is not None:
|
|
1313
|
+
meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1314
|
+
self.autodetected_search_keys = {k: SearchKey[v] for k, v in meta.autodetectedSearchKeys.items()}
|
|
1315
|
+
|
|
1316
|
+
return self.autodetected_search_keys
|
|
1317
|
+
|
|
1318
|
+
def _get_fit_search_keys_with_original_names(self):
|
|
1319
|
+
if self.fit_search_keys is None and self._search_task is not None:
|
|
1320
|
+
fit_search_keys = dict()
|
|
1321
|
+
meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1322
|
+
for column in meta.columns:
|
|
1323
|
+
# TODO check for EMAIL->HEM and multikeys
|
|
1324
|
+
search_key_type = SearchKey.from_meaning_type(column.meaningType)
|
|
1325
|
+
if search_key_type is not None:
|
|
1326
|
+
fit_search_keys[column.originalName] = search_key_type
|
|
1327
|
+
else:
|
|
1328
|
+
fit_search_keys = {self.fit_columns_renaming.get(k, k): v for k, v in self.fit_search_keys.items()}
|
|
1329
|
+
return fit_search_keys
|
|
1330
|
+
|
|
1310
1331
|
def _select_features_by_psi(
|
|
1311
1332
|
self,
|
|
1312
|
-
trace_id: str,
|
|
1313
1333
|
X: pd.DataFrame | pd.Series | np.ndarray,
|
|
1314
1334
|
y: pd.DataFrame | pd.Series | np.ndarray | list,
|
|
1315
1335
|
eval_set: list[tuple] | tuple | None,
|
|
@@ -1322,7 +1342,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1322
1342
|
progress_callback: Callable | None = None,
|
|
1323
1343
|
):
|
|
1324
1344
|
search_keys = self.search_keys.copy()
|
|
1325
|
-
|
|
1345
|
+
search_keys.update(self._get_autodetected_search_keys())
|
|
1346
|
+
validated_X, _, validated_eval_set = self._validate_train_eval(X, y, eval_set)
|
|
1326
1347
|
if isinstance(X, np.ndarray):
|
|
1327
1348
|
search_keys = {str(k): v for k, v in search_keys.items()}
|
|
1328
1349
|
|
|
@@ -1357,7 +1378,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1357
1378
|
]
|
|
1358
1379
|
|
|
1359
1380
|
prepared_data = self._get_cached_enriched_data(
|
|
1360
|
-
trace_id=trace_id,
|
|
1361
1381
|
X=X,
|
|
1362
1382
|
y=y,
|
|
1363
1383
|
eval_set=eval_set,
|
|
@@ -1506,7 +1526,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1506
1526
|
|
|
1507
1527
|
return total_unstable_features
|
|
1508
1528
|
|
|
1509
|
-
def _update_shap_values(self,
|
|
1529
|
+
def _update_shap_values(self, df: pd.DataFrame, new_shaps: dict[str, float], silent: bool = False):
|
|
1510
1530
|
renaming = self.fit_columns_renaming or {}
|
|
1511
1531
|
self.logger.info(f"Updating SHAP values: {new_shaps}")
|
|
1512
1532
|
new_shaps = {
|
|
@@ -1514,7 +1534,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1514
1534
|
for feature, shap in new_shaps.items()
|
|
1515
1535
|
if feature in self.feature_names_ or renaming.get(feature, feature) in self.feature_names_
|
|
1516
1536
|
}
|
|
1517
|
-
self.__prepare_feature_importances(
|
|
1537
|
+
self.__prepare_feature_importances(df, new_shaps)
|
|
1518
1538
|
|
|
1519
1539
|
if not silent and self.features_info_display_handle is not None:
|
|
1520
1540
|
try:
|
|
@@ -1694,7 +1714,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1694
1714
|
|
|
1695
1715
|
def _get_cached_enriched_data(
|
|
1696
1716
|
self,
|
|
1697
|
-
trace_id: str,
|
|
1698
1717
|
X: pd.DataFrame | pd.Series | np.ndarray | None = None,
|
|
1699
1718
|
y: pd.DataFrame | pd.Series | np.ndarray | list | None = None,
|
|
1700
1719
|
eval_set: list[tuple] | tuple | None = None,
|
|
@@ -1710,10 +1729,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1710
1729
|
is_input_same_as_fit, X, y, eval_set = self._is_input_same_as_fit(X, y, eval_set)
|
|
1711
1730
|
is_demo_dataset = hash_input(X, y, eval_set) in DEMO_DATASET_HASHES
|
|
1712
1731
|
checked_eval_set = self._check_eval_set(eval_set, X)
|
|
1713
|
-
validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set
|
|
1732
|
+
validated_X, validated_y, validated_eval_set = self._validate_train_eval(X, y, checked_eval_set)
|
|
1714
1733
|
|
|
1715
1734
|
sampled_data = self._get_enriched_datasets(
|
|
1716
|
-
trace_id=trace_id,
|
|
1717
1735
|
validated_X=validated_X,
|
|
1718
1736
|
validated_y=validated_y,
|
|
1719
1737
|
eval_set=validated_eval_set,
|
|
@@ -1740,7 +1758,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1740
1758
|
|
|
1741
1759
|
self.logger.info(f"Excluding search keys: {excluding_search_keys}")
|
|
1742
1760
|
|
|
1743
|
-
file_meta = self._search_task.get_file_metadata(
|
|
1761
|
+
file_meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
1744
1762
|
fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
|
|
1745
1763
|
original_dropped_features = [columns_renaming.get(f, f) for f in fit_dropped_features]
|
|
1746
1764
|
|
|
@@ -1917,7 +1935,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1917
1935
|
|
|
1918
1936
|
def _get_enriched_datasets(
|
|
1919
1937
|
self,
|
|
1920
|
-
trace_id: str,
|
|
1921
1938
|
validated_X: pd.DataFrame | pd.Series | np.ndarray | None,
|
|
1922
1939
|
validated_y: pd.DataFrame | pd.Series | np.ndarray | list | None,
|
|
1923
1940
|
eval_set: list[tuple] | None,
|
|
@@ -1945,9 +1962,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1945
1962
|
and self.df_with_original_index is not None
|
|
1946
1963
|
):
|
|
1947
1964
|
self.logger.info("Dataset is not imbalanced, so use enriched_X from fit")
|
|
1948
|
-
return self.__get_enriched_from_fit(
|
|
1949
|
-
validated_X, validated_y, eval_set, trace_id, remove_outliers_calc_metrics
|
|
1950
|
-
)
|
|
1965
|
+
return self.__get_enriched_from_fit(validated_X, validated_y, eval_set, remove_outliers_calc_metrics)
|
|
1951
1966
|
else:
|
|
1952
1967
|
self.logger.info(
|
|
1953
1968
|
"Dataset is imbalanced or exclude_features_sources or X was passed or this is saved search."
|
|
@@ -1959,7 +1974,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1959
1974
|
validated_y,
|
|
1960
1975
|
eval_set,
|
|
1961
1976
|
exclude_features_sources,
|
|
1962
|
-
trace_id,
|
|
1963
1977
|
progress_bar,
|
|
1964
1978
|
progress_callback,
|
|
1965
1979
|
is_for_metrics=is_for_metrics,
|
|
@@ -2088,7 +2102,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2088
2102
|
validated_X: pd.DataFrame,
|
|
2089
2103
|
validated_y: pd.Series,
|
|
2090
2104
|
eval_set: list[tuple] | None,
|
|
2091
|
-
trace_id: str,
|
|
2092
2105
|
remove_outliers_calc_metrics: bool | None,
|
|
2093
2106
|
) -> _EnrichedDataForMetrics:
|
|
2094
2107
|
eval_set_sampled_dict = {}
|
|
@@ -2103,7 +2116,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2103
2116
|
if remove_outliers_calc_metrics is None:
|
|
2104
2117
|
remove_outliers_calc_metrics = True
|
|
2105
2118
|
if self.model_task_type == ModelTaskType.REGRESSION and remove_outliers_calc_metrics:
|
|
2106
|
-
target_outliers_df = self._search_task.get_target_outliers(
|
|
2119
|
+
target_outliers_df = self._search_task.get_target_outliers(self._get_trace_id())
|
|
2107
2120
|
if target_outliers_df is not None and len(target_outliers_df) > 0:
|
|
2108
2121
|
outliers = pd.merge(
|
|
2109
2122
|
self.df_with_original_index,
|
|
@@ -2120,7 +2133,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2120
2133
|
|
|
2121
2134
|
# index in each dataset (X, eval set) may be reordered and non unique, but index in validated datasets
|
|
2122
2135
|
# can differs from it
|
|
2123
|
-
fit_features = self._search_task.get_all_initial_raw_features(
|
|
2136
|
+
fit_features = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
|
|
2124
2137
|
|
|
2125
2138
|
# Pre-process features if we need to drop outliers
|
|
2126
2139
|
if rows_to_drop is not None:
|
|
@@ -2146,7 +2159,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2146
2159
|
validated_Xy[TARGET] = validated_y
|
|
2147
2160
|
|
|
2148
2161
|
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2149
|
-
validated_Xy, self.fit_generated_features, keep_input=True
|
|
2162
|
+
validated_Xy, self.fit_generated_features, keep_input=True
|
|
2150
2163
|
)
|
|
2151
2164
|
selecting_columns.extend(
|
|
2152
2165
|
c
|
|
@@ -2211,7 +2224,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2211
2224
|
validated_y: pd.Series,
|
|
2212
2225
|
eval_set: list[tuple] | None,
|
|
2213
2226
|
exclude_features_sources: list[str] | None,
|
|
2214
|
-
trace_id: str,
|
|
2215
2227
|
progress_bar: ProgressBar | None,
|
|
2216
2228
|
progress_callback: Callable[[SearchProgress], Any] | None,
|
|
2217
2229
|
is_for_metrics: bool = False,
|
|
@@ -2237,7 +2249,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2237
2249
|
|
|
2238
2250
|
# Transform
|
|
2239
2251
|
enriched_df, columns_renaming, generated_features, search_keys = self.__inner_transform(
|
|
2240
|
-
trace_id,
|
|
2241
2252
|
X=df.drop(columns=[TARGET]),
|
|
2242
2253
|
y=df[TARGET],
|
|
2243
2254
|
exclude_features_sources=exclude_features_sources,
|
|
@@ -2414,11 +2425,10 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2414
2425
|
|
|
2415
2426
|
return self.features_info
|
|
2416
2427
|
|
|
2417
|
-
def get_progress(self,
|
|
2428
|
+
def get_progress(self, search_task: SearchTask | None = None) -> SearchProgress:
|
|
2418
2429
|
search_task = search_task or self._search_task
|
|
2419
2430
|
if search_task is not None:
|
|
2420
|
-
|
|
2421
|
-
return search_task.get_progress(trace_id)
|
|
2431
|
+
return search_task.get_progress(self._get_trace_id())
|
|
2422
2432
|
|
|
2423
2433
|
def display_transactional_transform_api(self, only_online_sources=False):
|
|
2424
2434
|
if self.api_key is None:
|
|
@@ -2519,7 +2529,6 @@ if response.status_code == 200:
|
|
|
2519
2529
|
|
|
2520
2530
|
def __inner_transform(
|
|
2521
2531
|
self,
|
|
2522
|
-
trace_id: str,
|
|
2523
2532
|
X: pd.DataFrame,
|
|
2524
2533
|
*,
|
|
2525
2534
|
y: pd.Series | None = None,
|
|
@@ -2538,182 +2547,135 @@ if response.status_code == 200:
|
|
|
2538
2547
|
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
2539
2548
|
|
|
2540
2549
|
start_time = time.time()
|
|
2541
|
-
|
|
2542
|
-
with MDC(correlation_id=trace_id, search_id=search_id):
|
|
2543
|
-
self.logger.info("Start transform")
|
|
2550
|
+
self.logger.info("Start transform")
|
|
2544
2551
|
|
|
2545
|
-
|
|
2546
|
-
X, y, eval_set=None, is_transform=True, silent=True
|
|
2547
|
-
)
|
|
2548
|
-
df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
|
|
2552
|
+
search_keys = self.search_keys.copy()
|
|
2549
2553
|
|
|
2550
|
-
|
|
2554
|
+
self.__validate_search_keys(search_keys, self.search_id)
|
|
2551
2555
|
|
|
2552
|
-
|
|
2556
|
+
validated_X, validated_y, validated_eval_set = self._validate_train_eval(
|
|
2557
|
+
X, y, eval_set=None, is_transform=True
|
|
2558
|
+
)
|
|
2559
|
+
df = self.__combine_train_and_eval_sets(validated_X, validated_y, validated_eval_set)
|
|
2553
2560
|
|
|
2554
|
-
|
|
2555
|
-
if len(self.feature_names_) == 0:
|
|
2556
|
-
msg = self.bundle.get("no_important_features_for_transform")
|
|
2557
|
-
self.__log_warning(msg, show_support_link=True)
|
|
2558
|
-
return None, {}, [], self.search_keys
|
|
2561
|
+
validated_Xy = df.copy()
|
|
2559
2562
|
|
|
2560
|
-
|
|
2563
|
+
self.__log_debug_information(validated_X, validated_y, exclude_features_sources=exclude_features_sources)
|
|
2561
2564
|
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2565
|
+
# If there are no important features, return original dataframe
|
|
2566
|
+
if len(self.feature_names_) == 0:
|
|
2567
|
+
msg = self.bundle.get("no_important_features_for_transform")
|
|
2568
|
+
self.__log_warning(msg, show_support_link=True)
|
|
2569
|
+
return None, {}, [], search_keys
|
|
2567
2570
|
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
msg = self.bundle.get("online_api_features_transform").format(online_api_features)
|
|
2574
|
-
self.logger.warning(msg)
|
|
2575
|
-
print(msg)
|
|
2576
|
-
self.display_transactional_transform_api(only_online_sources=True)
|
|
2577
|
-
|
|
2578
|
-
if not metrics_calculation:
|
|
2579
|
-
transform_usage = self.rest_client.get_current_transform_usage(trace_id)
|
|
2580
|
-
self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
|
|
2581
|
-
if transform_usage.has_limit:
|
|
2582
|
-
if len(X) > transform_usage.rest_rows:
|
|
2583
|
-
rest_rows = max(transform_usage.rest_rows, 0)
|
|
2584
|
-
bundle_msg = (
|
|
2585
|
-
"transform_usage_warning_registered"
|
|
2586
|
-
if self.__is_registered
|
|
2587
|
-
else "transform_usage_warning_demo"
|
|
2588
|
-
)
|
|
2589
|
-
msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
|
|
2590
|
-
self.logger.warning(msg)
|
|
2591
|
-
print(msg)
|
|
2592
|
-
show_request_quote_button(is_registered=self.__is_registered)
|
|
2593
|
-
return None, {}, [], {}
|
|
2594
|
-
else:
|
|
2595
|
-
msg = self.bundle.get("transform_usage_info").format(
|
|
2596
|
-
transform_usage.limit, transform_usage.transformed_rows
|
|
2597
|
-
)
|
|
2598
|
-
self.logger.info(msg)
|
|
2599
|
-
print(msg)
|
|
2571
|
+
if self._has_paid_features(exclude_features_sources):
|
|
2572
|
+
msg = self.bundle.get("transform_with_paid_features")
|
|
2573
|
+
self.logger.warning(msg)
|
|
2574
|
+
self.__display_support_link(msg)
|
|
2575
|
+
return None, {}, [], search_keys
|
|
2600
2576
|
|
|
2601
|
-
|
|
2577
|
+
online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
|
|
2578
|
+
if len(online_api_features) > 0:
|
|
2579
|
+
self.logger.warning(
|
|
2580
|
+
f"There are important features for transform, that generated by online API: {online_api_features}"
|
|
2581
|
+
)
|
|
2582
|
+
msg = self.bundle.get("online_api_features_transform").format(online_api_features)
|
|
2583
|
+
self.logger.warning(msg)
|
|
2584
|
+
print(msg)
|
|
2585
|
+
self.display_transactional_transform_api(only_online_sources=True)
|
|
2586
|
+
|
|
2587
|
+
if not metrics_calculation:
|
|
2588
|
+
transform_usage = self.rest_client.get_current_transform_usage(self._get_trace_id())
|
|
2589
|
+
self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
|
|
2590
|
+
if transform_usage.has_limit:
|
|
2591
|
+
if len(X) > transform_usage.rest_rows:
|
|
2592
|
+
rest_rows = max(transform_usage.rest_rows, 0)
|
|
2593
|
+
bundle_msg = (
|
|
2594
|
+
"transform_usage_warning_registered" if self.__is_registered else "transform_usage_warning_demo"
|
|
2595
|
+
)
|
|
2596
|
+
msg = self.bundle.get(bundle_msg).format(rest_rows, len(X))
|
|
2597
|
+
self.logger.warning(msg)
|
|
2598
|
+
print(msg)
|
|
2599
|
+
show_request_quote_button(is_registered=self.__is_registered)
|
|
2600
|
+
return None, {}, [], {}
|
|
2601
|
+
else:
|
|
2602
|
+
msg = self.bundle.get("transform_usage_info").format(
|
|
2603
|
+
transform_usage.limit, transform_usage.transformed_rows
|
|
2604
|
+
)
|
|
2605
|
+
self.logger.info(msg)
|
|
2606
|
+
print(msg)
|
|
2602
2607
|
|
|
2603
|
-
|
|
2604
|
-
c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
|
|
2605
|
-
]
|
|
2606
|
-
if len(columns_to_drop) > 0:
|
|
2607
|
-
msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
|
|
2608
|
-
self.logger.warning(msg)
|
|
2609
|
-
print(msg)
|
|
2610
|
-
df = df.drop(columns=columns_to_drop)
|
|
2608
|
+
is_demo_dataset = hash_input(df) in DEMO_DATASET_HASHES
|
|
2611
2609
|
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2610
|
+
columns_to_drop = [
|
|
2611
|
+
c for c in df.columns if c in self.feature_names_ and c in self.external_source_feature_names
|
|
2612
|
+
]
|
|
2613
|
+
if len(columns_to_drop) > 0:
|
|
2614
|
+
msg = self.bundle.get("x_contains_enriching_columns").format(columns_to_drop)
|
|
2615
|
+
self.logger.warning(msg)
|
|
2616
|
+
print(msg)
|
|
2617
|
+
df = df.drop(columns=columns_to_drop)
|
|
2617
2618
|
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
)
|
|
2619
|
+
if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
|
|
2620
|
+
search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns if col not in search_keys})
|
|
2621
2621
|
|
|
2622
|
-
|
|
2622
|
+
search_keys = self.__prepare_search_keys(
|
|
2623
|
+
df, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
|
|
2624
|
+
)
|
|
2623
2625
|
|
|
2624
|
-
|
|
2625
|
-
msg = self.bundle.get("unsupported_index_column")
|
|
2626
|
-
self.logger.info(msg)
|
|
2627
|
-
print(msg)
|
|
2628
|
-
df.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2629
|
-
validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2626
|
+
df = self.__handle_index_search_keys(df, search_keys)
|
|
2630
2627
|
|
|
2631
|
-
|
|
2628
|
+
if DEFAULT_INDEX in df.columns:
|
|
2629
|
+
msg = self.bundle.get("unsupported_index_column")
|
|
2630
|
+
self.logger.info(msg)
|
|
2631
|
+
print(msg)
|
|
2632
|
+
df.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2633
|
+
validated_Xy.drop(columns=DEFAULT_INDEX, inplace=True)
|
|
2632
2634
|
|
|
2633
|
-
|
|
2634
|
-
date_column = self._get_date_column(search_keys)
|
|
2635
|
-
if date_column is not None:
|
|
2636
|
-
converter = DateTimeConverter(
|
|
2637
|
-
date_column,
|
|
2638
|
-
self.date_format,
|
|
2639
|
-
self.logger,
|
|
2640
|
-
bundle=self.bundle,
|
|
2641
|
-
generate_cyclical_features=self.generate_search_key_features,
|
|
2642
|
-
)
|
|
2643
|
-
df = converter.convert(df, keep_time=True)
|
|
2644
|
-
self.logger.info(f"Date column after convertion: {df[date_column]}")
|
|
2645
|
-
generated_features.extend(converter.generated_features)
|
|
2646
|
-
else:
|
|
2647
|
-
self.logger.info("Input dataset hasn't date column")
|
|
2648
|
-
if self.__should_add_date_column():
|
|
2649
|
-
df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
|
|
2650
|
-
|
|
2651
|
-
email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
|
|
2652
|
-
if email_columns and self.generate_search_key_features:
|
|
2653
|
-
generator = EmailDomainGenerator(email_columns)
|
|
2654
|
-
df = generator.generate(df)
|
|
2655
|
-
generated_features.extend(generator.generated_features)
|
|
2656
|
-
|
|
2657
|
-
normalizer = Normalizer(self.bundle, self.logger)
|
|
2658
|
-
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2659
|
-
columns_renaming = normalizer.columns_renaming
|
|
2660
|
-
|
|
2661
|
-
# If there are no external features, we don't call backend on transform
|
|
2662
|
-
external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
|
|
2663
|
-
if len(external_features) == 0:
|
|
2664
|
-
self.logger.warning(
|
|
2665
|
-
"No external features found, returning original dataframe"
|
|
2666
|
-
f" with generated important features: {self.feature_names_}"
|
|
2667
|
-
)
|
|
2668
|
-
df = df.rename(columns=columns_renaming)
|
|
2669
|
-
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2670
|
-
search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
|
|
2671
|
-
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2672
|
-
validated_Xy, generated_features, keep_input, trace_id, is_transform=True
|
|
2673
|
-
)
|
|
2674
|
-
self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
|
|
2675
|
-
if add_fit_system_record_id:
|
|
2676
|
-
df = self._add_fit_system_record_id(
|
|
2677
|
-
df,
|
|
2678
|
-
search_keys,
|
|
2679
|
-
SYSTEM_RECORD_ID,
|
|
2680
|
-
TARGET,
|
|
2681
|
-
columns_renaming,
|
|
2682
|
-
self.id_columns,
|
|
2683
|
-
self.cv,
|
|
2684
|
-
self.model_task_type,
|
|
2685
|
-
self.logger,
|
|
2686
|
-
self.bundle,
|
|
2687
|
-
)
|
|
2688
|
-
selecting_columns.append(SYSTEM_RECORD_ID)
|
|
2689
|
-
return df[selecting_columns], columns_renaming, generated_features, search_keys
|
|
2690
|
-
|
|
2691
|
-
# Don't pass all features in backend on transform
|
|
2692
|
-
runtime_parameters = self._get_copy_of_runtime_parameters()
|
|
2693
|
-
features_for_transform = self._search_task.get_features_for_transform()
|
|
2694
|
-
if features_for_transform:
|
|
2695
|
-
missing_features_for_transform = [
|
|
2696
|
-
columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
|
|
2697
|
-
]
|
|
2698
|
-
if TARGET in missing_features_for_transform:
|
|
2699
|
-
raise ValidationError(self.bundle.get("missing_target_for_transform"))
|
|
2635
|
+
df = self.__add_country_code(df, search_keys)
|
|
2700
2636
|
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
|
|
2637
|
+
generated_features = []
|
|
2638
|
+
date_column = self._get_date_column(search_keys)
|
|
2639
|
+
if date_column is not None:
|
|
2640
|
+
converter = DateTimeConverter(
|
|
2641
|
+
date_column,
|
|
2642
|
+
self.date_format,
|
|
2643
|
+
self.logger,
|
|
2644
|
+
bundle=self.bundle,
|
|
2645
|
+
generate_cyclical_features=self.generate_search_key_features,
|
|
2646
|
+
)
|
|
2647
|
+
df = converter.convert(df, keep_time=True)
|
|
2648
|
+
self.logger.info(f"Date column after convertion: {df[date_column]}")
|
|
2649
|
+
generated_features.extend(converter.generated_features)
|
|
2650
|
+
else:
|
|
2651
|
+
self.logger.info("Input dataset hasn't date column")
|
|
2652
|
+
if self.__should_add_date_column():
|
|
2653
|
+
df = self._add_current_date_as_key(df, search_keys, self.bundle, silent=True)
|
|
2709
2654
|
|
|
2710
|
-
|
|
2655
|
+
email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
|
|
2656
|
+
if email_columns and self.generate_search_key_features:
|
|
2657
|
+
generator = EmailDomainGenerator(email_columns)
|
|
2658
|
+
df = generator.generate(df)
|
|
2659
|
+
generated_features.extend(generator.generated_features)
|
|
2711
2660
|
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2661
|
+
normalizer = Normalizer(self.bundle, self.logger)
|
|
2662
|
+
df, search_keys, generated_features = normalizer.normalize(df, search_keys, generated_features)
|
|
2663
|
+
columns_renaming = normalizer.columns_renaming
|
|
2715
2664
|
|
|
2716
|
-
|
|
2665
|
+
# If there are no external features, we don't call backend on transform
|
|
2666
|
+
external_features = [fm for fm in features_meta if fm.shap_value > 0 and fm.source != "etalon"]
|
|
2667
|
+
if len(external_features) == 0:
|
|
2668
|
+
self.logger.warning(
|
|
2669
|
+
"No external features found, returning original dataframe"
|
|
2670
|
+
f" with generated important features: {self.feature_names_}"
|
|
2671
|
+
)
|
|
2672
|
+
df = df.rename(columns=columns_renaming)
|
|
2673
|
+
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2674
|
+
search_keys = {columns_renaming.get(c, c): t for c, t in search_keys.items()}
|
|
2675
|
+
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2676
|
+
validated_Xy, generated_features, keep_input, is_transform=True
|
|
2677
|
+
)
|
|
2678
|
+
self.logger.warning(f"Filtered columns by existance in dataframe: {selecting_columns}")
|
|
2717
2679
|
if add_fit_system_record_id:
|
|
2718
2680
|
df = self._add_fit_system_record_id(
|
|
2719
2681
|
df,
|
|
@@ -2727,252 +2689,294 @@ if response.status_code == 200:
|
|
|
2727
2689
|
self.logger,
|
|
2728
2690
|
self.bundle,
|
|
2729
2691
|
)
|
|
2730
|
-
|
|
2731
|
-
|
|
2692
|
+
selecting_columns.append(SYSTEM_RECORD_ID)
|
|
2693
|
+
return df[selecting_columns], columns_renaming, generated_features, search_keys
|
|
2732
2694
|
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2695
|
+
# Don't pass all features in backend on transform
|
|
2696
|
+
runtime_parameters = self._get_copy_of_runtime_parameters()
|
|
2697
|
+
features_for_transform = self._search_task.get_features_for_transform()
|
|
2698
|
+
if features_for_transform:
|
|
2699
|
+
missing_features_for_transform = [
|
|
2700
|
+
columns_renaming.get(f) or f for f in features_for_transform if f not in df.columns
|
|
2701
|
+
]
|
|
2702
|
+
if TARGET in missing_features_for_transform:
|
|
2703
|
+
raise ValidationError(self.bundle.get("missing_target_for_transform"))
|
|
2736
2704
|
|
|
2737
|
-
|
|
2705
|
+
if len(missing_features_for_transform) > 0:
|
|
2706
|
+
raise ValidationError(
|
|
2707
|
+
self.bundle.get("missing_features_for_transform").format(missing_features_for_transform)
|
|
2708
|
+
)
|
|
2709
|
+
features_for_embeddings = self._search_task.get_features_for_embeddings()
|
|
2710
|
+
if features_for_embeddings:
|
|
2711
|
+
runtime_parameters.properties["features_for_embeddings"] = ",".join(features_for_embeddings)
|
|
2712
|
+
features_for_transform = [f for f in features_for_transform if f not in search_keys.keys()]
|
|
2738
2713
|
|
|
2739
|
-
|
|
2740
|
-
df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
|
|
2714
|
+
columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
|
|
2741
2715
|
|
|
2742
|
-
|
|
2716
|
+
df[ENTITY_SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
|
|
2717
|
+
"float64"
|
|
2718
|
+
)
|
|
2743
2719
|
|
|
2744
|
-
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2720
|
+
features_not_to_pass = []
|
|
2721
|
+
if add_fit_system_record_id:
|
|
2722
|
+
df = self._add_fit_system_record_id(
|
|
2723
|
+
df,
|
|
2724
|
+
search_keys,
|
|
2725
|
+
SYSTEM_RECORD_ID,
|
|
2726
|
+
TARGET,
|
|
2727
|
+
columns_renaming,
|
|
2728
|
+
self.id_columns,
|
|
2729
|
+
self.cv,
|
|
2730
|
+
self.model_task_type,
|
|
2731
|
+
self.logger,
|
|
2732
|
+
self.bundle,
|
|
2733
|
+
)
|
|
2734
|
+
df = df.rename(columns={SYSTEM_RECORD_ID: SORT_ID})
|
|
2735
|
+
features_not_to_pass.append(SORT_ID)
|
|
2756
2736
|
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
2777
|
-
|
|
2778
|
-
col: FileColumnMeaningType.FEATURE
|
|
2779
|
-
for col in features_for_transform
|
|
2780
|
-
if col not in date_features and col not in generated_features
|
|
2781
|
-
}
|
|
2737
|
+
system_columns_with_original_index = [ENTITY_SYSTEM_RECORD_ID] + generated_features
|
|
2738
|
+
if add_fit_system_record_id:
|
|
2739
|
+
system_columns_with_original_index.append(SORT_ID)
|
|
2740
|
+
|
|
2741
|
+
df_before_explode = df[system_columns_with_original_index].copy()
|
|
2742
|
+
|
|
2743
|
+
# Explode multiple search keys
|
|
2744
|
+
df, unnest_search_keys = self._explode_multiple_search_keys(df, search_keys, columns_renaming)
|
|
2745
|
+
|
|
2746
|
+
# Convert search keys and generate features on them
|
|
2747
|
+
|
|
2748
|
+
email_column = self._get_email_column(search_keys)
|
|
2749
|
+
hem_column = self._get_hem_column(search_keys)
|
|
2750
|
+
if email_column:
|
|
2751
|
+
converter = EmailSearchKeyConverter(
|
|
2752
|
+
email_column,
|
|
2753
|
+
hem_column,
|
|
2754
|
+
search_keys,
|
|
2755
|
+
columns_renaming,
|
|
2756
|
+
list(unnest_search_keys.keys()),
|
|
2757
|
+
self.logger,
|
|
2782
2758
|
)
|
|
2783
|
-
|
|
2784
|
-
meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
|
|
2785
|
-
meaning_types.update({col: key.value for col, key in search_keys.items()})
|
|
2759
|
+
df = converter.convert(df)
|
|
2786
2760
|
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
|
|
2794
|
-
|
|
2761
|
+
ip_column = self._get_ip_column(search_keys)
|
|
2762
|
+
if ip_column:
|
|
2763
|
+
converter = IpSearchKeyConverter(
|
|
2764
|
+
ip_column,
|
|
2765
|
+
search_keys,
|
|
2766
|
+
columns_renaming,
|
|
2767
|
+
list(unnest_search_keys.keys()),
|
|
2768
|
+
self.bundle,
|
|
2769
|
+
self.logger,
|
|
2795
2770
|
)
|
|
2771
|
+
df = converter.convert(df)
|
|
2796
2772
|
|
|
2797
|
-
|
|
2798
|
-
|
|
2773
|
+
date_features = []
|
|
2774
|
+
for col in features_for_transform:
|
|
2775
|
+
if DateTimeConverter(col).is_datetime(df):
|
|
2776
|
+
df[col] = DateTimeConverter(col).to_date_string(df)
|
|
2777
|
+
date_features.append(col)
|
|
2799
2778
|
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2803
|
-
|
|
2804
|
-
|
|
2805
|
-
|
|
2806
|
-
|
|
2807
|
-
|
|
2808
|
-
|
|
2779
|
+
meaning_types = {}
|
|
2780
|
+
meaning_types.update(
|
|
2781
|
+
{
|
|
2782
|
+
col: FileColumnMeaningType.FEATURE
|
|
2783
|
+
for col in features_for_transform
|
|
2784
|
+
if col not in date_features and col not in generated_features
|
|
2785
|
+
}
|
|
2786
|
+
)
|
|
2787
|
+
meaning_types.update({col: FileColumnMeaningType.GENERATED_FEATURE for col in generated_features})
|
|
2788
|
+
meaning_types.update({col: FileColumnMeaningType.DATE_FEATURE for col in date_features})
|
|
2789
|
+
meaning_types.update({col: key.value for col, key in search_keys.items()})
|
|
2809
2790
|
|
|
2810
|
-
|
|
2791
|
+
features_not_to_pass.extend(
|
|
2792
|
+
[
|
|
2793
|
+
c
|
|
2794
|
+
for c in df.columns
|
|
2795
|
+
if c not in search_keys.keys()
|
|
2796
|
+
and c not in features_for_transform
|
|
2797
|
+
and c not in [ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]
|
|
2798
|
+
]
|
|
2799
|
+
)
|
|
2811
2800
|
|
|
2812
|
-
|
|
2801
|
+
if DateTimeConverter.DATETIME_COL in df.columns:
|
|
2802
|
+
df = df.drop(columns=DateTimeConverter.DATETIME_COL)
|
|
2813
2803
|
|
|
2814
|
-
|
|
2804
|
+
# search keys might be changed after explode
|
|
2805
|
+
columns_for_system_record_id = sorted(list(search_keys.keys()) + features_for_transform)
|
|
2806
|
+
df[SYSTEM_RECORD_ID] = pd.util.hash_pandas_object(df[columns_for_system_record_id], index=False).astype(
|
|
2807
|
+
"float64"
|
|
2808
|
+
)
|
|
2809
|
+
meaning_types[SYSTEM_RECORD_ID] = FileColumnMeaningType.SYSTEM_RECORD_ID
|
|
2810
|
+
meaning_types[ENTITY_SYSTEM_RECORD_ID] = FileColumnMeaningType.ENTITY_SYSTEM_RECORD_ID
|
|
2811
|
+
if SEARCH_KEY_UNNEST in df.columns:
|
|
2812
|
+
meaning_types[SEARCH_KEY_UNNEST] = FileColumnMeaningType.UNNEST_KEY
|
|
2815
2813
|
|
|
2816
|
-
|
|
2817
|
-
df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
|
|
2818
|
-
)
|
|
2819
|
-
if not silent_mode and full_duplicates_warning:
|
|
2820
|
-
self.__log_warning(full_duplicates_warning)
|
|
2814
|
+
df = df.reset_index(drop=True)
|
|
2821
2815
|
|
|
2822
|
-
|
|
2823
|
-
gc.collect()
|
|
2824
|
-
|
|
2825
|
-
dataset = Dataset(
|
|
2826
|
-
"sample_" + str(uuid.uuid4()),
|
|
2827
|
-
df=df_without_features,
|
|
2828
|
-
meaning_types=meaning_types,
|
|
2829
|
-
search_keys=combined_search_keys,
|
|
2830
|
-
unnest_search_keys=unnest_search_keys,
|
|
2831
|
-
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2832
|
-
date_column=self._get_date_column(search_keys),
|
|
2833
|
-
date_format=self.date_format,
|
|
2834
|
-
sample_config=self.sample_config,
|
|
2835
|
-
rest_client=self.rest_client,
|
|
2836
|
-
logger=self.logger,
|
|
2837
|
-
bundle=self.bundle,
|
|
2838
|
-
warning_callback=self.__log_warning,
|
|
2839
|
-
)
|
|
2840
|
-
dataset.columns_renaming = columns_renaming
|
|
2841
|
-
|
|
2842
|
-
validation_task = self._search_task.validation(
|
|
2843
|
-
trace_id,
|
|
2844
|
-
dataset,
|
|
2845
|
-
start_time=start_time,
|
|
2846
|
-
extract_features=True,
|
|
2847
|
-
runtime_parameters=runtime_parameters,
|
|
2848
|
-
exclude_features_sources=exclude_features_sources,
|
|
2849
|
-
metrics_calculation=metrics_calculation,
|
|
2850
|
-
silent_mode=silent_mode,
|
|
2851
|
-
progress_bar=progress_bar,
|
|
2852
|
-
progress_callback=progress_callback,
|
|
2853
|
-
)
|
|
2816
|
+
combined_search_keys = combine_search_keys(search_keys.keys())
|
|
2854
2817
|
|
|
2855
|
-
|
|
2856
|
-
gc.collect()
|
|
2818
|
+
df_without_features = df.drop(columns=features_not_to_pass, errors="ignore")
|
|
2857
2819
|
|
|
2858
|
-
|
|
2859
|
-
|
|
2860
|
-
|
|
2861
|
-
|
|
2820
|
+
df_without_features, full_duplicates_warning = clean_full_duplicates(
|
|
2821
|
+
df_without_features, is_transform=True, logger=self.logger, bundle=self.bundle
|
|
2822
|
+
)
|
|
2823
|
+
if not silent_mode and full_duplicates_warning:
|
|
2824
|
+
self.__log_warning(full_duplicates_warning)
|
|
2862
2825
|
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
if progress_bar is not None:
|
|
2866
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2867
|
-
if progress_callback is not None:
|
|
2868
|
-
progress_callback(progress)
|
|
2869
|
-
prev_progress: SearchProgress | None = None
|
|
2870
|
-
polling_period_seconds = 1
|
|
2871
|
-
try:
|
|
2872
|
-
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2873
|
-
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2874
|
-
progress.recalculate_eta(time.time() - start_time)
|
|
2875
|
-
else:
|
|
2876
|
-
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2877
|
-
prev_progress = progress
|
|
2878
|
-
if progress_bar is not None:
|
|
2879
|
-
progress_bar.progress = progress.to_progress_bar()
|
|
2880
|
-
if progress_callback is not None:
|
|
2881
|
-
progress_callback(progress)
|
|
2882
|
-
if progress.stage == ProgressStage.FAILED.value:
|
|
2883
|
-
raise Exception(progress.error_message)
|
|
2884
|
-
time.sleep(polling_period_seconds)
|
|
2885
|
-
progress = self.get_progress(trace_id, validation_task)
|
|
2886
|
-
except KeyboardInterrupt as e:
|
|
2887
|
-
print(self.bundle.get("search_stopping"))
|
|
2888
|
-
self.rest_client.stop_search_task_v2(trace_id, validation_task.search_task_id)
|
|
2889
|
-
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2890
|
-
print(self.bundle.get("search_stopped"))
|
|
2891
|
-
raise e
|
|
2826
|
+
del df
|
|
2827
|
+
gc.collect()
|
|
2892
2828
|
|
|
2893
|
-
|
|
2829
|
+
dataset = Dataset(
|
|
2830
|
+
"sample_" + str(uuid.uuid4()),
|
|
2831
|
+
df=df_without_features,
|
|
2832
|
+
meaning_types=meaning_types,
|
|
2833
|
+
search_keys=combined_search_keys,
|
|
2834
|
+
unnest_search_keys=unnest_search_keys,
|
|
2835
|
+
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2836
|
+
date_column=self._get_date_column(search_keys),
|
|
2837
|
+
date_format=self.date_format,
|
|
2838
|
+
sample_config=self.sample_config,
|
|
2839
|
+
rest_client=self.rest_client,
|
|
2840
|
+
logger=self.logger,
|
|
2841
|
+
bundle=self.bundle,
|
|
2842
|
+
warning_callback=self.__log_warning,
|
|
2843
|
+
)
|
|
2844
|
+
dataset.columns_renaming = columns_renaming
|
|
2894
2845
|
|
|
2895
|
-
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
|
|
2900
|
-
|
|
2846
|
+
validation_task = self._search_task.validation(
|
|
2847
|
+
self._get_trace_id(),
|
|
2848
|
+
dataset,
|
|
2849
|
+
start_time=start_time,
|
|
2850
|
+
extract_features=True,
|
|
2851
|
+
runtime_parameters=runtime_parameters,
|
|
2852
|
+
exclude_features_sources=exclude_features_sources,
|
|
2853
|
+
metrics_calculation=metrics_calculation,
|
|
2854
|
+
silent_mode=silent_mode,
|
|
2855
|
+
progress_bar=progress_bar,
|
|
2856
|
+
progress_callback=progress_callback,
|
|
2857
|
+
)
|
|
2901
2858
|
|
|
2902
|
-
|
|
2903
|
-
|
|
2859
|
+
del df_without_features, dataset
|
|
2860
|
+
gc.collect()
|
|
2904
2861
|
|
|
2905
|
-
|
|
2906
|
-
|
|
2907
|
-
|
|
2908
|
-
|
|
2909
|
-
[
|
|
2910
|
-
validated_Xy.reset_index(drop=True),
|
|
2911
|
-
df_before_explode.reset_index(drop=True),
|
|
2912
|
-
],
|
|
2913
|
-
axis=1,
|
|
2914
|
-
).set_index(validated_Xy.index)
|
|
2915
|
-
|
|
2916
|
-
result_features = validation_task.get_all_validation_raw_features(trace_id, metrics_calculation)
|
|
2917
|
-
|
|
2918
|
-
result = self.__enrich(
|
|
2919
|
-
combined_df,
|
|
2920
|
-
result_features,
|
|
2921
|
-
how="left",
|
|
2922
|
-
)
|
|
2862
|
+
if not silent_mode:
|
|
2863
|
+
print(self.bundle.get("polling_transform_task").format(validation_task.search_task_id))
|
|
2864
|
+
if not self.__is_registered:
|
|
2865
|
+
print(self.bundle.get("polling_unregister_information"))
|
|
2923
2866
|
|
|
2924
|
-
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
2928
|
-
|
|
2929
|
-
|
|
2930
|
-
|
|
2931
|
-
|
|
2932
|
-
|
|
2933
|
-
|
|
2867
|
+
progress = self.get_progress(validation_task)
|
|
2868
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2869
|
+
if progress_bar is not None:
|
|
2870
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2871
|
+
if progress_callback is not None:
|
|
2872
|
+
progress_callback(progress)
|
|
2873
|
+
prev_progress: SearchProgress | None = None
|
|
2874
|
+
polling_period_seconds = 1
|
|
2875
|
+
try:
|
|
2876
|
+
while progress.stage != ProgressStage.DOWNLOADING.value:
|
|
2877
|
+
if prev_progress is None or prev_progress.percent != progress.percent:
|
|
2878
|
+
progress.recalculate_eta(time.time() - start_time)
|
|
2879
|
+
else:
|
|
2880
|
+
progress.update_eta(prev_progress.eta - polling_period_seconds)
|
|
2881
|
+
prev_progress = progress
|
|
2882
|
+
if progress_bar is not None:
|
|
2883
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2884
|
+
if progress_callback is not None:
|
|
2885
|
+
progress_callback(progress)
|
|
2886
|
+
if progress.stage == ProgressStage.FAILED.value:
|
|
2887
|
+
raise Exception(progress.error_message)
|
|
2888
|
+
time.sleep(polling_period_seconds)
|
|
2889
|
+
progress = self.get_progress(validation_task)
|
|
2890
|
+
except KeyboardInterrupt as e:
|
|
2891
|
+
print(self.bundle.get("search_stopping"))
|
|
2892
|
+
self.rest_client.stop_search_task_v2(self._get_trace_id(), validation_task.search_task_id)
|
|
2893
|
+
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
|
|
2894
|
+
print(self.bundle.get("search_stopped"))
|
|
2895
|
+
raise e
|
|
2896
|
+
|
|
2897
|
+
validation_task.poll_result(self._get_trace_id(), quiet=True)
|
|
2898
|
+
|
|
2899
|
+
seconds_left = time.time() - start_time
|
|
2900
|
+
progress = SearchProgress(97.0, ProgressStage.DOWNLOADING, seconds_left)
|
|
2901
|
+
if progress_bar is not None:
|
|
2902
|
+
progress_bar.progress = progress.to_progress_bar()
|
|
2903
|
+
if progress_callback is not None:
|
|
2904
|
+
progress_callback(progress)
|
|
2934
2905
|
|
|
2935
|
-
|
|
2936
|
-
|
|
2937
|
-
sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
|
|
2938
|
-
for c in generated_features:
|
|
2939
|
-
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2940
|
-
sorted_selecting_columns.append(c)
|
|
2941
|
-
for c in result.columns:
|
|
2942
|
-
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2943
|
-
sorted_selecting_columns.append(c)
|
|
2906
|
+
if not silent_mode:
|
|
2907
|
+
print(self.bundle.get("transform_start"))
|
|
2944
2908
|
|
|
2945
|
-
|
|
2909
|
+
# Prepare input DataFrame for __enrich by concatenating generated ids and client features
|
|
2910
|
+
df_before_explode = df_before_explode.rename(columns=columns_renaming)
|
|
2911
|
+
generated_features = [columns_renaming.get(c, c) for c in generated_features]
|
|
2912
|
+
combined_df = pd.concat(
|
|
2913
|
+
[
|
|
2914
|
+
validated_Xy.reset_index(drop=True),
|
|
2915
|
+
df_before_explode.reset_index(drop=True),
|
|
2916
|
+
],
|
|
2917
|
+
axis=1,
|
|
2918
|
+
).set_index(validated_Xy.index)
|
|
2919
|
+
|
|
2920
|
+
result_features = validation_task.get_all_validation_raw_features(self._get_trace_id(), metrics_calculation)
|
|
2921
|
+
|
|
2922
|
+
result = self.__enrich(
|
|
2923
|
+
combined_df,
|
|
2924
|
+
result_features,
|
|
2925
|
+
how="left",
|
|
2926
|
+
)
|
|
2927
|
+
|
|
2928
|
+
selecting_columns = self._selecting_input_and_generated_columns(
|
|
2929
|
+
validated_Xy, generated_features, keep_input, is_transform=True
|
|
2930
|
+
)
|
|
2931
|
+
selecting_columns.extend(
|
|
2932
|
+
c
|
|
2933
|
+
for c in result.columns
|
|
2934
|
+
if c in self.feature_names_ and c not in selecting_columns and c not in validated_Xy.columns
|
|
2935
|
+
)
|
|
2936
|
+
if add_fit_system_record_id:
|
|
2937
|
+
selecting_columns.append(SORT_ID)
|
|
2946
2938
|
|
|
2947
|
-
|
|
2939
|
+
selecting_columns = list(set(selecting_columns))
|
|
2940
|
+
# sorting: first columns from X, then generated features, then enriched features
|
|
2941
|
+
sorted_selecting_columns = [c for c in validated_Xy.columns if c in selecting_columns]
|
|
2942
|
+
for c in generated_features:
|
|
2943
|
+
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2944
|
+
sorted_selecting_columns.append(c)
|
|
2945
|
+
for c in result.columns:
|
|
2946
|
+
if c in selecting_columns and c not in sorted_selecting_columns:
|
|
2947
|
+
sorted_selecting_columns.append(c)
|
|
2948
2948
|
|
|
2949
|
-
|
|
2950
|
-
result = result.drop(columns=COUNTRY, errors="ignore")
|
|
2949
|
+
self.logger.info(f"Transform sorted_selecting_columns: {sorted_selecting_columns}")
|
|
2951
2950
|
|
|
2952
|
-
|
|
2953
|
-
|
|
2951
|
+
result = result[sorted_selecting_columns]
|
|
2952
|
+
|
|
2953
|
+
if self.country_added:
|
|
2954
|
+
result = result.drop(columns=COUNTRY, errors="ignore")
|
|
2954
2955
|
|
|
2955
|
-
|
|
2956
|
-
|
|
2957
|
-
result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
|
|
2956
|
+
if add_fit_system_record_id:
|
|
2957
|
+
result = result.rename(columns={SORT_ID: SYSTEM_RECORD_ID})
|
|
2958
2958
|
|
|
2959
|
-
|
|
2959
|
+
for c in result.columns:
|
|
2960
|
+
if result[c].dtype == "category":
|
|
2961
|
+
result.loc[:, c] = np.where(~result[c].isin(result[c].dtype.categories), np.nan, result[c])
|
|
2962
|
+
|
|
2963
|
+
return result, columns_renaming, generated_features, search_keys
|
|
2960
2964
|
|
|
2961
2965
|
def _selecting_input_and_generated_columns(
|
|
2962
2966
|
self,
|
|
2963
2967
|
validated_Xy: pd.DataFrame,
|
|
2964
2968
|
generated_features: list[str],
|
|
2965
2969
|
keep_input: bool,
|
|
2966
|
-
trace_id: str,
|
|
2967
2970
|
is_transform: bool = False,
|
|
2968
2971
|
):
|
|
2969
|
-
file_meta = self._search_task.get_file_metadata(
|
|
2972
|
+
file_meta = self._search_task.get_file_metadata(self._get_trace_id())
|
|
2970
2973
|
fit_dropped_features = self.fit_dropped_features or file_meta.droppedColumns or []
|
|
2971
2974
|
fit_input_columns = [c.originalName for c in file_meta.columns]
|
|
2972
2975
|
original_dropped_features = [self.fit_columns_renaming.get(c, c) for c in fit_dropped_features]
|
|
2973
2976
|
new_columns_on_transform = [
|
|
2974
2977
|
c for c in validated_Xy.columns if c not in fit_input_columns and c not in original_dropped_features
|
|
2975
2978
|
]
|
|
2979
|
+
fit_original_search_keys = self._get_fit_search_keys_with_original_names()
|
|
2976
2980
|
|
|
2977
2981
|
selected_generated_features = [c for c in generated_features if c in self.feature_names_]
|
|
2978
2982
|
if keep_input is True:
|
|
@@ -2982,7 +2986,7 @@ if response.status_code == 200:
|
|
|
2982
2986
|
if not self.fit_select_features
|
|
2983
2987
|
or c in self.feature_names_
|
|
2984
2988
|
or (c in new_columns_on_transform and is_transform)
|
|
2985
|
-
or c in
|
|
2989
|
+
or c in fit_original_search_keys
|
|
2986
2990
|
or c in (self.id_columns or [])
|
|
2987
2991
|
or c in [EVAL_SET_INDEX, TARGET] # transform for metrics calculation
|
|
2988
2992
|
or c == self.baseline_score_column
|
|
@@ -3067,7 +3071,6 @@ if response.status_code == 200:
|
|
|
3067
3071
|
|
|
3068
3072
|
def __inner_fit(
|
|
3069
3073
|
self,
|
|
3070
|
-
trace_id: str,
|
|
3071
3074
|
X: pd.DataFrame | pd.Series | np.ndarray,
|
|
3072
3075
|
y: pd.DataFrame | pd.Series | np.ndarray | list | None,
|
|
3073
3076
|
eval_set: list[tuple] | None,
|
|
@@ -3149,6 +3152,8 @@ if response.status_code == 200:
|
|
|
3149
3152
|
df = self.__handle_index_search_keys(df, self.fit_search_keys)
|
|
3150
3153
|
self.fit_search_keys = self.__prepare_search_keys(df, self.fit_search_keys, is_demo_dataset)
|
|
3151
3154
|
|
|
3155
|
+
df = self._validate_OOT(df, self.fit_search_keys)
|
|
3156
|
+
|
|
3152
3157
|
maybe_date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
|
|
3153
3158
|
has_date = maybe_date_column is not None and maybe_date_column in validated_X.columns
|
|
3154
3159
|
|
|
@@ -3234,6 +3239,7 @@ if response.status_code == 200:
|
|
|
3234
3239
|
)
|
|
3235
3240
|
self.fit_columns_renaming = normalizer.columns_renaming
|
|
3236
3241
|
if normalizer.removed_datetime_features:
|
|
3242
|
+
self.fit_dropped_features.update(normalizer.removed_datetime_features)
|
|
3237
3243
|
original_removed_datetime_features = [
|
|
3238
3244
|
self.fit_columns_renaming.get(f, f) for f in normalizer.removed_datetime_features
|
|
3239
3245
|
]
|
|
@@ -3400,6 +3406,7 @@ if response.status_code == 200:
|
|
|
3400
3406
|
id_columns=self.__get_renamed_id_columns(),
|
|
3401
3407
|
is_imbalanced=self.imbalanced,
|
|
3402
3408
|
dropped_columns=[self.fit_columns_renaming.get(f, f) for f in self.fit_dropped_features],
|
|
3409
|
+
autodetected_search_keys=self.autodetected_search_keys,
|
|
3403
3410
|
date_column=self._get_date_column(self.fit_search_keys),
|
|
3404
3411
|
date_format=self.date_format,
|
|
3405
3412
|
random_state=self.random_state,
|
|
@@ -3423,7 +3430,7 @@ if response.status_code == 200:
|
|
|
3423
3430
|
]
|
|
3424
3431
|
|
|
3425
3432
|
self._search_task = dataset.search(
|
|
3426
|
-
trace_id=
|
|
3433
|
+
trace_id=self._get_trace_id(),
|
|
3427
3434
|
progress_bar=progress_bar,
|
|
3428
3435
|
start_time=start_time,
|
|
3429
3436
|
progress_callback=progress_callback,
|
|
@@ -3443,7 +3450,7 @@ if response.status_code == 200:
|
|
|
3443
3450
|
if not self.__is_registered:
|
|
3444
3451
|
print(self.bundle.get("polling_unregister_information"))
|
|
3445
3452
|
|
|
3446
|
-
progress = self.get_progress(
|
|
3453
|
+
progress = self.get_progress()
|
|
3447
3454
|
prev_progress = None
|
|
3448
3455
|
progress.recalculate_eta(time.time() - start_time)
|
|
3449
3456
|
if progress_bar is not None:
|
|
@@ -3469,16 +3476,16 @@ if response.status_code == 200:
|
|
|
3469
3476
|
)
|
|
3470
3477
|
raise RuntimeError(self.bundle.get("search_task_failed_status"))
|
|
3471
3478
|
time.sleep(poll_period_seconds)
|
|
3472
|
-
progress = self.get_progress(
|
|
3479
|
+
progress = self.get_progress()
|
|
3473
3480
|
except KeyboardInterrupt as e:
|
|
3474
3481
|
print(self.bundle.get("search_stopping"))
|
|
3475
|
-
self.rest_client.stop_search_task_v2(
|
|
3482
|
+
self.rest_client.stop_search_task_v2(self._get_trace_id(), self._search_task.search_task_id)
|
|
3476
3483
|
self.logger.warning(f"Search {self._search_task.search_task_id} stopped by user")
|
|
3477
3484
|
self._search_task = None
|
|
3478
3485
|
print(self.bundle.get("search_stopped"))
|
|
3479
3486
|
raise e
|
|
3480
3487
|
|
|
3481
|
-
self._search_task.poll_result(
|
|
3488
|
+
self._search_task.poll_result(self._get_trace_id(), quiet=True)
|
|
3482
3489
|
|
|
3483
3490
|
seconds_left = time.time() - start_time
|
|
3484
3491
|
progress = SearchProgress(97.0, ProgressStage.GENERATING_REPORT, seconds_left)
|
|
@@ -3507,10 +3514,9 @@ if response.status_code == 200:
|
|
|
3507
3514
|
msg = self.bundle.get("features_not_generated").format(unused_features_for_generation)
|
|
3508
3515
|
self.__log_warning(msg)
|
|
3509
3516
|
|
|
3510
|
-
self.__prepare_feature_importances(
|
|
3517
|
+
self.__prepare_feature_importances(df)
|
|
3511
3518
|
|
|
3512
3519
|
self._select_features_by_psi(
|
|
3513
|
-
trace_id=trace_id,
|
|
3514
3520
|
X=X,
|
|
3515
3521
|
y=y,
|
|
3516
3522
|
eval_set=eval_set,
|
|
@@ -3523,7 +3529,7 @@ if response.status_code == 200:
|
|
|
3523
3529
|
progress_callback=progress_callback,
|
|
3524
3530
|
)
|
|
3525
3531
|
|
|
3526
|
-
self.__prepare_feature_importances(
|
|
3532
|
+
self.__prepare_feature_importances(df)
|
|
3527
3533
|
|
|
3528
3534
|
self.__show_selected_features()
|
|
3529
3535
|
|
|
@@ -3558,7 +3564,6 @@ if response.status_code == 200:
|
|
|
3558
3564
|
scoring,
|
|
3559
3565
|
estimator,
|
|
3560
3566
|
remove_outliers_calc_metrics,
|
|
3561
|
-
trace_id,
|
|
3562
3567
|
progress_bar,
|
|
3563
3568
|
progress_callback,
|
|
3564
3569
|
)
|
|
@@ -3653,11 +3658,10 @@ if response.status_code == 200:
|
|
|
3653
3658
|
y: pd.Series | None = None,
|
|
3654
3659
|
eval_set: list[tuple[pd.DataFrame, pd.Series]] | None = None,
|
|
3655
3660
|
is_transform: bool = False,
|
|
3656
|
-
silent: bool = False,
|
|
3657
3661
|
) -> tuple[pd.DataFrame, pd.Series, list[tuple[pd.DataFrame, pd.Series]]] | None:
|
|
3658
3662
|
validated_X = self._validate_X(X, is_transform)
|
|
3659
3663
|
validated_y = self._validate_y(validated_X, y, enforce_y=not is_transform)
|
|
3660
|
-
validated_eval_set = self._validate_eval_set(validated_X, eval_set
|
|
3664
|
+
validated_eval_set = self._validate_eval_set(validated_X, eval_set)
|
|
3661
3665
|
return validated_X, validated_y, validated_eval_set
|
|
3662
3666
|
|
|
3663
3667
|
def _encode_id_columns(
|
|
@@ -3783,31 +3787,41 @@ if response.status_code == 200:
|
|
|
3783
3787
|
return validated_y
|
|
3784
3788
|
|
|
3785
3789
|
def _validate_eval_set(
|
|
3786
|
-
self,
|
|
3787
|
-
|
|
3790
|
+
self,
|
|
3791
|
+
X: pd.DataFrame,
|
|
3792
|
+
eval_set: list[tuple[pd.DataFrame, pd.Series]] | None,
|
|
3793
|
+
) -> list[tuple[pd.DataFrame, pd.Series]] | None:
|
|
3788
3794
|
if eval_set is None:
|
|
3789
3795
|
return None
|
|
3790
3796
|
validated_eval_set = []
|
|
3791
|
-
|
|
3792
|
-
has_date = date_col is not None and date_col in X.columns
|
|
3793
|
-
for idx, eval_pair in enumerate(eval_set):
|
|
3797
|
+
for _, eval_pair in enumerate(eval_set):
|
|
3794
3798
|
validated_pair = self._validate_eval_set_pair(X, eval_pair)
|
|
3795
|
-
if validated_pair[1].isna().all():
|
|
3796
|
-
if not has_date:
|
|
3797
|
-
msg = self.bundle.get("oot_without_date_not_supported").format(idx + 1)
|
|
3798
|
-
elif self.columns_for_online_api:
|
|
3799
|
-
msg = self.bundle.get("oot_with_online_sources_not_supported").format(idx + 1)
|
|
3800
|
-
else:
|
|
3801
|
-
msg = None
|
|
3802
|
-
if msg:
|
|
3803
|
-
if not silent:
|
|
3804
|
-
print(msg)
|
|
3805
|
-
self.logger.warning(msg)
|
|
3806
|
-
continue
|
|
3807
3799
|
validated_eval_set.append(validated_pair)
|
|
3808
3800
|
|
|
3809
3801
|
return validated_eval_set
|
|
3810
3802
|
|
|
3803
|
+
def _validate_OOT(self, df: pd.DataFrame, search_keys: dict[str, SearchKey]) -> pd.DataFrame:
|
|
3804
|
+
if EVAL_SET_INDEX not in df.columns:
|
|
3805
|
+
return df
|
|
3806
|
+
|
|
3807
|
+
for eval_set_index in df[EVAL_SET_INDEX].unique():
|
|
3808
|
+
if eval_set_index == 0:
|
|
3809
|
+
continue
|
|
3810
|
+
eval_df = df[df[EVAL_SET_INDEX] == eval_set_index]
|
|
3811
|
+
date_col = self._get_date_column(search_keys)
|
|
3812
|
+
has_date = date_col is not None and date_col in eval_df.columns
|
|
3813
|
+
if eval_df[TARGET].isna().all():
|
|
3814
|
+
msg = None
|
|
3815
|
+
if not has_date:
|
|
3816
|
+
msg = self.bundle.get("oot_without_date_not_supported").format(eval_set_index)
|
|
3817
|
+
elif self.columns_for_online_api:
|
|
3818
|
+
msg = self.bundle.get("oot_with_online_sources_not_supported").format(eval_set_index)
|
|
3819
|
+
if msg:
|
|
3820
|
+
print(msg)
|
|
3821
|
+
self.logger.warning(msg)
|
|
3822
|
+
df = df[df[EVAL_SET_INDEX] != eval_set_index]
|
|
3823
|
+
return df
|
|
3824
|
+
|
|
3811
3825
|
def _validate_eval_set_pair(self, X: pd.DataFrame, eval_pair: tuple) -> tuple[pd.DataFrame, pd.Series]:
|
|
3812
3826
|
if len(eval_pair) != 2:
|
|
3813
3827
|
raise ValidationError(self.bundle.get("eval_set_invalid_tuple_size").format(len(eval_pair)))
|
|
@@ -4423,47 +4437,6 @@ if response.status_code == 200:
|
|
|
4423
4437
|
|
|
4424
4438
|
return result_features
|
|
4425
4439
|
|
|
4426
|
-
def __get_features_importance_from_server(self, trace_id: str, df: pd.DataFrame):
|
|
4427
|
-
if self._search_task is None:
|
|
4428
|
-
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
4429
|
-
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4430
|
-
if features_meta is None:
|
|
4431
|
-
raise Exception(self.bundle.get("missing_features_meta"))
|
|
4432
|
-
features_meta = deepcopy(features_meta)
|
|
4433
|
-
|
|
4434
|
-
original_names_dict = {c.name: c.originalName for c in self._search_task.get_file_metadata(trace_id).columns}
|
|
4435
|
-
df = df.rename(columns=original_names_dict)
|
|
4436
|
-
|
|
4437
|
-
features_meta.sort(key=lambda m: (-m.shap_value, m.name))
|
|
4438
|
-
|
|
4439
|
-
importances = {}
|
|
4440
|
-
|
|
4441
|
-
for feature_meta in features_meta:
|
|
4442
|
-
if feature_meta.name in original_names_dict.keys():
|
|
4443
|
-
feature_meta.name = original_names_dict[feature_meta.name]
|
|
4444
|
-
|
|
4445
|
-
is_client_feature = feature_meta.name in df.columns
|
|
4446
|
-
|
|
4447
|
-
if feature_meta.shap_value == 0.0:
|
|
4448
|
-
continue
|
|
4449
|
-
|
|
4450
|
-
# Use only important features
|
|
4451
|
-
if (
|
|
4452
|
-
feature_meta.name == COUNTRY
|
|
4453
|
-
# In select_features mode we select also from etalon features and need to show them
|
|
4454
|
-
or (not self.fit_select_features and is_client_feature)
|
|
4455
|
-
):
|
|
4456
|
-
continue
|
|
4457
|
-
|
|
4458
|
-
# Temporary workaround for duplicate features metadata
|
|
4459
|
-
if feature_meta.name in importances:
|
|
4460
|
-
self.logger.warning(f"WARNING: Duplicate feature metadata: {feature_meta}")
|
|
4461
|
-
continue
|
|
4462
|
-
|
|
4463
|
-
importances[feature_meta.name] = feature_meta.shap_value
|
|
4464
|
-
|
|
4465
|
-
return importances
|
|
4466
|
-
|
|
4467
4440
|
def __get_categorical_features(self) -> list[str]:
|
|
4468
4441
|
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4469
4442
|
if features_meta is None:
|
|
@@ -4473,7 +4446,6 @@ if response.status_code == 200:
|
|
|
4473
4446
|
|
|
4474
4447
|
def __prepare_feature_importances(
|
|
4475
4448
|
self,
|
|
4476
|
-
trace_id: str,
|
|
4477
4449
|
clients_features_df: pd.DataFrame,
|
|
4478
4450
|
updated_shaps: dict[str, float] | None = None,
|
|
4479
4451
|
update_selected_features: bool = True,
|
|
@@ -4481,16 +4453,16 @@ if response.status_code == 200:
|
|
|
4481
4453
|
):
|
|
4482
4454
|
if self._search_task is None:
|
|
4483
4455
|
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
|
4484
|
-
selected_features = self._search_task.get_selected_features(
|
|
4456
|
+
selected_features = self._search_task.get_selected_features(self._get_trace_id())
|
|
4485
4457
|
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
4486
4458
|
if features_meta is None:
|
|
4487
4459
|
raise Exception(self.bundle.get("missing_features_meta"))
|
|
4488
4460
|
features_meta = deepcopy(features_meta)
|
|
4489
4461
|
|
|
4490
|
-
file_metadata_columns = self._search_task.get_file_metadata(
|
|
4462
|
+
file_metadata_columns = self._search_task.get_file_metadata(self._get_trace_id()).columns
|
|
4491
4463
|
file_meta_by_orig_name = {c.originalName: c for c in file_metadata_columns}
|
|
4492
4464
|
original_names_dict = {c.name: c.originalName for c in file_metadata_columns}
|
|
4493
|
-
features_df = self._search_task.get_all_initial_raw_features(
|
|
4465
|
+
features_df = self._search_task.get_all_initial_raw_features(self._get_trace_id(), metrics_calculation=True)
|
|
4494
4466
|
|
|
4495
4467
|
# To be sure that names with hash suffixes
|
|
4496
4468
|
clients_features_df = clients_features_df.rename(columns=original_names_dict)
|
|
@@ -4581,7 +4553,7 @@ if response.status_code == 200:
|
|
|
4581
4553
|
internal_features_info.append(feature_info.to_internal_row(self.bundle))
|
|
4582
4554
|
|
|
4583
4555
|
if update_selected_features:
|
|
4584
|
-
self._search_task.update_selected_features(
|
|
4556
|
+
self._search_task.update_selected_features(self._get_trace_id(), self.feature_names_)
|
|
4585
4557
|
|
|
4586
4558
|
if len(features_info) > 0:
|
|
4587
4559
|
self.features_info = pd.DataFrame(features_info)
|
|
@@ -4779,12 +4751,17 @@ if response.status_code == 200:
|
|
|
4779
4751
|
):
|
|
4780
4752
|
raise ValidationError(self.bundle.get("empty_search_key").format(column_name))
|
|
4781
4753
|
|
|
4782
|
-
if
|
|
4783
|
-
|
|
4784
|
-
|
|
4785
|
-
|
|
4786
|
-
|
|
4787
|
-
|
|
4754
|
+
if is_transform:
|
|
4755
|
+
fit_autodetected_search_keys = self._get_autodetected_search_keys()
|
|
4756
|
+
if fit_autodetected_search_keys is not None:
|
|
4757
|
+
for key in fit_autodetected_search_keys.keys():
|
|
4758
|
+
if key not in x.columns:
|
|
4759
|
+
raise ValidationError(
|
|
4760
|
+
self.bundle.get("autodetected_search_key_not_found").format(key, x.columns)
|
|
4761
|
+
)
|
|
4762
|
+
valid_search_keys.update(fit_autodetected_search_keys)
|
|
4763
|
+
elif self.autodetect_search_keys:
|
|
4764
|
+
valid_search_keys = self.__detect_missing_search_keys(x, valid_search_keys, is_demo_dataset)
|
|
4788
4765
|
|
|
4789
4766
|
if all(k == SearchKey.CUSTOM_KEY for k in valid_search_keys.values()):
|
|
4790
4767
|
if self.__is_registered:
|
|
@@ -4829,7 +4806,6 @@ if response.status_code == 200:
|
|
|
4829
4806
|
scoring: Callable | str | None,
|
|
4830
4807
|
estimator: Any | None,
|
|
4831
4808
|
remove_outliers_calc_metrics: bool | None,
|
|
4832
|
-
trace_id: str,
|
|
4833
4809
|
progress_bar: ProgressBar | None = None,
|
|
4834
4810
|
progress_callback: Callable[[SearchProgress], Any] | None = None,
|
|
4835
4811
|
):
|
|
@@ -4837,7 +4813,6 @@ if response.status_code == 200:
|
|
|
4837
4813
|
scoring=scoring,
|
|
4838
4814
|
estimator=estimator,
|
|
4839
4815
|
remove_outliers_calc_metrics=remove_outliers_calc_metrics,
|
|
4840
|
-
trace_id=trace_id,
|
|
4841
4816
|
internal_call=True,
|
|
4842
4817
|
progress_bar=progress_bar,
|
|
4843
4818
|
progress_callback=progress_callback,
|
|
@@ -4902,60 +4877,36 @@ if response.status_code == 200:
|
|
|
4902
4877
|
df: pd.DataFrame,
|
|
4903
4878
|
search_keys: dict[str, SearchKey],
|
|
4904
4879
|
is_demo_dataset: bool,
|
|
4905
|
-
silent_mode=False,
|
|
4906
|
-
is_transform=False,
|
|
4907
4880
|
) -> dict[str, SearchKey]:
|
|
4908
4881
|
sample = df.head(100)
|
|
4909
4882
|
|
|
4910
|
-
|
|
4911
|
-
return not is_transform or (
|
|
4912
|
-
search_key in self.fit_search_keys.values() and search_key not in search_keys.values()
|
|
4913
|
-
)
|
|
4914
|
-
|
|
4915
|
-
if (
|
|
4916
|
-
SearchKey.DATE not in search_keys.values()
|
|
4917
|
-
and SearchKey.DATETIME not in search_keys.values()
|
|
4918
|
-
and check_need_detect(SearchKey.DATE)
|
|
4919
|
-
and check_need_detect(SearchKey.DATETIME)
|
|
4920
|
-
):
|
|
4883
|
+
if SearchKey.DATE not in search_keys.values() and SearchKey.DATETIME not in search_keys.values():
|
|
4921
4884
|
maybe_keys = DateSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4922
4885
|
if len(maybe_keys) > 0:
|
|
4923
4886
|
datetime_key = maybe_keys[0]
|
|
4924
4887
|
search_keys[datetime_key] = SearchKey.DATETIME
|
|
4925
4888
|
self.autodetected_search_keys[datetime_key] = SearchKey.DATETIME
|
|
4926
4889
|
self.logger.info(f"Autodetected search key DATETIME in column {datetime_key}")
|
|
4927
|
-
|
|
4928
|
-
print(self.bundle.get("datetime_detected").format(datetime_key))
|
|
4890
|
+
print(self.bundle.get("datetime_detected").format(datetime_key))
|
|
4929
4891
|
|
|
4930
4892
|
# if SearchKey.POSTAL_CODE not in search_keys.values() and check_need_detect(SearchKey.POSTAL_CODE):
|
|
4931
|
-
|
|
4932
|
-
|
|
4933
|
-
|
|
4934
|
-
|
|
4935
|
-
|
|
4936
|
-
|
|
4937
|
-
|
|
4938
|
-
|
|
4939
|
-
|
|
4940
|
-
|
|
4941
|
-
if (
|
|
4942
|
-
SearchKey.COUNTRY not in search_keys.values()
|
|
4943
|
-
and self.country_code is None
|
|
4944
|
-
and check_need_detect(SearchKey.COUNTRY)
|
|
4945
|
-
):
|
|
4893
|
+
maybe_keys = PostalCodeSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4894
|
+
if maybe_keys:
|
|
4895
|
+
new_keys = {key: SearchKey.POSTAL_CODE for key in maybe_keys}
|
|
4896
|
+
search_keys.update(new_keys)
|
|
4897
|
+
self.autodetected_search_keys.update(new_keys)
|
|
4898
|
+
self.logger.info(f"Autodetected search key POSTAL_CODE in column {maybe_keys}")
|
|
4899
|
+
print(self.bundle.get("postal_code_detected").format(maybe_keys))
|
|
4900
|
+
|
|
4901
|
+
if SearchKey.COUNTRY not in search_keys.values() and self.country_code is None:
|
|
4946
4902
|
maybe_key = CountrySearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4947
4903
|
if maybe_key:
|
|
4948
4904
|
search_keys[maybe_key[0]] = SearchKey.COUNTRY
|
|
4949
4905
|
self.autodetected_search_keys[maybe_key[0]] = SearchKey.COUNTRY
|
|
4950
4906
|
self.logger.info(f"Autodetected search key COUNTRY in column {maybe_key}")
|
|
4951
|
-
|
|
4952
|
-
print(self.bundle.get("country_detected").format(maybe_key))
|
|
4907
|
+
print(self.bundle.get("country_detected").format(maybe_key))
|
|
4953
4908
|
|
|
4954
|
-
if (
|
|
4955
|
-
# SearchKey.EMAIL not in search_keys.values()
|
|
4956
|
-
SearchKey.HEM not in search_keys.values()
|
|
4957
|
-
and check_need_detect(SearchKey.HEM)
|
|
4958
|
-
):
|
|
4909
|
+
if SearchKey.EMAIL not in search_keys.values() and SearchKey.HEM not in search_keys.values():
|
|
4959
4910
|
maybe_keys = EmailSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4960
4911
|
if maybe_keys:
|
|
4961
4912
|
if self.__is_registered or is_demo_dataset:
|
|
@@ -4963,34 +4914,28 @@ if response.status_code == 200:
|
|
|
4963
4914
|
search_keys.update(new_keys)
|
|
4964
4915
|
self.autodetected_search_keys.update(new_keys)
|
|
4965
4916
|
self.logger.info(f"Autodetected search key EMAIL in column {maybe_keys}")
|
|
4966
|
-
|
|
4967
|
-
print(self.bundle.get("email_detected").format(maybe_keys))
|
|
4917
|
+
print(self.bundle.get("email_detected").format(maybe_keys))
|
|
4968
4918
|
else:
|
|
4969
4919
|
self.logger.warning(
|
|
4970
4920
|
f"Autodetected search key EMAIL in column {maybe_keys}."
|
|
4971
4921
|
" But not used because not registered user"
|
|
4972
4922
|
)
|
|
4973
|
-
|
|
4974
|
-
self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
|
|
4923
|
+
self.__log_warning(self.bundle.get("email_detected_not_registered").format(maybe_keys))
|
|
4975
4924
|
|
|
4976
4925
|
# if SearchKey.PHONE not in search_keys.values() and check_need_detect(SearchKey.PHONE):
|
|
4977
|
-
|
|
4978
|
-
|
|
4979
|
-
if
|
|
4980
|
-
|
|
4981
|
-
|
|
4982
|
-
|
|
4983
|
-
|
|
4984
|
-
|
|
4985
|
-
|
|
4986
|
-
|
|
4987
|
-
|
|
4988
|
-
|
|
4989
|
-
|
|
4990
|
-
"But not used because not registered user"
|
|
4991
|
-
)
|
|
4992
|
-
if not silent_mode:
|
|
4993
|
-
self.__log_warning(self.bundle.get("phone_detected_not_registered"))
|
|
4926
|
+
maybe_keys = PhoneSearchKeyDetector().get_search_key_columns(sample, search_keys)
|
|
4927
|
+
if maybe_keys:
|
|
4928
|
+
if self.__is_registered or is_demo_dataset:
|
|
4929
|
+
new_keys = {key: SearchKey.PHONE for key in maybe_keys}
|
|
4930
|
+
search_keys.update(new_keys)
|
|
4931
|
+
self.autodetected_search_keys.update(new_keys)
|
|
4932
|
+
self.logger.info(f"Autodetected search key PHONE in column {maybe_keys}")
|
|
4933
|
+
print(self.bundle.get("phone_detected").format(maybe_keys))
|
|
4934
|
+
else:
|
|
4935
|
+
self.logger.warning(
|
|
4936
|
+
f"Autodetected search key PHONE in column {maybe_keys}. " "But not used because not registered user"
|
|
4937
|
+
)
|
|
4938
|
+
self.__log_warning(self.bundle.get("phone_detected_not_registered"))
|
|
4994
4939
|
|
|
4995
4940
|
return search_keys
|
|
4996
4941
|
|
|
@@ -5062,13 +5007,12 @@ if response.status_code == 200:
|
|
|
5062
5007
|
|
|
5063
5008
|
def dump_input(
|
|
5064
5009
|
self,
|
|
5065
|
-
trace_id: str,
|
|
5066
5010
|
X: pd.DataFrame | pd.Series,
|
|
5067
5011
|
y: pd.DataFrame | pd.Series | None = None,
|
|
5068
5012
|
eval_set: tuple | None = None,
|
|
5069
5013
|
):
|
|
5070
|
-
def dump_task(X_, y_, eval_set_):
|
|
5071
|
-
with MDC(correlation_id=
|
|
5014
|
+
def dump_task(X_, y_, eval_set_, trace_id_):
|
|
5015
|
+
with MDC(correlation_id=trace_id_):
|
|
5072
5016
|
try:
|
|
5073
5017
|
if isinstance(X_, pd.Series):
|
|
5074
5018
|
X_ = X_.to_frame()
|
|
@@ -5076,27 +5020,25 @@ if response.status_code == 200:
|
|
|
5076
5020
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
5077
5021
|
X_.to_parquet(f"{tmp_dir}/x.parquet", compression="zstd")
|
|
5078
5022
|
x_digest_sha256 = file_hash(f"{tmp_dir}/x.parquet")
|
|
5079
|
-
if self.rest_client.is_file_uploaded(
|
|
5023
|
+
if self.rest_client.is_file_uploaded(trace_id_, x_digest_sha256):
|
|
5080
5024
|
self.logger.info(
|
|
5081
5025
|
f"File x.parquet was already uploaded with digest {x_digest_sha256}, skipping"
|
|
5082
5026
|
)
|
|
5083
5027
|
else:
|
|
5084
|
-
self.rest_client.dump_input_file(
|
|
5085
|
-
trace_id, f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256
|
|
5086
|
-
)
|
|
5028
|
+
self.rest_client.dump_input_file(f"{tmp_dir}/x.parquet", "x.parquet", x_digest_sha256)
|
|
5087
5029
|
|
|
5088
5030
|
if y_ is not None:
|
|
5089
5031
|
if isinstance(y_, pd.Series):
|
|
5090
5032
|
y_ = y_.to_frame()
|
|
5091
5033
|
y_.to_parquet(f"{tmp_dir}/y.parquet", compression="zstd")
|
|
5092
5034
|
y_digest_sha256 = file_hash(f"{tmp_dir}/y.parquet")
|
|
5093
|
-
if self.rest_client.is_file_uploaded(
|
|
5035
|
+
if self.rest_client.is_file_uploaded(trace_id_, y_digest_sha256):
|
|
5094
5036
|
self.logger.info(
|
|
5095
5037
|
f"File y.parquet was already uploaded with digest {y_digest_sha256}, skipping"
|
|
5096
5038
|
)
|
|
5097
5039
|
else:
|
|
5098
5040
|
self.rest_client.dump_input_file(
|
|
5099
|
-
|
|
5041
|
+
trace_id_, f"{tmp_dir}/y.parquet", "y.parquet", y_digest_sha256
|
|
5100
5042
|
)
|
|
5101
5043
|
|
|
5102
5044
|
if eval_set_ is not None and len(eval_set_) > 0:
|
|
@@ -5105,14 +5047,14 @@ if response.status_code == 200:
|
|
|
5105
5047
|
eval_x_ = eval_x_.to_frame()
|
|
5106
5048
|
eval_x_.to_parquet(f"{tmp_dir}/eval_x_{idx}.parquet", compression="zstd")
|
|
5107
5049
|
eval_x_digest_sha256 = file_hash(f"{tmp_dir}/eval_x_{idx}.parquet")
|
|
5108
|
-
if self.rest_client.is_file_uploaded(
|
|
5050
|
+
if self.rest_client.is_file_uploaded(trace_id_, eval_x_digest_sha256):
|
|
5109
5051
|
self.logger.info(
|
|
5110
5052
|
f"File eval_x_{idx}.parquet was already uploaded with"
|
|
5111
5053
|
f" digest {eval_x_digest_sha256}, skipping"
|
|
5112
5054
|
)
|
|
5113
5055
|
else:
|
|
5114
5056
|
self.rest_client.dump_input_file(
|
|
5115
|
-
|
|
5057
|
+
trace_id_,
|
|
5116
5058
|
f"{tmp_dir}/eval_x_{idx}.parquet",
|
|
5117
5059
|
f"eval_x_{idx}.parquet",
|
|
5118
5060
|
eval_x_digest_sha256,
|
|
@@ -5122,14 +5064,14 @@ if response.status_code == 200:
|
|
|
5122
5064
|
eval_y_ = eval_y_.to_frame()
|
|
5123
5065
|
eval_y_.to_parquet(f"{tmp_dir}/eval_y_{idx}.parquet", compression="zstd")
|
|
5124
5066
|
eval_y_digest_sha256 = file_hash(f"{tmp_dir}/eval_y_{idx}.parquet")
|
|
5125
|
-
if self.rest_client.is_file_uploaded(
|
|
5067
|
+
if self.rest_client.is_file_uploaded(trace_id_, eval_y_digest_sha256):
|
|
5126
5068
|
self.logger.info(
|
|
5127
5069
|
f"File eval_y_{idx}.parquet was already uploaded"
|
|
5128
5070
|
f" with digest {eval_y_digest_sha256}, skipping"
|
|
5129
5071
|
)
|
|
5130
5072
|
else:
|
|
5131
5073
|
self.rest_client.dump_input_file(
|
|
5132
|
-
|
|
5074
|
+
trace_id_,
|
|
5133
5075
|
f"{tmp_dir}/eval_y_{idx}.parquet",
|
|
5134
5076
|
f"eval_y_{idx}.parquet",
|
|
5135
5077
|
eval_y_digest_sha256,
|
|
@@ -5138,7 +5080,8 @@ if response.status_code == 200:
|
|
|
5138
5080
|
self.logger.warning("Failed to dump input files", exc_info=True)
|
|
5139
5081
|
|
|
5140
5082
|
try:
|
|
5141
|
-
|
|
5083
|
+
trace_id = self._get_trace_id()
|
|
5084
|
+
Thread(target=dump_task, args=(X, y, eval_set, trace_id), daemon=True).start()
|
|
5142
5085
|
except Exception:
|
|
5143
5086
|
self.logger.warning("Failed to dump input files", exc_info=True)
|
|
5144
5087
|
|