wavetrainer 0.0.28__tar.gz → 0.0.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavetrainer-0.0.28/wavetrainer.egg-info → wavetrainer-0.0.29}/PKG-INFO +1 -1
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/setup.py +1 -1
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/__init__.py +1 -1
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/trainer.py +34 -11
- {wavetrainer-0.0.28 → wavetrainer-0.0.29/wavetrainer.egg-info}/PKG-INFO +1 -1
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/LICENSE +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/MANIFEST.in +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/README.md +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/requirements.txt +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/setup.cfg +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/tests/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/tests/model/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/tests/model/catboost_kwargs_test.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/tests/trainer_test.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/calibrator/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/calibrator/calibrator.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/calibrator/calibrator_router.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/calibrator/mapie_calibrator.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/calibrator/vennabers_calibrator.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/create.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/exceptions.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/fit.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/catboost_classifier_wrap.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/catboost_kwargs.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/catboost_model.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/catboost_regressor_wrap.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/model.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/model_router.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model/tabpfn_model.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/model_type.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/params.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/base_selector_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/combined_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/constant_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/correlation_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/duplicate_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/nonnumeric_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/smart_correlation_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/reducer/unseen_reducer.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/selector/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/selector/selector.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/class_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/combined_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/exponential_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/linear_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/noop_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/sigmoid_weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/weights.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/weights/weights_router.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/windower/__init__.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer/windower/windower.py +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer.egg-info/SOURCES.txt +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer.egg-info/dependency_links.txt +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer.egg-info/not-zip-safe +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer.egg-info/requires.txt +0 -0
- {wavetrainer-0.0.28 → wavetrainer-0.0.29}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
|
|
23
23
|
|
24
24
|
setup(
|
25
25
|
name='wavetrainer',
|
26
|
-
version='0.0.
|
26
|
+
version='0.0.29',
|
27
27
|
description='A library for automatically finding the optimal model within feature and hyperparameter space.',
|
28
28
|
long_description=long_description,
|
29
29
|
long_description_content_type='text/markdown',
|
@@ -28,6 +28,7 @@ from .windower.windower import Windower
|
|
28
28
|
_SAMPLER_FILENAME = "sampler.pkl"
|
29
29
|
_STUDYDB_FILENAME = "study.db"
|
30
30
|
_PARAMS_FILENAME = "params.json"
|
31
|
+
_TRIAL_FILENAME = "trial.json"
|
31
32
|
_TRIALS_KEY = "trials"
|
32
33
|
_WALKFORWARD_TIMEDELTA_KEY = "walkforward_timedelta"
|
33
34
|
_DAYS_KEY = "days"
|
@@ -198,6 +199,20 @@ class Trainer(Fit):
|
|
198
199
|
) -> float:
|
199
200
|
print(f"Beginning trial for: {split_idx.isoformat()}")
|
200
201
|
trial.set_user_attr(_IDX_USR_ATTR_KEY, split_idx.isoformat())
|
202
|
+
folder = os.path.join(
|
203
|
+
self._folder, str(y_series.name), split_idx.isoformat()
|
204
|
+
)
|
205
|
+
os.makedirs(folder, exist_ok=True)
|
206
|
+
trial_file = os.path.join(folder, _TRIAL_FILENAME)
|
207
|
+
if os.path.exists(trial_file):
|
208
|
+
with open(trial_file, encoding="utf8") as handle:
|
209
|
+
trial_info = json.load(handle)
|
210
|
+
if trial_info["number"] == trial.number:
|
211
|
+
logging.info(
|
212
|
+
"Found trial %d previously executed, skipping...",
|
213
|
+
trial.number,
|
214
|
+
)
|
215
|
+
return trial_info["output"]
|
201
216
|
|
202
217
|
train_dt_index = dt_index[: len(x)]
|
203
218
|
x_train = x[train_dt_index < split_idx] # type: ignore
|
@@ -247,24 +262,32 @@ class Trainer(Fit):
|
|
247
262
|
calibrator.set_options(trial, x)
|
248
263
|
calibrator.fit(x_pred, y=y_train)
|
249
264
|
|
265
|
+
# Output
|
266
|
+
y_pred = model.transform(x_test)
|
267
|
+
y_pred = calibrator.transform(y_pred)
|
268
|
+
output = 0.0
|
269
|
+
if determine_model_type(y_series) == ModelType.REGRESSION:
|
270
|
+
output = float(r2_score(y_test, y_pred[[PREDICTION_COLUMN]]))
|
271
|
+
else:
|
272
|
+
output = float(f1_score(y_test, y_pred[[PREDICTION_COLUMN]]))
|
273
|
+
|
250
274
|
if save:
|
251
|
-
folder = os.path.join(
|
252
|
-
self._folder, str(y_series.name), split_idx.isoformat()
|
253
|
-
)
|
254
|
-
if not os.path.exists(folder):
|
255
|
-
os.mkdir(folder)
|
256
275
|
windower.save(folder, trial)
|
257
276
|
reducer.save(folder, trial)
|
258
277
|
weights.save(folder, trial)
|
259
278
|
model.save(folder, trial)
|
260
279
|
selector.save(folder, trial)
|
261
280
|
calibrator.save(folder, trial)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
281
|
+
with open(trial_file, "w", encoding="utf8") as handle:
|
282
|
+
json.dump(
|
283
|
+
{
|
284
|
+
"number": trial.number,
|
285
|
+
"output": output,
|
286
|
+
},
|
287
|
+
handle,
|
288
|
+
)
|
289
|
+
|
290
|
+
return output
|
268
291
|
except WavetrainException as exc:
|
269
292
|
logging.warning(str(exc))
|
270
293
|
return -1.0
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|