upgini 1.2.37__py3-none-any.whl → 1.2.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of upgini might be problematic. Click here for more details.
- upgini/__about__.py +1 -1
- upgini/dataset.py +24 -2
- upgini/features_enricher.py +72 -20
- upgini/metadata.py +3 -0
- upgini/utils/target_utils.py +78 -3
- {upgini-1.2.37.dist-info → upgini-1.2.38.dist-info}/METADATA +1 -1
- {upgini-1.2.37.dist-info → upgini-1.2.38.dist-info}/RECORD +9 -9
- {upgini-1.2.37.dist-info → upgini-1.2.38.dist-info}/WHEEL +1 -1
- {upgini-1.2.37.dist-info → upgini-1.2.38.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.2.
|
|
1
|
+
__version__ = "1.2.38"
|
upgini/dataset.py
CHANGED
|
@@ -22,6 +22,7 @@ from upgini.metadata import (
|
|
|
22
22
|
EVAL_SET_INDEX,
|
|
23
23
|
SYSTEM_RECORD_ID,
|
|
24
24
|
TARGET,
|
|
25
|
+
CVType,
|
|
25
26
|
DataType,
|
|
26
27
|
FeaturesFilter,
|
|
27
28
|
FileColumnMeaningType,
|
|
@@ -32,11 +33,12 @@ from upgini.metadata import (
|
|
|
32
33
|
NumericInterval,
|
|
33
34
|
RuntimeParameters,
|
|
34
35
|
SearchCustomization,
|
|
36
|
+
SearchKey,
|
|
35
37
|
)
|
|
36
38
|
from upgini.resource_bundle import ResourceBundle, get_custom_bundle
|
|
37
39
|
from upgini.search_task import SearchTask
|
|
38
40
|
from upgini.utils.email_utils import EmailSearchKeyConverter
|
|
39
|
-
from upgini.utils.target_utils import balance_undersample, balance_undersample_forced
|
|
41
|
+
from upgini.utils.target_utils import balance_undersample, balance_undersample_forced, balance_undersample_time_series
|
|
40
42
|
|
|
41
43
|
try:
|
|
42
44
|
from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
|
|
@@ -74,6 +76,8 @@ class Dataset: # (pd.DataFrame):
|
|
|
74
76
|
search_keys: Optional[List[Tuple[str, ...]]] = None,
|
|
75
77
|
unnest_search_keys: Optional[Dict[str, str]] = None,
|
|
76
78
|
model_task_type: Optional[ModelTaskType] = None,
|
|
79
|
+
cv_type: Optional[CVType] = None,
|
|
80
|
+
id_columns: Optional[List[str]] = None,
|
|
77
81
|
random_state: Optional[int] = None,
|
|
78
82
|
rest_client: Optional[_RestClient] = None,
|
|
79
83
|
logger: Optional[logging.Logger] = None,
|
|
@@ -104,6 +108,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
104
108
|
|
|
105
109
|
self.dataset_name = dataset_name
|
|
106
110
|
self.task_type = model_task_type
|
|
111
|
+
self.cv_type = cv_type
|
|
107
112
|
self.description = description
|
|
108
113
|
self.meaning_types = meaning_types
|
|
109
114
|
self.search_keys = search_keys
|
|
@@ -116,6 +121,7 @@ class Dataset: # (pd.DataFrame):
|
|
|
116
121
|
self.random_state = random_state
|
|
117
122
|
self.columns_renaming: Dict[str, str] = {}
|
|
118
123
|
self.imbalanced: bool = False
|
|
124
|
+
self.id_columns = id_columns
|
|
119
125
|
if logger is not None:
|
|
120
126
|
self.logger = logger
|
|
121
127
|
else:
|
|
@@ -225,6 +231,8 @@ class Dataset: # (pd.DataFrame):
|
|
|
225
231
|
df=self.data,
|
|
226
232
|
target_column=target_column,
|
|
227
233
|
task_type=self.task_type,
|
|
234
|
+
cv_type=self.cv_type,
|
|
235
|
+
id_columns=self.id_columns,
|
|
228
236
|
random_state=self.random_state,
|
|
229
237
|
sample_size=self.FORCE_SAMPLE_SIZE,
|
|
230
238
|
logger=self.logger,
|
|
@@ -297,7 +305,21 @@ class Dataset: # (pd.DataFrame):
|
|
|
297
305
|
f"Etalon has size {len(self.data)} more than threshold {sample_threshold} "
|
|
298
306
|
f"and will be downsampled to {sample_rows}"
|
|
299
307
|
)
|
|
300
|
-
|
|
308
|
+
if self.cv_type is not None and self.cv_type.is_time_series():
|
|
309
|
+
resampled_data = balance_undersample_time_series(
|
|
310
|
+
df=self.data,
|
|
311
|
+
id_columns=self.id_columns,
|
|
312
|
+
date_column=next(
|
|
313
|
+
k
|
|
314
|
+
for k, v in self.meaning_types.items()
|
|
315
|
+
if v in [FileColumnMeaningType.DATE, FileColumnMeaningType.DATETIME]
|
|
316
|
+
),
|
|
317
|
+
sample_size=sample_rows,
|
|
318
|
+
random_state=self.random_state,
|
|
319
|
+
logger=self.logger,
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
resampled_data = self.data.sample(n=sample_rows, random_state=self.random_state)
|
|
301
323
|
self.data = resampled_data
|
|
302
324
|
self.logger.info(f"Shape after threshold resampling: {self.data.shape}")
|
|
303
325
|
|
upgini/features_enricher.py
CHANGED
|
@@ -237,6 +237,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
237
237
|
add_date_if_missing: bool = True,
|
|
238
238
|
select_features: bool = False,
|
|
239
239
|
disable_force_downsampling: bool = False,
|
|
240
|
+
id_columns: Optional[List[str]] = None,
|
|
240
241
|
**kwargs,
|
|
241
242
|
):
|
|
242
243
|
self.bundle = get_custom_bundle(custom_bundle_config)
|
|
@@ -277,9 +278,12 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
277
278
|
)
|
|
278
279
|
|
|
279
280
|
validate_version(self.logger, self.__log_warning)
|
|
281
|
+
|
|
280
282
|
self.search_keys = search_keys or {}
|
|
283
|
+
self.id_columns = id_columns
|
|
281
284
|
self.country_code = country_code
|
|
282
285
|
self.__validate_search_keys(search_keys, search_id)
|
|
286
|
+
|
|
283
287
|
self.model_task_type = model_task_type
|
|
284
288
|
self.endpoint = endpoint
|
|
285
289
|
self._search_task: Optional[SearchTask] = None
|
|
@@ -928,6 +932,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
928
932
|
cat_features, search_keys_for_metrics = self._get_client_cat_features(
|
|
929
933
|
estimator, validated_X, self.search_keys
|
|
930
934
|
)
|
|
935
|
+
search_keys_for_metrics.extend([c for c in self.id_columns or [] if c not in search_keys_for_metrics])
|
|
936
|
+
self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
|
|
931
937
|
|
|
932
938
|
prepared_data = self._prepare_data_for_metrics(
|
|
933
939
|
trace_id=trace_id,
|
|
@@ -983,7 +989,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
983
989
|
with Spinner():
|
|
984
990
|
self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
|
|
985
991
|
|
|
986
|
-
has_date =
|
|
992
|
+
has_date = self._get_date_column(search_keys) is not None
|
|
987
993
|
model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
|
|
988
994
|
|
|
989
995
|
wrapper = EstimatorWrapper.create(
|
|
@@ -1185,7 +1191,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1185
1191
|
)
|
|
1186
1192
|
|
|
1187
1193
|
uplift_col = self.bundle.get("quality_metrics_uplift_header")
|
|
1188
|
-
date_column =
|
|
1194
|
+
date_column = self._get_date_column(search_keys)
|
|
1189
1195
|
if (
|
|
1190
1196
|
uplift_col in metrics_df.columns
|
|
1191
1197
|
and (metrics_df[uplift_col] < 0).any()
|
|
@@ -1354,7 +1360,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1354
1360
|
groups = None
|
|
1355
1361
|
|
|
1356
1362
|
if not isinstance(_cv, BaseCrossValidator):
|
|
1357
|
-
date_column =
|
|
1363
|
+
date_column = self._get_date_column(search_keys)
|
|
1358
1364
|
date_series = X[date_column] if date_column is not None else None
|
|
1359
1365
|
_cv, groups = CVConfig(
|
|
1360
1366
|
_cv, date_series, self.random_state, self._search_task.get_shuffle_kfold(), group_columns=group_columns
|
|
@@ -1443,9 +1449,13 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1443
1449
|
|
|
1444
1450
|
excluding_search_keys = list(search_keys.keys())
|
|
1445
1451
|
if search_keys_for_metrics is not None and len(search_keys_for_metrics) > 0:
|
|
1452
|
+
excluded = set()
|
|
1446
1453
|
for sk in excluding_search_keys:
|
|
1447
1454
|
if columns_renaming.get(sk) in search_keys_for_metrics:
|
|
1448
|
-
|
|
1455
|
+
excluded.add(sk)
|
|
1456
|
+
excluding_search_keys = [sk for sk in excluding_search_keys if sk not in excluded]
|
|
1457
|
+
|
|
1458
|
+
self.logger.info(f"Excluding search keys: {excluding_search_keys}")
|
|
1449
1459
|
|
|
1450
1460
|
client_features = [
|
|
1451
1461
|
c
|
|
@@ -1667,7 +1677,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1667
1677
|
search_keys = self.search_keys.copy()
|
|
1668
1678
|
search_keys = self.__prepare_search_keys(df, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
|
|
1669
1679
|
|
|
1670
|
-
date_column =
|
|
1680
|
+
date_column = self._get_date_column(search_keys)
|
|
1671
1681
|
generated_features = []
|
|
1672
1682
|
if date_column is not None:
|
|
1673
1683
|
converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
|
|
@@ -1741,7 +1751,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1741
1751
|
search_keys = self.fit_search_keys
|
|
1742
1752
|
|
|
1743
1753
|
rows_to_drop = None
|
|
1744
|
-
has_date =
|
|
1754
|
+
has_date = self._get_date_column(search_keys) is not None
|
|
1745
1755
|
self.model_task_type = self.model_task_type or define_task(
|
|
1746
1756
|
self.df_with_original_index[TARGET], has_date, self.logger, silent=True
|
|
1747
1757
|
)
|
|
@@ -1853,7 +1863,10 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
1853
1863
|
df = balance_undersample_forced(
|
|
1854
1864
|
df=df,
|
|
1855
1865
|
target_column=TARGET,
|
|
1866
|
+
id_columns=self.id_columns,
|
|
1867
|
+
date_column=self._get_date_column(self.search_keys),
|
|
1856
1868
|
task_type=self.model_task_type,
|
|
1869
|
+
cv_type=self.cv,
|
|
1857
1870
|
random_state=self.random_state,
|
|
1858
1871
|
sample_size=Dataset.FORCE_SAMPLE_SIZE,
|
|
1859
1872
|
logger=self.logger,
|
|
@@ -2097,7 +2110,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2097
2110
|
return None, {c: c for c in X.columns}, []
|
|
2098
2111
|
|
|
2099
2112
|
features_meta = self._search_task.get_all_features_metadata_v2()
|
|
2100
|
-
online_api_features = [fm.name for fm in features_meta if fm.from_online_api]
|
|
2113
|
+
online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
|
|
2101
2114
|
if len(online_api_features) > 0:
|
|
2102
2115
|
self.logger.warning(
|
|
2103
2116
|
f"There are important features for transform, that generated by online API: {online_api_features}"
|
|
@@ -2137,6 +2150,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2137
2150
|
validated_X = validated_X.drop(columns=columns_to_drop)
|
|
2138
2151
|
|
|
2139
2152
|
search_keys = self.search_keys.copy()
|
|
2153
|
+
if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
|
|
2154
|
+
self.search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns})
|
|
2155
|
+
|
|
2140
2156
|
search_keys = self.__prepare_search_keys(
|
|
2141
2157
|
validated_X, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
|
|
2142
2158
|
)
|
|
@@ -2155,7 +2171,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2155
2171
|
df = self.__add_country_code(df, search_keys)
|
|
2156
2172
|
|
|
2157
2173
|
generated_features = []
|
|
2158
|
-
date_column =
|
|
2174
|
+
date_column = self._get_date_column(search_keys)
|
|
2159
2175
|
if date_column is not None:
|
|
2160
2176
|
converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
|
|
2161
2177
|
df = converter.convert(df, keep_time=True)
|
|
@@ -2163,7 +2179,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2163
2179
|
generated_features.extend(converter.generated_features)
|
|
2164
2180
|
else:
|
|
2165
2181
|
self.logger.info("Input dataset hasn't date column")
|
|
2166
|
-
if self.
|
|
2182
|
+
if self.__should_add_date_column():
|
|
2167
2183
|
df = self._add_current_date_as_key(df, search_keys, self.logger, self.bundle)
|
|
2168
2184
|
|
|
2169
2185
|
email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
|
|
@@ -2294,6 +2310,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2294
2310
|
meaning_types=meaning_types,
|
|
2295
2311
|
search_keys=combined_search_keys,
|
|
2296
2312
|
unnest_search_keys=unnest_search_keys,
|
|
2313
|
+
id_columns=self.__get_renamed_id_columns(columns_renaming),
|
|
2297
2314
|
date_format=self.date_format,
|
|
2298
2315
|
rest_client=self.rest_client,
|
|
2299
2316
|
logger=self.logger,
|
|
@@ -2446,7 +2463,14 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2446
2463
|
# Multiple search keys allowed only for PHONE, IP, POSTAL_CODE, EMAIL, HEM
|
|
2447
2464
|
multi_keys = [key for key, count in Counter(key_types).items() if count > 1]
|
|
2448
2465
|
for multi_key in multi_keys:
|
|
2449
|
-
if multi_key not in [
|
|
2466
|
+
if multi_key not in [
|
|
2467
|
+
SearchKey.PHONE,
|
|
2468
|
+
SearchKey.IP,
|
|
2469
|
+
SearchKey.POSTAL_CODE,
|
|
2470
|
+
SearchKey.EMAIL,
|
|
2471
|
+
SearchKey.HEM,
|
|
2472
|
+
SearchKey.CUSTOM_KEY,
|
|
2473
|
+
]:
|
|
2450
2474
|
msg = self.bundle.get("unsupported_multi_key").format(multi_key)
|
|
2451
2475
|
self.logger.warning(msg)
|
|
2452
2476
|
raise ValidationError(msg)
|
|
@@ -2610,7 +2634,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2610
2634
|
self.fit_generated_features.extend(converter.generated_features)
|
|
2611
2635
|
else:
|
|
2612
2636
|
self.logger.info("Input dataset hasn't date column")
|
|
2613
|
-
if self.
|
|
2637
|
+
if self.__should_add_date_column():
|
|
2614
2638
|
df = self._add_current_date_as_key(df, self.fit_search_keys, self.logger, self.bundle)
|
|
2615
2639
|
|
|
2616
2640
|
email_columns = SearchKey.find_all_keys(self.fit_search_keys, SearchKey.EMAIL)
|
|
@@ -2643,6 +2667,13 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2643
2667
|
|
|
2644
2668
|
self.__adjust_cv(df)
|
|
2645
2669
|
|
|
2670
|
+
if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
|
|
2671
|
+
id_columns = self.__get_renamed_id_columns()
|
|
2672
|
+
if id_columns:
|
|
2673
|
+
self.fit_search_keys.update({col: SearchKey.CUSTOM_KEY for col in id_columns})
|
|
2674
|
+
self.search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns})
|
|
2675
|
+
self.runtime_parameters.properties["id_columns"] = ",".join(id_columns)
|
|
2676
|
+
|
|
2646
2677
|
df, fintech_warnings = remove_fintech_duplicates(
|
|
2647
2678
|
df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
|
|
2648
2679
|
)
|
|
@@ -2764,6 +2795,8 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2764
2795
|
search_keys=combined_search_keys,
|
|
2765
2796
|
unnest_search_keys=unnest_search_keys,
|
|
2766
2797
|
model_task_type=self.model_task_type,
|
|
2798
|
+
cv_type=self.cv,
|
|
2799
|
+
id_columns=self.__get_renamed_id_columns(),
|
|
2767
2800
|
date_format=self.date_format,
|
|
2768
2801
|
random_state=self.random_state,
|
|
2769
2802
|
rest_client=self.rest_client,
|
|
@@ -2920,6 +2953,14 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
2920
2953
|
if not self.warning_counter.has_warnings():
|
|
2921
2954
|
self.__display_support_link(self.bundle.get("all_ok_community_invite"))
|
|
2922
2955
|
|
|
2956
|
+
def __should_add_date_column(self):
|
|
2957
|
+
return self.add_date_if_missing or (self.cv is not None and self.cv.is_time_series())
|
|
2958
|
+
|
|
2959
|
+
def __get_renamed_id_columns(self, renaming: Optional[Dict[str, str]] = None):
|
|
2960
|
+
renaming = renaming or self.fit_columns_renaming
|
|
2961
|
+
reverse_renaming = {v: k for k, v in renaming.items()}
|
|
2962
|
+
return None if self.id_columns is None else [reverse_renaming.get(c) or c for c in self.id_columns]
|
|
2963
|
+
|
|
2923
2964
|
def __adjust_cv(self, df: pd.DataFrame):
|
|
2924
2965
|
date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
|
|
2925
2966
|
# Check Multivariate time series
|
|
@@ -3165,7 +3206,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3165
3206
|
if DateTimeSearchKeyConverter.DATETIME_COL in X.columns:
|
|
3166
3207
|
date_column = DateTimeSearchKeyConverter.DATETIME_COL
|
|
3167
3208
|
else:
|
|
3168
|
-
date_column =
|
|
3209
|
+
date_column = FeaturesEnricher._get_date_column(search_keys)
|
|
3169
3210
|
sort_columns = [date_column] if date_column is not None else []
|
|
3170
3211
|
|
|
3171
3212
|
# Xy = pd.concat([X, y], axis=1)
|
|
@@ -3228,6 +3269,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3228
3269
|
f"Generate features: {self.generate_features}\n"
|
|
3229
3270
|
f"Round embeddings: {self.round_embeddings}\n"
|
|
3230
3271
|
f"Detect missing search keys: {self.detect_missing_search_keys}\n"
|
|
3272
|
+
f"Exclude columns: {self.exclude_columns}\n"
|
|
3231
3273
|
f"Exclude features sources: {exclude_features_sources}\n"
|
|
3232
3274
|
f"Calculate metrics: {calculate_metrics}\n"
|
|
3233
3275
|
f"Scoring: {scoring}\n"
|
|
@@ -3235,6 +3277,15 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3235
3277
|
f"Remove target outliers: {remove_outliers_calc_metrics}\n"
|
|
3236
3278
|
f"Exclude columns: {self.exclude_columns}\n"
|
|
3237
3279
|
f"Search id: {self.search_id}\n"
|
|
3280
|
+
f"Custom loss: {self.loss}\n"
|
|
3281
|
+
f"Logs enabled: {self.logs_enabled}\n"
|
|
3282
|
+
f"Raise validation error: {self.raise_validation_error}\n"
|
|
3283
|
+
f"Baseline score column: {self.baseline_score_column}\n"
|
|
3284
|
+
f"Client ip: {self.client_ip}\n"
|
|
3285
|
+
f"Client visitorId: {self.client_visitorid}\n"
|
|
3286
|
+
f"Add date if missing: {self.add_date_if_missing}\n"
|
|
3287
|
+
f"Select features: {self.select_features}\n"
|
|
3288
|
+
f"Disable force downsampling: {self.disable_force_downsampling}\n"
|
|
3238
3289
|
)
|
|
3239
3290
|
|
|
3240
3291
|
def sample(df):
|
|
@@ -3357,6 +3408,10 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3357
3408
|
if t == SearchKey.POSTAL_CODE:
|
|
3358
3409
|
return col
|
|
3359
3410
|
|
|
3411
|
+
@staticmethod
|
|
3412
|
+
def _get_date_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
|
|
3413
|
+
return SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
|
|
3414
|
+
|
|
3360
3415
|
def _explode_multiple_search_keys(
|
|
3361
3416
|
self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], columns_renaming: Dict[str, str]
|
|
3362
3417
|
) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
|
|
@@ -3365,7 +3420,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3365
3420
|
for key_name, key_type in search_keys.items():
|
|
3366
3421
|
search_key_names_by_type[key_type] = search_key_names_by_type.get(key_type, []) + [key_name]
|
|
3367
3422
|
search_key_names_by_type = {
|
|
3368
|
-
key_type: key_names
|
|
3423
|
+
key_type: key_names
|
|
3424
|
+
for key_type, key_names in search_key_names_by_type.items()
|
|
3425
|
+
if len(key_names) > 1 and key_type != SearchKey.CUSTOM_KEY
|
|
3369
3426
|
}
|
|
3370
3427
|
if len(search_key_names_by_type) == 0:
|
|
3371
3428
|
return df, {}
|
|
@@ -3418,9 +3475,9 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3418
3475
|
]
|
|
3419
3476
|
if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
|
|
3420
3477
|
date_column = DateTimeSearchKeyConverter.DATETIME_COL
|
|
3421
|
-
sort_exclude_columns.append(
|
|
3478
|
+
sort_exclude_columns.append(self._get_date_column(search_keys))
|
|
3422
3479
|
else:
|
|
3423
|
-
date_column =
|
|
3480
|
+
date_column = self._get_date_column(search_keys)
|
|
3424
3481
|
sort_columns = [date_column] if date_column is not None else []
|
|
3425
3482
|
|
|
3426
3483
|
sorted_other_keys = sorted(search_keys, key=lambda x: str(search_keys.get(x)))
|
|
@@ -3856,11 +3913,6 @@ class FeaturesEnricher(TransformerMixin):
|
|
|
3856
3913
|
self.logger.warning(msg + f" Provided search keys: {search_keys}")
|
|
3857
3914
|
raise ValidationError(msg)
|
|
3858
3915
|
|
|
3859
|
-
if SearchKey.CUSTOM_KEY in valid_search_keys.values():
|
|
3860
|
-
custom_keys = [column for column, key in valid_search_keys.items() if key == SearchKey.CUSTOM_KEY]
|
|
3861
|
-
for key in custom_keys:
|
|
3862
|
-
del valid_search_keys[key]
|
|
3863
|
-
|
|
3864
3916
|
if (
|
|
3865
3917
|
len(valid_search_keys.values()) == 1
|
|
3866
3918
|
and self.country_code is None
|
upgini/metadata.py
CHANGED
upgini/utils/target_utils.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
1
|
+
import itertools
|
|
1
2
|
import logging
|
|
2
|
-
from typing import Callable, Optional, Union
|
|
3
|
+
from typing import Callable, List, Optional, Union
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
from pandas.api.types import is_numeric_dtype, is_bool_dtype
|
|
7
8
|
|
|
8
9
|
from upgini.errors import ValidationError
|
|
9
|
-
from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
|
|
10
|
+
from upgini.metadata import SYSTEM_RECORD_ID, CVType, ModelTaskType
|
|
10
11
|
from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
|
|
11
12
|
from upgini.sampler.random_under_sampler import RandomUnderSampler
|
|
12
13
|
|
|
14
|
+
TS_MIN_DIFFERENT_IDS_RATIO = 0.2
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
def correct_string_target(y: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
|
|
15
18
|
if isinstance(y, pd.Series):
|
|
@@ -201,7 +204,10 @@ def balance_undersample(
|
|
|
201
204
|
def balance_undersample_forced(
|
|
202
205
|
df: pd.DataFrame,
|
|
203
206
|
target_column: str,
|
|
207
|
+
id_columns: List[str],
|
|
208
|
+
date_column: str,
|
|
204
209
|
task_type: ModelTaskType,
|
|
210
|
+
cv_type: CVType | None,
|
|
205
211
|
random_state: int,
|
|
206
212
|
sample_size: int = 7000,
|
|
207
213
|
logger: Optional[logging.Logger] = None,
|
|
@@ -233,7 +239,17 @@ def balance_undersample_forced(
|
|
|
233
239
|
|
|
234
240
|
resampled_data = df
|
|
235
241
|
df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
|
|
236
|
-
if
|
|
242
|
+
if cv_type is not None and cv_type.is_time_series():
|
|
243
|
+
logger.warning(f"Sampling time series dataset from {len(df)} to {sample_size}")
|
|
244
|
+
resampled_data = balance_undersample_time_series(
|
|
245
|
+
df,
|
|
246
|
+
id_columns=id_columns,
|
|
247
|
+
date_column=date_column,
|
|
248
|
+
sample_size=sample_size,
|
|
249
|
+
random_state=random_state,
|
|
250
|
+
logger=logger,
|
|
251
|
+
)
|
|
252
|
+
elif task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION]:
|
|
237
253
|
logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
|
|
238
254
|
resampled_data = df.sample(n=sample_size, random_state=random_state)
|
|
239
255
|
else:
|
|
@@ -264,6 +280,65 @@ def balance_undersample_forced(
|
|
|
264
280
|
return resampled_data
|
|
265
281
|
|
|
266
282
|
|
|
283
|
+
def balance_undersample_time_series(
|
|
284
|
+
df: pd.DataFrame,
|
|
285
|
+
id_columns: List[str],
|
|
286
|
+
date_column: str,
|
|
287
|
+
sample_size: int,
|
|
288
|
+
random_state: int = 42,
|
|
289
|
+
min_different_ids_ratio: float = TS_MIN_DIFFERENT_IDS_RATIO,
|
|
290
|
+
prefer_recent_dates: bool = True,
|
|
291
|
+
logger: Optional[logging.Logger] = None,
|
|
292
|
+
):
|
|
293
|
+
def ensure_tuple(x):
|
|
294
|
+
return tuple([x]) if not isinstance(x, tuple) else x
|
|
295
|
+
|
|
296
|
+
random_state = np.random.RandomState(random_state)
|
|
297
|
+
|
|
298
|
+
if not id_columns:
|
|
299
|
+
id_columns = [date_column]
|
|
300
|
+
ids_sort = df.groupby(id_columns)[date_column].aggregate(["max", "count"]).T.to_dict()
|
|
301
|
+
ids_sort = {
|
|
302
|
+
ensure_tuple(k): (
|
|
303
|
+
(v["max"], v["count"], random_state.rand()) if prefer_recent_dates else (v["count"], random_state.rand())
|
|
304
|
+
)
|
|
305
|
+
for k, v in ids_sort.items()
|
|
306
|
+
}
|
|
307
|
+
id_counts = df[id_columns].value_counts()
|
|
308
|
+
id_counts.index = [ensure_tuple(i) for i in id_counts.index]
|
|
309
|
+
id_counts = id_counts.sort_index(key=lambda x: [ids_sort[y] for y in x], ascending=False).cumsum()
|
|
310
|
+
id_counts = id_counts[id_counts <= sample_size]
|
|
311
|
+
min_different_ids = max(int(len(df[id_columns].drop_duplicates()) * min_different_ids_ratio), 1)
|
|
312
|
+
|
|
313
|
+
def id_mask(sample_index: pd.Index) -> pd.Index:
|
|
314
|
+
if isinstance(sample_index, pd.MultiIndex):
|
|
315
|
+
return pd.MultiIndex.from_frame(df[id_columns]).isin(sample_index)
|
|
316
|
+
else:
|
|
317
|
+
return df[id_columns[0]].isin(sample_index)
|
|
318
|
+
|
|
319
|
+
if len(id_counts) < min_different_ids:
|
|
320
|
+
if logger is not None:
|
|
321
|
+
logger.info(
|
|
322
|
+
f"Different ids count {len(id_counts)} for sample size {sample_size} is less than min different ids {min_different_ids}, sampling time window"
|
|
323
|
+
)
|
|
324
|
+
date_counts = df.groupby(id_columns)[date_column].nunique().sort_values(ascending=False)
|
|
325
|
+
ids_to_sample = date_counts.index[:min_different_ids] if len(id_counts) > 0 else date_counts.index
|
|
326
|
+
mask = id_mask(ids_to_sample)
|
|
327
|
+
df = df[mask]
|
|
328
|
+
sample_date_counts = df[date_column].value_counts().sort_index(ascending=False).cumsum()
|
|
329
|
+
sample_date_counts = sample_date_counts[sample_date_counts <= sample_size]
|
|
330
|
+
df = df[df[date_column].isin(sample_date_counts.index)]
|
|
331
|
+
else:
|
|
332
|
+
if len(id_columns) > 1:
|
|
333
|
+
id_counts.index = pd.MultiIndex.from_tuples(id_counts.index)
|
|
334
|
+
else:
|
|
335
|
+
id_counts.index = [i[0] for i in id_counts.index]
|
|
336
|
+
mask = id_mask(id_counts.index)
|
|
337
|
+
df = df[mask]
|
|
338
|
+
|
|
339
|
+
return df
|
|
340
|
+
|
|
341
|
+
|
|
267
342
|
def calculate_psi(expected: pd.Series, actual: pd.Series) -> Union[float, Exception]:
|
|
268
343
|
try:
|
|
269
344
|
df = pd.concat([expected, actual])
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
upgini/__about__.py,sha256=
|
|
1
|
+
upgini/__about__.py,sha256=LmN8HmHN2Px-OpOZ1Y29xw_nidvXMFTOaEZFjFE2XDs,23
|
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
|
4
|
-
upgini/dataset.py,sha256
|
|
4
|
+
upgini/dataset.py,sha256=-3FeDMADnHxGb70rKFY_U96NCQO-TEUAXFicFl25CtY,33222
|
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
|
6
|
-
upgini/features_enricher.py,sha256=
|
|
6
|
+
upgini/features_enricher.py,sha256=Hs2-O3KBH8adgkLvy6ccXyXkfEBhskPMBEceD7hj5Qo,197718
|
|
7
7
|
upgini/http.py,sha256=plZGTGoi1h2edd8Cnjt4eYB8t4NbBGnZz7DtPTByiNc,42885
|
|
8
8
|
upgini/lazy_import.py,sha256=74gQ8JuA48BGRLxAo7lNHNKY2D2emMxrUxKGdxVGhuY,1012
|
|
9
|
-
upgini/metadata.py,sha256
|
|
9
|
+
upgini/metadata.py,sha256=-ibqiNjD7dTagqg53FoEJNEqvAYbwgfyn9PGTRQ_YKU,12054
|
|
10
10
|
upgini/metrics.py,sha256=hr7UwLphbZ_FEglLuO2lzr_pFgxOJ4c3WBeg7H-fNqY,35521
|
|
11
11
|
upgini/search_task.py,sha256=qxUxAD-bed-FpZYmTB_4orW7YJsW_O6a1TcgnZIRFr4,17307
|
|
12
12
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
|
@@ -56,10 +56,10 @@ upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,1
|
|
|
56
56
|
upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
|
|
57
57
|
upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
|
|
58
58
|
upgini/utils/sklearn_ext.py,sha256=13jQS_k7v0aUtudXV6nGUEWjttPQzAW9AFYL5wgEz9k,44511
|
|
59
|
-
upgini/utils/target_utils.py,sha256=
|
|
59
|
+
upgini/utils/target_utils.py,sha256=RlpKGss9kMibVSlA8iZuO_qxmyeplqzn7X8g6hiGGGs,14341
|
|
60
60
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
|
61
61
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
|
62
|
-
upgini-1.2.
|
|
63
|
-
upgini-1.2.
|
|
64
|
-
upgini-1.2.
|
|
65
|
-
upgini-1.2.
|
|
62
|
+
upgini-1.2.38.dist-info/METADATA,sha256=rI6ffoFk1_DJRQiIMb-pYyh-0MaSYBERdqleYet6-C0,48594
|
|
63
|
+
upgini-1.2.38.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
64
|
+
upgini-1.2.38.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
|
65
|
+
upgini-1.2.38.dist-info/RECORD,,
|
|
File without changes
|