upgini 1.1.168a1__py3-none-any.whl → 1.1.169a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

upgini/metrics.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Callable, List, Optional, Tuple, Union, Dict, Any
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -8,13 +8,25 @@ from catboost import CatBoostClassifier, CatBoostRegressor
8
8
  from lightgbm import LGBMClassifier, LGBMRegressor
9
9
  from numpy import log1p
10
10
  from pandas.api.types import is_numeric_dtype
11
- from sklearn.metrics import get_scorer_names, check_scoring, get_scorer, make_scorer
11
+ from sklearn.metrics import check_scoring, get_scorer, make_scorer
12
+
13
+ from upgini.utils.sklearn_ext import cross_validate
14
+
15
+ try:
16
+ from sklearn.metrics import get_scorer_names
17
+
18
+ available_scorers = get_scorer_names()
19
+ except ImportError:
20
+ from sklearn.metrics._scorer import SCORERS
21
+
22
+ available_scorers = SCORERS
23
+
12
24
  from sklearn.metrics._regression import (
13
25
  _check_reg_targets,
14
26
  check_consistent_length,
15
27
  mean_squared_error,
16
28
  )
17
- from sklearn.model_selection import BaseCrossValidator, cross_validate
29
+ from sklearn.model_selection import BaseCrossValidator
18
30
 
19
31
  from upgini.errors import ValidationError
20
32
  from upgini.metadata import ModelTaskType
@@ -29,6 +41,7 @@ CATBOOST_PARAMS = {
29
41
  "min_child_samples": 10,
30
42
  "max_depth": 5,
31
43
  "early_stopping_rounds": 20,
44
+ "use_best_model": True,
32
45
  "one_hot_max_size": 100,
33
46
  "verbose": False,
34
47
  "random_state": DEFAULT_RANDOM_STATE,
@@ -44,6 +57,8 @@ CATBOOST_MULTICLASS_PARAMS = {
44
57
  "loss_function": "MultiClass",
45
58
  "subsample": 0.5,
46
59
  "bootstrap_type": "Bernoulli",
60
+ "early_stopping_rounds": 20,
61
+ "use_best_model": True,
47
62
  "rsm": 0.1,
48
63
  "verbose": False,
49
64
  "random_state": DEFAULT_RANDOM_STATE,
@@ -90,6 +105,87 @@ NA_VALUES = [
90
105
 
91
106
  NA_REPLACEMENT = "NA"
92
107
 
108
+ SUPPORTED_CATBOOST_METRICS = {s.upper(): s for s in {
109
+ "Logloss",
110
+ "CrossEntropy",
111
+ "CtrFactor",
112
+ "Focal",
113
+ "RMSE",
114
+ "LogCosh",
115
+ "Lq",
116
+ "MAE",
117
+ "Quantile",
118
+ "MultiQuantile",
119
+ "Expectile",
120
+ "LogLinQuantile",
121
+ "MAPE",
122
+ "Poisson",
123
+ "MSLE",
124
+ "MedianAbsoluteError",
125
+ "SMAPE",
126
+ "Huber",
127
+ "Tweedie",
128
+ "Cox",
129
+ "RMSEWithUncertainty",
130
+ "MultiClass",
131
+ "MultiClassOneVsAll",
132
+ "PairLogit",
133
+ "PairLogitPairwise",
134
+ "YetiRank",
135
+ "YetiRankPairwise",
136
+ "QueryRMSE",
137
+ "QuerySoftMax",
138
+ "QueryCrossEntropy",
139
+ "StochasticFilter",
140
+ "LambdaMart",
141
+ "StochasticRank",
142
+ "PythonUserDefinedPerObject",
143
+ "PythonUserDefinedMultiTarget",
144
+ "UserPerObjMetric",
145
+ "UserQuerywiseMetric",
146
+ "R2",
147
+ "NumErrors",
148
+ "FairLoss",
149
+ "AUC",
150
+ "Accuracy",
151
+ "BalancedAccuracy",
152
+ "BalancedErrorRate",
153
+ "BrierScore",
154
+ "Precision",
155
+ "Recall",
156
+ "F1",
157
+ "TotalF1",
158
+ "F",
159
+ "MCC",
160
+ "ZeroOneLoss",
161
+ "HammingLoss",
162
+ "HingeLoss",
163
+ "Kappa",
164
+ "WKappa",
165
+ "LogLikelihoodOfPrediction",
166
+ "NormalizedGini",
167
+ "PRAUC",
168
+ "PairAccuracy",
169
+ "AverageGain",
170
+ "QueryAverage",
171
+ "QueryAUC",
172
+ "PFound",
173
+ "PrecisionAt",
174
+ "RecallAt",
175
+ "MAP",
176
+ "NDCG",
177
+ "DCG",
178
+ "FilteredDCG",
179
+ "MRR",
180
+ "ERR",
181
+ "SurvivalAft",
182
+ "MultiRMSE",
183
+ "MultiRMSEWithMissingValues",
184
+ "MultiLogloss",
185
+ "MultiCrossEntropy",
186
+ "Combination",
187
+ }}
188
+
93
189
 
94
190
  class EstimatorWrapper:
95
191
  def __init__(
@@ -166,7 +262,7 @@ class EstimatorWrapper:
166
262
  estimator=self.estimator,
167
263
  X=X,
168
264
  y=y,
169
- scoring={"score": scorer},
265
+ scoring=scorer,
170
266
  cv=self.cv,
171
267
  fit_params=fit_params,
172
268
  return_estimator=True,
@@ -204,14 +300,20 @@ class EstimatorWrapper:
204
300
  "target_type": target_type,
205
301
  }
206
302
  if estimator is None:
303
+ params = dict()
304
+ # if metric_name.upper() in SUPPORTED_CATBOOST_METRICS:
305
+ # params["eval_metric"] = SUPPORTED_CATBOOST_METRICS[metric_name.upper()]
207
306
  if target_type == ModelTaskType.MULTICLASS:
208
- params = _get_add_params(CATBOOST_MULTICLASS_PARAMS, add_params)
307
+ params = _get_add_params(params, CATBOOST_MULTICLASS_PARAMS)
308
+ params = _get_add_params(params, add_params)
209
309
  estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
210
310
  elif target_type == ModelTaskType.BINARY:
211
- params = _get_add_params(CATBOOST_PARAMS, add_params)
311
+ params = _get_add_params(params, CATBOOST_PARAMS)
312
+ params = _get_add_params(params, add_params)
212
313
  estimator = CatBoostWrapper(CatBoostClassifier(**params), **kwargs)
213
314
  elif target_type == ModelTaskType.REGRESSION:
214
- params = _get_add_params(CATBOOST_PARAMS, add_params)
315
+ params = _get_add_params(params, CATBOOST_PARAMS)
316
+ params = _get_add_params(params, add_params)
215
317
  estimator = CatBoostWrapper(CatBoostRegressor(**params), **kwargs)
216
318
  else:
217
319
  raise Exception(bundle.get("metrics_unsupported_target_type").format(target_type))
@@ -385,7 +487,6 @@ def _get_scorer(target_type: ModelTaskType, scoring: Union[Callable, str, None])
385
487
 
386
488
  multiplier = 1
387
489
  if isinstance(scoring, str):
388
- available_scorers = get_scorer_names()
389
490
  metric_name = scoring
390
491
  if "mean_squared_log_error" == metric_name or "MSLE" == metric_name or "msle" == metric_name:
391
492
  scoring = make_scorer(_ext_mean_squared_log_error, greater_is_better=False)