upgini 1.2.91a3906.dev1__py3-none-any.whl → 1.2.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/autofe/unary.py +0 -8
- upgini/dataset.py +58 -109
- upgini/features_enricher.py +225 -187
- upgini/metadata.py +3 -0
- upgini/metrics.py +12 -11
- upgini/resource_bundle/strings.properties +2 -0
- upgini/utils/feature_info.py +2 -2
- upgini/utils/sample_utils.py +416 -0
- upgini/utils/target_utils.py +3 -199
- {upgini-1.2.91a3906.dev1.dist-info → upgini-1.2.92.dist-info}/METADATA +1 -1
- {upgini-1.2.91a3906.dev1.dist-info → upgini-1.2.92.dist-info}/RECORD +14 -13
- {upgini-1.2.91a3906.dev1.dist-info → upgini-1.2.92.dist-info}/WHEEL +0 -0
- {upgini-1.2.91a3906.dev1.dist-info → upgini-1.2.92.dist-info}/licenses/LICENSE +0 -0
upgini/utils/target_utils.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Callable,
|
2
|
+
from typing import Callable, Optional, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
6
|
from pandas.api.types import is_bool_dtype, is_datetime64_any_dtype, is_numeric_dtype
|
7
7
|
|
8
8
|
from upgini.errors import ValidationError
|
9
|
-
from upgini.metadata import SYSTEM_RECORD_ID,
|
10
|
-
from upgini.resource_bundle import ResourceBundle,
|
9
|
+
from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
|
10
|
+
from upgini.resource_bundle import ResourceBundle, get_custom_bundle, bundle
|
11
11
|
from upgini.sampler.random_under_sampler import RandomUnderSampler
|
12
|
-
from upgini.utils.ts_utils import get_most_frequent_time_unit, trunc_datetime
|
13
|
-
|
14
|
-
TS_MIN_DIFFERENT_IDS_RATIO = 0.2
|
15
12
|
|
16
13
|
|
17
14
|
def prepare_target(y: Union[pd.Series, np.ndarray], target_type: ModelTaskType) -> Union[pd.Series, np.ndarray]:
|
@@ -204,199 +201,6 @@ def balance_undersample(
|
|
204
201
|
return resampled_data
|
205
202
|
|
206
203
|
|
207
|
-
def balance_undersample_forced(
|
208
|
-
df: pd.DataFrame,
|
209
|
-
target_column: str,
|
210
|
-
id_columns: Optional[List[str]],
|
211
|
-
date_column: str,
|
212
|
-
task_type: ModelTaskType,
|
213
|
-
cv_type: Optional[CVType],
|
214
|
-
random_state: int,
|
215
|
-
sample_size: int = 7000,
|
216
|
-
logger: Optional[logging.Logger] = None,
|
217
|
-
bundle: Optional[ResourceBundle] = None,
|
218
|
-
warning_callback: Optional[Callable] = None,
|
219
|
-
):
|
220
|
-
if len(df) <= sample_size:
|
221
|
-
return df
|
222
|
-
|
223
|
-
if logger is None:
|
224
|
-
logger = logging.getLogger("muted_logger")
|
225
|
-
logger.setLevel("FATAL")
|
226
|
-
bundle = bundle or get_custom_bundle()
|
227
|
-
if SYSTEM_RECORD_ID not in df.columns:
|
228
|
-
raise Exception("System record id must be presented for undersampling")
|
229
|
-
|
230
|
-
msg = bundle.get("forced_balance_undersample")
|
231
|
-
logger.info(msg)
|
232
|
-
if warning_callback is not None:
|
233
|
-
warning_callback(msg)
|
234
|
-
|
235
|
-
target = df[target_column].copy()
|
236
|
-
|
237
|
-
vc = target.value_counts()
|
238
|
-
max_class_value = vc.index[0]
|
239
|
-
min_class_value = vc.index[len(vc) - 1]
|
240
|
-
max_class_count = vc[max_class_value]
|
241
|
-
min_class_count = vc[min_class_value]
|
242
|
-
|
243
|
-
resampled_data = df
|
244
|
-
df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
|
245
|
-
if cv_type is not None and cv_type.is_time_series():
|
246
|
-
logger.warning(f"Sampling time series dataset from {len(df)} to {sample_size}")
|
247
|
-
resampled_data = balance_undersample_time_series_trunc(
|
248
|
-
df,
|
249
|
-
id_columns=id_columns,
|
250
|
-
date_column=date_column,
|
251
|
-
sample_size=sample_size,
|
252
|
-
random_state=random_state,
|
253
|
-
logger=logger,
|
254
|
-
)
|
255
|
-
elif task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION]:
|
256
|
-
logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
|
257
|
-
resampled_data = df.sample(n=sample_size, random_state=random_state)
|
258
|
-
else:
|
259
|
-
msg = bundle.get("imbalanced_target").format(min_class_value, min_class_count)
|
260
|
-
logger.warning(msg)
|
261
|
-
|
262
|
-
# fill up to min_sample_threshold by majority class
|
263
|
-
minority_class = df[df[target_column] == min_class_value]
|
264
|
-
majority_class = df[df[target_column] != min_class_value]
|
265
|
-
logger.info(
|
266
|
-
f"Min class count: {min_class_count}. Max class count: {max_class_count}."
|
267
|
-
f" Rebalance sample size: {sample_size}"
|
268
|
-
)
|
269
|
-
if len(minority_class) > (sample_size / 2):
|
270
|
-
sampled_minority_class = minority_class.sample(n=int(sample_size / 2), random_state=random_state)
|
271
|
-
else:
|
272
|
-
sampled_minority_class = minority_class
|
273
|
-
|
274
|
-
if len(majority_class) > (sample_size) / 2:
|
275
|
-
sampled_majority_class = majority_class.sample(n=int(sample_size / 2), random_state=random_state)
|
276
|
-
|
277
|
-
resampled_data = df[
|
278
|
-
(df[SYSTEM_RECORD_ID].isin(sampled_minority_class[SYSTEM_RECORD_ID]))
|
279
|
-
| (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
|
280
|
-
]
|
281
|
-
|
282
|
-
logger.info(f"Shape after forced rebalance resampling: {resampled_data}")
|
283
|
-
return resampled_data
|
284
|
-
|
285
|
-
|
286
|
-
DEFAULT_HIGH_FREQ_TRUNC_LENGTHS = [pd.DateOffset(years=2, months=6), pd.DateOffset(years=2, days=7)]
|
287
|
-
DEFAULT_LOW_FREQ_TRUNC_LENGTHS = [pd.DateOffset(years=7), pd.DateOffset(years=5)]
|
288
|
-
DEFAULT_TIME_UNIT_THRESHOLD = pd.Timedelta(weeks=4)
|
289
|
-
|
290
|
-
|
291
|
-
def balance_undersample_time_series_trunc(
|
292
|
-
df: pd.DataFrame,
|
293
|
-
id_columns: Optional[List[str]],
|
294
|
-
date_column: str,
|
295
|
-
sample_size: int,
|
296
|
-
random_state: int = 42,
|
297
|
-
logger: Optional[logging.Logger] = None,
|
298
|
-
highfreq_trunc_lengths: List[pd.DateOffset] = DEFAULT_HIGH_FREQ_TRUNC_LENGTHS,
|
299
|
-
lowfreq_trunc_lengths: List[pd.DateOffset] = DEFAULT_LOW_FREQ_TRUNC_LENGTHS,
|
300
|
-
time_unit_threshold: pd.Timedelta = DEFAULT_TIME_UNIT_THRESHOLD,
|
301
|
-
**kwargs,
|
302
|
-
):
|
303
|
-
if id_columns is None:
|
304
|
-
id_columns = []
|
305
|
-
# Convert date column to datetime
|
306
|
-
dates_df = df[id_columns + [date_column]].copy()
|
307
|
-
dates_df[date_column] = pd.to_datetime(dates_df[date_column], unit="ms")
|
308
|
-
|
309
|
-
time_unit = get_most_frequent_time_unit(dates_df, id_columns, date_column)
|
310
|
-
if logger is not None:
|
311
|
-
logger.info(f"Time unit: {time_unit}")
|
312
|
-
|
313
|
-
if time_unit is None:
|
314
|
-
if logger is not None:
|
315
|
-
logger.info("Cannot detect time unit, returning original dataset")
|
316
|
-
return df
|
317
|
-
|
318
|
-
if time_unit < time_unit_threshold:
|
319
|
-
for trunc_length in highfreq_trunc_lengths:
|
320
|
-
sampled_df = trunc_datetime(dates_df, id_columns, date_column, trunc_length, logger=logger)
|
321
|
-
if len(sampled_df) <= sample_size:
|
322
|
-
break
|
323
|
-
if len(sampled_df) > sample_size:
|
324
|
-
sampled_df = balance_undersample_time_series(
|
325
|
-
sampled_df, id_columns, date_column, sample_size, random_state, logger=logger, **kwargs
|
326
|
-
)
|
327
|
-
else:
|
328
|
-
for trunc_length in lowfreq_trunc_lengths:
|
329
|
-
sampled_df = trunc_datetime(dates_df, id_columns, date_column, trunc_length, logger=logger)
|
330
|
-
if len(sampled_df) <= sample_size:
|
331
|
-
break
|
332
|
-
if len(sampled_df) > sample_size:
|
333
|
-
sampled_df = balance_undersample_time_series(
|
334
|
-
sampled_df, id_columns, date_column, sample_size, random_state, logger=logger, **kwargs
|
335
|
-
)
|
336
|
-
|
337
|
-
return df.loc[sampled_df.index]
|
338
|
-
|
339
|
-
|
340
|
-
def balance_undersample_time_series(
|
341
|
-
df: pd.DataFrame,
|
342
|
-
id_columns: List[str],
|
343
|
-
date_column: str,
|
344
|
-
sample_size: int,
|
345
|
-
random_state: int = 42,
|
346
|
-
min_different_ids_ratio: float = TS_MIN_DIFFERENT_IDS_RATIO,
|
347
|
-
prefer_recent_dates: bool = True,
|
348
|
-
logger: Optional[logging.Logger] = None,
|
349
|
-
):
|
350
|
-
def ensure_tuple(x):
|
351
|
-
return tuple([x]) if not isinstance(x, tuple) else x
|
352
|
-
|
353
|
-
random_state = np.random.RandomState(random_state)
|
354
|
-
|
355
|
-
if not id_columns:
|
356
|
-
id_columns = [date_column]
|
357
|
-
ids_sort = df.groupby(id_columns)[date_column].aggregate(["max", "count"]).T.to_dict()
|
358
|
-
ids_sort = {
|
359
|
-
ensure_tuple(k): (
|
360
|
-
(v["max"], v["count"], random_state.rand()) if prefer_recent_dates else (v["count"], random_state.rand())
|
361
|
-
)
|
362
|
-
for k, v in ids_sort.items()
|
363
|
-
}
|
364
|
-
id_counts = df[id_columns].value_counts()
|
365
|
-
id_counts.index = [ensure_tuple(i) for i in id_counts.index]
|
366
|
-
id_counts = id_counts.sort_index(key=lambda x: [ids_sort[y] for y in x], ascending=False).cumsum()
|
367
|
-
id_counts = id_counts[id_counts <= sample_size]
|
368
|
-
min_different_ids = max(int(len(df[id_columns].drop_duplicates()) * min_different_ids_ratio), 1)
|
369
|
-
|
370
|
-
def id_mask(sample_index: pd.Index) -> pd.Index:
|
371
|
-
if isinstance(sample_index, pd.MultiIndex):
|
372
|
-
return pd.MultiIndex.from_frame(df[id_columns]).isin(sample_index)
|
373
|
-
else:
|
374
|
-
return df[id_columns[0]].isin(sample_index)
|
375
|
-
|
376
|
-
if len(id_counts) < min_different_ids:
|
377
|
-
if logger is not None:
|
378
|
-
logger.info(
|
379
|
-
f"Different ids count {len(id_counts)} for sample size {sample_size}"
|
380
|
-
f" is less than min different ids {min_different_ids}, sampling time window"
|
381
|
-
)
|
382
|
-
date_counts = df.groupby(id_columns)[date_column].nunique().sort_values(ascending=False)
|
383
|
-
ids_to_sample = date_counts.index[:min_different_ids] if len(id_counts) > 0 else date_counts.index
|
384
|
-
mask = id_mask(ids_to_sample)
|
385
|
-
df = df[mask]
|
386
|
-
sample_date_counts = df[date_column].value_counts().sort_index(ascending=False).cumsum()
|
387
|
-
sample_date_counts = sample_date_counts[sample_date_counts <= sample_size]
|
388
|
-
df = df[df[date_column].isin(sample_date_counts.index)]
|
389
|
-
else:
|
390
|
-
if len(id_columns) > 1:
|
391
|
-
id_counts.index = pd.MultiIndex.from_tuples(id_counts.index)
|
392
|
-
else:
|
393
|
-
id_counts.index = [i[0] for i in id_counts.index]
|
394
|
-
mask = id_mask(id_counts.index)
|
395
|
-
df = df[mask]
|
396
|
-
|
397
|
-
return df
|
398
|
-
|
399
|
-
|
400
204
|
def calculate_psi(expected: pd.Series, actual: pd.Series) -> Union[float, Exception]:
|
401
205
|
try:
|
402
206
|
df = pd.concat([expected, actual])
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256=
|
1
|
+
upgini/__about__.py,sha256=wXo9Q87kBdNAVEzs4oUkI_3AmrQDgiMvfXa7xRn9cOE,23
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
|
-
upgini/dataset.py,sha256=
|
4
|
+
upgini/dataset.py,sha256=e6JDYTZ2AwC5aF-dqclKZKkiKrHo2f6cFmMQO2ZZmjM,32724
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=wFeqZ30Dkhiyath--Jg6uVVoTwCdPJ42Rbe_smr1ue4,218465
|
7
7
|
upgini/http.py,sha256=4i7fQwrwU3WzDUOWzrgR-4C8eJwj_5dBwRAR-UjUtlc,44345
|
8
|
-
upgini/metadata.py,sha256=
|
9
|
-
upgini/metrics.py,sha256=
|
8
|
+
upgini/metadata.py,sha256=vsbbHyPCP3Rs8WkeDgQg99uAA_zmsbDStAT-NwDYhO4,12455
|
9
|
+
upgini/metrics.py,sha256=Bc1L9DUmEL8OWwNvIEjPjw5EyHSZbiu3v2hWyBmedis,45313
|
10
10
|
upgini/search_task.py,sha256=Q5HjBpLIB3OCxAD1zNv5yQ3ZNJx696WCK_-H35_y7Rs,17912
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -19,7 +19,7 @@ upgini/autofe/date.py,sha256=MM1S-6imNSzCDOhbNnmsc_bwSqUWBcS8vWAdHF8j1kY,11134
|
|
19
19
|
upgini/autofe/feature.py,sha256=cu4xXjzVVF13ZV4RxuTrysK2qCfezlRCMOzCKRo1rNs,15558
|
20
20
|
upgini/autofe/groupby.py,sha256=IYmQV9uoCdRcpkeWZj_kI3ObzoNCNx3ff3h8sTL01tk,3603
|
21
21
|
upgini/autofe/operator.py,sha256=EOffJw6vKXpEh5yymqb1RFNJPxGxmnHdFRo9dB5SCFo,4969
|
22
|
-
upgini/autofe/unary.py,sha256=
|
22
|
+
upgini/autofe/unary.py,sha256=Sx11IoHRh5nwyALzjgG9GQOrVNIs8NZ1JzunAJuN66A,5731
|
23
23
|
upgini/autofe/utils.py,sha256=dYrtyAM8Vcc_R8u4dNo54IsGrHKagTHDJTKhGho0bRg,2967
|
24
24
|
upgini/autofe/vector.py,sha256=jHs0nNTOaHspYUlxW7fjQepk4cvr_JDQ65L1OCiVsds,1360
|
25
25
|
upgini/autofe/timeseries/__init__.py,sha256=PGwwDAMwvkXl3el12tXVEmZUgDUvlmIPlXtROm6bD18,738
|
@@ -38,7 +38,7 @@ upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
|
|
38
38
|
upgini/normalizer/normalize_utils.py,sha256=g2TcDXZeJp9kAFO2sTqZ4CAsN4J1qHNgoJHZ8gtzUWo,7376
|
39
39
|
upgini/resource_bundle/__init__.py,sha256=S5F2G47pnJd2LDpmFsjDqEwiKkP8Hm-hcseDbMka6Ko,8345
|
40
40
|
upgini/resource_bundle/exceptions.py,sha256=5fRvx0_vWdE1-7HcSgF0tckB4A9AKyf5RiinZkInTsI,621
|
41
|
-
upgini/resource_bundle/strings.properties,sha256=
|
41
|
+
upgini/resource_bundle/strings.properties,sha256=Hfpr2-I5Ws6ugIN1QSz549OHayZeLYglRsbrGDT6g9g,28491
|
42
42
|
upgini/resource_bundle/strings_widget.properties,sha256=gOdqvZWntP2LCza_tyVk1_yRYcG4c04K9sQOAVhF_gw,1577
|
43
43
|
upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
upgini/sampler/base.py,sha256=7GpjYqjOp58vYcJLiX__1R5wjUlyQbxvHJ2klFnup_M,6389
|
@@ -56,7 +56,7 @@ upgini/utils/deduplicate_utils.py,sha256=EpBVCov42-FJIAPfa4jY_ZRct3N2MFaC7i-oJNZ
|
|
56
56
|
upgini/utils/display_utils.py,sha256=hAeWEcJtPDg8fAVcMNrNB-azFD2WJp1nvbPAhR7SeP4,12071
|
57
57
|
upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
|
58
58
|
upgini/utils/fallback_progress_bar.py,sha256=PDaKb8dYpVZaWMroNcOHsTc3pSjgi9mOm0--cOFTwJ0,1074
|
59
|
-
upgini/utils/feature_info.py,sha256=
|
59
|
+
upgini/utils/feature_info.py,sha256=b3RvAeOHSEu-ZXWTrf42Dll_3ZUBL0pw7sdk7hgUKD0,7284
|
60
60
|
upgini/utils/features_validator.py,sha256=lEfmk4DoxZ4ooOE1HC0ZXtUb_lFKRFHIrnFULZ4_rL8,3746
|
61
61
|
upgini/utils/format.py,sha256=Yv5cvvSs2bOLUzzNu96Pu33VMDNbabio92QepUj41jU,243
|
62
62
|
upgini/utils/ip_utils.py,sha256=TSQ_qDsLlVnm09X1HacpabEf_HNqSWpxBF4Sdc2xs08,6580
|
@@ -64,13 +64,14 @@ upgini/utils/mstats.py,sha256=u3gQVUtDRbyrOQK6V1UJ2Rx1QbkSNYGjXa6m3Z_dPVs,6286
|
|
64
64
|
upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,10432
|
65
65
|
upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
|
66
66
|
upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
|
67
|
+
upgini/utils/sample_utils.py,sha256=ETLPKQU_YngiYbdlnEoF2h7QS-3oN8et54q3Qs2ZAbA,15417
|
67
68
|
upgini/utils/sklearn_ext.py,sha256=jLJWAKkqQinV15Z4y1ZnsN3c-fKFwXTsprs00COnyVU,49315
|
68
69
|
upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
|
69
|
-
upgini/utils/target_utils.py,sha256=
|
70
|
+
upgini/utils/target_utils.py,sha256=i3Xt5l9ybB2_nF_ma5cfPuL3OeFTs2dY2xDI0p4Azpg,9049
|
70
71
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
72
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
73
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.
|
74
|
-
upgini-1.2.
|
75
|
-
upgini-1.2.
|
76
|
-
upgini-1.2.
|
74
|
+
upgini-1.2.92.dist-info/METADATA,sha256=yXqDsCwRNGqlytVFuoBL04Swo6xYo5lsk9_YHj-6PfQ,49536
|
75
|
+
upgini-1.2.92.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
76
|
+
upgini-1.2.92.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
77
|
+
upgini-1.2.92.dist-info/RECORD,,
|
File without changes
|
File without changes
|