wavetrainer 0.0.26__tar.gz → 0.0.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {wavetrainer-0.0.26/wavetrainer.egg-info → wavetrainer-0.0.27}/PKG-INFO +1 -1
  2. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/setup.py +1 -1
  3. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/__init__.py +1 -1
  4. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/calibrator_router.py +2 -2
  5. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/mapie_calibrator.py +1 -1
  6. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/vennabers_calibrator.py +1 -1
  7. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_model.py +25 -11
  8. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/model_router.py +2 -2
  9. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/tabpfn_model.py +1 -1
  10. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/params.py +1 -1
  11. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/base_selector_reducer.py +1 -1
  12. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/combined_reducer.py +2 -2
  13. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/nonnumeric_reducer.py +1 -1
  14. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/pca_reducer.py +1 -1
  15. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/unseen_reducer.py +1 -1
  16. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/selector/selector.py +1 -1
  17. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/trainer.py +7 -6
  18. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/class_weights.py +1 -1
  19. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/combined_weights.py +2 -2
  20. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/exponential_weights.py +1 -1
  21. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/linear_weights.py +1 -1
  22. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/noop_weights.py +1 -1
  23. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/sigmoid_weights.py +1 -1
  24. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/weights_router.py +2 -1
  25. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/windower/windower.py +1 -1
  26. {wavetrainer-0.0.26 → wavetrainer-0.0.27/wavetrainer.egg-info}/PKG-INFO +1 -1
  27. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/LICENSE +0 -0
  28. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/MANIFEST.in +0 -0
  29. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/README.md +0 -0
  30. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/requirements.txt +0 -0
  31. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/setup.cfg +0 -0
  32. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/__init__.py +0 -0
  33. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/model/__init__.py +0 -0
  34. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/model/catboost_kwargs_test.py +0 -0
  35. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/trainer_test.py +0 -0
  36. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/__init__.py +0 -0
  37. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/calibrator.py +0 -0
  38. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/create.py +0 -0
  39. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/exceptions.py +0 -0
  40. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/fit.py +0 -0
  41. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/__init__.py +0 -0
  42. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_classifier_wrap.py +0 -0
  43. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_kwargs.py +0 -0
  44. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_regressor_wrap.py +0 -0
  45. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/model.py +0 -0
  46. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model_type.py +0 -0
  47. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/__init__.py +0 -0
  48. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/constant_reducer.py +0 -0
  49. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/correlation_reducer.py +0 -0
  50. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/duplicate_reducer.py +0 -0
  51. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/reducer.py +0 -0
  52. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/selector/__init__.py +0 -0
  53. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/__init__.py +0 -0
  54. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/weights.py +0 -0
  55. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/windower/__init__.py +0 -0
  56. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/SOURCES.txt +0 -0
  57. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/dependency_links.txt +0 -0
  58. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/not-zip-safe +0 -0
  59. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/requires.txt +0 -0
  60. {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wavetrainer
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: A library for automatically finding the optimal model within feature and hyperparameter space.
5
5
  Home-page: https://github.com/8W9aG/wavetrainer
6
6
  Author: Will Sackfield
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
23
23
 
24
24
  setup(
25
25
  name='wavetrainer',
26
- version='0.0.26',
26
+ version='0.0.27',
27
27
  description='A library for automatically finding the optimal model within feature and hyperparameter space.',
28
28
  long_description=long_description,
29
29
  long_description_content_type='text/markdown',
@@ -2,5 +2,5 @@
2
2
 
3
3
  from .create import create
4
4
 
5
- __VERSION__ = "0.0.26"
5
+ __VERSION__ = "0.0.27"
6
6
  __all__ = ("create",)
@@ -48,11 +48,11 @@ class CalibratorRouter(Calibrator):
48
48
  calibrator.load(folder)
49
49
  self._calibrator = calibrator
50
50
 
51
- def save(self, folder: str) -> None:
51
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
52
52
  calibrator = self._calibrator
53
53
  if calibrator is None:
54
54
  raise ValueError("calibrator is null.")
55
- calibrator.save(folder)
55
+ calibrator.save(folder, trial)
56
56
  with open(
57
57
  os.path.join(folder, _CALIBRATOR_ROUTER_FILE), "w", encoding="utf8"
58
58
  ) as handle:
@@ -35,7 +35,7 @@ class MAPIECalibrator(Calibrator):
35
35
  def load(self, folder: str) -> None:
36
36
  self._mapie = joblib.load(os.path.join(folder, _CALIBRATOR_FILENAME))
37
37
 
38
- def save(self, folder: str) -> None:
38
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
39
39
  joblib.dump(self._mapie, os.path.join(folder, _CALIBRATOR_FILENAME))
40
40
 
41
41
  def fit(
@@ -33,7 +33,7 @@ class VennabersCalibrator(Calibrator):
33
33
  def load(self, folder: str) -> None:
34
34
  self._vennabers = joblib.load(os.path.join(folder, _CALIBRATOR_FILENAME))
35
35
 
36
- def save(self, folder: str) -> None:
36
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
37
37
  joblib.dump(self._vennabers, os.path.join(folder, _CALIBRATOR_FILENAME))
38
38
 
39
39
  def fit(
@@ -26,6 +26,7 @@ _L2_LEAF_REG_KEY = "l2_leaf_reg"
26
26
  _BOOSTING_TYPE_KEY = "boosting_type"
27
27
  _MODEL_TYPE_KEY = "model_type"
28
28
  _EARLY_STOPPING_ROUNDS = "early_stopping_rounds"
29
+ _BEST_ITERATION_KEY = "best_iteration"
29
30
 
30
31
 
31
32
  class CatboostModel(Model):
@@ -41,6 +42,7 @@ class CatboostModel(Model):
41
42
  _boosting_type: None | str
42
43
  _model_type: None | ModelType
43
44
  _early_stopping_rounds: None | int
45
+ _best_iteration: None | int
44
46
 
45
47
  @classmethod
46
48
  def name(cls) -> str:
@@ -56,6 +58,7 @@ class CatboostModel(Model):
56
58
  self._boosting_type = None
57
59
  self._model_type = None
58
60
  self._early_stopping_rounds = None
61
+ self._best_iteration = None
59
62
 
60
63
  @property
61
64
  def estimator(self) -> Any:
@@ -92,6 +95,7 @@ class CatboostModel(Model):
92
95
  _BOOSTING_TYPE_KEY, ["Ordered", "Plain"]
93
96
  )
94
97
  self._early_stopping_rounds = trial.suggest_int(_EARLY_STOPPING_ROUNDS, 10, 500)
98
+ self._best_iteration = trial.user_attrs.get(_BEST_ITERATION_KEY)
95
99
 
96
100
  def load(self, folder: str) -> None:
97
101
  with open(
@@ -105,10 +109,11 @@ class CatboostModel(Model):
105
109
  self._boosting_type = params[_BOOSTING_TYPE_KEY]
106
110
  self._model_type = ModelType(params[_MODEL_TYPE_KEY])
107
111
  self._early_stopping_rounds = params[_EARLY_STOPPING_ROUNDS]
112
+ self._best_iteration = params.get(_BEST_ITERATION_KEY)
108
113
  catboost = self._provide_catboost()
109
114
  catboost.load_model(os.path.join(folder, _MODEL_FILENAME))
110
115
 
111
- def save(self, folder: str) -> None:
116
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
112
117
  with open(
113
118
  os.path.join(folder, _MODEL_PARAMS_FILENAME), "w", encoding="utf8"
114
119
  ) as handle:
@@ -121,11 +126,13 @@ class CatboostModel(Model):
121
126
  _BOOSTING_TYPE_KEY: self._boosting_type,
122
127
  _MODEL_TYPE_KEY: str(self._model_type),
123
128
  _EARLY_STOPPING_ROUNDS: self._early_stopping_rounds,
129
+ _BEST_ITERATION_KEY: self._best_iteration,
124
130
  },
125
131
  handle,
126
132
  )
127
133
  catboost = self._provide_catboost()
128
134
  catboost.save_model(os.path.join(folder, _MODEL_FILENAME))
135
+ trial.user_attrs[_BEST_ITERATION_KEY] = self._best_iteration
129
136
 
130
137
  def fit(
131
138
  self,
@@ -137,8 +144,6 @@ class CatboostModel(Model):
137
144
  ) -> Self:
138
145
  if y is None:
139
146
  raise ValueError("y is null.")
140
- if eval_x is None:
141
- raise ValueError("eval_x is null.")
142
147
  self._model_type = determine_model_type(y)
143
148
  catboost = self._provide_catboost()
144
149
 
@@ -148,10 +153,14 @@ class CatboostModel(Model):
148
153
  weight=w,
149
154
  cat_features=df.select_dtypes(include="category").columns.tolist(),
150
155
  )
151
- eval_pool = Pool(
152
- eval_x,
153
- label=eval_y,
154
- cat_features=eval_x.select_dtypes(include="category").columns.tolist(),
156
+ eval_pool = (
157
+ Pool(
158
+ eval_x,
159
+ label=eval_y,
160
+ cat_features=eval_x.select_dtypes(include="category").columns.tolist(),
161
+ )
162
+ if eval_x is not None
163
+ else None
155
164
  )
156
165
  catboost.fit(
157
166
  train_pool,
@@ -162,6 +171,7 @@ class CatboostModel(Model):
162
171
  )
163
172
  importances = catboost.get_feature_importance(prettified=True)
164
173
  logging.info("Importances:\n%s", importances)
174
+ self._best_iteration = catboost.get_best_iteration()
165
175
  return self
166
176
 
167
177
  def transform(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -186,10 +196,14 @@ class CatboostModel(Model):
186
196
  def _provide_catboost(self) -> CatBoost:
187
197
  catboost = self._catboost
188
198
  if catboost is None:
199
+ best_iteration = self._best_iteration
200
+ iterations = (
201
+ best_iteration if best_iteration is not None else self._iterations
202
+ )
189
203
  match self._model_type:
190
204
  case ModelType.BINARY:
191
205
  catboost = CatBoostClassifierWrapper(
192
- iterations=self._iterations,
206
+ iterations=iterations,
193
207
  learning_rate=self._learning_rate,
194
208
  depth=self._depth,
195
209
  l2_leaf_reg=self._l2_leaf_reg,
@@ -201,7 +215,7 @@ class CatboostModel(Model):
201
215
  )
202
216
  case ModelType.REGRESSION:
203
217
  catboost = CatBoostRegressorWrapper(
204
- iterations=self._iterations,
218
+ iterations=iterations,
205
219
  learning_rate=self._learning_rate,
206
220
  depth=self._depth,
207
221
  l2_leaf_reg=self._l2_leaf_reg,
@@ -213,7 +227,7 @@ class CatboostModel(Model):
213
227
  )
214
228
  case ModelType.BINNED_BINARY:
215
229
  catboost = CatBoostClassifierWrapper(
216
- iterations=self._iterations,
230
+ iterations=iterations,
217
231
  learning_rate=self._learning_rate,
218
232
  depth=self._depth,
219
233
  l2_leaf_reg=self._l2_leaf_reg,
@@ -225,7 +239,7 @@ class CatboostModel(Model):
225
239
  )
226
240
  case ModelType.MULTI_CLASSIFICATION:
227
241
  catboost = CatBoostClassifierWrapper(
228
- iterations=self._iterations,
242
+ iterations=iterations,
229
243
  learning_rate=self._learning_rate,
230
244
  depth=self._depth,
231
245
  l2_leaf_reg=self._l2_leaf_reg,
@@ -73,11 +73,11 @@ class ModelRouter(Model):
73
73
  model.load(folder)
74
74
  self._model = model
75
75
 
76
- def save(self, folder: str) -> None:
76
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
77
77
  model = self._model
78
78
  if model is None:
79
79
  raise ValueError("model is null")
80
- model.save(folder)
80
+ model.save(folder, trial)
81
81
  with open(
82
82
  os.path.join(folder, _MODEL_ROUTER_FILE), "w", encoding="utf8"
83
83
  ) as handle:
@@ -69,7 +69,7 @@ class TabPFNModel(Model):
69
69
  params = json.load(handle)
70
70
  self._model_type = ModelType(params[_MODEL_TYPE_KEY])
71
71
 
72
- def save(self, folder: str) -> None:
72
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
73
73
  with open(os.path.join(folder, _MODEL_FILENAME), "wb") as f:
74
74
  pickle.dump(self._tabpfn, f)
75
75
  with open(
@@ -14,6 +14,6 @@ class Params:
14
14
  """Loads the objects from a folder."""
15
15
  raise NotImplementedError("load not implemented in parent class.")
16
16
 
17
- def save(self, folder: str) -> None:
17
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
18
18
  """Saves the objects into a folder."""
19
19
  raise NotImplementedError("save not implemented in parent class.")
@@ -39,7 +39,7 @@ class BaseSelectorReducer(Reducer):
39
39
  file_path = os.path.join(folder, self._file_name)
40
40
  self._base_selector = joblib.load(file_path)
41
41
 
42
- def save(self, folder: str) -> None:
42
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
43
43
  file_path = os.path.join(folder, self._file_name)
44
44
  joblib.dump(self._base_selector, file_path)
45
45
 
@@ -67,7 +67,7 @@ class CombinedReducer(Reducer):
67
67
  for reducer in self._reducers:
68
68
  reducer.load(folder)
69
69
 
70
- def save(self, folder: str) -> None:
70
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
71
71
  with open(
72
72
  os.path.join(folder, _COMBINED_REDUCER_FILE), "w", encoding="utf8"
73
73
  ) as handle:
@@ -78,7 +78,7 @@ class CombinedReducer(Reducer):
78
78
  handle,
79
79
  )
80
80
  for reducer in self._reducers:
81
- reducer.save(folder)
81
+ reducer.save(folder, trial)
82
82
 
83
83
  def fit(
84
84
  self,
@@ -23,7 +23,7 @@ class NonNumericReducer(Reducer):
23
23
  def load(self, folder: str) -> None:
24
24
  pass
25
25
 
26
- def save(self, folder: str) -> None:
26
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
27
27
  pass
28
28
 
29
29
  def fit(
@@ -45,7 +45,7 @@ class PCAReducer(Reducer):
45
45
  if os.path.exists(pca_file):
46
46
  self._pca = joblib.load(pca_file)
47
47
 
48
- def save(self, folder: str) -> None:
48
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
49
49
  if self._scaler is not None:
50
50
  joblib.dump(self._scaler, os.path.join(folder, _PCA_SCALER_FILE))
51
51
  if self._pca is not None:
@@ -34,7 +34,7 @@ class UnseenReducer(Reducer):
34
34
  ) as handle:
35
35
  self._seen_features = json.load(handle)
36
36
 
37
- def save(self, folder: str) -> None:
37
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
38
38
  with open(
39
39
  os.path.join(folder, _UNSEEN_REDUCER_FILE), "w", encoding="utf8"
40
40
  ) as handle:
@@ -38,7 +38,7 @@ class Selector(Params, Fit):
38
38
  def load(self, folder: str) -> None:
39
39
  self._selector = joblib.load(os.path.join(folder, _SELECTOR_FILE))
40
40
 
41
- def save(self, folder: str) -> None:
41
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
42
42
  joblib.dump(self._selector, os.path.join(folder, _SELECTOR_FILE))
43
43
 
44
44
  def fit(
@@ -258,12 +258,12 @@ class Trainer(Fit):
258
258
  )
259
259
  if not os.path.exists(folder):
260
260
  os.mkdir(folder)
261
- windower.save(folder)
262
- reducer.save(folder)
263
- weights.save(folder)
264
- model.save(folder)
265
- selector.save(folder)
266
- calibrator.save(folder)
261
+ windower.save(folder, trial)
262
+ reducer.save(folder, trial)
263
+ weights.save(folder, trial)
264
+ model.save(folder, trial)
265
+ selector.save(folder, trial)
266
+ calibrator.save(folder, trial)
267
267
 
268
268
  y_pred = model.transform(x_test)
269
269
  y_pred = calibrator.transform(y_pred)
@@ -380,6 +380,7 @@ class Trainer(Fit):
380
380
 
381
381
  def transform(self, df: pd.DataFrame) -> pd.DataFrame:
382
382
  """Predict the expected values of the data."""
383
+ tqdm.tqdm.pandas(desc="Inferring...")
383
384
  input_df = df.copy()
384
385
  df = df.reindex(sorted(df.columns), axis=1)
385
386
  feature_columns = df.columns.values
@@ -33,7 +33,7 @@ class ClassWeights(Weights):
33
33
  def load(self, folder: str) -> None:
34
34
  pass
35
35
 
36
- def save(self, folder: str) -> None:
36
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
37
37
  pass
38
38
 
39
39
  def fit(
@@ -31,9 +31,9 @@ class CombinedWeights(Weights):
31
31
  for weights in self._weights:
32
32
  weights.load(folder)
33
33
 
34
- def save(self, folder: str) -> None:
34
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
35
35
  for weights in self._weights:
36
- weights.save(folder)
36
+ weights.save(folder, trial)
37
37
 
38
38
  def fit(
39
39
  self,
@@ -25,7 +25,7 @@ class ExponentialWeights(Weights):
25
25
  def load(self, folder: str) -> None:
26
26
  pass
27
27
 
28
- def save(self, folder: str) -> None:
28
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
29
29
  pass
30
30
 
31
31
  def fit(
@@ -25,7 +25,7 @@ class LinearWeights(Weights):
25
25
  def load(self, folder: str) -> None:
26
26
  pass
27
27
 
28
- def save(self, folder: str) -> None:
28
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
29
29
  pass
30
30
 
31
31
  def fit(
@@ -25,7 +25,7 @@ class NoopWeights(Weights):
25
25
  def load(self, folder: str) -> None:
26
26
  pass
27
27
 
28
- def save(self, folder: str) -> None:
28
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
29
29
  pass
30
30
 
31
31
  def fit(
@@ -26,7 +26,7 @@ class SigmoidWeights(Weights):
26
26
  def load(self, folder: str) -> None:
27
27
  pass
28
28
 
29
- def save(self, folder: str) -> None:
29
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
30
30
  pass
31
31
 
32
32
  def fit(
@@ -54,10 +54,11 @@ class WeightsRouter(Weights):
54
54
  weights = _WEIGHTS[params[_WEIGHTS_KEY]]()
55
55
  self._weights = weights
56
56
 
57
- def save(self, folder: str) -> None:
57
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
58
58
  weights = self._weights
59
59
  if weights is None:
60
60
  raise ValueError("weights is null")
61
+ weights.save(folder, trial)
61
62
  with open(
62
63
  os.path.join(folder, _WEIGHTS_ROUTER_FILE), "w", encoding="utf8"
63
64
  ) as handle:
@@ -36,7 +36,7 @@ class Windower(Params, Fit):
36
36
  params = json.load(handle)
37
37
  self._lookback = params[_LOOKBACK_KEY]
38
38
 
39
- def save(self, folder: str) -> None:
39
+ def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
40
40
  with open(os.path.join(folder, _WINDOWER_FILE), "w", encoding="utf8") as handle:
41
41
  json.dump(
42
42
  {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wavetrainer
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: A library for automatically finding the optimal model within feature and hyperparameter space.
5
5
  Home-page: https://github.com/8W9aG/wavetrainer
6
6
  Author: Will Sackfield
File without changes
File without changes
File without changes
File without changes