upgini 1.1.262a3250.post4__py3-none-any.whl → 1.1.280a3418.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

Files changed (49) hide show
  1. upgini/__about__.py +1 -0
  2. upgini/ads.py +6 -2
  3. upgini/ads_management/ads_manager.py +4 -2
  4. upgini/autofe/all_operands.py +16 -4
  5. upgini/autofe/binary.py +2 -1
  6. upgini/autofe/date.py +74 -7
  7. upgini/autofe/feature.py +1 -1
  8. upgini/autofe/groupby.py +3 -1
  9. upgini/autofe/operand.py +4 -3
  10. upgini/autofe/unary.py +20 -1
  11. upgini/autofe/vector.py +2 -0
  12. upgini/data_source/data_source_publisher.py +14 -4
  13. upgini/dataset.py +8 -7
  14. upgini/errors.py +1 -1
  15. upgini/features_enricher.py +156 -63
  16. upgini/http.py +11 -10
  17. upgini/mdc/__init__.py +1 -3
  18. upgini/mdc/context.py +4 -6
  19. upgini/metadata.py +3 -0
  20. upgini/metrics.py +160 -96
  21. upgini/normalizer/phone_normalizer.py +2 -2
  22. upgini/resource_bundle/__init__.py +5 -5
  23. upgini/resource_bundle/strings.properties +9 -4
  24. upgini/sampler/base.py +1 -4
  25. upgini/sampler/random_under_sampler.py +2 -5
  26. upgini/search_task.py +4 -4
  27. upgini/spinner.py +1 -1
  28. upgini/utils/__init__.py +3 -2
  29. upgini/utils/base_search_key_detector.py +2 -2
  30. upgini/utils/blocked_time_series.py +4 -2
  31. upgini/utils/country_utils.py +2 -2
  32. upgini/utils/custom_loss_utils.py +3 -2
  33. upgini/utils/cv_utils.py +2 -2
  34. upgini/utils/datetime_utils.py +75 -18
  35. upgini/utils/deduplicate_utils.py +61 -18
  36. upgini/utils/email_utils.py +3 -3
  37. upgini/utils/fallback_progress_bar.py +1 -1
  38. upgini/utils/features_validator.py +2 -1
  39. upgini/utils/progress_bar.py +1 -1
  40. upgini/utils/sklearn_ext.py +15 -15
  41. upgini/utils/target_utils.py +21 -7
  42. upgini/utils/track_info.py +27 -15
  43. upgini/version_validator.py +2 -2
  44. {upgini-1.1.262a3250.post4.dist-info → upgini-1.1.280a3418.post2.dist-info}/METADATA +21 -23
  45. upgini-1.1.280a3418.post2.dist-info/RECORD +62 -0
  46. {upgini-1.1.262a3250.post4.dist-info → upgini-1.1.280a3418.post2.dist-info}/WHEEL +1 -2
  47. upgini-1.1.262a3250.post4.dist-info/RECORD +0 -62
  48. upgini-1.1.262a3250.post4.dist-info/top_level.txt +0 -1
  49. {upgini-1.1.262a3250.post4.dist-info → upgini-1.1.280a3418.post2.dist-info/licenses}/LICENSE +0 -0
@@ -5,13 +5,10 @@
5
5
  # License: MIT
6
6
 
7
7
  import numpy as np
8
-
9
- from sklearn.utils import check_random_state
10
- from sklearn.utils import _safe_indexing
8
+ from sklearn.utils import _safe_indexing, check_random_state
11
9
 
12
10
  from .base import BaseUnderSampler
13
- from .utils import check_target_type
14
- from .utils import _deprecate_positional_args
11
+ from .utils import _deprecate_positional_args, check_target_type
15
12
 
16
13
 
17
14
  class RandomUnderSampler(BaseUnderSampler):
upgini/search_task.py CHANGED
@@ -8,10 +8,10 @@ import pandas as pd
8
8
 
9
9
  from upgini import dataset
10
10
  from upgini.http import (
11
- _RestClient,
12
11
  ProviderTaskSummary,
13
12
  SearchProgress,
14
13
  SearchTaskSummary,
14
+ _RestClient,
15
15
  get_rest_client,
16
16
  is_demo_api_key,
17
17
  )
@@ -295,7 +295,7 @@ class SearchTask:
295
295
  return self.rest_client.get_search_file_metadata(self.search_task_id, trace_id)
296
296
 
297
297
 
298
- @lru_cache()
298
+ @lru_cache
299
299
  def _get_all_initial_raw_features_cached(
300
300
  endpoint: Optional[str],
301
301
  api_key: Optional[str],
@@ -328,7 +328,7 @@ def _get_all_initial_raw_features_cached(
328
328
  return result_df
329
329
 
330
330
 
331
- @lru_cache()
331
+ @lru_cache
332
332
  def _get_all_validation_raw_features_cached(
333
333
  endpoint: Optional[str],
334
334
  api_key: Optional[str],
@@ -357,7 +357,7 @@ def _get_all_validation_raw_features_cached(
357
357
  return result_df
358
358
 
359
359
 
360
- @lru_cache()
360
+ @lru_cache
361
361
  def _get_target_outliers_cached(
362
362
  endpoint: Optional[str],
363
363
  api_key: Optional[str],
upgini/spinner.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import threading
2
- from typing import Optional, List
3
2
  import time
3
+ from typing import List, Optional
4
4
 
5
5
 
6
6
  class Spinner:
upgini/utils/__init__.py CHANGED
@@ -2,7 +2,7 @@ import itertools
2
2
  from typing import List, Tuple
3
3
 
4
4
  import pandas as pd
5
- from pandas.api.types import is_string_dtype
5
+ from pandas.api.types import is_object_dtype, is_string_dtype
6
6
 
7
7
 
8
8
  def combine_search_keys(search_keys: List[str]) -> List[Tuple[str]]:
@@ -20,5 +20,6 @@ def find_numbers_with_decimal_comma(df: pd.DataFrame) -> pd.DataFrame:
20
20
  return [
21
21
  col
22
22
  for col in tmp.columns
23
- if is_string_dtype(tmp[col]) and tmp[col].astype("string").str.match("^[0-9]+,[0-9]*$").any()
23
+ if (is_string_dtype(tmp[col]) or is_object_dtype(tmp[col]))
24
+ and tmp[col].astype("string").str.match("^[0-9]+,[0-9]*$").any()
24
25
  ]
@@ -5,10 +5,10 @@ import pandas as pd
5
5
 
6
6
  class BaseSearchKeyDetector:
7
7
  def _is_search_key_by_name(self, column_name: str) -> bool:
8
- raise NotImplementedError()
8
+ raise NotImplementedError
9
9
 
10
10
  def _is_search_key_by_values(self, column: pd.Series) -> bool:
11
- raise NotImplementedError()
11
+ raise NotImplementedError
12
12
 
13
13
  def _get_search_key_by_name(self, column_names: List[str]) -> Optional[str]:
14
14
  for column_name in column_names:
@@ -1,8 +1,10 @@
1
- import numpy as np
2
1
  import numbers
2
+
3
+ import numpy as np
4
+ from sklearn.model_selection import BaseCrossValidator
3
5
  from sklearn.utils import indexable
4
6
  from sklearn.utils.validation import _num_samples
5
- from sklearn.model_selection import BaseCrossValidator
7
+
6
8
  from upgini.resource_bundle import bundle
7
9
 
8
10
 
@@ -1,5 +1,5 @@
1
1
  import pandas as pd
2
- from pandas.api.types import is_string_dtype
2
+ from pandas.api.types import is_object_dtype, is_string_dtype
3
3
 
4
4
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
5
5
 
@@ -9,7 +9,7 @@ class CountrySearchKeyDetector(BaseSearchKeyDetector):
9
9
  return "country" in str(column_name).lower()
10
10
 
11
11
  def _is_search_key_by_values(self, column: pd.Series) -> bool:
12
- if not is_string_dtype(column):
12
+ if not is_string_dtype(column) and not is_object_dtype(column):
13
13
  return False
14
14
 
15
15
  all_count = len(column)
@@ -1,6 +1,7 @@
1
- from upgini.metadata import ModelTaskType, RuntimeParameters
2
- from typing import Optional, Dict, Any
3
1
  import logging
2
+ from typing import Any, Dict, Optional
3
+
4
+ from upgini.metadata import ModelTaskType, RuntimeParameters
4
5
  from upgini.resource_bundle import bundle
5
6
 
6
7
 
upgini/utils/cv_utils.py CHANGED
@@ -1,9 +1,9 @@
1
1
  from functools import reduce
2
2
  from typing import Any, Dict, List, Optional, Tuple, Union
3
- import numpy as np
4
3
 
4
+ import numpy as np
5
5
  import pandas as pd
6
- from sklearn.model_selection import BaseCrossValidator, KFold, TimeSeriesSplit, GroupKFold, GroupShuffleSplit
6
+ from sklearn.model_selection import BaseCrossValidator, GroupKFold, GroupShuffleSplit, KFold, TimeSeriesSplit
7
7
 
8
8
  from upgini.metadata import CVType
9
9
  from upgini.utils.blocked_time_series import BlockedTimeSeriesSplit
@@ -1,15 +1,20 @@
1
1
  import datetime
2
2
  import logging
3
3
  import re
4
- from typing import List, Optional
4
+ from typing import Dict, List, Optional
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  from dateutil.relativedelta import relativedelta
9
- from pandas.api.types import is_numeric_dtype, is_period_dtype, is_string_dtype
9
+ from pandas.api.types import (
10
+ is_numeric_dtype,
11
+ is_period_dtype,
12
+ )
10
13
 
11
14
  from upgini.errors import ValidationError
15
+ from upgini.metadata import SearchKey
12
16
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
17
+ from upgini.utils.warning_counter import WarningCounter
13
18
 
14
19
  DATE_FORMATS = [
15
20
  "%Y-%m-%d",
@@ -76,9 +81,6 @@ class DateTimeSearchKeyConverter:
76
81
  df[self.date_column] = df[self.date_column].apply(lambda x: x.replace(tzinfo=None))
77
82
  elif isinstance(df[self.date_column].values[0], datetime.date):
78
83
  df[self.date_column] = pd.to_datetime(df[self.date_column], errors="coerce")
79
- elif is_string_dtype(df[self.date_column]):
80
- df[self.date_column] = df[self.date_column].apply(self.clean_date)
81
- df[self.date_column] = self.parse_date(df)
82
84
  elif is_period_dtype(df[self.date_column]):
83
85
  df[self.date_column] = pd.to_datetime(df[self.date_column].astype("string"))
84
86
  elif is_numeric_dtype(df[self.date_column]):
@@ -98,6 +100,9 @@ class DateTimeSearchKeyConverter:
98
100
  msg = self.bundle.get("unsupported_date_type").format(self.date_column)
99
101
  self.logger.warning(msg)
100
102
  raise ValidationError(msg)
103
+ else:
104
+ df[self.date_column] = df[self.date_column].astype("string").apply(self.clean_date)
105
+ df[self.date_column] = self.parse_date(df)
101
106
 
102
107
  # If column with date is datetime then extract seconds of the day and minute of the hour
103
108
  # as additional features
@@ -121,9 +126,9 @@ class DateTimeSearchKeyConverter:
121
126
  df.drop(columns=seconds, inplace=True)
122
127
 
123
128
  if keep_time:
124
- df[self.DATETIME_COL] = df[self.date_column].view(np.int64) // 1_000_000
129
+ df[self.DATETIME_COL] = df[self.date_column].astype(np.int64) // 1_000_000
125
130
  df[self.DATETIME_COL] = df[self.DATETIME_COL].apply(self._int_to_opt).astype("Int64")
126
- df[self.date_column] = df[self.date_column].dt.floor("D").view(np.int64) // 1_000_000
131
+ df[self.date_column] = df[self.date_column].dt.floor("D").astype(np.int64) // 1_000_000
127
132
  df[self.date_column] = df[self.date_column].apply(self._int_to_opt).astype("Int64")
128
133
 
129
134
  self.logger.info(f"Date after convertion to timestamp: {df[self.date_column]}")
@@ -203,18 +208,17 @@ def is_blocked_time_series(df: pd.DataFrame, date_col: str, search_keys: List[st
203
208
  if nunique_dates / days_delta < 0.3:
204
209
  return False
205
210
 
206
- def check_differences(group):
207
- data = group.drop(date_col, axis=1)
208
- diffs = data.values[:, None] != data.values
209
- diff_counts = diffs.sum(axis=2)
210
- max_diff = np.max(diff_counts)
211
- return max_diff <= 2
211
+ accumulated_changing_columns = set()
212
212
 
213
- def is_multiple_rows(group):
213
+ def check_differences(group: pd.DataFrame):
214
+ changing_columns = group.columns[group.nunique(dropna=False) > 1].to_list()
215
+ accumulated_changing_columns.update(changing_columns)
216
+
217
+ def is_multiple_rows(group: pd.DataFrame) -> bool:
214
218
  return group.shape[0] > 1
215
219
 
216
- grouped = df.groupby(date_col)
217
- dates_with_multiple_rows = len(grouped.apply(is_multiple_rows))
220
+ grouped = df.groupby(date_col)[[c for c in df.columns if c != date_col]]
221
+ dates_with_multiple_rows = grouped.apply(is_multiple_rows).sum()
218
222
 
219
223
  # share of dates with more than one record is more than 99%
220
224
  if dates_with_multiple_rows / nunique_dates < 0.99:
@@ -223,5 +227,58 @@ def is_blocked_time_series(df: pd.DataFrame, date_col: str, search_keys: List[st
223
227
  if df.shape[1] <= 3:
224
228
  return True
225
229
 
226
- is_diff_less_than_two_columns = grouped.apply(check_differences)
227
- return is_diff_less_than_two_columns.all()
230
+ grouped.apply(check_differences)
231
+ return len(accumulated_changing_columns) <= 2
232
+
233
+
234
+ def validate_dates_distribution(
235
+ X: pd.DataFrame,
236
+ search_keys: Dict[str, SearchKey],
237
+ logger: Optional[logging.Logger] = None,
238
+ bundle: Optional[ResourceBundle] = None,
239
+ warning_counter: Optional[WarningCounter] = None,
240
+ ):
241
+ maybe_date_col = None
242
+ for key, key_type in search_keys.items():
243
+ if key_type in [SearchKey.DATE, SearchKey.DATETIME]:
244
+ maybe_date_col = key
245
+
246
+ if maybe_date_col is None:
247
+ for col in X.columns:
248
+ if col in search_keys:
249
+ continue
250
+ try:
251
+ if pd.__version__ >= "2.0.0":
252
+ # Format mixed to avoid massive warnings
253
+ pd.to_datetime(X[col], format="mixed")
254
+ else:
255
+ pd.to_datetime(X[col])
256
+ maybe_date_col = col
257
+ break
258
+ except Exception:
259
+ pass
260
+
261
+ if maybe_date_col is None:
262
+ return
263
+
264
+ if pd.__version__ >= "2.0.0":
265
+ dates = pd.to_datetime(X[maybe_date_col], format="mixed").dt.date
266
+ else:
267
+ dates = pd.to_datetime(X[maybe_date_col]).dt.date
268
+
269
+ date_counts = dates.value_counts().sort_index()
270
+
271
+ date_counts_1 = date_counts[: round(len(date_counts) / 2)]
272
+ date_counts_2 = date_counts[round(len(date_counts) / 2) :]
273
+ ratio = date_counts_2.mean() / date_counts_1.mean()
274
+
275
+ if ratio > 1.2 or ratio < 0.8:
276
+ if warning_counter is not None:
277
+ warning_counter.increment()
278
+ if logger is None:
279
+ logger = logging.getLogger("muted_logger")
280
+ logger.setLevel("FATAL")
281
+ bundle = bundle or get_custom_bundle()
282
+ msg = bundle.get("x_unstable_by_date")
283
+ print(msg)
284
+ logger.warning(msg)
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Union
3
3
 
4
4
  import pandas as pd
5
5
 
6
- from upgini.metadata import SORT_ID, SYSTEM_RECORD_ID, TARGET, ModelTaskType, SearchKey
6
+ from upgini.metadata import EVAL_SET_INDEX, SORT_ID, SYSTEM_RECORD_ID, TARGET, ModelTaskType, SearchKey
7
7
  from upgini.resource_bundle import ResourceBundle
8
8
  from upgini.utils.datetime_utils import DateTimeSearchKeyConverter
9
9
  from upgini.utils.target_utils import define_task
@@ -78,20 +78,58 @@ def remove_fintech_duplicates(
78
78
  rows_with_diff_target = grouped_by_personal_cols.filter(has_diff_target_within_60_days)
79
79
  if len(rows_with_diff_target) > 0:
80
80
  unique_keys_to_delete = rows_with_diff_target[personal_cols].drop_duplicates()
81
- rows_to_remove = pd.merge(df.reset_index(), unique_keys_to_delete, on=personal_cols)
82
- rows_to_remove = rows_to_remove.set_index(df.index.name or "index")
83
- perc = len(rows_to_remove) * 100 / len(df)
84
- msg = bundle.get("dataset_diff_target_duplicates_fintech").format(
85
- perc, len(rows_to_remove), rows_to_remove.index.to_list()
86
- )
87
- if not silent:
88
- print(msg)
89
- if logger:
90
- logger.warning(msg)
91
- logger.info(f"Dataset shape before clean fintech duplicates: {df.shape}")
92
- df = df[~df.index.isin(rows_to_remove.index)]
93
- logger.info(f"Dataset shape after clean fintech duplicates: {df.shape}")
94
-
81
+ if EVAL_SET_INDEX not in df.columns:
82
+ rows_to_remove = pd.merge(df.reset_index(), unique_keys_to_delete, on=personal_cols)
83
+ rows_to_remove = rows_to_remove.set_index(df.index.name or "index")
84
+ perc = len(rows_to_remove) * 100 / len(df)
85
+ msg = bundle.get("dataset_train_diff_target_duplicates_fintech").format(
86
+ perc, len(rows_to_remove), rows_to_remove.index.to_list()
87
+ )
88
+ if not silent:
89
+ print(msg)
90
+ if logger:
91
+ logger.warning(msg)
92
+ logger.info(f"Dataset shape before clean fintech duplicates: {df.shape}")
93
+ df = df[~df.index.isin(rows_to_remove.index)]
94
+ logger.info(f"Dataset shape after clean fintech duplicates: {df.shape}")
95
+ else:
96
+ # Indices in train and eval_set can be the same so we remove rows from them separately
97
+ train = df.query(f"{EVAL_SET_INDEX} == 0")
98
+ train_rows_to_remove = pd.merge(train.reset_index(), unique_keys_to_delete, on=personal_cols)
99
+ train_rows_to_remove = train_rows_to_remove.set_index(train.index.name or "index")
100
+ train_perc = len(train_rows_to_remove) * 100 / len(train)
101
+ msg = bundle.get("dataset_train_diff_target_duplicates_fintech").format(
102
+ train_perc, len(train_rows_to_remove), train_rows_to_remove.index.to_list()
103
+ )
104
+ if not silent:
105
+ print(msg)
106
+ if logger:
107
+ logger.warning(msg)
108
+ logger.info(f"Train dataset shape before clean fintech duplicates: {train.shape}")
109
+ train = train[~train.index.isin(train_rows_to_remove.index)]
110
+ logger.info(f"Train dataset shape after clean fintech duplicates: {train.shape}")
111
+
112
+ evals = [df.query(f"{EVAL_SET_INDEX} == {i}") for i in df[EVAL_SET_INDEX].unique() if i != 0]
113
+ new_evals = []
114
+ for i, eval in enumerate(evals):
115
+ eval_rows_to_remove = pd.merge(eval.reset_index(), unique_keys_to_delete, on=personal_cols)
116
+ eval_rows_to_remove = eval_rows_to_remove.set_index(eval.index.name or "index")
117
+ eval_perc = len(eval_rows_to_remove) * 100 / len(eval)
118
+ msg = bundle.get("dataset_eval_diff_target_duplicates_fintech").format(
119
+ eval_perc, len(eval_rows_to_remove), i + 1, eval_rows_to_remove.index.to_list()
120
+ )
121
+ if not silent:
122
+ print(msg)
123
+ if logger:
124
+ logger.warning(msg)
125
+ logger.info(f"Eval {i + 1} dataset shape before clean fintech duplicates: {eval.shape}")
126
+ eval = eval[~eval.index.isin(eval_rows_to_remove.index)]
127
+ logger.info(f"Eval {i + 1} dataset shape after clean fintech duplicates: {eval.shape}")
128
+ new_evals.append(eval)
129
+
130
+ logger.info(f"Dataset shape before clean fintech duplicates: {df.shape}")
131
+ df = pd.concat([train] + new_evals)
132
+ logger.info(f"Dataset shape after clean fintech duplicates: {df.shape}")
95
133
  return df
96
134
 
97
135
 
@@ -101,14 +139,18 @@ def clean_full_duplicates(
101
139
  nrows = len(df)
102
140
  if nrows == 0:
103
141
  return df
104
- # Remove absolute duplicates (exclude system_record_id)
142
+ # Remove full duplicates (exclude system_record_id, sort_id and eval_set_index)
105
143
  unique_columns = df.columns.tolist()
106
144
  if SYSTEM_RECORD_ID in unique_columns:
107
145
  unique_columns.remove(SYSTEM_RECORD_ID)
108
146
  if SORT_ID in unique_columns:
109
147
  unique_columns.remove(SORT_ID)
148
+ if EVAL_SET_INDEX in unique_columns:
149
+ unique_columns.remove(EVAL_SET_INDEX)
110
150
  logger.info(f"Dataset shape before clean duplicates: {df.shape}")
111
- df = df.drop_duplicates(subset=unique_columns)
151
+ # Train segment goes first so if duplicates are found in train and eval set
152
+ # then we keep unique rows in train segment
153
+ df = df.drop_duplicates(subset=unique_columns, keep="first")
112
154
  logger.info(f"Dataset shape after clean duplicates: {df.shape}")
113
155
  nrows_after_full_dedup = len(df)
114
156
  share_full_dedup = 100 * (1 - nrows_after_full_dedup / nrows)
@@ -123,7 +165,7 @@ def clean_full_duplicates(
123
165
  marked_duplicates = df.duplicated(subset=unique_columns, keep=False)
124
166
  if marked_duplicates.sum() > 0:
125
167
  dups_indices = df[marked_duplicates].index.to_list()
126
- nrows_after_tgt_dedup = len(df.drop_duplicates(subset=unique_columns))
168
+ nrows_after_tgt_dedup = len(df.drop_duplicates(subset=unique_columns, keep=False))
127
169
  num_dup_rows = nrows_after_full_dedup - nrows_after_tgt_dedup
128
170
  share_tgt_dedup = 100 * num_dup_rows / nrows_after_full_dedup
129
171
 
@@ -133,6 +175,7 @@ def clean_full_duplicates(
133
175
  print(msg)
134
176
  df = df.drop_duplicates(subset=unique_columns, keep=False)
135
177
  logger.info(f"Dataset shape after clean invalid target duplicates: {df.shape}")
178
+
136
179
  return df
137
180
 
138
181
 
@@ -4,10 +4,10 @@ from hashlib import sha256
4
4
  from typing import Dict, List, Optional
5
5
 
6
6
  import pandas as pd
7
- from pandas.api.types import is_string_dtype
8
- from upgini.resource_bundle import bundle
7
+ from pandas.api.types import is_object_dtype, is_string_dtype
9
8
 
10
9
  from upgini.metadata import SearchKey
10
+ from upgini.resource_bundle import bundle
11
11
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
12
12
 
13
13
  EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
@@ -18,7 +18,7 @@ class EmailSearchKeyDetector(BaseSearchKeyDetector):
18
18
  return str(column_name).lower() in ["email", "e_mail", "e-mail"]
19
19
 
20
20
  def _is_search_key_by_values(self, column: pd.Series) -> bool:
21
- if not is_string_dtype(column):
21
+ if not is_string_dtype(column) and not is_object_dtype:
22
22
  return False
23
23
  if not column.astype("string").str.contains("@").any():
24
24
  return False
@@ -22,7 +22,7 @@ class CustomFallbackProgressBar:
22
22
  fraction = self.progress / self.total
23
23
  filled = "=" * int(fraction * self.text_width)
24
24
  rest = " " * (self.text_width - len(filled))
25
- return "[{}{}] {}% {} {}".format(filled, rest, self.progress, self._stage, self._eta)
25
+ return f"[{filled}{rest}] {self.progress}% {self._stage} {self._eta}"
26
26
 
27
27
  def display(self):
28
28
  print(self)
@@ -81,7 +81,8 @@ class FeaturesValidator:
81
81
  return [
82
82
  i
83
83
  for i in df
84
- if (is_string_dtype(df[i]) or is_integer_dtype(df[i])) and (df[i].nunique(dropna=False) / row_count >= 0.95)
84
+ if (is_object_dtype(df[i]) or is_string_dtype(df[i]) or is_integer_dtype(df[i]))
85
+ and (df[i].nunique(dropna=False) / row_count >= 0.85)
85
86
  ]
86
87
 
87
88
  @staticmethod
@@ -28,7 +28,7 @@ class CustomProgressBar(DisplayObject):
28
28
  fraction = self.progress / self.total
29
29
  filled = "=" * int(fraction * self.text_width)
30
30
  rest = " " * (self.text_width - len(filled))
31
- return "[{}{}] {}% {}".format(filled, rest, self.progress, self._stage)
31
+ return f"[{filled}{rest}] {self.progress}% {self._stage}"
32
32
 
33
33
  def _repr_html_(self):
34
34
  return "<progress style='width:{}' max='{}' value='{}'></progress> {}% {}</br>{}".format(
@@ -1,5 +1,4 @@
1
1
  import functools
2
- import logging
3
2
  import numbers
4
3
  import time
5
4
  import warnings
@@ -21,6 +20,7 @@ from sklearn.metrics._scorer import _MultimetricScorer
21
20
  from sklearn.model_selection import check_cv
22
21
  from sklearn.utils.fixes import np_version, parse_version
23
22
  from sklearn.utils.validation import indexable
23
+
24
24
  # from sklearn.model_selection import cross_validate as original_cross_validate
25
25
 
26
26
  _DEFAULT_TAGS = {
@@ -47,7 +47,7 @@ _DEFAULT_TAGS = {
47
47
 
48
48
  def cross_validate(
49
49
  estimator,
50
- X,
50
+ x,
51
51
  y=None,
52
52
  *,
53
53
  groups=None,
@@ -70,7 +70,7 @@ def cross_validate(
70
70
  estimator : estimator object implementing 'fit'
71
71
  The object to use to fit the data.
72
72
 
73
- X : array-like of shape (n_samples, n_features)
73
+ x : array-like of shape (n_samples, n_features)
74
74
  The data to fit. Can be for example a list, or an array.
75
75
 
76
76
  y : array-like of shape (n_samples,) or (n_samples, n_outputs), \
@@ -251,7 +251,7 @@ def cross_validate(
251
251
 
252
252
  """
253
253
  try:
254
- X, y, groups = indexable(X, y, groups)
254
+ x, y, groups = indexable(x, y, groups)
255
255
 
256
256
  cv = check_cv(cv, y, classifier=is_classifier(estimator))
257
257
 
@@ -268,7 +268,7 @@ def cross_validate(
268
268
  results = parallel(
269
269
  delayed(_fit_and_score)(
270
270
  clone(estimator),
271
- X,
271
+ x,
272
272
  y,
273
273
  scorers,
274
274
  train,
@@ -281,7 +281,7 @@ def cross_validate(
281
281
  return_estimator=return_estimator,
282
282
  error_score=error_score,
283
283
  )
284
- for train, test in cv.split(X, y, groups)
284
+ for train, test in cv.split(x, y, groups)
285
285
  )
286
286
 
287
287
  _warn_about_fit_failures(results, error_score)
@@ -313,7 +313,7 @@ def cross_validate(
313
313
 
314
314
  return ret
315
315
  except Exception:
316
- logging.exception("Failed to execute overriden cross_validate. Fallback to original")
316
+ # logging.exception("Failed to execute overriden cross_validate. Fallback to original")
317
317
  raise
318
318
  # fit_params["use_best_model"] = False
319
319
  # return original_cross_validate(
@@ -488,7 +488,7 @@ def _fit_and_score(
488
488
  if y_train is None:
489
489
  estimator.fit(X_train, **fit_params)
490
490
  else:
491
- if isinstance(estimator, CatBoostClassifier) or isinstance(estimator, CatBoostRegressor):
491
+ if isinstance(estimator, (CatBoostClassifier, CatBoostRegressor)):
492
492
  fit_params = fit_params.copy()
493
493
  fit_params["eval_set"] = [(X_test, y_test)]
494
494
  estimator.fit(X_train, y_train, **fit_params)
@@ -583,9 +583,11 @@ def _aggregate_score_dicts(scores):
583
583
  """
584
584
 
585
585
  return {
586
- key: np.asarray([score[key] for score in scores])
587
- if isinstance(scores[0][key], numbers.Number)
588
- else [score[key] for score in scores]
586
+ key: (
587
+ np.asarray([score[key] for score in scores])
588
+ if isinstance(scores[0][key], numbers.Number)
589
+ else [score[key] for score in scores]
590
+ )
589
591
  for key in scores[0]
590
592
  }
591
593
 
@@ -970,9 +972,7 @@ def _safe_indexing(X, indices, *, axis=0):
970
972
  return X
971
973
 
972
974
  if axis not in (0, 1):
973
- raise ValueError(
974
- "'axis' should be either 0 (to index rows) or 1 (to index " " column). Got {} instead.".format(axis)
975
- )
975
+ raise ValueError("'axis' should be either 0 (to index rows) or 1 (to index " f" column). Got {axis} instead.")
976
976
 
977
977
  indices_dtype = _determine_key_type(indices)
978
978
 
@@ -983,7 +983,7 @@ def _safe_indexing(X, indices, *, axis=0):
983
983
  raise ValueError(
984
984
  "'X' should be a 2D NumPy array, 2D sparse matrix or pandas "
985
985
  "dataframe when indexing the columns (i.e. 'axis=1'). "
986
- "Got {} instead with {} dimension(s).".format(type(X), X.ndim)
986
+ f"Got {type(X)} instead with {X.ndim} dimension(s)."
987
987
  )
988
988
 
989
989
  if axis == 1 and indices_dtype == "str" and not hasattr(X, "loc"):
@@ -107,7 +107,7 @@ def balance_undersample(
107
107
  min_class_count = vc[min_class_value]
108
108
 
109
109
  min_class_percent = imbalance_threshold / target_classes_count
110
- min_class_threshold = min_class_percent * count
110
+ min_class_threshold = int(min_class_percent * count)
111
111
 
112
112
  resampled_data = df
113
113
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
@@ -132,9 +132,7 @@ def balance_undersample(
132
132
  class_value = classes[class_idx]
133
133
  class_count = vc[class_value]
134
134
  sample_strategy[class_value] = min(class_count, quantile25_class_cnt * multiclass_bootstrap_loops)
135
- sampler = RandomUnderSampler(
136
- sampling_strategy=sample_strategy, random_state=random_state
137
- )
135
+ sampler = RandomUnderSampler(sampling_strategy=sample_strategy, random_state=random_state)
138
136
  X = df[SYSTEM_RECORD_ID]
139
137
  X = X.to_frame(SYSTEM_RECORD_ID)
140
138
  new_x, _ = sampler.fit_resample(X, target) # type: ignore
@@ -153,9 +151,7 @@ def balance_undersample(
153
151
  minority_class = df[df[target_column] == min_class_value]
154
152
  majority_class = df[df[target_column] != min_class_value]
155
153
  sample_size = min(len(majority_class), min_sample_threshold - min_class_count)
156
- sampled_majority_class = majority_class.sample(
157
- n=sample_size, random_state=random_state
158
- )
154
+ sampled_majority_class = majority_class.sample(n=sample_size, random_state=random_state)
159
155
  resampled_data = df[
160
156
  (df[SYSTEM_RECORD_ID].isin(minority_class[SYSTEM_RECORD_ID]))
161
157
  | (df[SYSTEM_RECORD_ID].isin(sampled_majority_class[SYSTEM_RECORD_ID]))
@@ -181,3 +177,21 @@ def balance_undersample(
181
177
 
182
178
  logger.info(f"Shape after rebalance resampling: {resampled_data}")
183
179
  return resampled_data
180
+
181
+
182
+ def calculate_psi(expected: pd.Series, actual: pd.Series) -> float:
183
+ df = pd.concat([expected, actual])
184
+
185
+ # Define the bins for the target variable
186
+ df_min = df.min()
187
+ df_max = df.max()
188
+ bins = [df_min, (df_min + df_max) / 2, df_max]
189
+
190
+ # Calculate the base distribution
191
+ train_distribution = expected.value_counts(bins=bins, normalize=True).sort_index().values
192
+
193
+ # Calculate the target distribution
194
+ test_distribution = actual.value_counts(bins=bins, normalize=True).sort_index().values
195
+
196
+ # Calculate the PSI
197
+ return np.sum((train_distribution - test_distribution) * np.log(train_distribution / test_distribution))