pynnlf 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pynnlf/__about__.py +1 -0
- pynnlf/__init__.py +5 -0
- pynnlf/api.py +17 -0
- pynnlf/discovery.py +63 -0
- pynnlf/engine.py +1238 -0
- pynnlf/hyperparams.py +38 -0
- pynnlf/model_utils.py +186 -0
- pynnlf/runner.py +108 -0
- pynnlf/scaffold/README_WORKSPACE.md +0 -0
- pynnlf/scaffold/data/README_data.md +40 -0
- pynnlf/scaffold/data/ds0_test.csv +4081 -0
- pynnlf/scaffold/models/README_models.md +61 -0
- pynnlf/scaffold/models/hyperparameters.yaml +264 -0
- pynnlf/scaffold/models/m10_rf.py +65 -0
- pynnlf/scaffold/models/m11_svr.py +53 -0
- pynnlf/scaffold/models/m12_rnn.py +152 -0
- pynnlf/scaffold/models/m13_lstm.py +208 -0
- pynnlf/scaffold/models/m14_gru.py +139 -0
- pynnlf/scaffold/models/m15_transformer.py +138 -0
- pynnlf/scaffold/models/m16_prophet.py +216 -0
- pynnlf/scaffold/models/m17_xgb.py +66 -0
- pynnlf/scaffold/models/m18_nbeats.py +107 -0
- pynnlf/scaffold/models/m1_naive.py +49 -0
- pynnlf/scaffold/models/m2_snaive.py +49 -0
- pynnlf/scaffold/models/m3_ets.py +133 -0
- pynnlf/scaffold/models/m4_arima.py +123 -0
- pynnlf/scaffold/models/m5_sarima.py +128 -0
- pynnlf/scaffold/models/m6_lr.py +76 -0
- pynnlf/scaffold/models/m7_ann.py +148 -0
- pynnlf/scaffold/models/m8_dnn.py +141 -0
- pynnlf/scaffold/models/m9_rt.py +74 -0
- pynnlf/scaffold/models/mXX_template.py +68 -0
- pynnlf/scaffold/specs/batch.yaml +4 -0
- pynnlf/scaffold/specs/experiment.yaml +4 -0
- pynnlf/scaffold/specs/pynnlf_config.yaml +69 -0
- pynnlf/scaffold/specs/testing_benchmark.csv +613 -0
- pynnlf/scaffold/specs/testing_benchmark_metadata.md +12 -0
- pynnlf/scaffold/specs/tests_ci.yaml +8 -0
- pynnlf/scaffold/specs/tests_full.yaml +23 -0
- pynnlf/tests_runner.py +211 -0
- pynnlf/tools/strip_notebook_artifacts.py +32 -0
- pynnlf/workspace.py +63 -0
- pynnlf/yamlio.py +28 -0
- pynnlf-0.2.2.dist-info/METADATA +168 -0
- pynnlf-0.2.2.dist-info/RECORD +47 -0
- pynnlf-0.2.2.dist-info/WHEEL +5 -0
- pynnlf-0.2.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
def train_model_mXX_template(hyperparameter, train_df_X, train_df_y, forecast_horizon=None):
|
|
8
|
+
"""
|
|
9
|
+
Train the model and return a model object.
|
|
10
|
+
|
|
11
|
+
Notes:
|
|
12
|
+
- Your model file name MUST start with an ID prefix: m19_*.py
|
|
13
|
+
- Your function names MUST match the file stem (without .py).
|
|
14
|
+
Example file: m19_my_model.py
|
|
15
|
+
Functions must be:
|
|
16
|
+
- train_model_m19_my_model(...)
|
|
17
|
+
- produce_forecast_m19_my_model(...)
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
hyperparameter (dict): hyperparameter dict for this run (from models/hyperparameters.yaml)
|
|
21
|
+
train_df_X (pd.DataFrame): predictors for training set
|
|
22
|
+
train_df_y (pd.DataFrame): target for training set, with column 'y'
|
|
23
|
+
forecast_horizon (int | None): forecast horizon in minutes (optional)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
object: model object (can be dict/class) used later in produce_forecast_*
|
|
27
|
+
"""
|
|
28
|
+
# Example: no training, just store a bias from hyperparameters
|
|
29
|
+
bias = float(hyperparameter.get("bias", 0.0))
|
|
30
|
+
model = {"bias": bias}
|
|
31
|
+
return model
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def produce_forecast_mXX_template(model, train_df_X, test_df_X, train_df_y=None, forecast_horizon=None):
|
|
35
|
+
"""
|
|
36
|
+
Produce forecasts for train and test sets.
|
|
37
|
+
|
|
38
|
+
Important:
|
|
39
|
+
- Return TWO outputs: (train_y_hat, test_y_hat)
|
|
40
|
+
- Each output can be:
|
|
41
|
+
- pd.Series indexed like train_df_X/test_df_X
|
|
42
|
+
- OR a 1D numpy array of matching length
|
|
43
|
+
(The engine will align/convert to Series.)
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model (object): model object returned by train_model_*
|
|
47
|
+
train_df_X (pd.DataFrame): predictors for training set
|
|
48
|
+
test_df_X (pd.DataFrame): predictors for test set
|
|
49
|
+
train_df_y (pd.DataFrame | None): target training set (optional; used by Prophet-like models)
|
|
50
|
+
forecast_horizon (int | None): minutes (optional)
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
tuple:
|
|
54
|
+
train_y_hat (pd.Series | np.ndarray): forecast for train
|
|
55
|
+
test_y_hat (pd.Series | np.ndarray): forecast for test
|
|
56
|
+
"""
|
|
57
|
+
# Example baseline: last-observation “naive” + bias
|
|
58
|
+
if forecast_horizon is None:
|
|
59
|
+
raise ValueError("forecast_horizon must be provided for this template.")
|
|
60
|
+
|
|
61
|
+
horizon = pd.Timedelta(minutes=int(forecast_horizon))
|
|
62
|
+
last_obs_col = f"y_lag_{horizon}m"
|
|
63
|
+
bias = float(model.get("bias", 0.0))
|
|
64
|
+
|
|
65
|
+
train_y_hat = train_df_X[last_obs_col] + bias
|
|
66
|
+
test_y_hat = test_df_X[last_obs_col] + bias
|
|
67
|
+
|
|
68
|
+
return train_y_hat, test_y_hat
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
paths:
|
|
2
|
+
data_dir: data
|
|
3
|
+
output_dir: experiment_result
|
|
4
|
+
hyperparameters_path: models/hyperparameters.yaml
|
|
5
|
+
|
|
6
|
+
dataset_download:
|
|
7
|
+
base_url: "https://raw.githubusercontent.com/mssamhan31/PyNNLF/main/data/"
|
|
8
|
+
|
|
9
|
+
datasets:
|
|
10
|
+
ds0: ds0_test.csv
|
|
11
|
+
ds1: ds1_ashd.csv
|
|
12
|
+
ds2: ds2_aedp_5min.csv
|
|
13
|
+
ds3: ds3_aedp_30min.csv
|
|
14
|
+
ds4: ds4_ashd_with_weather.csv
|
|
15
|
+
ds5: ds5_aedp_30min_with_weather.csv
|
|
16
|
+
ds6: ds6_aedp_cluster_5min.csv
|
|
17
|
+
ds7: ds7_aedp_cluster_30min.csv
|
|
18
|
+
ds8: ds8_aedp_cluster_30min_with_weather.csv
|
|
19
|
+
ds9: ds9_aedp_cluster2_5min.csv
|
|
20
|
+
ds10: ds10_aedp_cluster2_30min.csv
|
|
21
|
+
ds11: ds11_aedp_cluster2_30min_with_weather.csv
|
|
22
|
+
ds12: ds12_ashd_with_cloud_bom.csv
|
|
23
|
+
ds13: ds13_ashd_with_cloud_solcast.csv
|
|
24
|
+
ds14: ds14_ausgrid_zs_mascot.csv
|
|
25
|
+
ds15: ds15_ausgrid_zs_mascot_30min_with_weather.csv
|
|
26
|
+
|
|
27
|
+
forecast_horizons:
|
|
28
|
+
fh1: 30
|
|
29
|
+
fh2: 60
|
|
30
|
+
fh3: 120
|
|
31
|
+
fh4: 180
|
|
32
|
+
fh5: 240
|
|
33
|
+
fh6: 300
|
|
34
|
+
fh7: 360
|
|
35
|
+
fh8: 1440
|
|
36
|
+
fh9: 2880
|
|
37
|
+
fh10: 10080
|
|
38
|
+
fh11: 43200
|
|
39
|
+
|
|
40
|
+
models:
|
|
41
|
+
m1: m1_naive
|
|
42
|
+
m2: m2_snaive
|
|
43
|
+
m3: m3_ets
|
|
44
|
+
m4: m4_arima
|
|
45
|
+
m5: m5_sarima
|
|
46
|
+
m6: m6_lr
|
|
47
|
+
m7: m7_ann
|
|
48
|
+
m8: m8_dnn
|
|
49
|
+
m9: m9_rt
|
|
50
|
+
m10: m10_rf
|
|
51
|
+
m11: m11_svr
|
|
52
|
+
m12: m12_rnn
|
|
53
|
+
m13: m13_lstm
|
|
54
|
+
m14: m14_gru
|
|
55
|
+
m15: m15_transformer
|
|
56
|
+
m16: m16_prophet
|
|
57
|
+
m17: m17_xgb
|
|
58
|
+
m18: m18_nbeats
|
|
59
|
+
|
|
60
|
+
cv:
|
|
61
|
+
k: 10
|
|
62
|
+
max_lag_day: 7
|
|
63
|
+
|
|
64
|
+
plot:
|
|
65
|
+
enabled: true
|
|
66
|
+
colors:
|
|
67
|
+
dark_blue: "#22303D"
|
|
68
|
+
orange: "#EB932C"
|
|
69
|
+
font_family: Arial
|