wavetrainer 0.0.40__tar.gz → 0.0.41__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {wavetrainer-0.0.40/wavetrainer.egg-info → wavetrainer-0.0.41}/PKG-INFO +1 -1
  2. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/setup.py +1 -1
  3. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/__init__.py +1 -1
  4. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/calibrator/calibrator.py +6 -0
  5. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/calibrator/calibrator_router.py +7 -0
  6. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/calibrator/mapie_calibrator.py +19 -20
  7. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/calibrator/vennabers_calibrator.py +3 -0
  8. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/catboost/catboost_kwargs.py +7 -6
  9. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/catboost/catboost_model.py +66 -83
  10. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/model.py +7 -15
  11. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/model_router.py +9 -17
  12. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/tabpfn/tabpfn_model.py +33 -39
  13. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/xgboost/xgboost_logger.py +4 -0
  14. wavetrainer-0.0.41/wavetrainer/model/xgboost/xgboost_model.py +331 -0
  15. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/base_selector_reducer.py +4 -14
  16. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/smart_correlation_reducer.py +2 -4
  17. wavetrainer-0.0.41/wavetrainer/selector/selector.py +112 -0
  18. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/trainer.py +33 -12
  19. {wavetrainer-0.0.40 → wavetrainer-0.0.41/wavetrainer.egg-info}/PKG-INFO +1 -1
  20. wavetrainer-0.0.40/wavetrainer/model/xgboost/xgboost_model.py +0 -277
  21. wavetrainer-0.0.40/wavetrainer/selector/selector.py +0 -90
  22. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/LICENSE +0 -0
  23. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/MANIFEST.in +0 -0
  24. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/README.md +0 -0
  25. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/requirements.txt +0 -0
  26. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/setup.cfg +0 -0
  27. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/tests/__init__.py +0 -0
  28. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/tests/model/__init__.py +0 -0
  29. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/tests/model/catboost_kwargs_test.py +0 -0
  30. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/tests/trainer_test.py +0 -0
  31. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/calibrator/__init__.py +0 -0
  32. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/create.py +0 -0
  33. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/exceptions.py +0 -0
  34. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/fit.py +0 -0
  35. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/__init__.py +0 -0
  36. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/catboost/__init__.py +0 -0
  37. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/catboost/catboost_classifier_wrap.py +0 -0
  38. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/catboost/catboost_regressor_wrap.py +0 -0
  39. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/tabpfn/__init__.py +0 -0
  40. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/xgboost/__init__.py +0 -0
  41. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model/xgboost/early_stopper.py +0 -0
  42. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/model_type.py +0 -0
  43. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/params.py +0 -0
  44. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/__init__.py +0 -0
  45. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/combined_reducer.py +0 -0
  46. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/constant_reducer.py +0 -0
  47. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/correlation_reducer.py +0 -0
  48. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/duplicate_reducer.py +0 -0
  49. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/non_categorical_numeric_columns.py +0 -0
  50. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/nonnumeric_reducer.py +0 -0
  51. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/reducer.py +0 -0
  52. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/select_by_single_feature_performance_reducer.py +0 -0
  53. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/reducer/unseen_reducer.py +0 -0
  54. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/selector/__init__.py +0 -0
  55. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/__init__.py +0 -0
  56. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/class_weights.py +0 -0
  57. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/combined_weights.py +0 -0
  58. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/exponential_weights.py +0 -0
  59. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/linear_weights.py +0 -0
  60. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/noop_weights.py +0 -0
  61. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/sigmoid_weights.py +0 -0
  62. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/weights.py +0 -0
  63. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/weights/weights_router.py +0 -0
  64. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/windower/__init__.py +0 -0
  65. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer/windower/windower.py +0 -0
  66. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer.egg-info/SOURCES.txt +0 -0
  67. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer.egg-info/dependency_links.txt +0 -0
  68. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer.egg-info/not-zip-safe +0 -0
  69. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer.egg-info/requires.txt +0 -0
  70. {wavetrainer-0.0.40 → wavetrainer-0.0.41}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wavetrainer
3
- Version: 0.0.40
3
+ Version: 0.0.41
4
4
  Summary: A library for automatically finding the optimal model within feature and hyperparameter space.
5
5
  Home-page: https://github.com/8W9aG/wavetrainer
6
6
  Author: Will Sackfield
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
23
23
 
24
24
  setup(
25
25
  name='wavetrainer',
26
- version='0.0.40',
26
+ version='0.0.41',
27
27
  description='A library for automatically finding the optimal model within feature and hyperparameter space.',
28
28
  long_description=long_description,
29
29
  long_description_content_type='text/markdown',
@@ -2,5 +2,5 @@
2
2
 
3
3
  from .create import create
4
4
 
5
- __VERSION__ = "0.0.40"
5
+ __VERSION__ = "0.0.41"
6
6
  __all__ = ("create",)
@@ -1,5 +1,7 @@
1
1
  """The prototype calibrator class."""
2
2
 
3
+ import pandas as pd
4
+
3
5
  from ..fit import Fit
4
6
  from ..model.model import Model
5
7
  from ..params import Params
@@ -15,3 +17,7 @@ class Calibrator(Params, Fit):
15
17
  def name(cls) -> str:
16
18
  """The name of the calibrator."""
17
19
  raise NotImplementedError("name not implemented in parent class.")
20
+
21
+ def predictions_as_x(self, y: pd.Series | pd.DataFrame | None = None) -> bool:
22
+ """Whether the calibrator wants predictions as X rather than features."""
23
+ raise NotImplementedError("predictions_as_x not implemented in parent class.")
@@ -36,6 +36,13 @@ class CalibratorRouter(Calibrator):
36
36
  def name(cls) -> str:
37
37
  return "router"
38
38
 
39
+ def predictions_as_x(self, y: pd.Series | pd.DataFrame | None = None) -> bool:
40
+ if y is None:
41
+ raise ValueError("y is null")
42
+ if determine_model_type(y) == ModelType.REGRESSION:
43
+ return False
44
+ return True
45
+
39
46
  def set_options(
40
47
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
41
48
  ) -> None:
@@ -1,13 +1,11 @@
1
1
  """A calibrator that implements MAPIE."""
2
2
 
3
- import logging
4
3
  import os
5
4
  from typing import Self
6
5
 
7
6
  import joblib # type: ignore
8
7
  import optuna
9
8
  import pandas as pd
10
- import sklearn # type: ignore
11
9
  from mapie.regression import MapieRegressor # type: ignore
12
10
 
13
11
  from ..model.model import PROBABILITY_COLUMN_PREFIX, Model
@@ -23,12 +21,15 @@ class MAPIECalibrator(Calibrator):
23
21
 
24
22
  def __init__(self, model: Model):
25
23
  super().__init__(model)
26
- self._mapie = MapieRegressor(model.estimator, method="plus")
24
+ self._mapie = MapieRegressor(model.create_estimator(), method="plus")
27
25
 
28
26
  @classmethod
29
27
  def name(cls) -> str:
30
28
  return "mapie"
31
29
 
30
+ def predictions_as_x(self, y: pd.Series | pd.DataFrame | None = None) -> bool:
31
+ return False
32
+
32
33
  def set_options(
33
34
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
34
35
  ) -> None:
@@ -59,20 +60,18 @@ class MAPIECalibrator(Calibrator):
59
60
  return self
60
61
 
61
62
  def transform(self, df: pd.DataFrame) -> pd.DataFrame:
62
- try:
63
- alpha = []
64
- for potential_alpha in [0.05, 0.32]:
65
- if len(df) > int(1.0 / potential_alpha) + 1:
66
- alpha.append(potential_alpha)
67
- if alpha:
68
- _, y_pis = self._mapie.predict(df, alpha=alpha)
69
- for i in range(y_pis.shape[1]):
70
- if i >= len(alpha):
71
- continue
72
- for ii in range(y_pis.shape[2]):
73
- alpha_val = alpha[i]
74
- values = y_pis[:, i, ii].flatten().tolist()
75
- df[f"{PROBABILITY_COLUMN_PREFIX}{alpha_val}_{ii == 1}"] = values
76
- except sklearn.exceptions.NotFittedError as exc: # type: ignore
77
- logging.warning(str(exc))
78
- return df
63
+ alpha = []
64
+ for potential_alpha in [0.05, 0.32]:
65
+ if len(df) > int(1.0 / potential_alpha) + 1:
66
+ alpha.append(potential_alpha)
67
+ ret_df = pd.DataFrame(index=df.index)
68
+ if alpha:
69
+ _, y_pis = self._mapie.predict(df, alpha=alpha)
70
+ for i in range(y_pis.shape[1]):
71
+ if i >= len(alpha):
72
+ continue
73
+ for ii in range(y_pis.shape[2]):
74
+ alpha_val = alpha[i]
75
+ values = y_pis[:, i, ii].flatten().tolist()
76
+ ret_df[f"{PROBABILITY_COLUMN_PREFIX}{alpha_val}_{ii == 1}"] = values
77
+ return ret_df
@@ -28,6 +28,9 @@ class VennabersCalibrator(Calibrator):
28
28
  def name(cls) -> str:
29
29
  return "vennabers"
30
30
 
31
+ def predictions_as_x(self, y: pd.Series | pd.DataFrame | None = None) -> bool:
32
+ return True
33
+
31
34
  def set_options(
32
35
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
33
36
  ) -> None:
@@ -38,12 +38,13 @@ def handle_fit_kwargs(*args, **kwargs) -> tuple[tuple[Any, ...], dict[str, Any]]
38
38
  args_list[0] = df[included_columns]
39
39
  args = tuple(args_list)
40
40
 
41
- eval_x = eval_x[included_columns]
42
- kwargs[EVAL_SET_ARG_KEY] = Pool(
43
- eval_x,
44
- label=eval_y,
45
- cat_features=cat_features,
46
- )
41
+ if eval_x is not None:
42
+ eval_x = eval_x[included_columns]
43
+ kwargs[EVAL_SET_ARG_KEY] = Pool(
44
+ eval_x,
45
+ label=eval_y,
46
+ cat_features=cat_features,
47
+ )
47
48
  kwargs[CAT_FEATURES_ARG_KEY] = cat_features
48
49
 
49
50
  del kwargs[ORIGINAL_X_ARG_KEY]
@@ -1,9 +1,10 @@
1
1
  """A model that wraps catboost."""
2
2
 
3
+ # pylint: disable=line-too-long
3
4
  import json
4
5
  import logging
5
6
  import os
6
- from typing import Any, Self
7
+ from typing import Self
7
8
 
8
9
  import optuna
9
10
  import pandas as pd
@@ -13,8 +14,6 @@ from catboost import CatBoost, Pool # type: ignore
13
14
  from ...model_type import ModelType, determine_model_type
14
15
  from ..model import PREDICTION_COLUMN, PROBABILITY_COLUMN_PREFIX, Model
15
16
  from .catboost_classifier_wrap import CatBoostClassifierWrapper
16
- from .catboost_kwargs import (CAT_FEATURES_ARG_KEY, EVAL_SET_ARG_KEY,
17
- ORIGINAL_X_ARG_KEY)
18
17
  from .catboost_regressor_wrap import CatBoostRegressorWrapper
19
18
 
20
19
  _MODEL_FILENAME = "model.cbm"
@@ -64,10 +63,6 @@ class CatboostModel(Model):
64
63
  self._early_stopping_rounds = None
65
64
  self._best_iteration = None
66
65
 
67
- @property
68
- def estimator(self) -> Any:
69
- return self._provide_catboost()
70
-
71
66
  @property
72
67
  def supports_importances(self) -> bool:
73
68
  return True
@@ -82,23 +77,11 @@ class CatboostModel(Model):
82
77
  importances = importances["Importances"].to_list() # type: ignore
83
78
  return {feature_ids[x]: importances[x] for x in range(len(feature_ids))}
84
79
 
85
- def pre_fit(
86
- self,
87
- df: pd.DataFrame,
88
- y: pd.Series | pd.DataFrame | None,
89
- eval_x: pd.DataFrame | None = None,
90
- eval_y: pd.Series | pd.DataFrame | None = None,
91
- w: pd.Series | None = None,
92
- ):
93
- if y is None:
94
- raise ValueError("y is null.")
95
- self._model_type = determine_model_type(y)
96
- return {
97
- EVAL_SET_ARG_KEY: (eval_x, eval_y),
98
- CAT_FEATURES_ARG_KEY: df.select_dtypes(include="category").columns.tolist(),
99
- ORIGINAL_X_ARG_KEY: df,
100
- "sample_weight": w,
101
- }
80
+ def provide_estimator(self):
81
+ return self._provide_catboost()
82
+
83
+ def create_estimator(self):
84
+ return self._create_catboost()
102
85
 
103
86
  def set_options(
104
87
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
@@ -214,66 +197,66 @@ class CatboostModel(Model):
214
197
  def _provide_catboost(self) -> CatBoost:
215
198
  catboost = self._catboost
216
199
  if catboost is None:
217
- best_iteration = self._best_iteration
218
- iterations = (
219
- best_iteration if best_iteration is not None else self._iterations
220
- )
221
- logging.info(
222
- "Creating catboost model with depth %d, boosting type %s, best iteration %d",
223
- self._depth,
224
- self._boosting_type,
225
- -1 if best_iteration is None else best_iteration,
226
- )
227
- match self._model_type:
228
- case ModelType.BINARY:
229
- catboost = CatBoostClassifierWrapper(
230
- iterations=iterations,
231
- learning_rate=self._learning_rate,
232
- depth=self._depth,
233
- l2_leaf_reg=self._l2_leaf_reg,
234
- boosting_type=self._boosting_type,
235
- early_stopping_rounds=self._early_stopping_rounds,
236
- metric_period=100,
237
- task_type="GPU" if torch.cuda.is_available() else "CPU",
238
- devices="0" if torch.cuda.is_available() else None,
239
- )
240
- case ModelType.REGRESSION:
241
- catboost = CatBoostRegressorWrapper(
242
- iterations=iterations,
243
- learning_rate=self._learning_rate,
244
- depth=self._depth,
245
- l2_leaf_reg=self._l2_leaf_reg,
246
- boosting_type=self._boosting_type,
247
- early_stopping_rounds=self._early_stopping_rounds,
248
- metric_period=100,
249
- task_type="GPU" if torch.cuda.is_available() else "CPU",
250
- devices="0" if torch.cuda.is_available() else None,
251
- )
252
- case ModelType.BINNED_BINARY:
253
- catboost = CatBoostClassifierWrapper(
254
- iterations=iterations,
255
- learning_rate=self._learning_rate,
256
- depth=self._depth,
257
- l2_leaf_reg=self._l2_leaf_reg,
258
- boosting_type=self._boosting_type,
259
- early_stopping_rounds=self._early_stopping_rounds,
260
- metric_period=100,
261
- task_type="GPU" if torch.cuda.is_available() else "CPU",
262
- devices="0" if torch.cuda.is_available() else None,
263
- )
264
- case ModelType.MULTI_CLASSIFICATION:
265
- catboost = CatBoostClassifierWrapper(
266
- iterations=iterations,
267
- learning_rate=self._learning_rate,
268
- depth=self._depth,
269
- l2_leaf_reg=self._l2_leaf_reg,
270
- boosting_type=self._boosting_type,
271
- early_stopping_rounds=self._early_stopping_rounds,
272
- metric_period=100,
273
- task_type="GPU" if torch.cuda.is_available() else "CPU",
274
- devices="0" if torch.cuda.is_available() else None,
275
- )
200
+ catboost = self._create_catboost()
276
201
  self._catboost = catboost
277
202
  if catboost is None:
278
203
  raise ValueError("catboost is null")
279
204
  return catboost
205
+
206
+ def _create_catboost(self) -> CatBoost:
207
+ best_iteration = self._best_iteration
208
+ iterations = best_iteration if best_iteration is not None else self._iterations
209
+ print(
210
+ f"Creating catboost model with depth {self._depth}, boosting type {self._boosting_type}, best iteration {best_iteration}",
211
+ )
212
+ match self._model_type:
213
+ case ModelType.BINARY:
214
+ return CatBoostClassifierWrapper(
215
+ iterations=iterations,
216
+ learning_rate=self._learning_rate,
217
+ depth=self._depth,
218
+ l2_leaf_reg=self._l2_leaf_reg,
219
+ boosting_type=self._boosting_type,
220
+ early_stopping_rounds=self._early_stopping_rounds,
221
+ metric_period=100,
222
+ task_type="GPU" if torch.cuda.is_available() else "CPU",
223
+ devices="0" if torch.cuda.is_available() else None,
224
+ )
225
+ case ModelType.REGRESSION:
226
+ return CatBoostRegressorWrapper(
227
+ iterations=iterations,
228
+ learning_rate=self._learning_rate,
229
+ depth=self._depth,
230
+ l2_leaf_reg=self._l2_leaf_reg,
231
+ boosting_type=self._boosting_type,
232
+ early_stopping_rounds=self._early_stopping_rounds,
233
+ metric_period=100,
234
+ task_type="GPU" if torch.cuda.is_available() else "CPU",
235
+ devices="0" if torch.cuda.is_available() else None,
236
+ )
237
+ case ModelType.BINNED_BINARY:
238
+ return CatBoostClassifierWrapper(
239
+ iterations=iterations,
240
+ learning_rate=self._learning_rate,
241
+ depth=self._depth,
242
+ l2_leaf_reg=self._l2_leaf_reg,
243
+ boosting_type=self._boosting_type,
244
+ early_stopping_rounds=self._early_stopping_rounds,
245
+ metric_period=100,
246
+ task_type="GPU" if torch.cuda.is_available() else "CPU",
247
+ devices="0" if torch.cuda.is_available() else None,
248
+ )
249
+ case ModelType.MULTI_CLASSIFICATION:
250
+ return CatBoostClassifierWrapper(
251
+ iterations=iterations,
252
+ learning_rate=self._learning_rate,
253
+ depth=self._depth,
254
+ l2_leaf_reg=self._l2_leaf_reg,
255
+ boosting_type=self._boosting_type,
256
+ early_stopping_rounds=self._early_stopping_rounds,
257
+ metric_period=100,
258
+ task_type="GPU" if torch.cuda.is_available() else "CPU",
259
+ devices="0" if torch.cuda.is_available() else None,
260
+ )
261
+ case _:
262
+ raise ValueError(f"Unrecognised model type: {self._model_type}")
@@ -25,11 +25,6 @@ class Model(Params, Fit):
25
25
  """Whether the model supports the X values."""
26
26
  raise NotImplementedError("supports_x not implemented in parent class.")
27
27
 
28
- @property
29
- def estimator(self) -> Any:
30
- """The estimator backing the model."""
31
- raise NotImplementedError("estimator not implemented in parent class.")
32
-
33
28
  @property
34
29
  def supports_importances(self) -> bool:
35
30
  """Whether this model supports feature importances."""
@@ -44,13 +39,10 @@ class Model(Params, Fit):
44
39
  "feature_importances not implemented in parent class."
45
40
  )
46
41
 
47
- def pre_fit(
48
- self,
49
- df: pd.DataFrame,
50
- y: pd.Series | pd.DataFrame | None,
51
- eval_x: pd.DataFrame | None = None,
52
- eval_y: pd.Series | pd.DataFrame | None = None,
53
- w: pd.Series | None = None,
54
- ) -> dict[str, Any]:
55
- """A call to make sure the model is prepared for the target type."""
56
- raise NotImplementedError("pre_fit not implemented in parent class.")
42
+ def provide_estimator(self) -> Any:
43
+ """Provides the current estimator."""
44
+ raise NotImplementedError("provides_estimator not implemented in parent class.")
45
+
46
+ def create_estimator(self) -> Any:
47
+ """Creates a new estimator."""
48
+ raise NotImplementedError("creates_estimator not implemented in parent class.")
@@ -2,7 +2,7 @@
2
2
 
3
3
  import json
4
4
  import os
5
- from typing import Any, Self
5
+ from typing import Self
6
6
 
7
7
  import optuna
8
8
  import pandas as pd
@@ -40,13 +40,6 @@ class ModelRouter(Model):
40
40
  def supports_x(cls, df: pd.DataFrame) -> bool:
41
41
  return True
42
42
 
43
- @property
44
- def estimator(self) -> Any:
45
- model = self._model
46
- if model is None:
47
- raise ValueError("model is null")
48
- return model.estimator
49
-
50
43
  @property
51
44
  def supports_importances(self) -> bool:
52
45
  model = self._model
@@ -61,18 +54,17 @@ class ModelRouter(Model):
61
54
  raise ValueError("model is null")
62
55
  return model.feature_importances
63
56
 
64
- def pre_fit(
65
- self,
66
- df: pd.DataFrame,
67
- y: pd.Series | pd.DataFrame | None,
68
- eval_x: pd.DataFrame | None = None,
69
- eval_y: pd.Series | pd.DataFrame | None = None,
70
- w: pd.Series | None = None,
71
- ) -> dict[str, Any]:
57
+ def provide_estimator(self):
58
+ model = self._model
59
+ if model is None:
60
+ raise ValueError("model is null")
61
+ return model.provide_estimator()
62
+
63
+ def create_estimator(self):
72
64
  model = self._model
73
65
  if model is None:
74
66
  raise ValueError("model is null")
75
- return model.pre_fit(df, y=y, eval_x=eval_x, eval_y=eval_y, w=w)
67
+ return model.create_estimator()
76
68
 
77
69
  def set_options(
78
70
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
@@ -5,7 +5,7 @@ import json
5
5
  import logging
6
6
  import os
7
7
  import pickle
8
- from typing import Any, Self
8
+ from typing import Self
9
9
 
10
10
  import optuna
11
11
  import pandas as pd
@@ -42,10 +42,6 @@ class TabPFNModel(Model):
42
42
  self._tabpfn = None
43
43
  self._model_type = None
44
44
 
45
- @property
46
- def estimator(self) -> Any:
47
- return self._provide_tabpfn()
48
-
49
45
  @property
50
46
  def supports_importances(self) -> bool:
51
47
  return False
@@ -54,18 +50,11 @@ class TabPFNModel(Model):
54
50
  def feature_importances(self) -> dict[str, float]:
55
51
  return {}
56
52
 
57
- def pre_fit(
58
- self,
59
- df: pd.DataFrame,
60
- y: pd.Series | pd.DataFrame | None,
61
- eval_x: pd.DataFrame | None = None,
62
- eval_y: pd.Series | pd.DataFrame | None = None,
63
- w: pd.Series | None = None,
64
- ):
65
- if y is None:
66
- raise ValueError("y is null.")
67
- self._model_type = determine_model_type(y)
68
- return {}
53
+ def provide_estimator(self):
54
+ return self._provide_tabpfn()
55
+
56
+ def create_estimator(self):
57
+ return self._create_tabpfn()
69
58
 
70
59
  def set_options(
71
60
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
@@ -133,29 +122,34 @@ class TabPFNModel(Model):
133
122
  def _provide_tabpfn(self) -> AutoTabPFNClassifier | AutoTabPFNRegressor:
134
123
  tabpfn = self._tabpfn
135
124
  if tabpfn is None:
136
- max_time = 1 if pytest_is_running.is_running() else 120
137
- match self._model_type:
138
- case ModelType.BINARY:
139
- tabpfn = AutoTabPFNClassifier(
140
- max_time=max_time,
141
- device="cuda" if torch.cuda.is_available() else "cpu",
142
- )
143
- case ModelType.REGRESSION:
144
- tabpfn = AutoTabPFNRegressor(
145
- max_time=max_time,
146
- device="cuda" if torch.cuda.is_available() else "cpu",
147
- )
148
- case ModelType.BINNED_BINARY:
149
- tabpfn = AutoTabPFNClassifier(
150
- max_time=max_time,
151
- device="cuda" if torch.cuda.is_available() else "cpu",
152
- )
153
- case ModelType.MULTI_CLASSIFICATION:
154
- tabpfn = AutoTabPFNClassifier(
155
- max_time=max_time,
156
- device="cuda" if torch.cuda.is_available() else "cpu",
157
- )
125
+ tabpfn = self._create_tabpfn()
158
126
  self._tabpfn = tabpfn
159
127
  if tabpfn is None:
160
128
  raise ValueError("tabpfn is null")
161
129
  return tabpfn
130
+
131
+ def _create_tabpfn(self) -> AutoTabPFNClassifier | AutoTabPFNRegressor:
132
+ max_time = 1 if pytest_is_running.is_running() else 120
133
+ match self._model_type:
134
+ case ModelType.BINARY:
135
+ return AutoTabPFNClassifier(
136
+ max_time=max_time,
137
+ device="cuda" if torch.cuda.is_available() else "cpu",
138
+ )
139
+ case ModelType.REGRESSION:
140
+ return AutoTabPFNRegressor(
141
+ max_time=max_time,
142
+ device="cuda" if torch.cuda.is_available() else "cpu",
143
+ )
144
+ case ModelType.BINNED_BINARY:
145
+ return AutoTabPFNClassifier(
146
+ max_time=max_time,
147
+ device="cuda" if torch.cuda.is_available() else "cpu",
148
+ )
149
+ case ModelType.MULTI_CLASSIFICATION:
150
+ return AutoTabPFNClassifier(
151
+ max_time=max_time,
152
+ device="cuda" if torch.cuda.is_available() else "cpu",
153
+ )
154
+ case _:
155
+ raise ValueError(f"Unrecognised model type: {self._model_type}")
@@ -15,6 +15,10 @@ class XGBoostEpochsLogger(TrainingCallback):
15
15
  return False
16
16
  log_items = []
17
17
  for dataset, metrics in evals_log.items():
18
+ if dataset == "validation_0":
19
+ dataset = "validation"
20
+ elif dataset == "validation_1":
21
+ dataset = "train"
18
22
  for metric_name, values in metrics.items():
19
23
  current_val = values[-1]
20
24
  log_items.append(f"{dataset}-{metric_name}: {current_val:.5f}")