upgini 1.1.240__py3-none-any.whl → 1.1.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
upgini/dataset.py CHANGED
@@ -20,7 +20,7 @@ from pandas.api.types import (
20
20
  from pandas.core.dtypes.common import is_period_dtype
21
21
 
22
22
  from upgini.errors import ValidationError
23
- from upgini.http import ProgressStage, SearchProgress, get_rest_client
23
+ from upgini.http import ProgressStage, SearchProgress, _RestClient
24
24
  from upgini.metadata import (
25
25
  EVAL_SET_INDEX,
26
26
  SYSTEM_COLUMNS,
@@ -78,8 +78,7 @@ class Dataset: # (pd.DataFrame):
78
78
  search_keys: Optional[List[Tuple[str, ...]]] = None,
79
79
  model_task_type: Optional[ModelTaskType] = None,
80
80
  random_state: Optional[int] = None,
81
- endpoint: Optional[str] = None,
82
- api_key: Optional[str] = None,
81
+ rest_client: Optional[_RestClient] = None,
83
82
  logger: Optional[logging.Logger] = None,
84
83
  warning_counter: Optional[WarningCounter] = None,
85
84
  **kwargs,
@@ -114,8 +113,7 @@ class Dataset: # (pd.DataFrame):
114
113
  self.hierarchical_subgroup_keys = []
115
114
  self.file_upload_id: Optional[str] = None
116
115
  self.etalon_def: Optional[Dict[str, str]] = None
117
- self.endpoint = endpoint
118
- self.api_key = api_key
116
+ self.rest_client = rest_client
119
117
  self.random_state = random_state
120
118
  self.columns_renaming: Dict[str, str] = {}
121
119
  self.imbalanced: bool = False
@@ -983,10 +981,10 @@ class Dataset: # (pd.DataFrame):
983
981
  runtime_parameters=runtime_parameters,
984
982
  )
985
983
 
986
- if self.file_upload_id is not None and get_rest_client(self.endpoint, self.api_key).check_uploaded_file_v2(
984
+ if self.file_upload_id is not None and self.rest_client.check_uploaded_file_v2(
987
985
  trace_id, self.file_upload_id, file_metadata
988
986
  ):
989
- search_task_response = get_rest_client(self.endpoint, self.api_key).initial_search_without_upload_v2(
987
+ search_task_response = self.rest_client.initial_search_without_upload_v2(
990
988
  trace_id, self.file_upload_id, file_metadata, file_metrics, search_customization
991
989
  )
992
990
  else:
@@ -999,7 +997,7 @@ class Dataset: # (pd.DataFrame):
999
997
  progress_bar.progress = search_progress.to_progress_bar()
1000
998
  if progress_callback is not None:
1001
999
  progress_callback(search_progress)
1002
- search_task_response = get_rest_client(self.endpoint, self.api_key).initial_search_v2(
1000
+ search_task_response = self.rest_client.initial_search_v2(
1003
1001
  trace_id, parquet_file_path, file_metadata, file_metrics, search_customization
1004
1002
  )
1005
1003
  # if progress_bar is not None:
@@ -1015,8 +1013,7 @@ class Dataset: # (pd.DataFrame):
1015
1013
  extract_features,
1016
1014
  accurate_model,
1017
1015
  task_type=self.task_type,
1018
- endpoint=self.endpoint,
1019
- api_key=self.api_key,
1016
+ rest_client=self.rest_client,
1020
1017
  logger=self.logger,
1021
1018
  )
1022
1019
 
@@ -1053,10 +1050,10 @@ class Dataset: # (pd.DataFrame):
1053
1050
  progress_bar.progress = search_progress.to_progress_bar()
1054
1051
  if progress_callback is not None:
1055
1052
  progress_callback(search_progress)
1056
- if self.file_upload_id is not None and get_rest_client(self.endpoint, self.api_key).check_uploaded_file_v2(
1053
+ if self.file_upload_id is not None and self.rest_client.check_uploaded_file_v2(
1057
1054
  trace_id, self.file_upload_id, file_metadata
1058
1055
  ):
1059
- search_task_response = get_rest_client(self.endpoint, self.api_key).validation_search_without_upload_v2(
1056
+ search_task_response = self.rest_client.validation_search_without_upload_v2(
1060
1057
  trace_id, self.file_upload_id, initial_search_task_id, file_metadata, file_metrics, search_customization
1061
1058
  )
1062
1059
  else:
@@ -1065,7 +1062,7 @@ class Dataset: # (pd.DataFrame):
1065
1062
  # To avoid rate limit
1066
1063
  time.sleep(1)
1067
1064
 
1068
- search_task_response = get_rest_client(self.endpoint, self.api_key).validation_search_v2(
1065
+ search_task_response = self.rest_client.validation_search_v2(
1069
1066
  trace_id,
1070
1067
  parquet_file_path,
1071
1068
  initial_search_task_id,
@@ -1085,8 +1082,7 @@ class Dataset: # (pd.DataFrame):
1085
1082
  return_scores,
1086
1083
  extract_features,
1087
1084
  initial_search_task_id=initial_search_task_id,
1088
- endpoint=self.endpoint,
1089
- api_key=self.api_key,
1085
+ rest_client=self.rest_client,
1090
1086
  logger=self.logger,
1091
1087
  )
1092
1088
 
@@ -233,7 +233,7 @@ class FeaturesEnricher(TransformerMixin):
233
233
  self.feature_importances_ = []
234
234
  self.search_id = search_id
235
235
  if search_id:
236
- search_task = SearchTask(search_id, endpoint=self.endpoint, api_key=self._api_key, logger=self.logger)
236
+ search_task = SearchTask(search_id, rest_client=self.rest_client, logger=self.logger)
237
237
 
238
238
  print(bundle.get("search_by_task_id_start"))
239
239
  trace_id = str(uuid.uuid4())
@@ -297,7 +297,8 @@ class FeaturesEnricher(TransformerMixin):
297
297
  def _set_api_key(self, api_key: str):
298
298
  self._api_key = api_key
299
299
  if self.logs_enabled:
300
- self.logger = LoggerFactory().get_logger(self.endpoint, self._api_key, self.client_ip, self.client_visitorid)
300
+ self.logger = LoggerFactory().get_logger(self.endpoint, self._api_key,
301
+ self.client_ip, self.client_visitorid)
301
302
 
302
303
  api_key = property(_get_api_key, _set_api_key)
303
304
 
@@ -855,9 +856,17 @@ class FeaturesEnricher(TransformerMixin):
855
856
 
856
857
  if X is not None and y is None:
857
858
  raise ValidationError("X passed without y")
859
+
860
+ effective_X = X if X is not None else self.X
861
+ effective_eval_set = eval_set if eval_set is not None else self.eval_set
862
+
863
+ effective_X = X if X is not None else self.X
864
+ effective_eval_set = eval_set if eval_set is not None else self.eval_set
858
865
 
859
866
  validate_scoring_argument(scoring)
860
867
 
868
+ self._validate_baseline_score(effective_X, effective_eval_set)
869
+
861
870
  if self._has_paid_features(exclude_features_sources):
862
871
  msg = bundle.get("metrics_with_paid_features")
863
872
  self.logger.warning(msg)
@@ -1000,15 +1009,17 @@ class FeaturesEnricher(TransformerMixin):
1000
1009
  enriched_metric = None
1001
1010
  uplift = None
1002
1011
 
1012
+ effective_X = X if X is not None else self.X
1013
+ effective_y = y if y is not None else self.y
1003
1014
  train_metrics = {
1004
1015
  bundle.get("quality_metrics_segment_header"): bundle.get("quality_metrics_train_segment"),
1005
- bundle.get("quality_metrics_rows_header"): _num_samples(self.X),
1016
+ bundle.get("quality_metrics_rows_header"): _num_samples(effective_X),
1006
1017
  # bundle.get("quality_metrics_match_rate_header"): self._search_task.initial_max_hit_rate_v2(),
1007
1018
  }
1008
1019
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1009
1020
  y_sorted
1010
1021
  ):
1011
- train_metrics[bundle.get("quality_metrics_mean_target_header")] = round(self.y.mean(), 4)
1022
+ train_metrics[bundle.get("quality_metrics_mean_target_header")] = round(np.mean(effective_y), 4)
1012
1023
  if etalon_metric is not None:
1013
1024
  train_metrics[bundle.get("quality_metrics_baseline_header").format(metric)] = etalon_metric
1014
1025
  if enriched_metric is not None:
@@ -1064,18 +1075,19 @@ class FeaturesEnricher(TransformerMixin):
1064
1075
  else:
1065
1076
  eval_uplift = None
1066
1077
 
1078
+ effective_eval_set = eval_set if eval_set is not None else self.eval_set
1067
1079
  eval_metrics = {
1068
1080
  bundle.get("quality_metrics_segment_header"): bundle.get(
1069
1081
  "quality_metrics_eval_segment"
1070
1082
  ).format(idx + 1),
1071
- bundle.get("quality_metrics_rows_header"): _num_samples(self.eval_set[idx][0]), #_num_samples(eval_X_sorted),
1083
+ bundle.get("quality_metrics_rows_header"): _num_samples(effective_eval_set[idx][0]),
1072
1084
  # bundle.get("quality_metrics_match_rate_header"): eval_hit_rate,
1073
1085
  }
1074
1086
  if model_task_type in [ModelTaskType.BINARY, ModelTaskType.REGRESSION] and is_numeric_dtype(
1075
1087
  eval_y_sorted
1076
1088
  ):
1077
1089
  eval_metrics[bundle.get("quality_metrics_mean_target_header")] = round(
1078
- self.eval_set[idx][1].mean(), 4
1090
+ np.mean(effective_eval_set[idx][1]), 4
1079
1091
  )
1080
1092
  if etalon_eval_metric is not None:
1081
1093
  eval_metrics[
@@ -1091,6 +1103,9 @@ class FeaturesEnricher(TransformerMixin):
1091
1103
  metrics.append(eval_metrics)
1092
1104
 
1093
1105
  metrics_df = pd.DataFrame(metrics)
1106
+ mean_target_hdr = bundle.get("quality_metrics_mean_target_header")
1107
+ if mean_target_hdr in metrics_df.columns:
1108
+ metrics_df[mean_target_hdr] = metrics_df[mean_target_hdr].astype("float64")
1094
1109
  do_without_pandas_limits(
1095
1110
  lambda: self.logger.info(f"Metrics calculation finished successfully:\n{metrics_df}")
1096
1111
  )
@@ -1802,10 +1817,9 @@ class FeaturesEnricher(TransformerMixin):
1802
1817
 
1803
1818
  dataset = Dataset(
1804
1819
  "sample_" + str(uuid.uuid4()),
1805
- df=df_without_features, # type: ignore
1806
- endpoint=self.endpoint, # type: ignore
1807
- api_key=self.api_key, # type: ignore
1808
- date_format=self.date_format, # type: ignore
1820
+ df=df_without_features,
1821
+ date_format=self.date_format,
1822
+ rest_client=self.rest_client,
1809
1823
  logger=self.logger,
1810
1824
  )
1811
1825
  dataset.meaning_types = meaning_types
@@ -2135,11 +2149,10 @@ class FeaturesEnricher(TransformerMixin):
2135
2149
  dataset = Dataset(
2136
2150
  "tds_" + str(uuid.uuid4()),
2137
2151
  df=df, # type: ignore
2138
- model_task_type=model_task_type, # type: ignore
2139
- endpoint=self.endpoint, # type: ignore
2140
- api_key=self.api_key, # type: ignore
2141
- date_format=self.date_format, # type: ignore
2142
- random_state=self.random_state, # type: ignore
2152
+ model_task_type=model_task_type,
2153
+ date_format=self.date_format,
2154
+ random_state=self.random_state,
2155
+ rest_client=self.rest_client,
2143
2156
  logger=self.logger,
2144
2157
  )
2145
2158
  dataset.meaning_types = meaning_types
@@ -2468,6 +2481,36 @@ class FeaturesEnricher(TransformerMixin):
2468
2481
  raise ValidationError(bundle.get("y_is_constant_eval_set"))
2469
2482
 
2470
2483
  return validated_eval_X, validated_eval_y
2484
+
2485
+ def _validate_baseline_score(self, X: pd.DataFrame, eval_set: Optional[List[Tuple]]):
2486
+ if self.baseline_score_column is not None:
2487
+ if self.baseline_score_column not in X.columns:
2488
+ raise ValidationError(bundle.get("baseline_score_column_not_exists").format(self.baseline_score_column))
2489
+ if X[self.baseline_score_column].isna().any():
2490
+ raise ValidationError(bundle.get("baseline_score_column_has_na"))
2491
+ if eval_set is not None:
2492
+ if isinstance(eval_set, tuple):
2493
+ eval_set = [eval_set]
2494
+ for eval in eval_set:
2495
+ if self.baseline_score_column not in eval[0].columns:
2496
+ raise ValidationError(bundle.get("baseline_score_column_not_exists"))
2497
+ if eval[0][self.baseline_score_column].isna().any():
2498
+ raise ValidationError(bundle.get("baseline_score_column_has_na"))
2499
+
2500
+ def _validate_baseline_score(self, X: pd.DataFrame, eval_set: Optional[List[Tuple]]):
2501
+ if self.baseline_score_column is not None:
2502
+ if self.baseline_score_column not in X.columns:
2503
+ raise ValidationError(bundle.get("baseline_score_column_not_exists").format(self.baseline_score_column))
2504
+ if X[self.baseline_score_column].isna().any():
2505
+ raise ValidationError(bundle.get("baseline_score_column_has_na"))
2506
+ if eval_set is not None:
2507
+ if isinstance(eval_set, tuple):
2508
+ eval_set = [eval_set]
2509
+ for eval in eval_set:
2510
+ if self.baseline_score_column not in eval[0].columns:
2511
+ raise ValidationError(bundle.get("baseline_score_column_not_exists"))
2512
+ if eval[0][self.baseline_score_column].isna().any():
2513
+ raise ValidationError(bundle.get("baseline_score_column_has_na"))
2471
2514
 
2472
2515
  @staticmethod
2473
2516
  def _sample_X_and_y(X: pd.DataFrame, y: pd.Series, enriched_X: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series]:
@@ -3396,6 +3439,8 @@ class FeaturesEnricher(TransformerMixin):
3396
3439
 
3397
3440
  def _num_samples(x):
3398
3441
  """Return number of samples in array-like x."""
3442
+ if x is None:
3443
+ return 0
3399
3444
  message = "Expected sequence or array-like, got %s" % type(x)
3400
3445
  if hasattr(x, "fit") and callable(x.fit):
3401
3446
  # Don't get num_samples from an ensembles length!
upgini/http.py CHANGED
@@ -301,13 +301,14 @@ class _RestClient:
301
301
  USER_AGENT_HEADER_VALUE = "pyupgini/" + __version__
302
302
  SEARCH_KEYS_HEADER_NAME = "Search-Keys"
303
303
 
304
- def __init__(self, service_endpoint, refresh_token, silent_mode=False, client_ip=None, client_visitorid=None):
304
+ def __init__(self, service_endpoint, refresh_token, client_ip=None, client_visitorid=None):
305
305
  # debug_requests_on()
306
306
  self._service_endpoint = service_endpoint
307
307
  self._refresh_token = refresh_token
308
- self.silent_mode = silent_mode
308
+ # self.silent_mode = silent_mode
309
309
  self.client_ip = client_ip
310
310
  self.client_visitorid = client_visitorid
311
+ print(f"Created RestClient with {client_ip} and {client_visitorid}")
311
312
  self._access_token = self._refresh_access_token()
312
313
  # self._access_token: Optional[str] = None # self._refresh_access_token()
313
314
  self.last_refresh_time = time.time()
@@ -441,6 +442,10 @@ class _RestClient:
441
442
  ) -> SearchTaskResponse:
442
443
  api_path = self.INITIAL_SEARCH_URI_FMT_V2
443
444
 
445
+ print(f"Start initial search with {self.client_ip} and {self.client_visitorid}")
446
+ track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
447
+ print(f"Sending track metrics: {track_metrics}")
448
+
444
449
  def open_and_send():
445
450
  md5_hash = hashlib.md5()
446
451
  with open(file_path, "rb") as file:
@@ -461,6 +466,11 @@ class _RestClient:
461
466
  metadata_with_md5.json(exclude_none=True).encode(),
462
467
  "application/json",
463
468
  ),
469
+ "tracking": (
470
+ "tracking.json",
471
+ dumps(track_metrics).encode(),
472
+ "application/json",
473
+ ),
464
474
  "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
465
475
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
466
476
  }
@@ -470,11 +480,6 @@ class _RestClient:
470
480
  search_customization.json(exclude_none=True).encode(),
471
481
  "application/json",
472
482
  )
473
- files["tracking"] = (
474
- "tracking.json",
475
- dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
476
- "application/json",
477
- )
478
483
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
479
484
 
480
485
  return self._send_post_file_req_v2(
@@ -545,6 +550,11 @@ class _RestClient:
545
550
  metadata_with_md5.json(exclude_none=True).encode(),
546
551
  "application/json",
547
552
  ),
553
+ "tracking": (
554
+ "tracking.json",
555
+ dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
556
+ "application/json",
557
+ ),
548
558
  "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
549
559
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
550
560
  }
@@ -554,11 +564,6 @@ class _RestClient:
554
564
  search_customization.json(exclude_none=True).encode(),
555
565
  "application/json",
556
566
  )
557
- files["tracking"] = (
558
- "ide",
559
- dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
560
- "application/json",
561
- )
562
567
 
563
568
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
564
569
 
@@ -922,12 +927,12 @@ def is_demo_api_key(api_token: Optional[str]) -> bool:
922
927
  @lru_cache()
923
928
  def _get_rest_client(backend_url: str, api_token: str,
924
929
  client_ip: Optional[str] = None, client_visitorid: Optional[str] = None) -> _RestClient:
925
- return _RestClient(backend_url, api_token)
930
+ return _RestClient(backend_url, api_token, client_ip, client_visitorid)
926
931
 
927
932
 
928
933
  class BackendLogHandler(logging.Handler):
929
- def __init__(self, rest_client: _RestClient,
930
- client_ip: Optional[str] = None, client_visitorid: Optional[str] = None,
934
+ def __init__(self, rest_client: _RestClient,
935
+ client_ip: Optional[str] = None, client_visitorid: Optional[str] = None,
931
936
  *args, **kwargs) -> None:
932
937
  super().__init__(*args, **kwargs)
933
938
  self.rest_client = rest_client
@@ -982,7 +987,7 @@ class LoggerFactory:
982
987
  root.handlers.clear()
983
988
 
984
989
  def get_logger(
985
- self, backend_url: Optional[str] = None, api_token: Optional[str] = None,
990
+ self, backend_url: Optional[str] = None, api_token: Optional[str] = None,
986
991
  client_ip: Optional[str] = None, client_visitorid: Optional[str] = None
987
992
  ) -> logging.Logger:
988
993
  url = _resolve_backend_url(backend_url)
@@ -994,7 +999,7 @@ class LoggerFactory:
994
999
 
995
1000
  upgini_logger = logging.getLogger(f"upgini.{hash(key)}")
996
1001
  upgini_logger.handlers.clear()
997
- rest_client = get_rest_client(backend_url, api_token)
1002
+ rest_client = get_rest_client(backend_url, api_token, client_ip, client_visitorid)
998
1003
  datadog_handler = BackendLogHandler(rest_client, client_ip, client_visitorid)
999
1004
  json_formatter = jsonlogger.JsonFormatter(
1000
1005
  "%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s",
upgini/metrics.py CHANGED
@@ -215,7 +215,7 @@ class EstimatorWrapper:
215
215
  self.groups = groups
216
216
 
217
217
  def fit(self, X: pd.DataFrame, y: np.ndarray, **kwargs):
218
- X, y, fit_params = self._prepare_to_fit(X, y)
218
+ X, y, _, fit_params = self._prepare_to_fit(X, y)
219
219
  kwargs.update(fit_params)
220
220
  self.estimator.fit(X, y, **kwargs)
221
221
  return self
@@ -223,7 +223,13 @@ class EstimatorWrapper:
223
223
  def predict(self, **kwargs):
224
224
  return self.estimator.predict(**kwargs)
225
225
 
226
- def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
226
+ def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
227
+ X, y, groups = self._prepare_data(X, y, groups=self.groups)
228
+ return X, y, groups, {}
229
+
230
+ def _prepare_data(
231
+ self, X: pd.DataFrame, y: pd.Series, groups: Optional[np.ndarray] = None
232
+ ) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
227
233
  for c in X.columns:
228
234
  if is_numeric_dtype(X[c]):
229
235
  X[c] = X[c].astype(float)
@@ -233,36 +239,33 @@ class EstimatorWrapper:
233
239
  if not isinstance(y, pd.Series):
234
240
  raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
235
241
 
236
- joined = pd.concat([X, y], axis=1)
237
- joined = joined[joined[y.name].notna()]
238
- joined = joined.reset_index(drop=True)
239
- X = joined.drop(columns=y.name)
240
- y = np.array(list(joined[y.name].values))
241
- return X, y, {}
242
-
243
- def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
244
- for c in X.columns:
245
- if is_numeric_dtype(X[c]):
246
- X[c] = X[c].astype(float)
247
- else:
248
- X[c] = X[c].astype(str)
242
+ if groups is not None:
243
+ X["__groups"] = groups
244
+ X, y = self._remove_empty_target_rows(X, y)
245
+ groups = X["__groups"]
246
+ X.drop(columns="__groups", inplace=True)
247
+ else:
248
+ X, y = self._remove_empty_target_rows(X, y)
249
249
 
250
- if not isinstance(y, pd.Series):
251
- raise Exception(bundle.get("metrics_unsupported_target_type").format(type(y)))
250
+ return X, y, groups
252
251
 
252
+ def _remove_empty_target_rows(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
253
253
  joined = pd.concat([X, y], axis=1)
254
254
  joined = joined[joined[y.name].notna()]
255
255
  joined = joined.reset_index(drop=True)
256
256
  X = joined.drop(columns=y.name)
257
257
  y = np.array(list(joined[y.name].values))
258
+
259
+ return X, y
260
+
261
+ def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
262
+ X, y, _ = self._prepare_data(X, y)
258
263
  return X, y, {}
259
264
 
260
265
  def cross_val_predict(
261
266
  self, X: pd.DataFrame, y: np.ndarray, baseline_score_column: Optional[Any] = None
262
267
  ) -> Optional[float]:
263
- X, y, fit_params = self._prepare_to_fit(X, y)
264
- # if isinstance(self.estimator, CatBoostClassifier) or isinstance(self.estimator, CatBoostRegressor):
265
- # fit_params["early_stopping_rounds"] = 20
268
+ X, y, groups, fit_params = self._prepare_to_fit(X, y)
266
269
 
267
270
  if X.shape[1] == 0:
268
271
  return None
@@ -278,7 +281,7 @@ class EstimatorWrapper:
278
281
  y=y,
279
282
  scoring=scorer,
280
283
  cv=self.cv,
281
- groups=self.groups,
284
+ groups=groups,
282
285
  fit_params=fit_params,
283
286
  return_estimator=True,
284
287
  )
@@ -393,8 +396,8 @@ class CatBoostWrapper(EstimatorWrapper):
393
396
  self.cat_features = None
394
397
  self.cat_features_idx = None
395
398
 
396
- def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
397
- X, y, params = super()._prepare_to_fit(X, y)
399
+ def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
400
+ X, y, groups, params = super()._prepare_to_fit(X, y)
398
401
  self.cat_features = _get_cat_features(X)
399
402
  X = fill_na_cat_features(X, self.cat_features)
400
403
  # unique_cat_features = []
@@ -418,7 +421,7 @@ class CatBoostWrapper(EstimatorWrapper):
418
421
  del self.estimator._init_params["cat_features"]
419
422
 
420
423
  params.update({"cat_features": self.cat_features_idx})
421
- return X, y, params
424
+ return X, y, groups, params
422
425
 
423
426
  def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
424
427
  X, y, params = super()._prepare_to_calculate(X, y)
@@ -445,8 +448,8 @@ class LightGBMWrapper(EstimatorWrapper):
445
448
  )
446
449
  self.cat_features = None
447
450
 
448
- def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, dict]:
449
- X, y, params = super()._prepare_to_fit(X, y)
451
+ def _prepare_to_fit(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, np.ndarray, dict]:
452
+ X, y, groups, params = super()._prepare_to_fit(X, y)
450
453
  self.cat_features = _get_cat_features(X)
451
454
  X = fill_na_cat_features(X, self.cat_features)
452
455
  for feature in self.cat_features:
@@ -454,7 +457,7 @@ class LightGBMWrapper(EstimatorWrapper):
454
457
  if not is_numeric_dtype(y):
455
458
  y = correct_string_target(y)
456
459
 
457
- return X, y, params
460
+ return X, y, groups, params
458
461
 
459
462
  def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
460
463
  X, y, params = super()._prepare_to_calculate(X, y)
@@ -483,8 +486,8 @@ class OtherEstimatorWrapper(EstimatorWrapper):
483
486
  )
484
487
  self.cat_features = None
485
488
 
486
- def _prepare_to_fit(self, X: pd.DataFrame, y: np.ndarray) -> Tuple[pd.DataFrame, np.ndarray, dict]:
487
- X, y, params = super()._prepare_to_fit(X, y)
489
+ def _prepare_to_fit(self, X: pd.DataFrame, y: np.ndarray) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, dict]:
490
+ X, y, groups, params = super()._prepare_to_fit(X, y)
488
491
  self.cat_features = _get_cat_features(X)
489
492
  num_features = [col for col in X.columns if col not in self.cat_features]
490
493
  X[num_features] = X[num_features].fillna(-999)
@@ -494,7 +497,7 @@ class OtherEstimatorWrapper(EstimatorWrapper):
494
497
  X[feature] = X[feature].astype("category").cat.codes
495
498
  if not is_numeric_dtype(y):
496
499
  y = correct_string_target(y)
497
- return X, y, params
500
+ return X, y, groups, params
498
501
 
499
502
  def _prepare_to_calculate(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, np.ndarray, dict]:
500
503
  X, y, params = super()._prepare_to_calculate(X, y)
@@ -126,6 +126,8 @@ eval_y_multiindex_unsupported=Multi index in y in eval_set is not supported
126
126
  eval_x_is_empty=X in eval_set is empty.
127
127
  eval_y_is_empty=y in eval_set is empty.
128
128
  x_and_eval_x_diff_types=X and eval_set X has different types: {} and {}
129
+ baseline_score_column_not_exists=baseline_score_column {} doesn't exist in input dataframe
130
+ baseline_score_column_has_na=baseline_score_column contains NaN. Clear it and and retry
129
131
  # target validation
130
132
  empty_target=Target is empty in all rows
131
133
  non_numeric_target=Binary target should be numerical type
upgini/search_task.py CHANGED
@@ -8,7 +8,7 @@ import pandas as pd
8
8
 
9
9
  from upgini import dataset
10
10
  from upgini.http import (
11
- LoggerFactory,
11
+ _RestClient,
12
12
  ProviderTaskSummary,
13
13
  SearchProgress,
14
14
  SearchTaskSummary,
@@ -42,8 +42,7 @@ class SearchTask:
42
42
  accurate_model: bool = False,
43
43
  initial_search_task_id: Optional[str] = None,
44
44
  task_type: Optional[ModelTaskType] = None,
45
- endpoint: Optional[str] = None,
46
- api_key: Optional[str] = None,
45
+ rest_client: Optional[_RestClient] = None,
47
46
  logger: Optional[logging.Logger] = None,
48
47
  ):
49
48
  self.search_task_id = search_task_id
@@ -54,8 +53,7 @@ class SearchTask:
54
53
  self.accurate_model = accurate_model
55
54
  self.task_type = task_type
56
55
  self.summary = None
57
- self.endpoint = endpoint
58
- self.api_key = api_key
56
+ self.rest_client = rest_client
59
57
  if logger is not None:
60
58
  self.logger = logger
61
59
  else:
@@ -65,7 +63,7 @@ class SearchTask:
65
63
  self.unused_features_for_generation: Optional[List[str]] = None
66
64
 
67
65
  def get_progress(self, trace_id: str) -> SearchProgress:
68
- return get_rest_client(self.endpoint, self.api_key).get_search_progress(trace_id, self.search_task_id)
66
+ return self.rest_client.get_search_progress(trace_id, self.search_task_id)
69
67
 
70
68
  def poll_result(self, trace_id: str, quiet: bool = False, check_fit: bool = False) -> "SearchTask":
71
69
  completed_statuses = {"COMPLETED", "VALIDATION_COMPLETED"}
@@ -73,7 +71,7 @@ class SearchTask:
73
71
  submitted_statuses = {"SUBMITTED", "VALIDATION_SUBMITTED"}
74
72
  if not quiet:
75
73
  print(bundle.get("polling_search_task").format(self.search_task_id))
76
- if is_demo_api_key(self.api_key):
74
+ if is_demo_api_key(self.rest_client._refresh_token):
77
75
  print(bundle.get("polling_unregister_information"))
78
76
  search_task_id = self.initial_search_task_id if self.initial_search_task_id is not None else self.search_task_id
79
77
 
@@ -81,14 +79,14 @@ class SearchTask:
81
79
  with Spinner():
82
80
  if self.PROTECT_FROM_RATE_LIMIT:
83
81
  time.sleep(1) # this is neccesary to avoid requests rate limit restrictions
84
- self.summary = get_rest_client(self.endpoint, self.api_key).search_task_summary_v2(
82
+ self.summary = self.rest_client.search_task_summary_v2(
85
83
  trace_id, search_task_id
86
84
  )
87
85
  while self.summary.status not in completed_statuses and (
88
86
  not check_fit or "VALIDATION" not in self.summary.status
89
87
  ):
90
88
  time.sleep(self.POLLING_DELAY_SECONDS)
91
- self.summary = get_rest_client(self.endpoint, self.api_key).search_task_summary_v2(
89
+ self.summary = self.rest_client.search_task_summary_v2(
92
90
  trace_id, search_task_id
93
91
  )
94
92
  if self.summary.status in failed_statuses:
@@ -104,7 +102,7 @@ class SearchTask:
104
102
  except KeyboardInterrupt as e:
105
103
  if not check_fit:
106
104
  print(bundle.get("search_stopping"))
107
- get_rest_client(self.endpoint, self.api_key).stop_search_task_v2(trace_id, search_task_id)
105
+ self.rest_client.stop_search_task_v2(trace_id, search_task_id)
108
106
  self.logger.warning(f"Search {search_task_id} stopped by user")
109
107
  print(bundle.get("search_stopped"))
110
108
  raise e
@@ -132,7 +130,7 @@ class SearchTask:
132
130
  for provider_summary in self.summary.initial_important_providers:
133
131
  if provider_summary.status == "COMPLETED":
134
132
  self.provider_metadata_v2.append(
135
- get_rest_client(self.endpoint, self.api_key).get_provider_search_metadata_v3(
133
+ self.rest_client.get_provider_search_metadata_v3(
136
134
  provider_summary.ads_search_task_id, trace_id
137
135
  )
138
136
  )
@@ -258,8 +256,8 @@ class SearchTask:
258
256
  if self.PROTECT_FROM_RATE_LIMIT:
259
257
  time.sleep(1) # this is neccesary to avoid requests rate limit restrictions
260
258
  return _get_all_initial_raw_features_cached(
261
- self.endpoint,
262
- self.api_key,
259
+ self.rest_client._service_endpoint,
260
+ self.rest_client._refresh_token,
263
261
  trace_id,
264
262
  self.search_task_id,
265
263
  metrics_calculation,
@@ -269,7 +267,11 @@ class SearchTask:
269
267
  def get_target_outliers(self, trace_id: str) -> Optional[pd.DataFrame]:
270
268
  self._check_finished_initial_search()
271
269
  return _get_target_outliers_cached(
272
- self.endpoint, self.api_key, trace_id, self.search_task_id, self.PROTECT_FROM_RATE_LIMIT
270
+ self.rest_client._service_endpoint,
271
+ self.rest_client._refresh_token,
272
+ trace_id,
273
+ self.search_task_id,
274
+ self.PROTECT_FROM_RATE_LIMIT
273
275
  )
274
276
 
275
277
  def get_max_initial_eval_set_hit_rate_v2(self) -> Optional[Dict[int, float]]:
@@ -287,8 +289,8 @@ class SearchTask:
287
289
  def get_all_validation_raw_features(self, trace_id: str, metrics_calculation=False) -> Optional[pd.DataFrame]:
288
290
  self._check_finished_validation_search()
289
291
  return _get_all_validation_raw_features_cached(
290
- self.endpoint,
291
- self.api_key,
292
+ self.rest_client._service_endpoint,
293
+ self.rest_client._refresh_token,
292
294
  trace_id,
293
295
  self.search_task_id,
294
296
  metrics_calculation,
@@ -296,7 +298,7 @@ class SearchTask:
296
298
  )
297
299
 
298
300
  def get_file_metadata(self, trace_id: str) -> FileMetadata:
299
- return get_rest_client(self.endpoint, self.api_key).get_search_file_metadata(self.search_task_id, trace_id)
301
+ return self.rest_client.get_search_file_metadata(self.search_task_id, trace_id)
300
302
 
301
303
 
302
304
  @lru_cache()
@@ -30,7 +30,7 @@ def define_task(y: pd.Series, logger: Optional[logging.Logger] = None, silent: b
30
30
  target_items = target.nunique()
31
31
  if target_items == 1:
32
32
  raise ValidationError(bundle.get("dataset_constant_target"))
33
-
33
+
34
34
  if target_items == 2:
35
35
  task = ModelTaskType.BINARY
36
36
  else:
@@ -50,6 +50,7 @@ def _get_execution_ide() -> str:
50
50
  except Exception:
51
51
  return "other"
52
52
 
53
+
53
54
  @lru_cache()
54
55
  def get_track_metrics(client_ip: Optional[str] = None, client_visitorid: Optional[str] = None) -> dict:
55
56
  # default values
@@ -73,7 +74,7 @@ def get_track_metrics(client_ip: Optional[str] = None, client_visitorid: Optiona
73
74
  display(
74
75
  Javascript(
75
76
  """
76
- import('https://upgini.github.io/upgini/js/visitorid.js')
77
+ import('https://upgini.github.io/upgini/js/a.js')
77
78
  .then(FingerprintJS => FingerprintJS.load())
78
79
  .then(fp => fp.get())
79
80
  .then(result => window.visitorId = result.visitorId);
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: upgini
3
- Version: 1.1.240
3
+ Version: 1.1.242
4
4
  Summary: Intelligent data search & enrichment for Machine Learning
5
5
  Home-page: https://upgini.com/
6
6
  Author: Upgini Developers
@@ -1,13 +1,13 @@
1
1
  upgini/__init__.py,sha256=asENHgEVHQBIkV-e_0IhE_ZWqkCG6398U3ZLrNzAH6k,407
2
2
  upgini/ads.py,sha256=mre6xn44wcC_fg63iLT_kTh4mViZqR9AKRJZAtpQz8Y,2592
3
- upgini/dataset.py,sha256=qSjv09LKzCYayucb_JlhExw9uSRcscLWTaD8hqATE3s,49676
3
+ upgini/dataset.py,sha256=y9rpNhdLU9QgfFZndrPGK-S6CL67q5ocmB9HMzwHtaA,49395
4
4
  upgini/errors.py,sha256=BqpvfhW2jJW5fa5KXj0alhXatGl-WK4xTl309-QNLp8,959
5
- upgini/features_enricher.py,sha256=DUo-pvBqHwp5O_Fr71f56TwGvZsmAM-KyzFUBMUAHk4,160312
5
+ upgini/features_enricher.py,sha256=n2L9MWq4WoUQIzoDDECFyiuprwZslFPPhbLfpXsT3sQ,162975
6
6
  upgini/fingerprint.js,sha256=VygVIQlN1v4NGZfjHqtRogOw8zjTnnMNJg_f7M5iGQU,33442
7
- upgini/http.py,sha256=RG93QmV3mqKixQsSHqYeM1Mtucp-EpdavcpCuhufnGE,42141
7
+ upgini/http.py,sha256=xeSatYNnSBMQfGMXsER_ZvhR5zfDTY8_E1g3YpIOb38,42477
8
8
  upgini/metadata.py,sha256=FZ5CQluLLWrfrBVThSIes1SW6wcs7n50aNZwzYnHiF0,9584
9
- upgini/metrics.py,sha256=YeYHJtEIs8OG-EzidG-nbSYB919pjZ4MMbdcZ_jfV2s,23639
10
- upgini/search_task.py,sha256=sqgb5MfwWXg6YAbVhLOPcVJ5tDCUyzxFRWfd9aWj8SM,17236
9
+ upgini/metrics.py,sha256=rteVPPjDFYlL5bBFVpu-YwwXQGNV1IzwT7V7L9JtjaE,23762
10
+ upgini/search_task.py,sha256=nTVrb3CE4M1zfDkI-W_qVdUhsc90b98w3lo0XxegeKo,17200
11
11
  upgini/spinner.py,sha256=yhakBaydMNS8E8TRAwTdCMdnWrHeWT0cR1M8c9hP6jA,1157
12
12
  upgini/version_validator.py,sha256=rDIncP6BEko4J2F2hUcMOtKm_vZbI4ICWcNcw8hrwM4,1400
13
13
  upgini/ads_management/__init__.py,sha256=qzyisOToVRP-tquAJD1PblZhNtMrOB8FiyF9JvfkvgE,50
@@ -28,7 +28,7 @@ upgini/normalizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
28
28
  upgini/normalizer/phone_normalizer.py,sha256=VIgLXuDuzzjPEXiy_LyDVLZKGaS7-le6Fh6T4D-TQDU,9930
29
29
  upgini/resource_bundle/__init__.py,sha256=M7GtS7KPQw9pinz8P2aQWXpSkD2YFwUPVGk1w92Pn84,7888
30
30
  upgini/resource_bundle/exceptions.py,sha256=KT-OnqA2J4OTfLjhbEl3KFZM2ci7EOPjqJuY_rXp3vs,622
31
- upgini/resource_bundle/strings.properties,sha256=1mpOkd_wkKIJGwWRBgfXz0mLx4lqdDro5IUoj8BBxuE,24527
31
+ upgini/resource_bundle/strings.properties,sha256=C6rXpf2nXByeCTCog1ZacEF9bKal6JJNlDUTvE0szAQ,24706
32
32
  upgini/sampler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  upgini/sampler/base.py,sha256=X2PVsfZ3Rl7twpFDh5UWyxqY2K_jcMGxZ2NcHLwFRj4,6489
34
34
  upgini/sampler/random_under_sampler.py,sha256=whX_f_TtalHH8Seyn_7n3sX_TSiDHeYfALmme9saqDg,4082
@@ -50,11 +50,11 @@ upgini/utils/phone_utils.py,sha256=JNSkF8G6mgsN8Czy11pamaJdsY6rBINEMpi7jbVt_RA,4
50
50
  upgini/utils/postal_code_utils.py,sha256=_8CR9tBqsPptQsmMUvnrCAmBaMIQSWH3JfJ4ly3x_zs,409
51
51
  upgini/utils/progress_bar.py,sha256=iNXyqT3vKCeHpfiG5HHwr7Lk2cTtKViM93Fl8iZnjGc,1564
52
52
  upgini/utils/sklearn_ext.py,sha256=IMx2La70AXAggApVpT7sMEjWqVWon5AMZt4MARDsIMQ,43847
53
- upgini/utils/target_utils.py,sha256=n03QhNbm9P5OvvI_RPex2Wa8_swrE5l3CslPniU95Bg,1712
54
- upgini/utils/track_info.py,sha256=NK4VSPR4gkphnt0fMiOLEQLaOW04HPK0nKLgZHeS820,5214
53
+ upgini/utils/target_utils.py,sha256=_VjYUm4ECXbgNvxNupr982fuOK_jtkg-8Xw7-zJBz2w,1708
54
+ upgini/utils/track_info.py,sha256=EPcJ13Jqa17_T0JjM37Ac9kWDz5Zk0GVsIZKutOb8aU,5207
55
55
  upgini/utils/warning_counter.py,sha256=vnmdFo5-7GBkU2bK9h_uC0K0Y_wtfcYstxOdeRfacO0,228
56
- upgini-1.1.240.dist-info/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
57
- upgini-1.1.240.dist-info/METADATA,sha256=c5l9RquzeHvhU-aq3esgp-5HWjiIxhc7vc-EIrQd-S8,48262
58
- upgini-1.1.240.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
59
- upgini-1.1.240.dist-info/top_level.txt,sha256=OFhTGiDIWKl5gFI49qvWq1R9IKflPaE2PekcbDXDtx4,7
60
- upgini-1.1.240.dist-info/RECORD,,
56
+ upgini-1.1.242.dist-info/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
57
+ upgini-1.1.242.dist-info/METADATA,sha256=FwVINjwPmABqlcahJ70lv1hjpyDTH7bt3CGKGZmBHE0,48262
58
+ upgini-1.1.242.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
59
+ upgini-1.1.242.dist-info/top_level.txt,sha256=OFhTGiDIWKl5gFI49qvWq1R9IKflPaE2PekcbDXDtx4,7
60
+ upgini-1.1.242.dist-info/RECORD,,