upgini 1.2.113a5__py3-none-any.whl → 1.2.113a3974.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/metadata.py CHANGED
@@ -285,7 +285,6 @@ class FeaturesMetadataV2(BaseModel):
285
285
  doc_link: Optional[str] = None
286
286
  update_frequency: Optional[str] = None
287
287
  from_online_api: Optional[bool] = None
288
- psi_value: Optional[float] = None
289
288
 
290
289
 
291
290
  class HitRateMetrics(BaseModel):
upgini/metrics.py CHANGED
@@ -1175,10 +1175,7 @@ def _ext_mean_squared_log_error(y_true, y_pred, *, sample_weight=None, multioutp
1175
1175
  >>> mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
1176
1176
  0.060...
1177
1177
  """
1178
- try:
1179
- _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
1180
- except TypeError:
1181
- _, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(y_true, y_pred, sample_weight, multioutput)
1178
+ _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
1182
1179
  check_consistent_length(y_true, y_pred, sample_weight)
1183
1180
 
1184
1181
  if (y_true < 0).any():
@@ -123,7 +123,7 @@ train_unstable_target=Your training sample contains an unstable target event, PS
123
123
  eval_unstable_target=Your training and evaluation samples have a difference in target distribution. PSI = {}. The results will be unstable. It is recommended to redesign the training and evaluation samples
124
124
  # eval set validation
125
125
  unsupported_type_eval_set=Unsupported type of eval_set: {}. It should be list of tuples with two elements: X and y
126
- eval_set_invalid_tuple_size=eval_set contains a tuple of size {}. It should contain only pairs of X and y or X only
126
+ eval_set_invalid_tuple_size=eval_set contains a tuple of size {}. It should contain only pairs of X and y
127
127
  unsupported_x_type_eval_set=Unsupported type of X in eval_set: {}. Use pandas.DataFrame, pandas.Series or numpy.ndarray or list.
128
128
  eval_x_and_x_diff_shape=The column set in eval_set are differ from the column set in X
129
129
  unsupported_y_type_eval_set=Unsupported type of y in eval_set: {}. Use pandas.Series, numpy.ndarray or list
@@ -139,8 +139,6 @@ eval_x_is_empty=X in eval_set is empty.
139
139
  eval_y_is_empty=y in eval_set is empty.
140
140
  x_and_eval_x_diff_types=X and eval_set X has different types: {} and {}
141
141
  eval_x_has_train_samples=Eval set X has rows that are present in train set X
142
- oot_without_date_not_supported=Eval set {} provided as OOT but date column is missing. It will be ignored for stability check
143
- oot_with_online_sources_not_supported=Eval set {} provided as OOT and also provided columns for online API. It will be ignored for stability check
144
142
 
145
143
  baseline_score_column_not_exists=baseline_score_column {} doesn't exist in input dataframe
146
144
  baseline_score_column_has_na=baseline_score_column contains NaN. Clear it and and retry
@@ -257,7 +255,6 @@ features_info_provider=Provider
257
255
  features_info_source=Source
258
256
  features_info_name=Feature name
259
257
  features_info_shap=SHAP value
260
- features_info_psi=PSI value
261
258
  features_info_hitrate=Coverage %
262
259
  features_info_type=Type
263
260
  # Deprecated
upgini/sampler/base.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """
2
2
  Base class for the under-sampling method.
3
3
  """
4
-
5
4
  # Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
6
5
  # License: MIT
7
6
 
@@ -13,7 +12,6 @@ import numpy as np
13
12
  from sklearn.base import BaseEstimator
14
13
  from sklearn.preprocessing import label_binarize
15
14
  from sklearn.utils.multiclass import check_classification_targets
16
- from sklearn.utils.validation import check_X_y
17
15
 
18
16
  from .utils import ArraysTransformer, check_sampling_strategy, check_target_type
19
17
 
@@ -127,7 +125,7 @@ class BaseSampler(SamplerMixin):
127
125
  if accept_sparse is None:
128
126
  accept_sparse = ["csr", "csc"]
129
127
  y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
130
- X, y = check_X_y(X, y, accept_sparse=accept_sparse, dtype=None, ensure_all_finite=False)
128
+ X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
131
129
  return X, y, binarize_y
132
130
 
133
131
  def _more_tags(self):
@@ -80,24 +80,14 @@ RandomUnderSampler # doctest: +NORMALIZE_WHITESPACE
80
80
 
81
81
  def _check_X_y(self, X, y):
82
82
  y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
83
- try:
84
- X, y = self._validate_data(
85
- X,
86
- y,
87
- reset=True,
88
- accept_sparse=["csr", "csc"],
89
- dtype=None,
90
- force_all_finite=False,
91
- )
92
- except AttributeError:
93
- from sklearn.utils.validation import check_X_y
94
- X, y = check_X_y(
95
- X,
96
- y,
97
- accept_sparse=["csr", "csc"],
98
- dtype=None,
99
- ensure_all_finite=False,
100
- )
83
+ X, y = self._validate_data(
84
+ X,
85
+ y,
86
+ reset=True,
87
+ accept_sparse=["csr", "csc"],
88
+ dtype=None,
89
+ force_all_finite=False,
90
+ )
101
91
  return X, y, binarize_y
102
92
 
103
93
  def _fit_resample(self, X, y):
@@ -136,9 +136,6 @@ def remove_fintech_duplicates(
136
136
  # Process each eval_set part separately
137
137
  new_eval_dfs = []
138
138
  for i, eval_df in enumerate(eval_dfs, 1):
139
- # Skip OOT
140
- if eval_df[TARGET].isna().all():
141
- continue
142
139
  logger.info(f"Eval {i} dataset shape before clean fintech duplicates: {eval_df.shape}")
143
140
  cleaned_eval_df, eval_warning = process_df(eval_df, i)
144
141
  if eval_warning:
@@ -193,49 +190,16 @@ def clean_full_duplicates(
193
190
  msg = None
194
191
  if TARGET in df.columns:
195
192
  unique_columns.remove(TARGET)
196
-
197
- # Separate rows to exclude from deduplication:
198
- # for each eval_set_index != 0 check separately, all TARGET values are NaN
199
- excluded_from_dedup = pd.DataFrame()
200
- df_for_dedup = df
201
-
202
- if EVAL_SET_INDEX in df.columns:
203
- excluded_parts = []
204
- # Get all unique eval_set_index values, except 0
205
- unique_eval_indices = df[df[EVAL_SET_INDEX] != 0][EVAL_SET_INDEX].unique()
206
-
207
- for eval_idx in unique_eval_indices:
208
- eval_subset = df[df[EVAL_SET_INDEX] == eval_idx]
209
- # Check that all TARGET values for this specific eval_set_index are NaN
210
- if len(eval_subset) > 0 and eval_subset[TARGET].isna().all():
211
- excluded_parts.append(eval_subset)
212
- logger.info(
213
- f"Excluded {len(eval_subset)} rows from deduplication "
214
- f"(eval_set_index={eval_idx} and all TARGET values are NaN)"
215
- )
216
-
217
- # Combine all excluded parts
218
- if excluded_parts:
219
- excluded_from_dedup = pd.concat(excluded_parts, ignore_index=False)
220
- # Remove excluded rows from dataframe for deduplication
221
- excluded_indices = excluded_from_dedup.index
222
- df_for_dedup = df[~df.index.isin(excluded_indices)]
223
- marked_duplicates = df_for_dedup.duplicated(subset=unique_columns, keep=False)
193
+ marked_duplicates = df.duplicated(subset=unique_columns, keep=False)
224
194
  if marked_duplicates.sum() > 0:
225
- dups_indices = df_for_dedup[marked_duplicates].index.to_list()[:100]
226
- nrows_after_tgt_dedup = len(df_for_dedup.drop_duplicates(subset=unique_columns, keep=False))
227
- num_dup_rows = len(df_for_dedup) - nrows_after_tgt_dedup
228
- share_tgt_dedup = 100 * num_dup_rows / len(df_for_dedup)
195
+ dups_indices = df[marked_duplicates].index.to_list()[:100]
196
+ nrows_after_tgt_dedup = len(df.drop_duplicates(subset=unique_columns, keep=False))
197
+ num_dup_rows = nrows_after_full_dedup - nrows_after_tgt_dedup
198
+ share_tgt_dedup = 100 * num_dup_rows / nrows_after_full_dedup
229
199
 
230
200
  msg = bundle.get("dataset_diff_target_duplicates").format(share_tgt_dedup, num_dup_rows, dups_indices)
231
- df_for_dedup = df_for_dedup.drop_duplicates(subset=unique_columns, keep=False)
232
- logger.info(f"Dataset shape after clean invalid target duplicates: {df_for_dedup.shape}")
233
- # Combine back excluded rows
234
- if len(excluded_from_dedup) > 0:
235
- df = pd.concat([df_for_dedup, excluded_from_dedup], ignore_index=False)
236
- logger.info(f"Final dataset shape after adding back excluded rows: {df.shape}")
237
- else:
238
- df = df_for_dedup
201
+ df = df.drop_duplicates(subset=unique_columns, keep=False)
202
+ logger.info(f"Dataset shape after clean invalid target duplicates: {df.shape}")
239
203
 
240
204
  return df, msg
241
205
 
@@ -27,7 +27,6 @@ class FeatureInfo:
27
27
  doc_link: str
28
28
  data_provider_link: str
29
29
  data_source_link: str
30
- psi_value: Optional[float] = None
31
30
 
32
31
  @staticmethod
33
32
  def from_metadata(
@@ -48,14 +47,12 @@ class FeatureInfo:
48
47
  doc_link=feature_meta.doc_link,
49
48
  data_provider_link=feature_meta.data_provider_link,
50
49
  data_source_link=feature_meta.data_source_link,
51
- psi_value=feature_meta.psi_value,
52
50
  )
53
51
 
54
52
  def to_row(self, bundle: ResourceBundle) -> Dict[str, str]:
55
53
  return {
56
54
  bundle.get("features_info_name"): self.name,
57
55
  bundle.get("features_info_shap"): self.rounded_shap,
58
- bundle.get("features_info_psi"): self.psi_value,
59
56
  bundle.get("features_info_hitrate"): self.hitrate,
60
57
  bundle.get("features_info_value_preview"): self.value_preview,
61
58
  bundle.get("features_info_provider"): self.provider,
@@ -67,7 +64,6 @@ class FeatureInfo:
67
64
  return {
68
65
  bundle.get("features_info_name"): self.internal_name,
69
66
  bundle.get("features_info_shap"): self.rounded_shap,
70
- bundle.get("features_info_psi"): self.psi_value,
71
67
  bundle.get("features_info_hitrate"): self.hitrate,
72
68
  bundle.get("features_info_value_preview"): self.value_preview,
73
69
  bundle.get("features_info_provider"): self.internal_provider,
@@ -80,7 +76,6 @@ class FeatureInfo:
80
76
  bundle.get("features_info_name"): self.internal_name,
81
77
  "feature_link": self.doc_link,
82
78
  bundle.get("features_info_shap"): self.rounded_shap,
83
- bundle.get("features_info_psi"): self.psi_value,
84
79
  bundle.get("features_info_hitrate"): self.hitrate,
85
80
  bundle.get("features_info_value_preview"): self.value_preview,
86
81
  bundle.get("features_info_provider"): self.internal_provider,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.113a5
3
+ Version: 1.2.113a3974.dev2
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -1,12 +1,12 @@
1
- upgini/__about__.py,sha256=QdA0r4M8wEBY37BMjK9uA_83s1sWkyXy2XJhfn7vl3A,26
1
+ upgini/__about__.py,sha256=ziYMT-cCb1zPGJYidvejUtxXlUCjQLvR25p82kAy21c,34
2
2
  upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
3
3
  upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
4
4
  upgini/dataset.py,sha256=xFi0a-A3uvtxVwFM6JOyitkEPd1I2slIBj5SWfys3hQ,32724
5
5
  upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
6
- upgini/features_enricher.py,sha256=wifdmDP-3e3y51KYhCHPYuN6vU8mj2m3SYo-kMWcNz0,234523
6
+ upgini/features_enricher.py,sha256=rfVdHgUYEq9saqhWcI04jUmNQcAAn5Kto4w3WpxlOpA,221762
7
7
  upgini/http.py,sha256=zeAZvT6IAzOs9jQ3WG8mJBANLajgvv2LZePFzKz004w,45482
8
- upgini/metadata.py,sha256=sx4X9fPkyCgXB6FPk9Rq_S1Kx8ibkbaWA-qNDVCuSmg,12811
9
- upgini/metrics.py,sha256=O19UqmgZ6SA136eCYV5lVU3J26ecgZlGXnxGblMvZJc,45869
8
+ upgini/metadata.py,sha256=9_0lFEWPpIHRBW-xWYSEcwPzICTC6_bQ6dUUlE75Xns,12773
9
+ upgini/metrics.py,sha256=V2SP6NS5bfFHzRqufeKVsCXME1yG4t_8Dmk2E3zKdYk,45715
10
10
  upgini/search_task.py,sha256=Q5HjBpLIB3OCxAD1zNv5yQ3ZNJx696WCK_-H35_y7Rs,17912
11
11
  upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
12
12
  upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
@@ -15,7 +15,7 @@ upgini/ads_management/ads_manager.py,sha256=igVbN2jz80Umb2BUJixmJVj-zx8unoKpecVo
15
15
  upgini/autofe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  upgini/autofe/all_operators.py,sha256=rdjF5eaE4bC6Q4eu_el5Z7ekYt8DjOFermz2bePPbUc,333
17
17
  upgini/autofe/binary.py,sha256=oOEECc4nRzZN2tYaiqx8F2XHnfWpk1bVvb7ZkZJ0lO8,7709
18
- upgini/autofe/date.py,sha256=Ga022BUSgXJ4W3P8uWkPNo6k6J0IuEZw6Ezs9KNikPk,11188
18
+ upgini/autofe/date.py,sha256=RvexgrL1_6ISYPVrl9HUQmPgpVSGQsTNv8YhNQWs-5M,11329
19
19
  upgini/autofe/feature.py,sha256=b4Ps_sCPui9b4h0K3ya85cfL1SWpLVrlHc40zkKVfAY,16329
20
20
  upgini/autofe/groupby.py,sha256=IYmQV9uoCdRcpkeWZj_kI3ObzoNCNx3ff3h8sTL01tk,3603
21
21
  upgini/autofe/operator.py,sha256=RB3rKMjFi5Cx81RiYXN3OTCuXjmvzmFKQrxn4h0Oclo,5219
@@ -38,11 +38,11 @@ upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
38
38
  upgini/normalizer/normalize_utils.py,sha256=mDh2mBW3aQMB4EFP2aHbf2dGMVkOcWnp4sKKvKDBh8w,8511
39
39
  upgini/resource_bundle/__init__.py,sha256=S5F2G47pnJd2LDpmFsjDqEwiKkP8Hm-hcseDbMka6Ko,8345
40
40
  upgini/resource_bundle/exceptions.py,sha256=5fRvx0_vWdE1-7HcSgF0tckB4A9AKyf5RiinZkInTsI,621
41
- upgini/resource_bundle/strings.properties,sha256=6Q3dwI0v1aiXt7_3Xx0Ih6jMmSCBaaRGIoUiZ5-VnCY,28988
41
+ upgini/resource_bundle/strings.properties,sha256=NyxRwzehkrL5LMoVyjkhN811MvalepavNfjlC9ubE0Q,28677
42
42
  upgini/resource_bundle/strings_widget.properties,sha256=gOdqvZWntP2LCza_tyVk1_yRYcG4c04K9sQOAVhF_gw,1577
43
43
  upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- upgini/sampler/base.py,sha256=Fva2FEhLiNRPZ9Q6uOtJRtRzwsayjv7aphalAZO_4lc,6452
45
- upgini/sampler/random_under_sampler.py,sha256=4mofmaRTmNwT_HqxecWJyfXdLKK0h9jMBwS46xdrIqE,4356
44
+ upgini/sampler/base.py,sha256=7GpjYqjOp58vYcJLiX__1R5wjUlyQbxvHJ2klFnup_M,6389
45
+ upgini/sampler/random_under_sampler.py,sha256=TIbm7ATo-bCMF-IiS5sZeDC1ad1SYg0eY_rRmg84yIQ,4024
46
46
  upgini/sampler/utils.py,sha256=PYOk3kKSnFlyxcpdtDNLBEEhTB4lO_iP7pQHqeUcmAc,20211
47
47
  upgini/utils/Roboto-Regular.ttf,sha256=kqYnZjMRQMpbyLulIChCLSdgYa1XF8GsUIoRi2Gcauw,168260
48
48
  upgini/utils/__init__.py,sha256=O_KgzKiJjW3g4NoqZ7lAxUpoHcBi_gze6r3ndEjCH74,842
@@ -52,11 +52,11 @@ upgini/utils/country_utils.py,sha256=lY-eXWwFVegdVENFttbvLcgGDjFO17Sex8hd2PyJaRk
52
52
  upgini/utils/custom_loss_utils.py,sha256=kieNZYBYZm5ZGBltF1F_jOSF4ea6C29rYuCyiDcqVNY,3857
53
53
  upgini/utils/cv_utils.py,sha256=w6FQb9nO8BWDx88EF83NpjPLarK4eR4ia0Wg0kLBJC4,3525
54
54
  upgini/utils/datetime_utils.py,sha256=UL1ernnawW0LV9mPDpCIc6sFy0HUhFscWVNwfH4V7rI,14366
55
- upgini/utils/deduplicate_utils.py,sha256=xXashCSIg87gCy6QyXc0eb8huuzPLANmckMVxUVBEgM,10729
55
+ upgini/utils/deduplicate_utils.py,sha256=EpBVCov42-FJIAPfa4jY_ZRct3N2MFaC7i-oJNZ_MGI,8954
56
56
  upgini/utils/display_utils.py,sha256=Ou7dYdgvvdh443OgOLTM_xKwC2ITx9DQrpKoC2vCRYc,11856
57
57
  upgini/utils/email_utils.py,sha256=pZ2vCfNxLIPUhxr0-OlABNXm12jjU44isBk8kGmqQzA,5277
58
58
  upgini/utils/fallback_progress_bar.py,sha256=PDaKb8dYpVZaWMroNcOHsTc3pSjgi9mOm0--cOFTwJ0,1074
59
- upgini/utils/feature_info.py,sha256=6vihytwKma_TlXtTn4l6Aj4kqlOj0ouLy-yWVV6VUw8,7551
59
+ upgini/utils/feature_info.py,sha256=b3RvAeOHSEu-ZXWTrf42Dll_3ZUBL0pw7sdk7hgUKD0,7284
60
60
  upgini/utils/features_validator.py,sha256=lEfmk4DoxZ4ooOE1HC0ZXtUb_lFKRFHIrnFULZ4_rL8,3746
61
61
  upgini/utils/format.py,sha256=Yv5cvvSs2bOLUzzNu96Pu33VMDNbabio92QepUj41jU,243
62
62
  upgini/utils/ip_utils.py,sha256=wmnnwVQdjX9o1cNQw6VQMk6maHhvsq6hNsZBYf9knrw,6585
@@ -64,7 +64,6 @@ upgini/utils/mstats.py,sha256=u3gQVUtDRbyrOQK6V1UJ2Rx1QbkSNYGjXa6m3Z_dPVs,6286
64
64
  upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,10432
65
65
  upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
66
66
  upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
67
- upgini/utils/psi.py,sha256=pLtECcCeco_WRqMjFnQvhUB4vHArjHtD5HzJFP9ICMc,10972
68
67
  upgini/utils/sample_utils.py,sha256=lZJ4yf9Jiq9Em2Ny9m3RIiF7WSxBPrc4E3xxn_8sQk8,15417
69
68
  upgini/utils/sklearn_ext.py,sha256=jLJWAKkqQinV15Z4y1ZnsN3c-fKFwXTsprs00COnyVU,49315
70
69
  upgini/utils/sort.py,sha256=8uuHs2nfSMVnz8GgvbOmgMB1PgEIZP1uhmeRFxcwnYw,7039
@@ -72,7 +71,7 @@ upgini/utils/target_utils.py,sha256=i3Xt5l9ybB2_nF_ma5cfPuL3OeFTs2dY2xDI0p4Azpg,
72
71
  upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
73
72
  upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
74
73
  upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
75
- upgini-1.2.113a5.dist-info/METADATA,sha256=VOeoK4hhJyhb0OJWG2cgsN-hES6xe3QIRyZMovxP8ek,49531
76
- upgini-1.2.113a5.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
77
- upgini-1.2.113a5.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
78
- upgini-1.2.113a5.dist-info/RECORD,,
74
+ upgini-1.2.113a3974.dev2.dist-info/METADATA,sha256=RC2p2RrCBlPWX6hGAcLGtt-k6wOmmq2DFhetxg3LvGk,49539
75
+ upgini-1.2.113a3974.dev2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
76
+ upgini-1.2.113a3974.dev2.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
77
+ upgini-1.2.113a3974.dev2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.24.2
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
upgini/utils/psi.py DELETED
@@ -1,294 +0,0 @@
1
- import itertools
2
- import logging
3
- import operator
4
- from functools import reduce
5
- from typing import Callable, Dict, Optional
6
-
7
- import more_itertools
8
- import numpy as np
9
- import pandas as pd
10
- from pandas.api.types import is_numeric_dtype
11
- from pydantic import BaseModel
12
-
13
- from upgini.metadata import TARGET, ModelTaskType
14
-
15
-
16
- class StabilityParams(BaseModel):
17
- threshold: float = 999
18
- n_intervals: int = 12
19
- min_intervals: int = 10
20
- max_intervals: Optional[int] = None
21
- min_values_in_interval: Optional[int] = None
22
- n_bins: int = 10
23
- min_values_in_bin: Optional[int] = None
24
- cat_top_pct: float = 0.7
25
- agg: str = "max"
26
-
27
-
28
- DEFAULT_TARGET_PARAMS = StabilityParams(
29
- n_intervals=12,
30
- min_intervals=10,
31
- max_intervals=None,
32
- min_values_in_interval=None,
33
- n_bins=5,
34
- )
35
-
36
- DEFAULT_FEATURES_PARAMS = StabilityParams(
37
- n_intervals=12,
38
- min_intervals=10,
39
- max_intervals=None,
40
- min_values_in_interval=None,
41
- n_bins=10,
42
- )
43
-
44
-
45
- def calculate_sparsity_psi(
46
- df: pd.DataFrame,
47
- cat_features: list[str],
48
- date_column: str,
49
- logger: logging.Logger,
50
- model_task_type: ModelTaskType,
51
- psi_features_params: StabilityParams = DEFAULT_FEATURES_PARAMS,
52
- psi_target_params: StabilityParams = DEFAULT_TARGET_PARAMS,
53
- ) -> Dict[str, float]:
54
- sparse_features = df.columns[df.isna().sum() > 0].to_list()
55
- if len(sparse_features) > 0:
56
- logger.info(f"Calculating sparsity stability for {len(sparse_features)} sparse features")
57
- sparse_df = df[sparse_features].notna()
58
- sparse_df[date_column] = df[date_column]
59
- return calculate_features_psi(
60
- sparse_df,
61
- cat_features,
62
- date_column,
63
- logger,
64
- model_task_type,
65
- psi_target_params,
66
- psi_features_params,
67
- )
68
- return {}
69
-
70
-
71
- def calculate_features_psi(
72
- df: pd.DataFrame,
73
- cat_features: list[str],
74
- date_column: str,
75
- logger: logging.Logger,
76
- model_task_type: ModelTaskType,
77
- psi_features_params: StabilityParams = DEFAULT_FEATURES_PARAMS,
78
- psi_target_params: StabilityParams = DEFAULT_TARGET_PARAMS,
79
- ) -> Dict[str, float]:
80
- empty_res = pd.Series(index=df.columns, data=0)
81
-
82
- if not is_numeric_dtype(df[date_column]):
83
- df[date_column] = pd.to_datetime(df[date_column]).dt.floor("D").astype(np.int64) / 10**6
84
-
85
- n_months = pd.to_datetime(df[date_column], unit="ms").dt.month.nunique()
86
-
87
- if TARGET in df.columns:
88
- psi_target_params.n_intervals = min(
89
- psi_target_params.max_intervals or np.inf, max(psi_target_params.min_intervals, n_months)
90
- )
91
- logger.info(f"Setting {psi_target_params.n_intervals} intervals for target PSI check")
92
-
93
- logger.info(f"Calculating target PSI for {psi_target_params.n_intervals} intervals")
94
- reference_mask, current_masks = _split_intervals(df, date_column, psi_target_params.n_intervals, logger)
95
-
96
- if psi_target_params.min_values_in_interval is not None and any(
97
- len(mask) < psi_target_params.min_values_in_interval
98
- for mask in itertools.chain(current_masks, [reference_mask])
99
- ):
100
- logger.info(
101
- f"Some intervals have less than {psi_target_params.min_values_in_interval} values. Skip PSI check"
102
- )
103
- return empty_res
104
-
105
- target_agg_func = _get_agg_func(psi_target_params.agg)
106
- logger.info(f"Calculating target PSI with agg function {target_agg_func}")
107
- target_psi = _stability_agg(
108
- [df[TARGET][cur] for cur in current_masks],
109
- reference_data=df[TARGET][reference_mask],
110
- is_numerical=model_task_type == ModelTaskType.REGRESSION,
111
- min_values_in_bin=psi_target_params.min_values_in_bin,
112
- n_bins=psi_target_params.n_bins,
113
- cat_top_pct=psi_target_params.cat_top_pct,
114
- agg_func=target_agg_func,
115
- )
116
- if target_psi is None:
117
- logger.info("Cannot determine target PSI. Skip feature PSI check")
118
- return pd.Series(index=df.columns, data=0)
119
-
120
- if target_psi > psi_target_params.threshold:
121
- logger.info(
122
- f"Target PSI {target_psi} is more than threshold {psi_target_params.threshold}. Skip feature PSI check"
123
- )
124
- return empty_res
125
-
126
- psi_features_params.n_intervals = min(
127
- psi_features_params.max_intervals or np.inf, max(psi_features_params.min_intervals, n_months)
128
- )
129
- logger.info(f"Setting {psi_features_params.n_intervals} intervals for features PSI check")
130
-
131
- logger.info(f"Calculating PSI for {len(df.columns)} features")
132
- reference_mask, current_masks = _split_intervals(df, date_column, psi_features_params.n_intervals, logger)
133
- features_agg_func = _get_agg_func(psi_features_params.agg)
134
- logger.info(f"Calculating features PSI with agg function {features_agg_func}")
135
- psi_values = [
136
- _stability_agg(
137
- [df[feature][cur] for cur in current_masks],
138
- reference_data=df[feature][reference_mask],
139
- is_numerical=feature not in cat_features,
140
- min_values_in_bin=psi_features_params.min_values_in_bin,
141
- n_bins=psi_features_params.n_bins,
142
- cat_top_pct=psi_features_params.cat_top_pct,
143
- agg_func=features_agg_func,
144
- )
145
- for feature in df.columns
146
- if feature not in [TARGET, date_column]
147
- ]
148
- return {feature: psi for feature, psi in zip(df.columns, psi_values)}
149
-
150
-
151
- def _split_intervals(
152
- df: pd.DataFrame, date_column: str, n_intervals: int, logger: logging.Logger
153
- ) -> tuple[pd.Series, list[pd.Series]]:
154
- date_series = df[date_column]
155
-
156
- # Check if we have enough unique values for the requested number of intervals
157
- unique_values = date_series.nunique()
158
-
159
- # If we have fewer unique values than requested intervals, adjust n_intervals
160
- if unique_values < n_intervals:
161
- logger.warning(f"Date column '{date_column}' has only {unique_values} unique values")
162
-
163
- time_intervals = pd.qcut(date_series, q=n_intervals, duplicates="drop")
164
- interval_labels = time_intervals.unique()
165
- reference_mask = time_intervals == interval_labels[0]
166
- current_masks = [time_intervals == label for label in interval_labels[1:]]
167
- return reference_mask, current_masks
168
-
169
-
170
- def _get_agg_func(agg: str):
171
- np_agg = getattr(np, agg, None)
172
- if np_agg is None and agg.startswith("q"):
173
- q = int(agg[1:])
174
- return lambda x: np.quantile(list(x), q / 100, method="higher")
175
- return np_agg
176
-
177
-
178
- def _psi(reference_percent: np.ndarray, current_percent: np.ndarray) -> float:
179
- return np.sum((reference_percent - current_percent) * np.log(reference_percent / current_percent))
180
-
181
-
182
- def _stability_agg(
183
- current_data: list[pd.Series],
184
- reference_data: pd.Series,
185
- is_numerical: bool = True,
186
- min_values_in_bin: int | None = None,
187
- n_bins: int = 10,
188
- cat_top_pct: float = 0.7,
189
- agg_func: Callable = max,
190
- ) -> float | None:
191
- """Calculate the PSI
192
- Args:
193
- current_data: current data
194
- reference_data: reference data
195
- is_numerical: whether the feature is numerical
196
- reference_ratio: ratio of current data to use as reference if reference_data is not provided
197
- min_values_in_bin: minimum number of values in a bin to calculate PSI
198
- n_bins: number of bins to use for numerical features
199
- Returns:
200
- psi_value: calculated PSI
201
- """
202
- reference, current = _get_binned_data(reference_data, current_data, is_numerical, n_bins, cat_top_pct)
203
-
204
- if len(reference) == 0 or len(current) == 0:
205
- return None
206
-
207
- nonempty_current = [i for i, c in enumerate(current) if len(c) > 0]
208
- current = [current[i] for i in nonempty_current]
209
- current_data = [current_data[i] for i in nonempty_current]
210
-
211
- if len(current) == 0:
212
- return None
213
-
214
- if min_values_in_bin is not None and (
215
- np.array(reference).min() < min_values_in_bin or any(np.array(c).min() < min_values_in_bin for c in current)
216
- ):
217
- return None
218
-
219
- reference = _fill_zeroes(reference / len(reference_data))
220
- current = [_fill_zeroes(c / len(d)) for c, d in zip(current, current_data)]
221
-
222
- psi_value = agg_func([_psi(reference, c) for c in current])
223
-
224
- return psi_value
225
-
226
-
227
- def _get_binned_data(
228
- reference_data: pd.Series,
229
- current_data: list[pd.Series],
230
- is_numerical: bool,
231
- n_bins: int,
232
- cat_top_pct: float,
233
- ):
234
- """Split variable into n buckets based on reference quantiles
235
- Args:
236
- reference_data: reference data
237
- current_data: current data
238
- feature_type: feature type
239
- n: number of quantiles
240
- Returns:
241
- reference_counts: number of records in each bucket for reference
242
- current_counts: number of records in each bucket for current
243
- """
244
- n_vals = reference_data.nunique()
245
-
246
- if is_numerical and n_vals > 20:
247
- bins = _get_bin_edges(reference_data, n_bins)
248
- reference_counts = np.histogram(reference_data, bins)[0]
249
- current_counts = [np.histogram(d, bins)[0] for d in current_data]
250
-
251
- else:
252
- keys = _get_unique_not_nan_values_list_from_series([reference_data] + current_data)
253
- ref_feature_dict = {**dict.fromkeys(keys, 0), **dict(reference_data.value_counts())}
254
- current_feature_dict = [{**dict.fromkeys(keys, 0), **dict(d.value_counts())} for d in current_data]
255
- key_dict = more_itertools.map_reduce(
256
- itertools.chain(ref_feature_dict.items(), *(d.items() for d in current_feature_dict)),
257
- keyfunc=operator.itemgetter(0),
258
- valuefunc=operator.itemgetter(1),
259
- reducefunc=sum,
260
- )
261
- key_dict = pd.Series(key_dict)
262
- keys = key_dict.index[key_dict.rank(pct=True) >= cat_top_pct]
263
- reference_counts = np.array([ref_feature_dict[key] for key in keys])
264
- current_counts = [np.array([current_feature_dict[i][key] for key in keys]) for i in range(len(current_data))]
265
-
266
- reference_counts = np.append(reference_counts, reference_data.isna().sum())
267
- current_counts = [np.append(d, current_data[i].isna().sum()) for i, d in enumerate(current_counts)]
268
-
269
- return reference_counts, current_counts
270
-
271
-
272
- def _fill_zeroes(percents: np.ndarray) -> np.ndarray:
273
- eps = 0.0001
274
- if (percents == 0).all():
275
- np.place(percents, percents == 0, eps)
276
- else:
277
- min_value = min(percents[percents != 0])
278
- if min_value <= eps:
279
- np.place(percents, percents == 0, eps)
280
- else:
281
- np.place(percents, percents == 0, min_value / 10**6)
282
- return percents
283
-
284
-
285
- def _get_bin_edges(data: pd.Series, n_bins: int) -> np.ndarray:
286
- bins = np.nanquantile(data, np.linspace(0, 1, n_bins + 1))
287
- bins[0] = -np.inf
288
- bins[-1] = np.inf
289
- return bins
290
-
291
-
292
- def _get_unique_not_nan_values_list_from_series(series: list[pd.Series]) -> list:
293
- """Get unique values from current and reference series, drop NaNs"""
294
- return list(reduce(set.union, (set(s.dropna().unique()) for s in series)))