upgini 1.1.244a7__tar.gz → 1.1.244a8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of upgini might be problematic. Click here for more details.
- {upgini-1.1.244a7/src/upgini.egg-info → upgini-1.1.244a8}/PKG-INFO +1 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/setup.py +1 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/dataset.py +1 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/features_enricher.py +3 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/metrics.py +68 -15
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/sklearn_ext.py +18 -15
- {upgini-1.1.244a7 → upgini-1.1.244a8/src/upgini.egg-info}/PKG-INFO +1 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_continuous_dataset.py +6 -3
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_etalon_validation.py +4 -3
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_features_enricher.py +0 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_metrics.py +9 -1
- {upgini-1.1.244a7 → upgini-1.1.244a8}/LICENSE +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/README.md +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/pyproject.toml +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/setup.cfg +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/ads.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/ads_management/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/ads_management/ads_manager.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/all_operands.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/binary.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/feature.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/groupby.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/operand.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/unary.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/autofe/vector.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/data_source/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/data_source/data_source_publisher.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/errors.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/fingerprint.js +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/http.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/mdc/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/mdc/context.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/metadata.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/normalizer/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/normalizer/phone_normalizer.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/resource_bundle/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/resource_bundle/exceptions.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/resource_bundle/strings.properties +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/sampler/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/sampler/base.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/sampler/random_under_sampler.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/sampler/utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/search_task.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/spinner.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/__init__.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/base_search_key_detector.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/blocked_time_series.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/country_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/custom_loss_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/cv_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/datetime_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/deduplicate_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/display_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/email_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/fallback_progress_bar.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/features_validator.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/format.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/ip_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/phone_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/postal_code_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/progress_bar.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/target_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/track_info.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/utils/warning_counter.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini/version_validator.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini.egg-info/SOURCES.txt +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini.egg-info/dependency_links.txt +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini.egg-info/requires.txt +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/src/upgini.egg-info/top_level.txt +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_binary_dataset.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_blocked_time_series.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_categorical_dataset.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_country_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_custom_loss_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_datetime_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_email_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_phone_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_postal_code_utils.py +0 -0
- {upgini-1.1.244a7 → upgini-1.1.244a8}/tests/test_widget.py +0 -0
|
@@ -61,7 +61,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
61
61
|
FIT_SAMPLE_THRESHOLD = 200_000
|
|
62
62
|
FIT_SAMPLE_WITH_EVAL_SET_ROWS = 200_000
|
|
63
63
|
FIT_SAMPLE_WITH_EVAL_SET_THRESHOLD = 200_000
|
|
64
|
-
MIN_SAMPLE_THRESHOLD =
|
|
64
|
+
MIN_SAMPLE_THRESHOLD = 5_000
|
|
65
65
|
IMBALANCE_THESHOLD = 0.4
|
|
66
66
|
MIN_TARGET_CLASS_ROWS = 100
|
|
67
67
|
MAX_MULTICLASS_CLASS_COUNT = 100
|
|
@@ -955,6 +955,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
955
955
|
fitting_enriched_X,
|
|
956
956
|
scoring,
|
|
957
957
|
groups=groups,
|
|
958
|
+
text_features=self.generate_features,
|
|
958
959
|
)
|
|
959
960
|
metric = wrapper.metric_name
|
|
960
961
|
multiplier = wrapper.multiplier
|
|
@@ -980,6 +981,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
980
981
|
cat_features,
|
|
981
982
|
add_params=custom_loss_add_params,
|
|
982
983
|
groups=groups,
|
|
984
|
+
text_features=self.generate_features,
|
|
983
985
|
)
|
|
984
986
|
etalon_metric = baseline_estimator.cross_val_predict(
|
|
985
987
|
fitting_X, y_sorted, self.baseline_score_column
|
|
@@ -1004,6 +1006,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1004
1006
|
cat_features,
|
|
1005
1007
|
add_params=custom_loss_add_params,
|
|
1006
1008
|
groups=groups,
|
|
1009
|
+
text_features=self.generate_features,
|
|
1007
1010
|
)
|
|
1008
1011
|
enriched_metric = enriched_estimator.cross_val_predict(fitting_enriched_X, enriched_y_sorted)
|
|
1009
1012
|
self.logger.info(f"Enriched {metric} on train combined features: {enriched_metric}")
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import re
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
5
|
|
|
@@ -201,6 +202,7 @@ class EstimatorWrapper:
|
|
|
201
202
|
target_type: ModelTaskType,
|
|
202
203
|
add_params: Optional[Dict[str, Any]] = None,
|
|
203
204
|
groups: Optional[np.ndarray] = None,
|
|
205
|
+
text_features: Optional[List[str]] = None,
|
|
204
206
|
):
|
|
205
207
|
self.estimator = estimator
|
|
206
208
|
self.scorer = scorer
|
|
@@ -213,6 +215,7 @@ class EstimatorWrapper:
|
|
|
213
215
|
self.add_params = add_params
|
|
214
216
|
self.cv_estimators = None
|
|
215
217
|
self.groups = groups
|
|
218
|
+
self.text_features = text_features
|
|
216
219
|
|
|
217
220
|
def fit(self, X: pd.DataFrame, y: np.ndarray, **kwargs):
|
|
218
221
|
X, y, _, fit_params = self._prepare_to_fit(X, y)
|
|
@@ -285,6 +288,7 @@ class EstimatorWrapper:
|
|
|
285
288
|
groups=groups,
|
|
286
289
|
fit_params=fit_params,
|
|
287
290
|
return_estimator=True,
|
|
291
|
+
error_score="raise",
|
|
288
292
|
)
|
|
289
293
|
metrics_by_fold = cv_results["test_score"]
|
|
290
294
|
self.cv_estimators = cv_results["estimator"]
|
|
@@ -330,6 +334,7 @@ class EstimatorWrapper:
|
|
|
330
334
|
"cv": cv,
|
|
331
335
|
"target_type": target_type,
|
|
332
336
|
"groups": groups,
|
|
337
|
+
"text_features": text_features,
|
|
333
338
|
}
|
|
334
339
|
if estimator is None:
|
|
335
340
|
params = dict()
|
|
@@ -391,27 +396,56 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
|
391
396
|
cv: BaseCrossValidator,
|
|
392
397
|
target_type: ModelTaskType,
|
|
393
398
|
groups: Optional[List[str]] = None,
|
|
399
|
+
text_features: Optional[List[str]] = None,
|
|
394
400
|
):
|
|
395
401
|
super(CatBoostWrapper, self).__init__(
|
|
396
|
-
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups
|
|
402
|
+
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups, text_features=text_features
|
|
397
403
|
)
|
|
398
404
|
self.cat_features = None
|
|
399
405
|
self.cat_features_idx = None
|
|
406
|
+
self.emb_groups = None
|
|
400
407
|
|
|
401
408
|
def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
|
|
402
409
|
X, y, groups, params = super()._prepare_to_fit(X, y)
|
|
410
|
+
|
|
411
|
+
# Find embeddings
|
|
412
|
+
emb_pattern = r"(.+)_emb\d+"
|
|
413
|
+
emb_features = [c for c in X.columns if re.match(emb_pattern, c) and is_numeric_dtype(X[c])]
|
|
414
|
+
embedding_features = []
|
|
415
|
+
if len(emb_features) > 0:
|
|
416
|
+
# group by source feature
|
|
417
|
+
self.emb_groups = dict()
|
|
418
|
+
for emb in emb_features:
|
|
419
|
+
source_feature = re.match(emb_pattern, emb).group(1)
|
|
420
|
+
embs = self.emb_groups.get(source_feature, [])
|
|
421
|
+
embs.append(emb)
|
|
422
|
+
self.emb_groups[source_feature] = embs
|
|
423
|
+
self.emb_groups = {
|
|
424
|
+
source_feature: embs for source_feature, embs in self.emb_groups.items() if len(embs) > 1
|
|
425
|
+
}
|
|
426
|
+
X, embedding_features = self.group_embeddings(X)
|
|
427
|
+
params["embedding_features"] = embedding_features
|
|
428
|
+
|
|
429
|
+
# Find text features from passed in generate_features
|
|
430
|
+
if self.text_features is not None:
|
|
431
|
+
self.text_features = [f for f in self.text_features if not is_numeric_dtype(X[f])]
|
|
432
|
+
params["text_features"] = self.text_features
|
|
433
|
+
|
|
434
|
+
# Find rest categorical features
|
|
403
435
|
self.cat_features = _get_cat_features(X)
|
|
436
|
+
if self.text_features is not None:
|
|
437
|
+
self.cat_features = [
|
|
438
|
+
f for f in self.cat_features if f not in self.text_features and f not in embedding_features
|
|
439
|
+
]
|
|
404
440
|
X = fill_na_cat_features(X, self.cat_features)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
# cat_features_idx = [X.columns.get_loc(c) for c in unique_cat_features]
|
|
414
|
-
self.cat_features_idx = [X.columns.get_loc(c) for c in self.cat_features]
|
|
441
|
+
unique_cat_features = []
|
|
442
|
+
for name in self.cat_features:
|
|
443
|
+
# Remove constant categorical features
|
|
444
|
+
if X[name].nunique() > 1:
|
|
445
|
+
unique_cat_features.append(name)
|
|
446
|
+
else:
|
|
447
|
+
X = X.drop(columns=name)
|
|
448
|
+
self.cat_features_idx = [X.columns.get_loc(c) for c in unique_cat_features]
|
|
415
449
|
if (
|
|
416
450
|
hasattr(self.estimator, "get_param")
|
|
417
451
|
and hasattr(self.estimator, "_init_params")
|
|
@@ -422,15 +456,32 @@ class CatBoostWrapper(EstimatorWrapper):
|
|
|
422
456
|
self.cat_features_idx = list(cat_features_set)
|
|
423
457
|
del self.estimator._init_params["cat_features"]
|
|
424
458
|
|
|
425
|
-
params
|
|
459
|
+
params["cat_features"] = self.cat_features_idx
|
|
460
|
+
|
|
426
461
|
return X, y, groups, params
|
|
427
462
|
|
|
463
|
+
def group_embeddings(self, df: pd.DataFrame):
|
|
464
|
+
emb_columns = []
|
|
465
|
+
for source_feature, embs in self.emb_groups.items():
|
|
466
|
+
emb_name = f"{source_feature}_emb"
|
|
467
|
+
df[embs] = df[embs].fillna(0.0)
|
|
468
|
+
df[emb_name] = df[embs].values.tolist()
|
|
469
|
+
df = df.drop(columns=embs)
|
|
470
|
+
emb_columns.append(emb_name)
|
|
471
|
+
return df, emb_columns
|
|
472
|
+
|
|
428
473
|
def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
|
|
429
474
|
X, y, params = super()._prepare_to_calculate(X, y)
|
|
475
|
+
if self.text_features is not None:
|
|
476
|
+
params["text_features"] = self.text_features
|
|
477
|
+
if self.emb_groups is not None:
|
|
478
|
+
X, emb_columns = self.group_embeddings(X)
|
|
479
|
+
params["embedding_features"] = emb_columns
|
|
430
480
|
if self.cat_features is not None:
|
|
431
481
|
X = fill_na_cat_features(X, self.cat_features)
|
|
432
482
|
if self.cat_features_idx is not None:
|
|
433
|
-
params
|
|
483
|
+
params["cat_features"] = self.cat_features_idx
|
|
484
|
+
|
|
434
485
|
return X, y, params
|
|
435
486
|
|
|
436
487
|
|
|
@@ -444,9 +495,10 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
|
444
495
|
cv: BaseCrossValidator,
|
|
445
496
|
target_type: ModelTaskType,
|
|
446
497
|
groups: Optional[List[str]] = None,
|
|
498
|
+
text_features: Optional[List[str]] = None,
|
|
447
499
|
):
|
|
448
500
|
super(LightGBMWrapper, self).__init__(
|
|
449
|
-
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups
|
|
501
|
+
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups, text_features=text_features
|
|
450
502
|
)
|
|
451
503
|
self.cat_features = None
|
|
452
504
|
|
|
@@ -482,9 +534,10 @@ class OtherEstimatorWrapper(EstimatorWrapper):
|
|
|
482
534
|
cv: BaseCrossValidator,
|
|
483
535
|
target_type: ModelTaskType,
|
|
484
536
|
groups: Optional[List[str]] = None,
|
|
537
|
+
text_features: Optional[List[str]] = None,
|
|
485
538
|
):
|
|
486
539
|
super(OtherEstimatorWrapper, self).__init__(
|
|
487
|
-
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups
|
|
540
|
+
estimator, scorer, metric_name, multiplier, cv, target_type, groups=groups, text_features=text_features
|
|
488
541
|
)
|
|
489
542
|
self.cat_features = None
|
|
490
543
|
|
|
@@ -21,6 +21,7 @@ from sklearn.metrics._scorer import _MultimetricScorer
|
|
|
21
21
|
from sklearn.model_selection import check_cv
|
|
22
22
|
from sklearn.utils.fixes import np_version, parse_version
|
|
23
23
|
from sklearn.utils.validation import indexable
|
|
24
|
+
from sklearn.model_selection import cross_validate as original_cross_validate
|
|
24
25
|
|
|
25
26
|
_DEFAULT_TAGS = {
|
|
26
27
|
"non_deterministic": False,
|
|
@@ -313,21 +314,23 @@ def cross_validate(
|
|
|
313
314
|
return ret
|
|
314
315
|
except Exception:
|
|
315
316
|
logging.exception("Failed to execute overriden cross_validate. Fallback to original")
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
317
|
+
raise
|
|
318
|
+
# fit_params["use_best_model"] = False
|
|
319
|
+
# return original_cross_validate(
|
|
320
|
+
# estimator,
|
|
321
|
+
# X,
|
|
322
|
+
# y,
|
|
323
|
+
# groups=groups,
|
|
324
|
+
# scoring=scoring,
|
|
325
|
+
# cv=cv,
|
|
326
|
+
# n_jobs=n_jobs,
|
|
327
|
+
# verbose=verbose,
|
|
328
|
+
# fit_params=fit_params,
|
|
329
|
+
# pre_dispatch=pre_dispatch,
|
|
330
|
+
# return_train_score=return_train_score,
|
|
331
|
+
# return_estimator=return_estimator,
|
|
332
|
+
# error_score=error_score,
|
|
333
|
+
# )
|
|
331
334
|
|
|
332
335
|
|
|
333
336
|
def _fit_and_score(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Dict
|
|
2
3
|
|
|
3
4
|
import pandas as pd
|
|
4
5
|
import pytest
|
|
@@ -13,7 +14,7 @@ FIXTURE_DIR = os.path.join(
|
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
@pytest.fixture
|
|
16
|
-
def etalon_definition():
|
|
17
|
+
def etalon_definition() -> Dict[str, FileColumnMeaningType]:
|
|
17
18
|
return {
|
|
18
19
|
"phone_num": FileColumnMeaningType.MSISDN,
|
|
19
20
|
"rep_date": FileColumnMeaningType.DATE,
|
|
@@ -27,9 +28,11 @@ def etalon_search_keys():
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
@pytest.mark.datafiles(os.path.join(FIXTURE_DIR, "data.csv.gz"))
|
|
30
|
-
def test_continuous_dataset(datafiles, etalon_definition, etalon_search_keys):
|
|
31
|
+
def test_continuous_dataset(datafiles, etalon_definition: Dict[str, FileColumnMeaningType], etalon_search_keys):
|
|
31
32
|
df = pd.read_csv(datafiles / "data.csv.gz")
|
|
32
|
-
df = df.reset_index().rename(columns={"index": "system_record_id"})
|
|
33
|
+
df = df.reset_index().rename(columns={"index": "system_record_id", "score": "target"})
|
|
34
|
+
del etalon_definition["score"]
|
|
35
|
+
etalon_definition["target"] = FileColumnMeaningType.TARGET
|
|
33
36
|
converter = DateTimeSearchKeyConverter("rep_date")
|
|
34
37
|
df = converter.convert(df)
|
|
35
38
|
ds = Dataset(
|
|
@@ -25,7 +25,7 @@ def test_etalon_validation(etalon: Dataset):
|
|
|
25
25
|
valid_count = len(etalon)
|
|
26
26
|
valid_rate = 100 * valid_count / count
|
|
27
27
|
|
|
28
|
-
assert valid_count ==
|
|
28
|
+
assert valid_count == 5
|
|
29
29
|
valid_rate_expected = 100 * (valid_count / 10)
|
|
30
30
|
assert valid_rate == pytest.approx(valid_rate_expected, abs=0.01)
|
|
31
31
|
|
|
@@ -497,8 +497,9 @@ def test_time_cutoff_from_period():
|
|
|
497
497
|
def test_time_cutoff_from_timestamp():
|
|
498
498
|
df = pd.DataFrame({"date": [1577836800000000000, 1577840400000000000, 1577844000000000000]})
|
|
499
499
|
converter = DateTimeSearchKeyConverter("date")
|
|
500
|
-
with pytest.raises(Exception, match="Unsupported type of date column date.*"):
|
|
501
|
-
|
|
500
|
+
# with pytest.raises(Exception, match="Unsupported type of date column date.*"):
|
|
501
|
+
df = converter.convert(df)
|
|
502
|
+
assert len(df) == 3
|
|
502
503
|
|
|
503
504
|
|
|
504
505
|
def test_time_cutoff_with_different_timezones():
|
|
@@ -2161,7 +2161,6 @@ def test_idempotent_order_with_imbalanced_dataset(requests_mock: Mocker):
|
|
|
2161
2161
|
pass
|
|
2162
2162
|
|
|
2163
2163
|
actual_result_df = result_wrapper.df.sort_values(by="system_record_id").reset_index(drop=True)
|
|
2164
|
-
|
|
2165
2164
|
assert_frame_equal(actual_result_df, expected_result_df)
|
|
2166
2165
|
|
|
2167
2166
|
for i in range(5):
|
|
@@ -147,6 +147,7 @@ def test_real_case_metric_binary(requests_mock: Mocker):
|
|
|
147
147
|
mock_raw_features(requests_mock, url, search_task_id, path_to_mock_features)
|
|
148
148
|
|
|
149
149
|
train = pd.read_parquet(os.path.join(BASE_DIR, "real_train.parquet"))
|
|
150
|
+
train.sort_index()
|
|
150
151
|
X = train[["request_date", "score"]]
|
|
151
152
|
y = train["target1"].rename("target")
|
|
152
153
|
test = pd.read_parquet(os.path.join(BASE_DIR, "real_test.parquet"))
|
|
@@ -168,18 +169,25 @@ def test_real_case_metric_binary(requests_mock: Mocker):
|
|
|
168
169
|
enricher.eval_set = eval_set
|
|
169
170
|
|
|
170
171
|
enriched_X = pd.read_parquet(os.path.join(BASE_DIR, "real_enriched_x.parquet"))
|
|
172
|
+
|
|
173
|
+
# TODO join enriched_X and X and y by index and then sort and add system_record_id
|
|
174
|
+
|
|
175
|
+
enriched_X = enriched_X.sort_values(by="request_date").reset_index().rename(columns={"index": "system_record_id"})
|
|
171
176
|
enriched_eval_x = pd.read_parquet(os.path.join(BASE_DIR, "real_enriched_eval_x.parquet"))
|
|
177
|
+
enriched_eval_x = enriched_eval_x.sort_values(by="request_date").reset_index().rename(columns={"index": "system_record_id"})
|
|
172
178
|
|
|
173
179
|
sampled_Xy = X.copy()
|
|
174
180
|
sampled_Xy["target"] = y
|
|
175
181
|
sampled_Xy = sampled_Xy[sampled_Xy.index.isin(enriched_X.index)]
|
|
176
182
|
sampled_X = sampled_Xy.drop(columns="target")
|
|
183
|
+
sampled_X = sampled_X.reset_index().rename(columns={"index": "system_record_id"})
|
|
177
184
|
sampled_y = sampled_Xy["target"]
|
|
185
|
+
sampled_eval_x = eval_set[0][0].reset_index().rename(columns={"index": "system_record_id"})
|
|
178
186
|
enricher._FeaturesEnricher__cached_sampled_datasets = (
|
|
179
187
|
sampled_X,
|
|
180
188
|
sampled_y,
|
|
181
189
|
enriched_X,
|
|
182
|
-
{0: (
|
|
190
|
+
{0: (sampled_eval_x, enriched_eval_x, eval_set[0][1])},
|
|
183
191
|
search_keys,
|
|
184
192
|
)
|
|
185
193
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|