wavetrainer 0.1.11__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {wavetrainer-0.1.11/wavetrainer.egg-info → wavetrainer-0.1.13}/PKG-INFO +3 -1
  2. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/README.md +1 -0
  3. wavetrainer-0.1.11/wavetrainer.egg-info/requires.txt → wavetrainer-0.1.13/requirements.txt +1 -0
  4. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/setup.py +1 -1
  5. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/__init__.py +1 -1
  6. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/lightgbm/lightgbm_model.py +1 -1
  7. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/model_router.py +11 -2
  8. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/combined_reducer.py +1 -2
  9. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/trainer.py +82 -15
  10. {wavetrainer-0.1.11 → wavetrainer-0.1.13/wavetrainer.egg-info}/PKG-INFO +3 -1
  11. wavetrainer-0.1.11/requirements.txt → wavetrainer-0.1.13/wavetrainer.egg-info/requires.txt +2 -1
  12. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/LICENSE +0 -0
  13. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/MANIFEST.in +0 -0
  14. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/setup.cfg +0 -0
  15. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/tests/__init__.py +0 -0
  16. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/tests/model/__init__.py +0 -0
  17. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/tests/model/catboost_kwargs_test.py +0 -0
  18. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/tests/trainer_test.py +0 -0
  19. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/calibrator/__init__.py +0 -0
  20. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/calibrator/calibrator.py +0 -0
  21. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/calibrator/calibrator_router.py +0 -0
  22. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/calibrator/vennabers_calibrator.py +0 -0
  23. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/create.py +0 -0
  24. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/exceptions.py +0 -0
  25. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/fit.py +0 -0
  26. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/__init__.py +0 -0
  27. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/catboost/__init__.py +0 -0
  28. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/catboost/catboost_classifier_wrap.py +0 -0
  29. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/catboost/catboost_kwargs.py +0 -0
  30. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/catboost/catboost_model.py +0 -0
  31. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/catboost/catboost_regressor_wrap.py +0 -0
  32. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/lightgbm/__init__.py +0 -0
  33. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/model.py +0 -0
  34. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/tabpfn/__init__.py +0 -0
  35. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/tabpfn/tabpfn_model.py +0 -0
  36. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/xgboost/__init__.py +0 -0
  37. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/xgboost/early_stopper.py +0 -0
  38. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/xgboost/xgboost_logger.py +0 -0
  39. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model/xgboost/xgboost_model.py +0 -0
  40. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/model_type.py +0 -0
  41. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/params.py +0 -0
  42. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/__init__.py +0 -0
  43. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/base_selector_reducer.py +0 -0
  44. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/constant_reducer.py +0 -0
  45. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/correlation_reducer.py +0 -0
  46. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/duplicate_reducer.py +0 -0
  47. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/non_categorical_numeric_columns.py +0 -0
  48. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/nonnumeric_reducer.py +0 -0
  49. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/pca_reducer.py +0 -0
  50. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/reducer.py +0 -0
  51. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/select_by_single_feature_performance_reducer.py +0 -0
  52. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/smart_correlation_reducer.py +0 -0
  53. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/reducer/unseen_reducer.py +0 -0
  54. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/selector/__init__.py +0 -0
  55. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/selector/selector.py +0 -0
  56. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/__init__.py +0 -0
  57. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/class_weights.py +0 -0
  58. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/combined_weights.py +0 -0
  59. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/exponential_weights.py +0 -0
  60. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/linear_weights.py +0 -0
  61. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/noop_weights.py +0 -0
  62. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/sigmoid_weights.py +0 -0
  63. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/weights.py +0 -0
  64. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/weights/weights_router.py +0 -0
  65. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/windower/__init__.py +0 -0
  66. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer/windower/windower.py +0 -0
  67. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer.egg-info/SOURCES.txt +0 -0
  68. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer.egg-info/dependency_links.txt +0 -0
  69. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer.egg-info/not-zip-safe +0 -0
  70. {wavetrainer-0.1.11 → wavetrainer-0.1.13}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wavetrainer
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: A library for automatically finding the optimal model within feature and hyperparameter space.
5
5
  Home-page: https://github.com/8W9aG/wavetrainer
6
6
  Author: Will Sackfield
@@ -30,6 +30,7 @@ Requires-Dist: tabpfn_extensions>=0.0.4
30
30
  Requires-Dist: hyperopt>=0.2.7
31
31
  Requires-Dist: pycaleva>=0.8.2
32
32
  Requires-Dist: lightgbm>=4.6.0
33
+ Requires-Dist: kaleido>=0.2.1
33
34
 
34
35
  # wavetrainer
35
36
 
@@ -66,6 +67,7 @@ Python 3.11.6:
66
67
  - [hyperopt](https://github.com/hyperopt/hyperopt)
67
68
  - [pycaleva](https://github.com/MartinWeigl/pycaleva)
68
69
  - [lightgbm](https://github.com/microsoft/LightGBM)
70
+ - [kaleido](https://github.com/plotly/Kaleido)
69
71
 
70
72
  ## Raison D'être :thought_balloon:
71
73
 
@@ -33,6 +33,7 @@ Python 3.11.6:
33
33
  - [hyperopt](https://github.com/hyperopt/hyperopt)
34
34
  - [pycaleva](https://github.com/MartinWeigl/pycaleva)
35
35
  - [lightgbm](https://github.com/microsoft/LightGBM)
36
+ - [kaleido](https://github.com/plotly/Kaleido)
36
37
 
37
38
  ## Raison D'être :thought_balloon:
38
39
 
@@ -17,3 +17,4 @@ tabpfn_extensions>=0.0.4
17
17
  hyperopt>=0.2.7
18
18
  pycaleva>=0.8.2
19
19
  lightgbm>=4.6.0
20
+ kaleido>=0.2.1
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
23
23
 
24
24
  setup(
25
25
  name='wavetrainer',
26
- version='0.1.11',
26
+ version='0.1.13',
27
27
  description='A library for automatically finding the optimal model within feature and hyperparameter space.',
28
28
  long_description=long_description,
29
29
  long_description_content_type='text/markdown',
@@ -2,5 +2,5 @@
2
2
 
3
3
  from .create import create
4
4
 
5
- __VERSION__ = "0.1.11"
5
+ __VERSION__ = "0.1.13"
6
6
  __all__ = ("create",)
@@ -148,7 +148,7 @@ class LightGBMModel(Model):
148
148
 
149
149
  eval_set = None
150
150
  callbacks = []
151
- if eval_x is None or eval_y is None:
151
+ if eval_x is not None and eval_y is not None:
152
152
  eval_set = [(eval_x, eval_y.to_numpy().flatten())] # type: ignore
153
153
  callbacks = [
154
154
  lgb.early_stopping(stopping_rounds=early_stopping_rounds),
@@ -35,13 +35,18 @@ class ModelRouter(Model):
35
35
  _model: Model | None
36
36
  _false_positive_reduction_steps: int | None
37
37
 
38
- def __init__(self, allowed_models: set[str] | None) -> None:
38
+ def __init__(
39
+ self,
40
+ allowed_models: set[str] | None,
41
+ max_false_positive_reduction_steps: int | None,
42
+ ) -> None:
39
43
  super().__init__()
40
44
  self._model = None
41
45
  self._false_positive_reduction_steps = None
42
46
  self._allowed_models = (
43
47
  allowed_models if allowed_models is not None else set(_MODELS.keys())
44
48
  )
49
+ self._max_false_positive_reduction_steps = max_false_positive_reduction_steps
45
50
 
46
51
  @classmethod
47
52
  def name(cls) -> str:
@@ -93,7 +98,11 @@ class ModelRouter(Model):
93
98
  self, trial: optuna.Trial | optuna.trial.FrozenTrial, df: pd.DataFrame
94
99
  ) -> None:
95
100
  self._false_positive_reduction_steps = trial.suggest_int(
96
- _FALSE_POSITIVE_REDUCTION_STEPS_KEY, 0, 5
101
+ _FALSE_POSITIVE_REDUCTION_STEPS_KEY,
102
+ 0,
103
+ 5
104
+ if self._max_false_positive_reduction_steps is None
105
+ else self._max_false_positive_reduction_steps,
97
106
  )
98
107
  model_name = trial.suggest_categorical(
99
108
  "model",
@@ -2,7 +2,6 @@
2
2
 
3
3
  # pylint: disable=line-too-long
4
4
  import json
5
- import logging
6
5
  import os
7
6
  import time
8
7
  from typing import Self
@@ -129,6 +128,6 @@ class CombinedReducer(Reducer):
129
128
  try:
130
129
  df = reducer.transform(df)
131
130
  except ValueError as exc:
132
- logging.warning("Failed to reduce %s", reducer.name())
131
+ print("Failed to reduce %s", reducer.name())
133
132
  raise exc
134
133
  return df
@@ -1,5 +1,6 @@
1
1
  """The trainer class."""
2
2
 
3
+ # pylint: disable=line-too-long
3
4
  import datetime
4
5
  import functools
5
6
  import json
@@ -12,12 +13,14 @@ from typing import Self
12
13
  import optuna
13
14
  import pandas as pd
14
15
  import tqdm
15
- from sklearn.metrics import f1_score, r2_score # type: ignore
16
+ from sklearn.metrics import f1_score # type: ignore
17
+ from sklearn.metrics import (accuracy_score, brier_score_loss, log_loss,
18
+ precision_score, r2_score, recall_score)
16
19
 
17
20
  from .calibrator.calibrator_router import CalibratorRouter
18
21
  from .exceptions import WavetrainException
19
22
  from .fit import Fit
20
- from .model.model import PREDICTION_COLUMN
23
+ from .model.model import PREDICTION_COLUMN, PROBABILITY_COLUMN_PREFIX
21
24
  from .model.model_router import ModelRouter
22
25
  from .model_type import ModelType, determine_model_type
23
26
  from .reducer.combined_reducer import CombinedReducer
@@ -38,6 +41,7 @@ _TEST_SIZE_KEY = "test_size"
38
41
  _VALIDATION_SIZE_KEY = "validation_size"
39
42
  _IDX_USR_ATTR_KEY = "idx"
40
43
  _DT_COLUMN_KEY = "dt_column"
44
+ _MAX_FALSE_POSITIVE_REDUCTION_STEPS_KEY = "max_false_positive_reduction_steps"
41
45
  _BAD_OUTPUT = -1000.0
42
46
 
43
47
 
@@ -48,6 +52,11 @@ def _assign_bin(timestamp, bins: list[datetime.datetime]) -> int:
48
52
  return len(bins) - 2 # Assign to last bin if at the end
49
53
 
50
54
 
55
+ def _best_trial(study: optuna.Study) -> optuna.trial.FrozenTrial:
56
+ best_brier = min(study.best_trials, key=lambda t: t.values[1])
57
+ return best_brier
58
+
59
+
51
60
  class Trainer(Fit):
52
61
  """A class for training and predicting from an array of data."""
53
62
 
@@ -65,6 +74,7 @@ class Trainer(Fit):
65
74
  cutoff_dt: datetime.datetime | None = None,
66
75
  embedding_cols: list[list[str]] | None = None,
67
76
  allowed_models: set[str] | None = None,
77
+ max_false_positive_reduction_steps: int | None = None,
68
78
  ):
69
79
  tqdm.tqdm.pandas()
70
80
 
@@ -115,6 +125,10 @@ class Trainer(Fit):
115
125
  )
116
126
  if dt_column is None:
117
127
  dt_column = params[_DT_COLUMN_KEY]
128
+ if max_false_positive_reduction_steps is None:
129
+ max_false_positive_reduction_steps = params.get(
130
+ _MAX_FALSE_POSITIVE_REDUCTION_STEPS_KEY
131
+ )
118
132
  else:
119
133
  with open(params_file, "w", encoding="utf8") as handle:
120
134
  validation_size_value = None
@@ -145,6 +159,7 @@ class Trainer(Fit):
145
159
  _TEST_SIZE_KEY: test_size_value,
146
160
  _VALIDATION_SIZE_KEY: validation_size_value,
147
161
  _DT_COLUMN_KEY: dt_column,
162
+ _MAX_FALSE_POSITIVE_REDUCTION_STEPS_KEY: max_false_positive_reduction_steps,
148
163
  },
149
164
  handle,
150
165
  )
@@ -157,6 +172,7 @@ class Trainer(Fit):
157
172
  self._cutoff_dt = cutoff_dt
158
173
  self.embedding_cols = embedding_cols
159
174
  self._allowed_models = allowed_models
175
+ self._max_false_positive_reduction_steps = max_false_positive_reduction_steps
160
176
 
161
177
  def _provide_study(self, column: str) -> optuna.Study:
162
178
  storage_name = f"sqlite:///{self._folder}/{column}/{_STUDYDB_FILENAME}"
@@ -170,7 +186,10 @@ class Trainer(Fit):
170
186
  storage=storage_name,
171
187
  load_if_exists=True,
172
188
  sampler=restored_sampler,
173
- direction=optuna.study.StudyDirection.MAXIMIZE,
189
+ directions=[
190
+ optuna.study.StudyDirection.MAXIMIZE,
191
+ optuna.study.StudyDirection.MINIMIZE,
192
+ ],
174
193
  )
175
194
 
176
195
  def fit(
@@ -210,7 +229,7 @@ class Trainer(Fit):
210
229
  save: bool,
211
230
  split_idx: datetime.datetime,
212
231
  no_evaluation: bool,
213
- ) -> float:
232
+ ) -> tuple[float, float]:
214
233
  print(f"Beginning trial for: {split_idx.isoformat()}")
215
234
  trial.set_user_attr(_IDX_USR_ATTR_KEY, split_idx.isoformat())
216
235
  folder = os.path.join(
@@ -246,7 +265,7 @@ class Trainer(Fit):
246
265
  if new_folder:
247
266
  os.removedirs(folder)
248
267
  logging.warning("Y train only contains 1 unique datapoint.")
249
- return _BAD_OUTPUT
268
+ return _BAD_OUTPUT, -_BAD_OUTPUT
250
269
  print(f"Windowing took {time.time() - start_windower}")
251
270
 
252
271
  # Perform common reductions
@@ -267,7 +286,9 @@ class Trainer(Fit):
267
286
  print(f"Row weights took {time.time() - start_row_weights}")
268
287
 
269
288
  # Create model
270
- model = ModelRouter(self._allowed_models)
289
+ model = ModelRouter(
290
+ self._allowed_models, self._max_false_positive_reduction_steps
291
+ )
271
292
  model.set_options(trial, x)
272
293
 
273
294
  # Train
@@ -311,10 +332,29 @@ class Trainer(Fit):
311
332
  )
312
333
  cal_pred[PREDICTION_COLUMN] = y_pred[PREDICTION_COLUMN]
313
334
  output = 0.0
335
+ loss = 0.0
314
336
  if determine_model_type(y_series) == ModelType.REGRESSION:
315
337
  output = float(r2_score(y_test, y_pred[[PREDICTION_COLUMN]]))
338
+ print(f"R2: {output}")
316
339
  else:
317
340
  output = float(f1_score(y_test, y_pred[[PREDICTION_COLUMN]]))
341
+ print(f"F1: {output}")
342
+ prob_col = PROBABILITY_COLUMN_PREFIX + str(1)
343
+ if prob_col in y_pred.columns.values.tolist():
344
+ loss = float(brier_score_loss(y_test, y_pred[[prob_col]]))
345
+ print(f"Brier: {loss}")
346
+ print(
347
+ f"Log Loss: {float(log_loss(y_test.astype(float), y_pred[[prob_col]]))}"
348
+ )
349
+ print(
350
+ f"Accuracy: {float(accuracy_score(y_test, y_pred[[PREDICTION_COLUMN]]))}"
351
+ )
352
+ print(
353
+ f"Precision: {float(precision_score(y_test, y_pred[[PREDICTION_COLUMN]]))}"
354
+ )
355
+ print(
356
+ f"Recall: {float(recall_score(y_test, y_pred[[PREDICTION_COLUMN]]))}"
357
+ )
318
358
 
319
359
  if save:
320
360
  windower.save(folder, trial)
@@ -332,13 +372,13 @@ class Trainer(Fit):
332
372
  handle,
333
373
  )
334
374
 
335
- return output
375
+ return output, loss
336
376
  except WavetrainException as exc:
337
377
  print(str(exc))
338
378
  logging.warning(str(exc))
339
379
  if new_folder:
340
380
  os.removedirs(folder)
341
- return _BAD_OUTPUT
381
+ return _BAD_OUTPUT, -_BAD_OUTPUT
342
382
 
343
383
  start_validation_index = (
344
384
  dt_index.to_list()[-int(len(dt_index) * self._validation_size) - 1]
@@ -359,7 +399,7 @@ class Trainer(Fit):
359
399
  ].to_list()[0]
360
400
  )
361
401
 
362
- def test_objective(trial: optuna.Trial) -> float:
402
+ def test_objective(trial: optuna.Trial) -> tuple[float, float]:
363
403
  return _fit(
364
404
  trial,
365
405
  test_df,
@@ -382,7 +422,8 @@ class Trainer(Fit):
382
422
  else self._max_train_timeout.total_seconds(),
383
423
  )
384
424
  while (
385
- study.best_trial.value is None or study.best_trial.value == _BAD_OUTPUT
425
+ _best_trial(study).values is None
426
+ or _best_trial(study).values == (_BAD_OUTPUT, -_BAD_OUTPUT)
386
427
  ) and len(study.trials) < 1000:
387
428
  logging.info("Performing extra train")
388
429
  study.optimize(
@@ -420,7 +461,7 @@ class Trainer(Fit):
420
461
  if found:
421
462
  last_processed_dt = test_dt
422
463
  _fit(
423
- study.best_trial,
464
+ _best_trial(study),
424
465
  test_df.copy(),
425
466
  test_series,
426
467
  True,
@@ -441,7 +482,7 @@ class Trainer(Fit):
441
482
 
442
483
  def validate_objctive(
443
484
  trial: optuna.Trial, idx: datetime.datetime, series: pd.Series
444
- ) -> float:
485
+ ) -> tuple[float, float]:
445
486
  return _fit(trial, test_df.copy(), series, False, idx, False)
446
487
 
447
488
  study.optimize(
@@ -457,10 +498,36 @@ class Trainer(Fit):
457
498
  break
458
499
 
459
500
  _fit(
460
- study.best_trial, test_df.copy(), test_series, True, test_idx, True
501
+ _best_trial(study),
502
+ test_df.copy(),
503
+ test_series,
504
+ True,
505
+ test_idx,
506
+ True,
461
507
  )
462
508
  last_processed_dt = test_idx
463
509
 
510
+ target_names = ["F1", "Brier"]
511
+ fig = optuna.visualization.plot_pareto_front(
512
+ study, target_names=target_names
513
+ )
514
+ fig.write_image(
515
+ os.path.join(column_dir, "pareto_frontier.png"),
516
+ format="png",
517
+ width=800,
518
+ height=600,
519
+ )
520
+ for target_name in target_names:
521
+ fig = optuna.visualization.plot_param_importances(
522
+ study, target=lambda t: t.values[0], target_name=target_name
523
+ )
524
+ fig.write_image(
525
+ os.path.join(column_dir, f"{target_name}_frontier.png"),
526
+ format="png",
527
+ width=800,
528
+ height=600,
529
+ )
530
+
464
531
  if isinstance(y, pd.Series):
465
532
  _fit_column(y)
466
533
  else:
@@ -519,7 +586,7 @@ class Trainer(Fit):
519
586
  reducer = CombinedReducer(self.embedding_cols)
520
587
  reducer.load(folder)
521
588
 
522
- model = ModelRouter(None)
589
+ model = ModelRouter(None, None)
523
590
  model.load(folder)
524
591
 
525
592
  selector = Selector(model)
@@ -572,7 +639,7 @@ class Trainer(Fit):
572
639
  if not os.path.isdir(date_path):
573
640
  continue
574
641
  try:
575
- model = ModelRouter(None)
642
+ model = ModelRouter(None, None)
576
643
  model.load(date_path)
577
644
  feature_importances[date_str] = model.feature_importances
578
645
  except FileNotFoundError as exc:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wavetrainer
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: A library for automatically finding the optimal model within feature and hyperparameter space.
5
5
  Home-page: https://github.com/8W9aG/wavetrainer
6
6
  Author: Will Sackfield
@@ -30,6 +30,7 @@ Requires-Dist: tabpfn_extensions>=0.0.4
30
30
  Requires-Dist: hyperopt>=0.2.7
31
31
  Requires-Dist: pycaleva>=0.8.2
32
32
  Requires-Dist: lightgbm>=4.6.0
33
+ Requires-Dist: kaleido>=0.2.1
33
34
 
34
35
  # wavetrainer
35
36
 
@@ -66,6 +67,7 @@ Python 3.11.6:
66
67
  - [hyperopt](https://github.com/hyperopt/hyperopt)
67
68
  - [pycaleva](https://github.com/MartinWeigl/pycaleva)
68
69
  - [lightgbm](https://github.com/microsoft/LightGBM)
70
+ - [kaleido](https://github.com/plotly/Kaleido)
69
71
 
70
72
  ## Raison D'être :thought_balloon:
71
73
 
@@ -16,4 +16,5 @@ jax>=0.6.1
16
16
  tabpfn_extensions>=0.0.4
17
17
  hyperopt>=0.2.7
18
18
  pycaleva>=0.8.2
19
- lightgbm>=4.6.0
19
+ lightgbm>=4.6.0
20
+ kaleido>=0.2.1
File without changes
File without changes
File without changes