wavetrainer 0.0.26__tar.gz → 0.0.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavetrainer-0.0.26/wavetrainer.egg-info → wavetrainer-0.0.27}/PKG-INFO +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/setup.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/__init__.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/calibrator_router.py +2 -2
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/mapie_calibrator.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/vennabers_calibrator.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_model.py +25 -11
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/model_router.py +2 -2
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/tabpfn_model.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/params.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/base_selector_reducer.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/combined_reducer.py +2 -2
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/nonnumeric_reducer.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/pca_reducer.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/unseen_reducer.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/selector/selector.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/trainer.py +7 -6
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/class_weights.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/combined_weights.py +2 -2
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/exponential_weights.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/linear_weights.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/noop_weights.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/sigmoid_weights.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/weights_router.py +2 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/windower/windower.py +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27/wavetrainer.egg-info}/PKG-INFO +1 -1
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/LICENSE +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/MANIFEST.in +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/README.md +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/requirements.txt +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/setup.cfg +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/model/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/model/catboost_kwargs_test.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/tests/trainer_test.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/calibrator/calibrator.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/create.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/exceptions.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/fit.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_classifier_wrap.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_kwargs.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/catboost_regressor_wrap.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model/model.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/model_type.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/constant_reducer.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/correlation_reducer.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/duplicate_reducer.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/reducer/reducer.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/selector/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/weights/weights.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer/windower/__init__.py +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/SOURCES.txt +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/dependency_links.txt +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/not-zip-safe +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/requires.txt +0 -0
- {wavetrainer-0.0.26 → wavetrainer-0.0.27}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
|
|
23
23
|
|
24
24
|
setup(
|
25
25
|
name='wavetrainer',
|
26
|
-
version='0.0.
|
26
|
+
version='0.0.27',
|
27
27
|
description='A library for automatically finding the optimal model within feature and hyperparameter space.',
|
28
28
|
long_description=long_description,
|
29
29
|
long_description_content_type='text/markdown',
|
@@ -48,11 +48,11 @@ class CalibratorRouter(Calibrator):
|
|
48
48
|
calibrator.load(folder)
|
49
49
|
self._calibrator = calibrator
|
50
50
|
|
51
|
-
def save(self, folder: str) -> None:
|
51
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
52
52
|
calibrator = self._calibrator
|
53
53
|
if calibrator is None:
|
54
54
|
raise ValueError("calibrator is null.")
|
55
|
-
calibrator.save(folder)
|
55
|
+
calibrator.save(folder, trial)
|
56
56
|
with open(
|
57
57
|
os.path.join(folder, _CALIBRATOR_ROUTER_FILE), "w", encoding="utf8"
|
58
58
|
) as handle:
|
@@ -35,7 +35,7 @@ class MAPIECalibrator(Calibrator):
|
|
35
35
|
def load(self, folder: str) -> None:
|
36
36
|
self._mapie = joblib.load(os.path.join(folder, _CALIBRATOR_FILENAME))
|
37
37
|
|
38
|
-
def save(self, folder: str) -> None:
|
38
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
39
39
|
joblib.dump(self._mapie, os.path.join(folder, _CALIBRATOR_FILENAME))
|
40
40
|
|
41
41
|
def fit(
|
@@ -33,7 +33,7 @@ class VennabersCalibrator(Calibrator):
|
|
33
33
|
def load(self, folder: str) -> None:
|
34
34
|
self._vennabers = joblib.load(os.path.join(folder, _CALIBRATOR_FILENAME))
|
35
35
|
|
36
|
-
def save(self, folder: str) -> None:
|
36
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
37
37
|
joblib.dump(self._vennabers, os.path.join(folder, _CALIBRATOR_FILENAME))
|
38
38
|
|
39
39
|
def fit(
|
@@ -26,6 +26,7 @@ _L2_LEAF_REG_KEY = "l2_leaf_reg"
|
|
26
26
|
_BOOSTING_TYPE_KEY = "boosting_type"
|
27
27
|
_MODEL_TYPE_KEY = "model_type"
|
28
28
|
_EARLY_STOPPING_ROUNDS = "early_stopping_rounds"
|
29
|
+
_BEST_ITERATION_KEY = "best_iteration"
|
29
30
|
|
30
31
|
|
31
32
|
class CatboostModel(Model):
|
@@ -41,6 +42,7 @@ class CatboostModel(Model):
|
|
41
42
|
_boosting_type: None | str
|
42
43
|
_model_type: None | ModelType
|
43
44
|
_early_stopping_rounds: None | int
|
45
|
+
_best_iteration: None | int
|
44
46
|
|
45
47
|
@classmethod
|
46
48
|
def name(cls) -> str:
|
@@ -56,6 +58,7 @@ class CatboostModel(Model):
|
|
56
58
|
self._boosting_type = None
|
57
59
|
self._model_type = None
|
58
60
|
self._early_stopping_rounds = None
|
61
|
+
self._best_iteration = None
|
59
62
|
|
60
63
|
@property
|
61
64
|
def estimator(self) -> Any:
|
@@ -92,6 +95,7 @@ class CatboostModel(Model):
|
|
92
95
|
_BOOSTING_TYPE_KEY, ["Ordered", "Plain"]
|
93
96
|
)
|
94
97
|
self._early_stopping_rounds = trial.suggest_int(_EARLY_STOPPING_ROUNDS, 10, 500)
|
98
|
+
self._best_iteration = trial.user_attrs.get(_BEST_ITERATION_KEY)
|
95
99
|
|
96
100
|
def load(self, folder: str) -> None:
|
97
101
|
with open(
|
@@ -105,10 +109,11 @@ class CatboostModel(Model):
|
|
105
109
|
self._boosting_type = params[_BOOSTING_TYPE_KEY]
|
106
110
|
self._model_type = ModelType(params[_MODEL_TYPE_KEY])
|
107
111
|
self._early_stopping_rounds = params[_EARLY_STOPPING_ROUNDS]
|
112
|
+
self._best_iteration = params.get(_BEST_ITERATION_KEY)
|
108
113
|
catboost = self._provide_catboost()
|
109
114
|
catboost.load_model(os.path.join(folder, _MODEL_FILENAME))
|
110
115
|
|
111
|
-
def save(self, folder: str) -> None:
|
116
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
112
117
|
with open(
|
113
118
|
os.path.join(folder, _MODEL_PARAMS_FILENAME), "w", encoding="utf8"
|
114
119
|
) as handle:
|
@@ -121,11 +126,13 @@ class CatboostModel(Model):
|
|
121
126
|
_BOOSTING_TYPE_KEY: self._boosting_type,
|
122
127
|
_MODEL_TYPE_KEY: str(self._model_type),
|
123
128
|
_EARLY_STOPPING_ROUNDS: self._early_stopping_rounds,
|
129
|
+
_BEST_ITERATION_KEY: self._best_iteration,
|
124
130
|
},
|
125
131
|
handle,
|
126
132
|
)
|
127
133
|
catboost = self._provide_catboost()
|
128
134
|
catboost.save_model(os.path.join(folder, _MODEL_FILENAME))
|
135
|
+
trial.user_attrs[_BEST_ITERATION_KEY] = self._best_iteration
|
129
136
|
|
130
137
|
def fit(
|
131
138
|
self,
|
@@ -137,8 +144,6 @@ class CatboostModel(Model):
|
|
137
144
|
) -> Self:
|
138
145
|
if y is None:
|
139
146
|
raise ValueError("y is null.")
|
140
|
-
if eval_x is None:
|
141
|
-
raise ValueError("eval_x is null.")
|
142
147
|
self._model_type = determine_model_type(y)
|
143
148
|
catboost = self._provide_catboost()
|
144
149
|
|
@@ -148,10 +153,14 @@ class CatboostModel(Model):
|
|
148
153
|
weight=w,
|
149
154
|
cat_features=df.select_dtypes(include="category").columns.tolist(),
|
150
155
|
)
|
151
|
-
eval_pool =
|
152
|
-
|
153
|
-
|
154
|
-
|
156
|
+
eval_pool = (
|
157
|
+
Pool(
|
158
|
+
eval_x,
|
159
|
+
label=eval_y,
|
160
|
+
cat_features=eval_x.select_dtypes(include="category").columns.tolist(),
|
161
|
+
)
|
162
|
+
if eval_x is not None
|
163
|
+
else None
|
155
164
|
)
|
156
165
|
catboost.fit(
|
157
166
|
train_pool,
|
@@ -162,6 +171,7 @@ class CatboostModel(Model):
|
|
162
171
|
)
|
163
172
|
importances = catboost.get_feature_importance(prettified=True)
|
164
173
|
logging.info("Importances:\n%s", importances)
|
174
|
+
self._best_iteration = catboost.get_best_iteration()
|
165
175
|
return self
|
166
176
|
|
167
177
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
@@ -186,10 +196,14 @@ class CatboostModel(Model):
|
|
186
196
|
def _provide_catboost(self) -> CatBoost:
|
187
197
|
catboost = self._catboost
|
188
198
|
if catboost is None:
|
199
|
+
best_iteration = self._best_iteration
|
200
|
+
iterations = (
|
201
|
+
best_iteration if best_iteration is not None else self._iterations
|
202
|
+
)
|
189
203
|
match self._model_type:
|
190
204
|
case ModelType.BINARY:
|
191
205
|
catboost = CatBoostClassifierWrapper(
|
192
|
-
iterations=
|
206
|
+
iterations=iterations,
|
193
207
|
learning_rate=self._learning_rate,
|
194
208
|
depth=self._depth,
|
195
209
|
l2_leaf_reg=self._l2_leaf_reg,
|
@@ -201,7 +215,7 @@ class CatboostModel(Model):
|
|
201
215
|
)
|
202
216
|
case ModelType.REGRESSION:
|
203
217
|
catboost = CatBoostRegressorWrapper(
|
204
|
-
iterations=
|
218
|
+
iterations=iterations,
|
205
219
|
learning_rate=self._learning_rate,
|
206
220
|
depth=self._depth,
|
207
221
|
l2_leaf_reg=self._l2_leaf_reg,
|
@@ -213,7 +227,7 @@ class CatboostModel(Model):
|
|
213
227
|
)
|
214
228
|
case ModelType.BINNED_BINARY:
|
215
229
|
catboost = CatBoostClassifierWrapper(
|
216
|
-
iterations=
|
230
|
+
iterations=iterations,
|
217
231
|
learning_rate=self._learning_rate,
|
218
232
|
depth=self._depth,
|
219
233
|
l2_leaf_reg=self._l2_leaf_reg,
|
@@ -225,7 +239,7 @@ class CatboostModel(Model):
|
|
225
239
|
)
|
226
240
|
case ModelType.MULTI_CLASSIFICATION:
|
227
241
|
catboost = CatBoostClassifierWrapper(
|
228
|
-
iterations=
|
242
|
+
iterations=iterations,
|
229
243
|
learning_rate=self._learning_rate,
|
230
244
|
depth=self._depth,
|
231
245
|
l2_leaf_reg=self._l2_leaf_reg,
|
@@ -73,11 +73,11 @@ class ModelRouter(Model):
|
|
73
73
|
model.load(folder)
|
74
74
|
self._model = model
|
75
75
|
|
76
|
-
def save(self, folder: str) -> None:
|
76
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
77
77
|
model = self._model
|
78
78
|
if model is None:
|
79
79
|
raise ValueError("model is null")
|
80
|
-
model.save(folder)
|
80
|
+
model.save(folder, trial)
|
81
81
|
with open(
|
82
82
|
os.path.join(folder, _MODEL_ROUTER_FILE), "w", encoding="utf8"
|
83
83
|
) as handle:
|
@@ -69,7 +69,7 @@ class TabPFNModel(Model):
|
|
69
69
|
params = json.load(handle)
|
70
70
|
self._model_type = ModelType(params[_MODEL_TYPE_KEY])
|
71
71
|
|
72
|
-
def save(self, folder: str) -> None:
|
72
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
73
73
|
with open(os.path.join(folder, _MODEL_FILENAME), "wb") as f:
|
74
74
|
pickle.dump(self._tabpfn, f)
|
75
75
|
with open(
|
@@ -14,6 +14,6 @@ class Params:
|
|
14
14
|
"""Loads the objects from a folder."""
|
15
15
|
raise NotImplementedError("load not implemented in parent class.")
|
16
16
|
|
17
|
-
def save(self, folder: str) -> None:
|
17
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
18
18
|
"""Saves the objects into a folder."""
|
19
19
|
raise NotImplementedError("save not implemented in parent class.")
|
@@ -39,7 +39,7 @@ class BaseSelectorReducer(Reducer):
|
|
39
39
|
file_path = os.path.join(folder, self._file_name)
|
40
40
|
self._base_selector = joblib.load(file_path)
|
41
41
|
|
42
|
-
def save(self, folder: str) -> None:
|
42
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
43
43
|
file_path = os.path.join(folder, self._file_name)
|
44
44
|
joblib.dump(self._base_selector, file_path)
|
45
45
|
|
@@ -67,7 +67,7 @@ class CombinedReducer(Reducer):
|
|
67
67
|
for reducer in self._reducers:
|
68
68
|
reducer.load(folder)
|
69
69
|
|
70
|
-
def save(self, folder: str) -> None:
|
70
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
71
71
|
with open(
|
72
72
|
os.path.join(folder, _COMBINED_REDUCER_FILE), "w", encoding="utf8"
|
73
73
|
) as handle:
|
@@ -78,7 +78,7 @@ class CombinedReducer(Reducer):
|
|
78
78
|
handle,
|
79
79
|
)
|
80
80
|
for reducer in self._reducers:
|
81
|
-
reducer.save(folder)
|
81
|
+
reducer.save(folder, trial)
|
82
82
|
|
83
83
|
def fit(
|
84
84
|
self,
|
@@ -45,7 +45,7 @@ class PCAReducer(Reducer):
|
|
45
45
|
if os.path.exists(pca_file):
|
46
46
|
self._pca = joblib.load(pca_file)
|
47
47
|
|
48
|
-
def save(self, folder: str) -> None:
|
48
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
49
49
|
if self._scaler is not None:
|
50
50
|
joblib.dump(self._scaler, os.path.join(folder, _PCA_SCALER_FILE))
|
51
51
|
if self._pca is not None:
|
@@ -34,7 +34,7 @@ class UnseenReducer(Reducer):
|
|
34
34
|
) as handle:
|
35
35
|
self._seen_features = json.load(handle)
|
36
36
|
|
37
|
-
def save(self, folder: str) -> None:
|
37
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
38
38
|
with open(
|
39
39
|
os.path.join(folder, _UNSEEN_REDUCER_FILE), "w", encoding="utf8"
|
40
40
|
) as handle:
|
@@ -38,7 +38,7 @@ class Selector(Params, Fit):
|
|
38
38
|
def load(self, folder: str) -> None:
|
39
39
|
self._selector = joblib.load(os.path.join(folder, _SELECTOR_FILE))
|
40
40
|
|
41
|
-
def save(self, folder: str) -> None:
|
41
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
42
42
|
joblib.dump(self._selector, os.path.join(folder, _SELECTOR_FILE))
|
43
43
|
|
44
44
|
def fit(
|
@@ -258,12 +258,12 @@ class Trainer(Fit):
|
|
258
258
|
)
|
259
259
|
if not os.path.exists(folder):
|
260
260
|
os.mkdir(folder)
|
261
|
-
windower.save(folder)
|
262
|
-
reducer.save(folder)
|
263
|
-
weights.save(folder)
|
264
|
-
model.save(folder)
|
265
|
-
selector.save(folder)
|
266
|
-
calibrator.save(folder)
|
261
|
+
windower.save(folder, trial)
|
262
|
+
reducer.save(folder, trial)
|
263
|
+
weights.save(folder, trial)
|
264
|
+
model.save(folder, trial)
|
265
|
+
selector.save(folder, trial)
|
266
|
+
calibrator.save(folder, trial)
|
267
267
|
|
268
268
|
y_pred = model.transform(x_test)
|
269
269
|
y_pred = calibrator.transform(y_pred)
|
@@ -380,6 +380,7 @@ class Trainer(Fit):
|
|
380
380
|
|
381
381
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
382
382
|
"""Predict the expected values of the data."""
|
383
|
+
tqdm.tqdm.pandas(desc="Inferring...")
|
383
384
|
input_df = df.copy()
|
384
385
|
df = df.reindex(sorted(df.columns), axis=1)
|
385
386
|
feature_columns = df.columns.values
|
@@ -31,9 +31,9 @@ class CombinedWeights(Weights):
|
|
31
31
|
for weights in self._weights:
|
32
32
|
weights.load(folder)
|
33
33
|
|
34
|
-
def save(self, folder: str) -> None:
|
34
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
35
35
|
for weights in self._weights:
|
36
|
-
weights.save(folder)
|
36
|
+
weights.save(folder, trial)
|
37
37
|
|
38
38
|
def fit(
|
39
39
|
self,
|
@@ -54,10 +54,11 @@ class WeightsRouter(Weights):
|
|
54
54
|
weights = _WEIGHTS[params[_WEIGHTS_KEY]]()
|
55
55
|
self._weights = weights
|
56
56
|
|
57
|
-
def save(self, folder: str) -> None:
|
57
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
58
58
|
weights = self._weights
|
59
59
|
if weights is None:
|
60
60
|
raise ValueError("weights is null")
|
61
|
+
weights.save(folder, trial)
|
61
62
|
with open(
|
62
63
|
os.path.join(folder, _WEIGHTS_ROUTER_FILE), "w", encoding="utf8"
|
63
64
|
) as handle:
|
@@ -36,7 +36,7 @@ class Windower(Params, Fit):
|
|
36
36
|
params = json.load(handle)
|
37
37
|
self._lookback = params[_LOOKBACK_KEY]
|
38
38
|
|
39
|
-
def save(self, folder: str) -> None:
|
39
|
+
def save(self, folder: str, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
40
40
|
with open(os.path.join(folder, _WINDOWER_FILE), "w", encoding="utf8") as handle:
|
41
41
|
json.dump(
|
42
42
|
{
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|