upgini 1.2.39a1__py3-none-any.whl → 1.2.39a3769.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/__about__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.2.39a1"
1
+ __version__ = "1.2.39a3769.dev2"
upgini/dataset.py CHANGED
@@ -22,6 +22,7 @@ from upgini.metadata import (
22
22
  EVAL_SET_INDEX,
23
23
  SYSTEM_RECORD_ID,
24
24
  TARGET,
25
+ CVType,
25
26
  DataType,
26
27
  FeaturesFilter,
27
28
  FileColumnMeaningType,
@@ -32,11 +33,12 @@ from upgini.metadata import (
32
33
  NumericInterval,
33
34
  RuntimeParameters,
34
35
  SearchCustomization,
36
+ SearchKey,
35
37
  )
36
38
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
37
39
  from upgini.search_task import SearchTask
38
40
  from upgini.utils.email_utils import EmailSearchKeyConverter
39
- from upgini.utils.target_utils import balance_undersample, balance_undersample_forced
41
+ from upgini.utils.target_utils import balance_undersample, balance_undersample_forced, balance_undersample_time_series
40
42
 
41
43
  try:
42
44
  from upgini.utils.progress_bar import CustomProgressBar as ProgressBar
@@ -74,6 +76,9 @@ class Dataset: # (pd.DataFrame):
74
76
  search_keys: Optional[List[Tuple[str, ...]]] = None,
75
77
  unnest_search_keys: Optional[Dict[str, str]] = None,
76
78
  model_task_type: Optional[ModelTaskType] = None,
79
+ cv_type: Optional[CVType] = None,
80
+ date_column: Optional[str] = None,
81
+ id_columns: Optional[List[str]] = None,
77
82
  random_state: Optional[int] = None,
78
83
  rest_client: Optional[_RestClient] = None,
79
84
  logger: Optional[logging.Logger] = None,
@@ -104,6 +109,7 @@ class Dataset: # (pd.DataFrame):
104
109
 
105
110
  self.dataset_name = dataset_name
106
111
  self.task_type = model_task_type
112
+ self.cv_type = cv_type
107
113
  self.description = description
108
114
  self.meaning_types = meaning_types
109
115
  self.search_keys = search_keys
@@ -116,6 +122,8 @@ class Dataset: # (pd.DataFrame):
116
122
  self.random_state = random_state
117
123
  self.columns_renaming: Dict[str, str] = {}
118
124
  self.imbalanced: bool = False
125
+ self.id_columns = id_columns
126
+ self.date_column = date_column
119
127
  if logger is not None:
120
128
  self.logger = logger
121
129
  else:
@@ -225,6 +233,9 @@ class Dataset: # (pd.DataFrame):
225
233
  df=self.data,
226
234
  target_column=target_column,
227
235
  task_type=self.task_type,
236
+ cv_type=self.cv_type,
237
+ date_column=self.date_column,
238
+ id_columns=self.id_columns,
228
239
  random_state=self.random_state,
229
240
  sample_size=self.FORCE_SAMPLE_SIZE,
230
241
  logger=self.logger,
@@ -297,7 +308,21 @@ class Dataset: # (pd.DataFrame):
297
308
  f"Etalon has size {len(self.data)} more than threshold {sample_threshold} "
298
309
  f"and will be downsampled to {sample_rows}"
299
310
  )
300
- resampled_data = self.data.sample(n=sample_rows, random_state=self.random_state)
311
+ if self.cv_type is not None and self.cv_type.is_time_series():
312
+ resampled_data = balance_undersample_time_series(
313
+ df=self.data,
314
+ id_columns=self.id_columns,
315
+ date_column=next(
316
+ k
317
+ for k, v in self.meaning_types.items()
318
+ if v in [FileColumnMeaningType.DATE, FileColumnMeaningType.DATETIME]
319
+ ),
320
+ sample_size=sample_rows,
321
+ random_state=self.random_state,
322
+ logger=self.logger,
323
+ )
324
+ else:
325
+ resampled_data = self.data.sample(n=sample_rows, random_state=self.random_state)
301
326
  self.data = resampled_data
302
327
  self.logger.info(f"Shape after threshold resampling: {self.data.shape}")
303
328
 
@@ -237,6 +237,7 @@ class FeaturesEnricher(TransformerMixin):
237
237
  add_date_if_missing: bool = True,
238
238
  select_features: bool = False,
239
239
  disable_force_downsampling: bool = False,
240
+ id_columns: Optional[List[str]] = None,
240
241
  **kwargs,
241
242
  ):
242
243
  self.bundle = get_custom_bundle(custom_bundle_config)
@@ -277,9 +278,12 @@ class FeaturesEnricher(TransformerMixin):
277
278
  )
278
279
 
279
280
  validate_version(self.logger, self.__log_warning)
281
+
280
282
  self.search_keys = search_keys or {}
283
+ self.id_columns = id_columns
281
284
  self.country_code = country_code
282
285
  self.__validate_search_keys(search_keys, search_id)
286
+
283
287
  self.model_task_type = model_task_type
284
288
  self.endpoint = endpoint
285
289
  self._search_task: Optional[SearchTask] = None
@@ -928,6 +932,8 @@ class FeaturesEnricher(TransformerMixin):
928
932
  cat_features, search_keys_for_metrics = self._get_client_cat_features(
929
933
  estimator, validated_X, self.search_keys
930
934
  )
935
+ search_keys_for_metrics.extend([c for c in self.id_columns or [] if c not in search_keys_for_metrics])
936
+ self.logger.info(f"Search keys for metrics: {search_keys_for_metrics}")
931
937
 
932
938
  prepared_data = self._prepare_data_for_metrics(
933
939
  trace_id=trace_id,
@@ -983,7 +989,7 @@ class FeaturesEnricher(TransformerMixin):
983
989
  with Spinner():
984
990
  self._check_train_and_eval_target_distribution(y_sorted, fitting_eval_set_dict)
985
991
 
986
- has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
992
+ has_date = self._get_date_column(search_keys) is not None
987
993
  model_task_type = self.model_task_type or define_task(y_sorted, has_date, self.logger, silent=True)
988
994
 
989
995
  wrapper = EstimatorWrapper.create(
@@ -1185,7 +1191,7 @@ class FeaturesEnricher(TransformerMixin):
1185
1191
  )
1186
1192
 
1187
1193
  uplift_col = self.bundle.get("quality_metrics_uplift_header")
1188
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1194
+ date_column = self._get_date_column(search_keys)
1189
1195
  if (
1190
1196
  uplift_col in metrics_df.columns
1191
1197
  and (metrics_df[uplift_col] < 0).any()
@@ -1354,7 +1360,7 @@ class FeaturesEnricher(TransformerMixin):
1354
1360
  groups = None
1355
1361
 
1356
1362
  if not isinstance(_cv, BaseCrossValidator):
1357
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1363
+ date_column = self._get_date_column(search_keys)
1358
1364
  date_series = X[date_column] if date_column is not None else None
1359
1365
  _cv, groups = CVConfig(
1360
1366
  _cv, date_series, self.random_state, self._search_task.get_shuffle_kfold(), group_columns=group_columns
@@ -1443,9 +1449,13 @@ class FeaturesEnricher(TransformerMixin):
1443
1449
 
1444
1450
  excluding_search_keys = list(search_keys.keys())
1445
1451
  if search_keys_for_metrics is not None and len(search_keys_for_metrics) > 0:
1452
+ excluded = set()
1446
1453
  for sk in excluding_search_keys:
1447
1454
  if columns_renaming.get(sk) in search_keys_for_metrics:
1448
- excluding_search_keys.remove(sk)
1455
+ excluded.add(sk)
1456
+ excluding_search_keys = [sk for sk in excluding_search_keys if sk not in excluded]
1457
+
1458
+ self.logger.info(f"Excluding search keys: {excluding_search_keys}")
1449
1459
 
1450
1460
  client_features = [
1451
1461
  c
@@ -1667,7 +1677,7 @@ class FeaturesEnricher(TransformerMixin):
1667
1677
  search_keys = self.search_keys.copy()
1668
1678
  search_keys = self.__prepare_search_keys(df, search_keys, is_demo_dataset, is_transform=True, silent_mode=True)
1669
1679
 
1670
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
1680
+ date_column = self._get_date_column(search_keys)
1671
1681
  generated_features = []
1672
1682
  if date_column is not None:
1673
1683
  converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, self.bundle)
@@ -1741,7 +1751,7 @@ class FeaturesEnricher(TransformerMixin):
1741
1751
  search_keys = self.fit_search_keys
1742
1752
 
1743
1753
  rows_to_drop = None
1744
- has_date = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]) is not None
1754
+ has_date = self._get_date_column(search_keys) is not None
1745
1755
  self.model_task_type = self.model_task_type or define_task(
1746
1756
  self.df_with_original_index[TARGET], has_date, self.logger, silent=True
1747
1757
  )
@@ -1853,7 +1863,10 @@ class FeaturesEnricher(TransformerMixin):
1853
1863
  df = balance_undersample_forced(
1854
1864
  df=df,
1855
1865
  target_column=TARGET,
1866
+ id_columns=self.id_columns,
1867
+ date_column=self._get_date_column(self.search_keys),
1856
1868
  task_type=self.model_task_type,
1869
+ cv_type=self.cv,
1857
1870
  random_state=self.random_state,
1858
1871
  sample_size=Dataset.FORCE_SAMPLE_SIZE,
1859
1872
  logger=self.logger,
@@ -1995,7 +2008,7 @@ class FeaturesEnricher(TransformerMixin):
1995
2008
  trace_id = trace_id or uuid.uuid4()
1996
2009
  return search_task.get_progress(trace_id)
1997
2010
 
1998
- def get_transactional_transform_api(self, only_online_sources=False):
2011
+ def get_transactional_transform_api(self):
1999
2012
  if self.api_key is None:
2000
2013
  raise ValidationError(self.bundle.get("transactional_transform_unregistered"))
2001
2014
  if self._search_task is None:
@@ -2053,7 +2066,7 @@ class FeaturesEnricher(TransformerMixin):
2053
2066
  api_example = f"""curl 'https://search.upgini.com/online/api/http_inference_trigger?search_id={search_id}' \\
2054
2067
  -H 'Authorization: {self.api_key}' \\
2055
2068
  -H 'Content-Type: application/json' \\
2056
- -d '{{"search_keys": {keys}{features_section}, "only_online_sources": {str(only_online_sources).lower()}}}'"""
2069
+ -d '{{"search_keys": {keys}{features_section}}}'"""
2057
2070
  return api_example
2058
2071
 
2059
2072
  def _get_copy_of_runtime_parameters(self) -> RuntimeParameters:
@@ -2097,15 +2110,13 @@ class FeaturesEnricher(TransformerMixin):
2097
2110
  return None, {c: c for c in X.columns}, []
2098
2111
 
2099
2112
  features_meta = self._search_task.get_all_features_metadata_v2()
2100
- online_api_features = [fm.name for fm in features_meta if fm.from_online_api and fm.shap_value > 0]
2113
+ online_api_features = [fm.name for fm in features_meta if fm.from_online_api]
2101
2114
  if len(online_api_features) > 0:
2102
2115
  self.logger.warning(
2103
2116
  f"There are important features for transform, that generated by online API: {online_api_features}"
2104
2117
  )
2105
- msg = self.bundle.get("online_api_features_transform").format(online_api_features)
2106
- self.logger.warning(msg)
2107
- print(msg)
2108
- print(self.get_transactional_transform_api(only_online_sources=True))
2118
+ # TODO
2119
+ raise Exception("There are features selected that are paid. Contact support (sales@upgini.com)")
2109
2120
 
2110
2121
  if not metrics_calculation:
2111
2122
  transform_usage = self.rest_client.get_current_transform_usage(trace_id)
@@ -2137,6 +2148,9 @@ class FeaturesEnricher(TransformerMixin):
2137
2148
  validated_X = validated_X.drop(columns=columns_to_drop)
2138
2149
 
2139
2150
  search_keys = self.search_keys.copy()
2151
+ if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2152
+ self.search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns})
2153
+
2140
2154
  search_keys = self.__prepare_search_keys(
2141
2155
  validated_X, search_keys, is_demo_dataset, is_transform=True, silent_mode=silent_mode
2142
2156
  )
@@ -2155,7 +2169,7 @@ class FeaturesEnricher(TransformerMixin):
2155
2169
  df = self.__add_country_code(df, search_keys)
2156
2170
 
2157
2171
  generated_features = []
2158
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2172
+ date_column = self._get_date_column(search_keys)
2159
2173
  if date_column is not None:
2160
2174
  converter = DateTimeSearchKeyConverter(date_column, self.date_format, self.logger, bundle=self.bundle)
2161
2175
  df = converter.convert(df, keep_time=True)
@@ -2163,7 +2177,7 @@ class FeaturesEnricher(TransformerMixin):
2163
2177
  generated_features.extend(converter.generated_features)
2164
2178
  else:
2165
2179
  self.logger.info("Input dataset hasn't date column")
2166
- if self.add_date_if_missing:
2180
+ if self.__should_add_date_column():
2167
2181
  df = self._add_current_date_as_key(df, search_keys, self.logger, self.bundle)
2168
2182
 
2169
2183
  email_columns = SearchKey.find_all_keys(search_keys, SearchKey.EMAIL)
@@ -2294,6 +2308,8 @@ class FeaturesEnricher(TransformerMixin):
2294
2308
  meaning_types=meaning_types,
2295
2309
  search_keys=combined_search_keys,
2296
2310
  unnest_search_keys=unnest_search_keys,
2311
+ id_columns=self.__get_renamed_id_columns(columns_renaming),
2312
+ date_column=self._get_date_column(search_keys),
2297
2313
  date_format=self.date_format,
2298
2314
  rest_client=self.rest_client,
2299
2315
  logger=self.logger,
@@ -2446,7 +2462,14 @@ class FeaturesEnricher(TransformerMixin):
2446
2462
  # Multiple search keys allowed only for PHONE, IP, POSTAL_CODE, EMAIL, HEM
2447
2463
  multi_keys = [key for key, count in Counter(key_types).items() if count > 1]
2448
2464
  for multi_key in multi_keys:
2449
- if multi_key not in [SearchKey.PHONE, SearchKey.IP, SearchKey.POSTAL_CODE, SearchKey.EMAIL, SearchKey.HEM]:
2465
+ if multi_key not in [
2466
+ SearchKey.PHONE,
2467
+ SearchKey.IP,
2468
+ SearchKey.POSTAL_CODE,
2469
+ SearchKey.EMAIL,
2470
+ SearchKey.HEM,
2471
+ SearchKey.CUSTOM_KEY,
2472
+ ]:
2450
2473
  msg = self.bundle.get("unsupported_multi_key").format(multi_key)
2451
2474
  self.logger.warning(msg)
2452
2475
  raise ValidationError(msg)
@@ -2610,7 +2633,7 @@ class FeaturesEnricher(TransformerMixin):
2610
2633
  self.fit_generated_features.extend(converter.generated_features)
2611
2634
  else:
2612
2635
  self.logger.info("Input dataset hasn't date column")
2613
- if self.add_date_if_missing:
2636
+ if self.__should_add_date_column():
2614
2637
  df = self._add_current_date_as_key(df, self.fit_search_keys, self.logger, self.bundle)
2615
2638
 
2616
2639
  email_columns = SearchKey.find_all_keys(self.fit_search_keys, SearchKey.EMAIL)
@@ -2643,6 +2666,13 @@ class FeaturesEnricher(TransformerMixin):
2643
2666
 
2644
2667
  self.__adjust_cv(df)
2645
2668
 
2669
+ if self.id_columns is not None and self.cv is not None and self.cv.is_time_series():
2670
+ id_columns = self.__get_renamed_id_columns()
2671
+ if id_columns:
2672
+ self.fit_search_keys.update({col: SearchKey.CUSTOM_KEY for col in id_columns})
2673
+ self.search_keys.update({col: SearchKey.CUSTOM_KEY for col in self.id_columns})
2674
+ self.runtime_parameters.properties["id_columns"] = ",".join(id_columns)
2675
+
2646
2676
  df, fintech_warnings = remove_fintech_duplicates(
2647
2677
  df, self.fit_search_keys, date_format=self.date_format, logger=self.logger, bundle=self.bundle
2648
2678
  )
@@ -2672,7 +2702,6 @@ class FeaturesEnricher(TransformerMixin):
2672
2702
  self.fit_search_keys,
2673
2703
  self.fit_columns_renaming,
2674
2704
  list(unnest_search_keys.keys()),
2675
- self.bundle,
2676
2705
  self.logger,
2677
2706
  )
2678
2707
  df = converter.convert(df)
@@ -2765,6 +2794,9 @@ class FeaturesEnricher(TransformerMixin):
2765
2794
  search_keys=combined_search_keys,
2766
2795
  unnest_search_keys=unnest_search_keys,
2767
2796
  model_task_type=self.model_task_type,
2797
+ cv_type=self.cv,
2798
+ id_columns=self.__get_renamed_id_columns(),
2799
+ date_column=self._get_date_column(self.fit_search_keys),
2768
2800
  date_format=self.date_format,
2769
2801
  random_state=self.random_state,
2770
2802
  rest_client=self.rest_client,
@@ -2921,6 +2953,14 @@ class FeaturesEnricher(TransformerMixin):
2921
2953
  if not self.warning_counter.has_warnings():
2922
2954
  self.__display_support_link(self.bundle.get("all_ok_community_invite"))
2923
2955
 
2956
+ def __should_add_date_column(self):
2957
+ return self.add_date_if_missing or (self.cv is not None and self.cv.is_time_series())
2958
+
2959
+ def __get_renamed_id_columns(self, renaming: Optional[Dict[str, str]] = None):
2960
+ renaming = renaming or self.fit_columns_renaming
2961
+ reverse_renaming = {v: k for k, v in renaming.items()}
2962
+ return None if self.id_columns is None else [reverse_renaming.get(c) or c for c in self.id_columns]
2963
+
2924
2964
  def __adjust_cv(self, df: pd.DataFrame):
2925
2965
  date_column = SearchKey.find_key(self.fit_search_keys, [SearchKey.DATE, SearchKey.DATETIME])
2926
2966
  # Check Multivariate time series
@@ -3166,7 +3206,7 @@ class FeaturesEnricher(TransformerMixin):
3166
3206
  if DateTimeSearchKeyConverter.DATETIME_COL in X.columns:
3167
3207
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
3168
3208
  else:
3169
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
3209
+ date_column = FeaturesEnricher._get_date_column(search_keys)
3170
3210
  sort_columns = [date_column] if date_column is not None else []
3171
3211
 
3172
3212
  # Xy = pd.concat([X, y], axis=1)
@@ -3229,7 +3269,6 @@ class FeaturesEnricher(TransformerMixin):
3229
3269
  f"Generate features: {self.generate_features}\n"
3230
3270
  f"Round embeddings: {self.round_embeddings}\n"
3231
3271
  f"Detect missing search keys: {self.detect_missing_search_keys}\n"
3232
- f"Exclude columns: {self.exclude_columns}\n"
3233
3272
  f"Exclude features sources: {exclude_features_sources}\n"
3234
3273
  f"Calculate metrics: {calculate_metrics}\n"
3235
3274
  f"Scoring: {scoring}\n"
@@ -3237,15 +3276,6 @@ class FeaturesEnricher(TransformerMixin):
3237
3276
  f"Remove target outliers: {remove_outliers_calc_metrics}\n"
3238
3277
  f"Exclude columns: {self.exclude_columns}\n"
3239
3278
  f"Search id: {self.search_id}\n"
3240
- f"Custom loss: {self.loss}\n"
3241
- f"Logs enabled: {self.logs_enabled}\n"
3242
- f"Raise validation error: {self.raise_validation_error}\n"
3243
- f"Baseline score column: {self.baseline_score_column}\n"
3244
- f"Client ip: {self.client_ip}\n"
3245
- f"Client visitorId: {self.client_visitorid}\n"
3246
- f"Add date if missing: {self.add_date_if_missing}\n"
3247
- f"Select features: {self.select_features}\n"
3248
- f"Disable force downsampling: {self.disable_force_downsampling}\n"
3249
3279
  )
3250
3280
 
3251
3281
  def sample(df):
@@ -3368,6 +3398,10 @@ class FeaturesEnricher(TransformerMixin):
3368
3398
  if t == SearchKey.POSTAL_CODE:
3369
3399
  return col
3370
3400
 
3401
+ @staticmethod
3402
+ def _get_date_column(search_keys: Dict[str, SearchKey]) -> Optional[str]:
3403
+ return SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
3404
+
3371
3405
  def _explode_multiple_search_keys(
3372
3406
  self, df: pd.DataFrame, search_keys: Dict[str, SearchKey], columns_renaming: Dict[str, str]
3373
3407
  ) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
@@ -3376,7 +3410,9 @@ class FeaturesEnricher(TransformerMixin):
3376
3410
  for key_name, key_type in search_keys.items():
3377
3411
  search_key_names_by_type[key_type] = search_key_names_by_type.get(key_type, []) + [key_name]
3378
3412
  search_key_names_by_type = {
3379
- key_type: key_names for key_type, key_names in search_key_names_by_type.items() if len(key_names) > 1
3413
+ key_type: key_names
3414
+ for key_type, key_names in search_key_names_by_type.items()
3415
+ if len(key_names) > 1 and key_type != SearchKey.CUSTOM_KEY
3380
3416
  }
3381
3417
  if len(search_key_names_by_type) == 0:
3382
3418
  return df, {}
@@ -3429,9 +3465,9 @@ class FeaturesEnricher(TransformerMixin):
3429
3465
  ]
3430
3466
  if DateTimeSearchKeyConverter.DATETIME_COL in df.columns:
3431
3467
  date_column = DateTimeSearchKeyConverter.DATETIME_COL
3432
- sort_exclude_columns.append(SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME]))
3468
+ sort_exclude_columns.append(self._get_date_column(search_keys))
3433
3469
  else:
3434
- date_column = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
3470
+ date_column = self._get_date_column(search_keys)
3435
3471
  sort_columns = [date_column] if date_column is not None else []
3436
3472
 
3437
3473
  sorted_other_keys = sorted(search_keys, key=lambda x: str(search_keys.get(x)))
@@ -3867,11 +3903,6 @@ class FeaturesEnricher(TransformerMixin):
3867
3903
  self.logger.warning(msg + f" Provided search keys: {search_keys}")
3868
3904
  raise ValidationError(msg)
3869
3905
 
3870
- if SearchKey.CUSTOM_KEY in valid_search_keys.values():
3871
- custom_keys = [column for column, key in valid_search_keys.items() if key == SearchKey.CUSTOM_KEY]
3872
- for key in custom_keys:
3873
- del valid_search_keys[key]
3874
-
3875
3906
  if (
3876
3907
  len(valid_search_keys.values()) == 1
3877
3908
  and self.country_code is None
upgini/metadata.py CHANGED
@@ -350,3 +350,6 @@ class CVType(Enum):
350
350
  time_series = "time_series"
351
351
  blocked_time_series = "blocked_time_series"
352
352
  not_set = "not_set"
353
+
354
+ def is_time_series(self) -> bool:
355
+ return self in [CVType.time_series, CVType.blocked_time_series]
@@ -216,7 +216,6 @@ imbalanced_target=\nTarget is imbalanced and will be undersampled. Frequency of
216
216
  loss_selection_info=Using loss `{}` for feature selection
217
217
  loss_calc_metrics_info=Using loss `{}` for metrics calculation with default estimator
218
218
  forced_balance_undersample=For quick data retrieval, your dataset has been sampled. To use data search without data sampling please contact support (sales@upgini.com)
219
- online_api_features_transform=Please note that some of the selected features {} are provided through a slow enrichment interface and are not available via transformation. However, they can be accessed via the API:
220
219
 
221
220
  # Validation table
222
221
  validation_column_name_header=Column name
@@ -1,15 +1,18 @@
1
+ import itertools
1
2
  import logging
2
- from typing import Callable, Optional, Union
3
+ from typing import Callable, List, Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  from pandas.api.types import is_numeric_dtype, is_bool_dtype
7
8
 
8
9
  from upgini.errors import ValidationError
9
- from upgini.metadata import SYSTEM_RECORD_ID, ModelTaskType
10
+ from upgini.metadata import SYSTEM_RECORD_ID, CVType, ModelTaskType
10
11
  from upgini.resource_bundle import ResourceBundle, bundle, get_custom_bundle
11
12
  from upgini.sampler.random_under_sampler import RandomUnderSampler
12
13
 
14
+ TS_MIN_DIFFERENT_IDS_RATIO = 0.2
15
+
13
16
 
14
17
  def correct_string_target(y: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
15
18
  if isinstance(y, pd.Series):
@@ -201,7 +204,10 @@ def balance_undersample(
201
204
  def balance_undersample_forced(
202
205
  df: pd.DataFrame,
203
206
  target_column: str,
207
+ id_columns: List[str],
208
+ date_column: str,
204
209
  task_type: ModelTaskType,
210
+ cv_type: CVType | None,
205
211
  random_state: int,
206
212
  sample_size: int = 7000,
207
213
  logger: Optional[logging.Logger] = None,
@@ -233,7 +239,17 @@ def balance_undersample_forced(
233
239
 
234
240
  resampled_data = df
235
241
  df = df.copy().sort_values(by=SYSTEM_RECORD_ID)
236
- if task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION, ModelTaskType.TIMESERIES]:
242
+ if cv_type is not None and cv_type.is_time_series():
243
+ logger.warning(f"Sampling time series dataset from {len(df)} to {sample_size}")
244
+ resampled_data = balance_undersample_time_series(
245
+ df,
246
+ id_columns=id_columns,
247
+ date_column=date_column,
248
+ sample_size=sample_size,
249
+ random_state=random_state,
250
+ logger=logger,
251
+ )
252
+ elif task_type in [ModelTaskType.MULTICLASS, ModelTaskType.REGRESSION]:
237
253
  logger.warning(f"Sampling dataset from {len(df)} to {sample_size}")
238
254
  resampled_data = df.sample(n=sample_size, random_state=random_state)
239
255
  else:
@@ -264,6 +280,65 @@ def balance_undersample_forced(
264
280
  return resampled_data
265
281
 
266
282
 
283
+ def balance_undersample_time_series(
284
+ df: pd.DataFrame,
285
+ id_columns: List[str],
286
+ date_column: str,
287
+ sample_size: int,
288
+ random_state: int = 42,
289
+ min_different_ids_ratio: float = TS_MIN_DIFFERENT_IDS_RATIO,
290
+ prefer_recent_dates: bool = True,
291
+ logger: Optional[logging.Logger] = None,
292
+ ):
293
+ def ensure_tuple(x):
294
+ return tuple([x]) if not isinstance(x, tuple) else x
295
+
296
+ random_state = np.random.RandomState(random_state)
297
+
298
+ if not id_columns:
299
+ id_columns = [date_column]
300
+ ids_sort = df.groupby(id_columns)[date_column].aggregate(["max", "count"]).T.to_dict()
301
+ ids_sort = {
302
+ ensure_tuple(k): (
303
+ (v["max"], v["count"], random_state.rand()) if prefer_recent_dates else (v["count"], random_state.rand())
304
+ )
305
+ for k, v in ids_sort.items()
306
+ }
307
+ id_counts = df[id_columns].value_counts()
308
+ id_counts.index = [ensure_tuple(i) for i in id_counts.index]
309
+ id_counts = id_counts.sort_index(key=lambda x: [ids_sort[y] for y in x], ascending=False).cumsum()
310
+ id_counts = id_counts[id_counts <= sample_size]
311
+ min_different_ids = max(int(len(df[id_columns].drop_duplicates()) * min_different_ids_ratio), 1)
312
+
313
+ def id_mask(sample_index: pd.Index) -> pd.Index:
314
+ if isinstance(sample_index, pd.MultiIndex):
315
+ return pd.MultiIndex.from_frame(df[id_columns]).isin(sample_index)
316
+ else:
317
+ return df[id_columns[0]].isin(sample_index)
318
+
319
+ if len(id_counts) < min_different_ids:
320
+ if logger is not None:
321
+ logger.info(
322
+ f"Different ids count {len(id_counts)} for sample size {sample_size} is less than min different ids {min_different_ids}, sampling time window"
323
+ )
324
+ date_counts = df.groupby(id_columns)[date_column].nunique().sort_values(ascending=False)
325
+ ids_to_sample = date_counts.index[:min_different_ids] if len(id_counts) > 0 else date_counts.index
326
+ mask = id_mask(ids_to_sample)
327
+ df = df[mask]
328
+ sample_date_counts = df[date_column].value_counts().sort_index(ascending=False).cumsum()
329
+ sample_date_counts = sample_date_counts[sample_date_counts <= sample_size]
330
+ df = df[df[date_column].isin(sample_date_counts.index)]
331
+ else:
332
+ if len(id_columns) > 1:
333
+ id_counts.index = pd.MultiIndex.from_tuples(id_counts.index)
334
+ else:
335
+ id_counts.index = [i[0] for i in id_counts.index]
336
+ mask = id_mask(id_counts.index)
337
+ df = df[mask]
338
+
339
+ return df
340
+
341
+
267
342
  def calculate_psi(expected: pd.Series, actual: pd.Series) -> Union[float, Exception]:
268
343
  try:
269
344
  df = pd.concat([expected, actual])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: upgini
3
- Version: 1.2.39a1
3
+ Version: 1.2.39a3769.dev2
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
6
6
  Project-URL: Homepage, https://upgini.com/
@@ -1,12 +1,12 @@
1
- upgini/__about__.py,sha256=_wWeNiK5_JiwKIlVeEktsOM2zu0cB2l3qXursaGZU9U,25
1
+ upgini/__about__.py,sha256=2ilnzZVy_WdaVJ8AG6XQ1dEDOf4Mo3p6WiWCjIzOxF8,33
2
2
  upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
3
3
  upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
4
- upgini/dataset.py,sha256=rUBE7_G7CLaaHAviFEyVPqjVSsX1DaLmi1dGFQR-eEo,32279
4
+ upgini/dataset.py,sha256=d9VlOs9hTf6eL8TX_9bO400HQj3y_jVGthABvQJqONs,33350
5
5
  upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
6
- upgini/features_enricher.py,sha256=h17dmuAucpbkZs6E2T59-R9m-p8gW9bkXLY7NzvObKA,196002
6
+ upgini/features_enricher.py,sha256=HY7FBC-ioH5hNg2NVMLMV_YAqu4rThgrJoK0JT8cdhU,196975
7
7
  upgini/http.py,sha256=plZGTGoi1h2edd8Cnjt4eYB8t4NbBGnZz7DtPTByiNc,42885
8
8
  upgini/lazy_import.py,sha256=74gQ8JuA48BGRLxAo7lNHNKY2D2emMxrUxKGdxVGhuY,1012
9
- upgini/metadata.py,sha256=sB5uU-fdz_dA6g-PO6A8FzwIfDbkcFOewcpNs2xZzoY,11943
9
+ upgini/metadata.py,sha256=-ibqiNjD7dTagqg53FoEJNEqvAYbwgfyn9PGTRQ_YKU,12054
10
10
  upgini/metrics.py,sha256=hr7UwLphbZ_FEglLuO2lzr_pFgxOJ4c3WBeg7H-fNqY,35521
11
11
  upgini/search_task.py,sha256=qxUxAD-bed-FpZYmTB_4orW7YJsW_O6a1TcgnZIRFr4,17307
12
12
  upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
@@ -30,7 +30,7 @@ upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
30
30
  upgini/normalizer/normalize_utils.py,sha256=Ft2MwSgVoBilXAORAOYAuwPD79GOLfwn4qQE3IUFzzg,7218
31
31
  upgini/resource_bundle/__init__.py,sha256=S5F2G47pnJd2LDpmFsjDqEwiKkP8Hm-hcseDbMka6Ko,8345
32
32
  upgini/resource_bundle/exceptions.py,sha256=5fRvx0_vWdE1-7HcSgF0tckB4A9AKyf5RiinZkInTsI,621
33
- upgini/resource_bundle/strings.properties,sha256=uQWmbcd9TJh-xE0QpmHpHYKw-20utvXeHwFA-U_iTLw,27302
33
+ upgini/resource_bundle/strings.properties,sha256=TiYWmFnuhOq0R3aVg2nbA3F5AWLgjrgh68Yj6MhG-x8,27088
34
34
  upgini/resource_bundle/strings_widget.properties,sha256=gOdqvZWntP2LCza_tyVk1_yRYcG4c04K9sQOAVhF_gw,1577
35
35
  upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  upgini/sampler/base.py,sha256=7GpjYqjOp58vYcJLiX__1R5wjUlyQbxvHJ2klFnup_M,6389
@@ -56,10 +56,10 @@ upgini/utils/phone_utils.py,sha256=IrbztLuOJBiePqqxllfABWfYlfAjYevPhXKipl95wUI,1
56
56
  upgini/utils/postal_code_utils.py,sha256=5M0sUqH2DAr33kARWCTXR-ACyzWbjDq_-0mmEml6ZcU,1716
57
57
  upgini/utils/progress_bar.py,sha256=N-Sfdah2Hg8lXP_fV9EfUTXz_PyRt4lo9fAHoUDOoLc,1550
58
58
  upgini/utils/sklearn_ext.py,sha256=13jQS_k7v0aUtudXV6nGUEWjttPQzAW9AFYL5wgEz9k,44511
59
- upgini/utils/target_utils.py,sha256=Ed5IXkPjV9AfAZQAwCYksAmKaPGQliplvDYS_yeWdfk,11330
59
+ upgini/utils/target_utils.py,sha256=RlpKGss9kMibVSlA8iZuO_qxmyeplqzn7X8g6hiGGGs,14341
60
60
  upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
61
61
  upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
62
- upgini-1.2.39a1.dist-info/METADATA,sha256=qvNcejSCxKiITZbFqsGiaewkRsolxpy6OiePNwzqf90,48596
63
- upgini-1.2.39a1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
64
- upgini-1.2.39a1.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
65
- upgini-1.2.39a1.dist-info/RECORD,,
62
+ upgini-1.2.39a3769.dev2.dist-info/METADATA,sha256=Vh1Rr3q2Osl1_Ee7uetOp8LROY2nVUb_kvZwyxEDcHc,48604
63
+ upgini-1.2.39a3769.dev2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
64
+ upgini-1.2.39a3769.dev2.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
65
+ upgini-1.2.39a3769.dev2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.24.2
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any