PVNet_summation 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PVNet_summation might be problematic. Click here for more details.
- pvnet_summation/data/datamodule.py +127 -28
- pvnet_summation/load_model.py +5 -1
- pvnet_summation/models/base_model.py +14 -3
- pvnet_summation/models/horizon_dense_model.py +171 -0
- pvnet_summation/training/train.py +38 -7
- pvnet_summation/utils.py +51 -6
- {pvnet_summation-1.0.1.dist-info → pvnet_summation-1.1.0.dist-info}/METADATA +3 -3
- pvnet_summation-1.1.0.dist-info/RECORD +19 -0
- pvnet_summation-1.0.1.dist-info/RECORD +0 -18
- {pvnet_summation-1.0.1.dist-info → pvnet_summation-1.1.0.dist-info}/WHEEL +0 -0
- {pvnet_summation-1.0.1.dist-info → pvnet_summation-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {pvnet_summation-1.0.1.dist-info → pvnet_summation-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Pytorch lightning datamodules for loading pre-saved samples and predictions."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from glob import glob
|
|
4
5
|
from typing import TypeAlias
|
|
5
6
|
|
|
@@ -7,16 +8,68 @@ import numpy as np
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
import torch
|
|
9
10
|
from lightning.pytorch import LightningDataModule
|
|
10
|
-
from ocf_data_sampler.load.gsp import open_gsp
|
|
11
|
-
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
|
|
11
|
+
from ocf_data_sampler.load.gsp import get_gsp_boundaries, open_gsp
|
|
12
|
+
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
|
|
13
|
+
from ocf_data_sampler.numpy_sample.sun_position import calculate_azimuth_and_elevation
|
|
14
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
12
15
|
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
|
|
13
16
|
from ocf_data_sampler.utils import minutes
|
|
14
|
-
from torch.utils.data import DataLoader, Dataset, default_collate
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset, Subset, default_collate
|
|
15
18
|
from typing_extensions import override
|
|
16
19
|
|
|
17
20
|
SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
|
|
18
21
|
SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
|
|
19
22
|
|
|
23
|
+
def get_gb_centroid_lon_lat() -> tuple[float, float]:
|
|
24
|
+
"""Get the longitude and latitude of the centroid of Great Britain"""
|
|
25
|
+
row = get_gsp_boundaries("20250109").loc[0]
|
|
26
|
+
x_osgb = row.x_osgb.item()
|
|
27
|
+
y_osgb = row.y_osgb.item()
|
|
28
|
+
return osgb_to_lon_lat(x_osgb, y_osgb)
|
|
29
|
+
|
|
30
|
+
LON, LAT = get_gb_centroid_lon_lat()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def construct_sample(
|
|
34
|
+
pvnet_inputs: NumpyBatch,
|
|
35
|
+
valid_times: pd.DatetimeIndex,
|
|
36
|
+
relative_capacities: np.ndarray,
|
|
37
|
+
target: np.ndarray | None,
|
|
38
|
+
last_outturn: float | None = None,
|
|
39
|
+
) -> SumNumpySample:
|
|
40
|
+
"""Construct an input sample for the summation model
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
pvnet_inputs: The PVNet batch for all GSPs
|
|
44
|
+
valid_times: An array of valid-times for the forecast
|
|
45
|
+
relative_capacities: Array of capacities of all GSPs normalised by the total capacity
|
|
46
|
+
target: The target national outturn. This is only needed during training.
|
|
47
|
+
last_outturn: The previous national outturn. This is only needed during training.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
azimuth, elevation = calculate_azimuth_and_elevation(valid_times, LON, LAT)
|
|
51
|
+
|
|
52
|
+
sample = {
|
|
53
|
+
# NumpyBatch object with batch size = num_locations
|
|
54
|
+
"pvnet_inputs": pvnet_inputs,
|
|
55
|
+
# Shape: [time]
|
|
56
|
+
"valid_times": valid_times.values.astype(int),
|
|
57
|
+
# Shape: [num_locations]
|
|
58
|
+
"relative_capacity": relative_capacities,
|
|
59
|
+
# Shape: [time]
|
|
60
|
+
"azimuth": azimuth.astype(np.float32) / 360,
|
|
61
|
+
# Shape: [time]
|
|
62
|
+
"elevation": elevation.astype(np.float32) / 180 + 0.5,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if target is not None:
|
|
66
|
+
# Shape: [time]
|
|
67
|
+
sample["target"] = target
|
|
68
|
+
if last_outturn is not None:
|
|
69
|
+
# Shape: scalar
|
|
70
|
+
sample["last_outturn"] = last_outturn
|
|
71
|
+
return sample
|
|
72
|
+
|
|
20
73
|
|
|
21
74
|
class StreamedDataset(PVNetUKConcurrentDataset):
|
|
22
75
|
"""A torch dataset for creating concurrent PVNet inputs and national targets."""
|
|
@@ -37,7 +90,7 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
37
90
|
super().__init__(config_filename, start_time, end_time, gsp_ids=None)
|
|
38
91
|
|
|
39
92
|
# Load and nornmalise the national GSP data to use as target values
|
|
40
|
-
national_gsp_data = (
|
|
93
|
+
self.national_gsp_data = (
|
|
41
94
|
open_gsp(
|
|
42
95
|
zarr_path=self.config.input_data.gsp.zarr_path,
|
|
43
96
|
boundaries_version=self.config.input_data.gsp.boundaries_version
|
|
@@ -45,8 +98,6 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
45
98
|
.sel(gsp_id=0)
|
|
46
99
|
.compute()
|
|
47
100
|
)
|
|
48
|
-
self.national_gsp_data = national_gsp_data / national_gsp_data.effective_capacity_mwp
|
|
49
|
-
|
|
50
101
|
|
|
51
102
|
def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
52
103
|
"""Generate a concurrent PVNet sample for given init-time.
|
|
@@ -55,33 +106,32 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
55
106
|
t0: init-time for sample
|
|
56
107
|
"""
|
|
57
108
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
109
|
+
# Get the PVNet input batch
|
|
110
|
+
pvnet_inputs: NumpyBatch = super()._get_sample(t0)
|
|
61
111
|
|
|
112
|
+
# Construct an array of valid times for eahc forecast horizon
|
|
62
113
|
valid_times = pd.date_range(
|
|
63
114
|
t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
|
|
64
115
|
t0+minutes(self.config.input_data.gsp.interval_end_minutes),
|
|
65
116
|
freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
|
|
66
117
|
)
|
|
67
118
|
|
|
68
|
-
|
|
119
|
+
# Get the GSP and national capacities
|
|
120
|
+
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
69
121
|
total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
|
|
70
|
-
|
|
122
|
+
|
|
123
|
+
# Calculate requited inputs for the sample
|
|
71
124
|
relative_capacities = location_capacities / total_capacity
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
# Shape: [num_locations]
|
|
83
|
-
"relative_capacity": relative_capacities,
|
|
84
|
-
}
|
|
125
|
+
target = self.national_gsp_data.sel(time_utc=valid_times).values / total_capacity
|
|
126
|
+
last_outturn = self.national_gsp_data.sel(time_utc=t0).values / total_capacity
|
|
127
|
+
|
|
128
|
+
return construct_sample(
|
|
129
|
+
pvnet_inputs=pvnet_inputs,
|
|
130
|
+
valid_times=valid_times,
|
|
131
|
+
relative_capacities=relative_capacities,
|
|
132
|
+
target=target,
|
|
133
|
+
last_outturn=last_outturn,
|
|
134
|
+
)
|
|
85
135
|
|
|
86
136
|
@override
|
|
87
137
|
def __getitem__(self, idx: int) -> SumNumpySample:
|
|
@@ -103,6 +153,8 @@ class StreamedDataModule(LightningDataModule):
|
|
|
103
153
|
num_workers: int = 0,
|
|
104
154
|
prefetch_factor: int | None = None,
|
|
105
155
|
persistent_workers: bool = False,
|
|
156
|
+
seed: int | None = None,
|
|
157
|
+
dataset_pickle_dir: str | None = None,
|
|
106
158
|
):
|
|
107
159
|
"""Datamodule for creating concurrent PVNet inputs and national targets.
|
|
108
160
|
|
|
@@ -115,11 +167,16 @@ class StreamedDataModule(LightningDataModule):
|
|
|
115
167
|
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
116
168
|
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
117
169
|
instances alive.
|
|
170
|
+
seed: Random seed used in shuffling datasets.
|
|
171
|
+
dataset_pickle_dir: Directory in which the val and train set will be presaved as
|
|
172
|
+
pickle objects. Setting this speeds up instantiation of multiple workers a lot.
|
|
118
173
|
"""
|
|
119
174
|
super().__init__()
|
|
120
175
|
self.configuration = configuration
|
|
121
176
|
self.train_period = train_period
|
|
122
177
|
self.val_period = val_period
|
|
178
|
+
self.seed = seed
|
|
179
|
+
self.dataset_pickle_dir = dataset_pickle_dir
|
|
123
180
|
|
|
124
181
|
self._dataloader_kwargs = dict(
|
|
125
182
|
batch_size=None,
|
|
@@ -132,17 +189,58 @@ class StreamedDataModule(LightningDataModule):
|
|
|
132
189
|
worker_init_fn=None,
|
|
133
190
|
prefetch_factor=prefetch_factor,
|
|
134
191
|
persistent_workers=persistent_workers,
|
|
192
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
135
193
|
)
|
|
136
194
|
|
|
195
|
+
def setup(self, stage: str | None = None):
|
|
196
|
+
"""Called once to prepare the datasets."""
|
|
197
|
+
|
|
198
|
+
# This logic runs only once at the start of training, therefore the val dataset is only
|
|
199
|
+
# shuffled once
|
|
200
|
+
if self.dataset_pickle_dir is not None:
|
|
201
|
+
os.makedirs(self.dataset_pickle_dir, exist_ok=True)
|
|
202
|
+
|
|
203
|
+
train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
|
|
204
|
+
val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
|
|
205
|
+
|
|
206
|
+
# For safety, these pickled datasets cannot be overwritten.
|
|
207
|
+
# See: https://github.com/openclimatefix/pvnet/pull/445
|
|
208
|
+
for path in [train_dataset_path, val_dataset_path]:
|
|
209
|
+
if os.path.exists(path):
|
|
210
|
+
raise FileExistsError(
|
|
211
|
+
f"The pickled dataset path '{path}' already exists. Make sure that "
|
|
212
|
+
"this can be safely deleted (i.e. not currently being used by any "
|
|
213
|
+
"training run) and delete it manually. Else change the "
|
|
214
|
+
"`dataset_pickle_dir` to a different directory."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Prepare the train dataset
|
|
218
|
+
self.train_dataset = StreamedDataset(self.configuration, *self.train_period)
|
|
219
|
+
|
|
220
|
+
# Prepare and pre-shuffle the val dataset and set seed for reproducibility
|
|
221
|
+
val_dataset = StreamedDataset(self.configuration, *self.val_period)
|
|
222
|
+
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
223
|
+
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
224
|
+
|
|
225
|
+
if self.dataset_pickle_dir is not None:
|
|
226
|
+
self.train_dataset.presave_pickle(train_dataset_path)
|
|
227
|
+
self.train_dataset.presave_pickle(val_dataset_path)
|
|
228
|
+
|
|
229
|
+
def teardown(self, stage: str | None = None) -> None:
|
|
230
|
+
"""Clean up the pickled datasets"""
|
|
231
|
+
if self.dataset_pickle_dir is not None:
|
|
232
|
+
for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
|
|
233
|
+
filepath = f"{self.dataset_pickle_dir}/{filename}"
|
|
234
|
+
if os.path.exists(filepath):
|
|
235
|
+
os.remove(filepath)
|
|
236
|
+
|
|
137
237
|
def train_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
138
238
|
"""Construct train dataloader"""
|
|
139
|
-
|
|
140
|
-
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
239
|
+
return DataLoader(self.train_dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
141
240
|
|
|
142
241
|
def val_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
143
242
|
"""Construct val dataloader"""
|
|
144
|
-
|
|
145
|
-
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
243
|
+
return DataLoader(self.val_dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
146
244
|
|
|
147
245
|
|
|
148
246
|
class PresavedDataset(Dataset):
|
|
@@ -200,6 +298,7 @@ class PresavedDataModule(LightningDataModule):
|
|
|
200
298
|
worker_init_fn=None,
|
|
201
299
|
prefetch_factor=prefetch_factor,
|
|
202
300
|
persistent_workers=persistent_workers,
|
|
301
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
203
302
|
)
|
|
204
303
|
|
|
205
304
|
def train_dataloader(self, shuffle: bool = True) -> DataLoader:
|
pvnet_summation/load_model.py
CHANGED
|
@@ -20,10 +20,14 @@ def get_model_from_checkpoints(
|
|
|
20
20
|
) -> tuple[torch.nn.Module, dict, str | None, str | None]:
|
|
21
21
|
"""Load a model from its checkpoint directory
|
|
22
22
|
|
|
23
|
+
Args:
|
|
24
|
+
checkpoint_dir_path: str path to the directory with the model files
|
|
25
|
+
val_best (optional): if True, load the best epoch model; otherwise, load the last
|
|
26
|
+
|
|
23
27
|
Returns:
|
|
24
28
|
tuple:
|
|
25
29
|
model: nn.Module of pretrained model.
|
|
26
|
-
model_config:
|
|
30
|
+
model_config: dict of model config used to train the model.
|
|
27
31
|
datamodule_config: path to datamodule used to create samples e.g train/test split info.
|
|
28
32
|
experiment_configs: path to the full experimental config.
|
|
29
33
|
|
|
@@ -4,6 +4,7 @@ import os
|
|
|
4
4
|
import shutil
|
|
5
5
|
import time
|
|
6
6
|
from importlib.metadata import version
|
|
7
|
+
from math import prod
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
import hydra
|
|
@@ -293,6 +294,12 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
293
294
|
"""
|
|
294
295
|
super().__init__()
|
|
295
296
|
|
|
297
|
+
if (output_quantiles is not None):
|
|
298
|
+
if output_quantiles != sorted(output_quantiles):
|
|
299
|
+
raise ValueError("output_quantiles should be in ascending order")
|
|
300
|
+
if 0.5 not in output_quantiles:
|
|
301
|
+
raise ValueError("Quantiles must include 0.5")
|
|
302
|
+
|
|
296
303
|
self.output_quantiles = output_quantiles
|
|
297
304
|
|
|
298
305
|
self.num_input_locations = num_input_locations
|
|
@@ -309,17 +316,21 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
309
316
|
# Store whether the model should use quantile regression or simply predict the mean
|
|
310
317
|
self.use_quantile_regression = self.output_quantiles is not None
|
|
311
318
|
|
|
312
|
-
#
|
|
319
|
+
# Also store the final output shape
|
|
313
320
|
if self.use_quantile_regression:
|
|
314
|
-
self.
|
|
321
|
+
self.output_shape = (self.forecast_len, len(input_quantiles))
|
|
315
322
|
else:
|
|
316
|
-
self.
|
|
323
|
+
self.output_shape = (self.forecast_len,)
|
|
324
|
+
|
|
325
|
+
# Store the number of output features and that the model should predict for
|
|
326
|
+
self.num_output_features = prod(self.output_shape)
|
|
317
327
|
|
|
318
328
|
# Store the expected input shape
|
|
319
329
|
if input_quantiles is None:
|
|
320
330
|
self.input_shape = (self.num_input_locations, self.forecast_len)
|
|
321
331
|
else:
|
|
322
332
|
self.input_shape = (self.num_input_locations, self.forecast_len, len(input_quantiles))
|
|
333
|
+
|
|
323
334
|
|
|
324
335
|
def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
|
|
325
336
|
"""Convert network prediction into a point prediction.
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Neural network architecture based on dense layers applied independently at each horizon"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from pvnet_summation.data.datamodule import SumTensorBatch
|
|
9
|
+
from pvnet_summation.models.base_model import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HorizonDenseModel(BaseModel):
|
|
13
|
+
"""Neural network architecture based on dense layers applied independently at each horizon.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
output_quantiles: list[float] | None,
|
|
19
|
+
num_input_locations: int,
|
|
20
|
+
input_quantiles: list[float] | None,
|
|
21
|
+
history_minutes: int,
|
|
22
|
+
forecast_minutes: int,
|
|
23
|
+
interval_minutes: int,
|
|
24
|
+
output_network: torch.nn.Module,
|
|
25
|
+
predict_difference_from_sum: bool = False,
|
|
26
|
+
use_horizon_encoding: bool = False,
|
|
27
|
+
use_solar_position: bool = False,
|
|
28
|
+
force_non_crossing: bool = False,
|
|
29
|
+
beta: float = 3,
|
|
30
|
+
):
|
|
31
|
+
"""Neural network architecture based on dense layers applied independently at each horizon.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
|
35
|
+
None the output is a single value.
|
|
36
|
+
num_input_locations: The number of input locations (e.g. number of GSPs)
|
|
37
|
+
input_quantiles: A list of float (0.0, 1.0) quantiles which PVNet predicts for. If set
|
|
38
|
+
to None we assume PVNet predicts a single value
|
|
39
|
+
history_minutes (int): Length of the GSP history period in minutes
|
|
40
|
+
forecast_minutes (int): Length of the GSP forecast period in minutes
|
|
41
|
+
interval_minutes: The interval in minutes between each timestep in the data
|
|
42
|
+
output_network: A partially instantiated pytorch Module class used top predict the
|
|
43
|
+
outturn at each horizon.
|
|
44
|
+
predict_difference_from_sum: Whether to predict the difference from the sum of locations
|
|
45
|
+
else the total is predicted directly
|
|
46
|
+
use_horizon_encoding: Whether to use the forecast horizon as an input feature
|
|
47
|
+
use_solar_position: Whether to use the solar coordinates as input features
|
|
48
|
+
force_non_crossing: If predicting quantile, whether to predict the quantiles other than
|
|
49
|
+
the median by predicting the distance between them and integrating.
|
|
50
|
+
beta: If using force_non_crossing, the beta value to use in the softplus activation
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
super().__init__(
|
|
54
|
+
output_quantiles,
|
|
55
|
+
num_input_locations,
|
|
56
|
+
input_quantiles,
|
|
57
|
+
history_minutes,
|
|
58
|
+
forecast_minutes,
|
|
59
|
+
interval_minutes,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if force_non_crossing:
|
|
63
|
+
assert self.use_quantile_regression
|
|
64
|
+
|
|
65
|
+
self.use_horizon_encoding = use_horizon_encoding
|
|
66
|
+
self.predict_difference_from_sum = predict_difference_from_sum
|
|
67
|
+
self.force_non_crossing = force_non_crossing
|
|
68
|
+
self.beta = beta
|
|
69
|
+
self.use_solar_position = use_solar_position
|
|
70
|
+
|
|
71
|
+
in_features = 1 if self.input_quantiles is None else len(self.input_quantiles)
|
|
72
|
+
in_features = in_features * self.num_input_locations
|
|
73
|
+
|
|
74
|
+
if use_horizon_encoding:
|
|
75
|
+
in_features += 1
|
|
76
|
+
|
|
77
|
+
if use_solar_position:
|
|
78
|
+
in_features += 2
|
|
79
|
+
|
|
80
|
+
out_features = (len(self.output_quantiles) if self.use_quantile_regression else 1)
|
|
81
|
+
|
|
82
|
+
model = output_network(in_features=in_features, out_features=out_features)
|
|
83
|
+
|
|
84
|
+
# Add linear layer if predicting difference from sum
|
|
85
|
+
# - This allows difference to be positive or negative
|
|
86
|
+
# Also add linear layer if we are applying force_non_crossing since a softplus will be used
|
|
87
|
+
if predict_difference_from_sum or force_non_crossing:
|
|
88
|
+
model = nn.Sequential(
|
|
89
|
+
model,
|
|
90
|
+
nn.Linear(out_features, out_features),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.model = model
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def forward(self, x: SumTensorBatch) -> torch.Tensor:
|
|
97
|
+
"""Run model forward"""
|
|
98
|
+
|
|
99
|
+
# x["pvnet_outputs"] has shape [batch, locs, horizon, (quantile)]
|
|
100
|
+
batch_size = x["pvnet_outputs"].shape[0]
|
|
101
|
+
x_in = torch.swapaxes(x["pvnet_outputs"], 1, 2) # -> [batch, horizon, locs, (quantile)]
|
|
102
|
+
x_in = torch.flatten(x_in, start_dim=2) # -> [batch, horizon, locs*(quantile)]
|
|
103
|
+
|
|
104
|
+
if self.use_horizon_encoding:
|
|
105
|
+
horizon_encoding = torch.linspace(
|
|
106
|
+
start=0,
|
|
107
|
+
end=1,
|
|
108
|
+
steps=self.forecast_len,
|
|
109
|
+
device=x_in.device,
|
|
110
|
+
dtype=x_in.dtype,
|
|
111
|
+
)
|
|
112
|
+
horizon_encoding = horizon_encoding.tile((batch_size,1)).unsqueeze(-1)
|
|
113
|
+
x_in = torch.cat([x_in, horizon_encoding], dim=2)
|
|
114
|
+
|
|
115
|
+
if self.use_solar_position:
|
|
116
|
+
x_in = torch.cat(
|
|
117
|
+
[x_in, x["azimuth"].unsqueeze(-1), x["elevation"].unsqueeze(-1)],
|
|
118
|
+
dim=2
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
x_in = torch.flatten(x_in, start_dim=0, end_dim=1) # -> [batch*horizon, features]
|
|
122
|
+
|
|
123
|
+
out = self.model(x_in)
|
|
124
|
+
out = out.view(batch_size, *self.output_shape) # -> [batch, horizon, (quantile)]
|
|
125
|
+
|
|
126
|
+
if self.force_non_crossing:
|
|
127
|
+
|
|
128
|
+
# Get the prediction of the median
|
|
129
|
+
idx = self.output_quantiles.index(0.5)
|
|
130
|
+
if self.predict_difference_from_sum:
|
|
131
|
+
loc_sum = self.sum_of_locations(x).unsqueeze(-1)
|
|
132
|
+
y_median = loc_sum + out[..., idx:idx+1]
|
|
133
|
+
else:
|
|
134
|
+
y_median = out[..., idx:idx+1]
|
|
135
|
+
|
|
136
|
+
# These are the differences between the remaining quantiles
|
|
137
|
+
dy_below = F.softplus(out[..., :idx], beta=self.beta)
|
|
138
|
+
dy_above = F.softplus(out[..., idx+1:], beta=self.beta)
|
|
139
|
+
|
|
140
|
+
# Find the absolute value of the quantile predictions from the differences
|
|
141
|
+
y_below = []
|
|
142
|
+
y = y_median
|
|
143
|
+
for i in range(dy_below.shape[-1]):
|
|
144
|
+
# We detach y to avoid the gradients caused by errors from one quantile
|
|
145
|
+
# prediction flowing back to affect the other quantile predictions.
|
|
146
|
+
# For example if the 0.9 quantile prediction was too low, we don't want the
|
|
147
|
+
# gradient to pull the 0.5 quantile prediction higher to compensate.
|
|
148
|
+
y = y.detach() - dy_below[..., i:i+1]
|
|
149
|
+
y_below.append(y)
|
|
150
|
+
|
|
151
|
+
y_above = []
|
|
152
|
+
y = y_median
|
|
153
|
+
for i in range(dy_above.shape[-1]):
|
|
154
|
+
y = y.detach() + dy_above[..., i:i+1]
|
|
155
|
+
y_above.append(y)
|
|
156
|
+
|
|
157
|
+
# Compile the quantile predictions in the correct order
|
|
158
|
+
out = torch.cat(y_below[::-1] + [y_median,] + y_above, dim=-1)
|
|
159
|
+
|
|
160
|
+
else:
|
|
161
|
+
|
|
162
|
+
if self.predict_difference_from_sum:
|
|
163
|
+
loc_sum = self.sum_of_locations(x)
|
|
164
|
+
|
|
165
|
+
if self.use_quantile_regression:
|
|
166
|
+
loc_sum = loc_sum.unsqueeze(-1)
|
|
167
|
+
|
|
168
|
+
out = loc_sum + out
|
|
169
|
+
|
|
170
|
+
# Use leaky relu as a soft clip to 0
|
|
171
|
+
return F.leaky_relu(out, negative_slope=0.01)
|
|
@@ -7,13 +7,21 @@ import torch
|
|
|
7
7
|
from lightning.pytorch import Callback, Trainer, seed_everything
|
|
8
8
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
9
|
from lightning.pytorch.loggers import Logger, WandbLogger
|
|
10
|
-
from ocf_data_sampler.torch_datasets.
|
|
10
|
+
from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import (
|
|
11
|
+
batch_to_tensor,
|
|
12
|
+
copy_batch_to_device,
|
|
13
|
+
)
|
|
11
14
|
from omegaconf import DictConfig, OmegaConf
|
|
12
15
|
from pvnet.models import BaseModel as PVNetBaseModel
|
|
13
16
|
from tqdm import tqdm
|
|
14
17
|
|
|
15
18
|
from pvnet_summation.data.datamodule import PresavedDataModule, StreamedDataModule
|
|
16
|
-
from pvnet_summation.utils import
|
|
19
|
+
from pvnet_summation.utils import (
|
|
20
|
+
DATAMODULE_CONFIG_NAME,
|
|
21
|
+
FULL_CONFIG_NAME,
|
|
22
|
+
MODEL_CONFIG_NAME,
|
|
23
|
+
create_pvnet_model_config,
|
|
24
|
+
)
|
|
17
25
|
|
|
18
26
|
log = logging.getLogger(__name__)
|
|
19
27
|
|
|
@@ -21,9 +29,8 @@ log = logging.getLogger(__name__)
|
|
|
21
29
|
def resolve_monitor_loss(output_quantiles: list | None) -> str:
|
|
22
30
|
"""Return the desired metric to monitor based on whether quantile regression is being used.
|
|
23
31
|
|
|
24
|
-
|
|
32
|
+
Adds the option to use
|
|
25
33
|
monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}"
|
|
26
|
-
|
|
27
34
|
in early stopping and model checkpoint callbacks so the callbacks config does not need to be
|
|
28
35
|
modified depending on whether quantile regression is being used or not.
|
|
29
36
|
"""
|
|
@@ -86,15 +93,33 @@ def train(config: DictConfig) -> None:
|
|
|
86
93
|
os.makedirs(f"{save_dir}/train")
|
|
87
94
|
os.makedirs(f"{save_dir}/val")
|
|
88
95
|
|
|
96
|
+
pvnet_data_config_path = f"{save_dir}/pvnet_data_config.yaml"
|
|
97
|
+
|
|
98
|
+
data_source_paths = OmegaConf.to_container(
|
|
99
|
+
config.datamodule.data_source_paths,
|
|
100
|
+
resolve=True,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
create_pvnet_model_config(
|
|
104
|
+
save_path=pvnet_data_config_path,
|
|
105
|
+
repo=config.datamodule.pvnet_model.model_id,
|
|
106
|
+
commit=config.datamodule.pvnet_model.revision,
|
|
107
|
+
data_source_paths=data_source_paths,
|
|
108
|
+
)
|
|
109
|
+
|
|
89
110
|
datamodule = StreamedDataModule(
|
|
90
|
-
configuration=
|
|
111
|
+
configuration=pvnet_data_config_path,
|
|
91
112
|
num_workers=config.datamodule.num_workers,
|
|
92
113
|
prefetch_factor=config.datamodule.prefetch_factor,
|
|
93
114
|
train_period=config.datamodule.train_period,
|
|
94
115
|
val_period=config.datamodule.val_period,
|
|
95
116
|
persistent_workers=False,
|
|
117
|
+
seed=config.datamodule.seed,
|
|
118
|
+
dataset_pickle_dir=config.datamodule.dataset_pickle_dir,
|
|
96
119
|
)
|
|
97
120
|
|
|
121
|
+
datamodule.setup()
|
|
122
|
+
|
|
98
123
|
for dataloader_func, max_num_samples, split in [
|
|
99
124
|
(datamodule.train_dataloader, config.datamodule.max_num_train_samples, "train",),
|
|
100
125
|
(datamodule.val_dataloader, config.datamodule.max_num_val_samples, "val"),
|
|
@@ -103,7 +128,10 @@ def train(config: DictConfig) -> None:
|
|
|
103
128
|
log.info(f"Saving {split} outputs")
|
|
104
129
|
dataloader = dataloader_func(shuffle=True)
|
|
105
130
|
|
|
106
|
-
|
|
131
|
+
# If max_num_samples set to None use all samples
|
|
132
|
+
max_num_samples = max_num_samples or len(dataloader)
|
|
133
|
+
|
|
134
|
+
for i, sample in tqdm(zip(range(max_num_samples), dataloader), total=max_num_samples):
|
|
107
135
|
# Run PVNet inputs though model
|
|
108
136
|
x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device)
|
|
109
137
|
pvnet_outputs = pvnet_model(x).detach().cpu()
|
|
@@ -116,6 +144,9 @@ def train(config: DictConfig) -> None:
|
|
|
116
144
|
|
|
117
145
|
del dataloader
|
|
118
146
|
|
|
147
|
+
datamodule.teardown()
|
|
148
|
+
|
|
149
|
+
|
|
119
150
|
datamodule = PresavedDataModule(
|
|
120
151
|
sample_dir=save_dir,
|
|
121
152
|
batch_size=config.datamodule.batch_size,
|
|
@@ -182,4 +213,4 @@ def train(config: DictConfig) -> None:
|
|
|
182
213
|
)
|
|
183
214
|
|
|
184
215
|
# Train the model completely
|
|
185
|
-
trainer.fit(model=model, datamodule=datamodule)
|
|
216
|
+
trainer.fit(model=model, datamodule=datamodule)
|
pvnet_summation/utils.py
CHANGED
|
@@ -3,8 +3,10 @@ import logging
|
|
|
3
3
|
|
|
4
4
|
import rich.syntax
|
|
5
5
|
import rich.tree
|
|
6
|
+
import yaml
|
|
6
7
|
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from pvnet.models.base_model import BaseModel as PVNetBaseModel
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
10
12
|
|
|
@@ -17,11 +19,10 @@ MODEL_CARD_NAME = "README.md"
|
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def
|
|
21
|
-
"""
|
|
22
|
+
def maybe_apply_debug_mode(config: DictConfig) -> None:
|
|
23
|
+
"""Check if debugging run is requested and force debug-frendly configuration
|
|
22
24
|
|
|
23
|
-
Controlled by main config file
|
|
24
|
-
- forcing debug friendly configuration
|
|
25
|
+
Controlled by main config file
|
|
25
26
|
|
|
26
27
|
Modifies DictConfig in place.
|
|
27
28
|
|
|
@@ -52,7 +53,7 @@ def run_config_utilities(config: DictConfig) -> None:
|
|
|
52
53
|
@rank_zero_only
|
|
53
54
|
def print_config(
|
|
54
55
|
config: DictConfig,
|
|
55
|
-
fields: tuple[str] = (
|
|
56
|
+
fields: tuple[str, ...] = (
|
|
56
57
|
"trainer",
|
|
57
58
|
"model",
|
|
58
59
|
"datamodule",
|
|
@@ -66,7 +67,7 @@ def print_config(
|
|
|
66
67
|
|
|
67
68
|
Args:
|
|
68
69
|
config (DictConfig): Configuration composed by Hydra.
|
|
69
|
-
fields (
|
|
70
|
+
fields (tuple[str, ...], optional): Determines which main fields from config will
|
|
70
71
|
be printed and in what order.
|
|
71
72
|
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
72
73
|
"""
|
|
@@ -85,3 +86,47 @@ def print_config(
|
|
|
85
86
|
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
86
87
|
|
|
87
88
|
rich.print(tree)
|
|
89
|
+
|
|
90
|
+
def populate_config_with_data_data_filepaths(config: dict, data_source_paths: dict) -> dict:
|
|
91
|
+
"""Populate the data source filepaths in the config
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: The data config
|
|
95
|
+
data_source_paths: A dictionary of data paths for the different input sources
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
# Replace the GSP data path
|
|
99
|
+
config["input_data"]["gsp"]["zarr_path"] = data_source_paths["gsp"]
|
|
100
|
+
|
|
101
|
+
# Replace satellite data path if using it
|
|
102
|
+
if "satellite" in config["input_data"]:
|
|
103
|
+
if config["input_data"]["satellite"]["zarr_path"] != "":
|
|
104
|
+
config["input_data"]["satellite"]["zarr_path"] = data_source_paths["satellite"]
|
|
105
|
+
|
|
106
|
+
# NWP is nested so much be treated separately
|
|
107
|
+
if "nwp" in config["input_data"]:
|
|
108
|
+
nwp_config = config["input_data"]["nwp"]
|
|
109
|
+
for nwp_source in nwp_config.keys():
|
|
110
|
+
provider = nwp_config[nwp_source]["provider"]
|
|
111
|
+
assert provider in data_source_paths["nwp"], f"Missing NWP path: {provider}"
|
|
112
|
+
nwp_config[nwp_source]["zarr_path"] = data_source_paths["nwp"][provider]
|
|
113
|
+
|
|
114
|
+
return config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def create_pvnet_model_config(
|
|
118
|
+
save_path: str,
|
|
119
|
+
repo: str,
|
|
120
|
+
commit: str,
|
|
121
|
+
data_source_paths: dict,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Create the data config needed to run the PVNet model"""
|
|
124
|
+
data_config_path = PVNetBaseModel.get_data_config(repo, revision=commit)
|
|
125
|
+
|
|
126
|
+
with open(data_config_path) as file:
|
|
127
|
+
data_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
128
|
+
|
|
129
|
+
data_config = populate_config_with_data_data_filepaths(data_config, data_source_paths)
|
|
130
|
+
|
|
131
|
+
with open(save_path, "w") as file:
|
|
132
|
+
yaml.dump(data_config, file, default_flow_style=False)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet_summation
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: PVNet_summation
|
|
5
5
|
Author-email: James Fulton <info@openclimatefix.org>
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
9
|
Requires-Dist: pvnet>=5.0.0
|
|
10
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
10
|
+
Requires-Dist: ocf-data-sampler>=0.6.0
|
|
11
11
|
Requires-Dist: numpy
|
|
12
12
|
Requires-Dist: pandas
|
|
13
13
|
Requires-Dist: matplotlib
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
|
|
2
|
+
pvnet_summation/load_model.py,sha256=mQJXJ9p8wb25CVsm5UBGb0IL6xGZj-81iIBKHsNdQMY,2515
|
|
3
|
+
pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
|
|
4
|
+
pvnet_summation/utils.py,sha256=JyqzDQjABCtRsdLgxr5j9K9AdmNlQhmYGenj6mKGnFY,4352
|
|
5
|
+
pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
|
|
6
|
+
pvnet_summation/data/datamodule.py,sha256=Pa2iip-ALihhkAVtqDBPJZ93vh4evJwG9L9YCJiRQag,12517
|
|
7
|
+
pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
|
|
8
|
+
pvnet_summation/models/base_model.py,sha256=mxrEq8k6NAVpezLx3ORPM33OrXzRccVD2ErFkPIw8bc,12496
|
|
9
|
+
pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
|
|
10
|
+
pvnet_summation/models/horizon_dense_model.py,sha256=8NfJiO4upQT8ksqwDn1Jkct5-nrbs_EKfKBseVRay1U,7011
|
|
11
|
+
pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
|
|
12
|
+
pvnet_summation/training/lightning_module.py,sha256=t16gcAc4Fmi1g26dhQwQOm4qe2mwnTfEBbOyH_BFZ4o,8695
|
|
13
|
+
pvnet_summation/training/plots.py,sha256=VZHyzI6UvCEd4nmXiJCF1FiVlpDyFHTxX6_rc0vmJrU,2248
|
|
14
|
+
pvnet_summation/training/train.py,sha256=ze4LCr4XvJ18NjiZhR9KslVf_5HoC1xjGIhBcfw8u5E,8000
|
|
15
|
+
pvnet_summation-1.1.0.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
16
|
+
pvnet_summation-1.1.0.dist-info/METADATA,sha256=uy-zlQ8IyRNgM27nYxL207m41MrlHfuCPDzp0474-e8,3720
|
|
17
|
+
pvnet_summation-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
pvnet_summation-1.1.0.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
|
|
19
|
+
pvnet_summation-1.1.0.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
|
|
2
|
-
pvnet_summation/load_model.py,sha256=GfreRSaKVTWjV9fnJGNYjp09wrpZwaTunHijdff6cyc,2338
|
|
3
|
-
pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
|
|
4
|
-
pvnet_summation/utils.py,sha256=G7l2iZK8qNWEau27pJYPvGOLSzPaSttFrGwr75yTlPQ,2628
|
|
5
|
-
pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
|
|
6
|
-
pvnet_summation/data/datamodule.py,sha256=dexqqz9CHsH2c7ehgOTnJw5LjlOTNCvNhDZsFOVwy1g,8072
|
|
7
|
-
pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
|
|
8
|
-
pvnet_summation/models/base_model.py,sha256=qtsbH8WqrRUQdWpBdeLJ3yz3dlhUeLFUKzVvX7uiopo,12074
|
|
9
|
-
pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
|
|
10
|
-
pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
|
|
11
|
-
pvnet_summation/training/lightning_module.py,sha256=t16gcAc4Fmi1g26dhQwQOm4qe2mwnTfEBbOyH_BFZ4o,8695
|
|
12
|
-
pvnet_summation/training/plots.py,sha256=VZHyzI6UvCEd4nmXiJCF1FiVlpDyFHTxX6_rc0vmJrU,2248
|
|
13
|
-
pvnet_summation/training/train.py,sha256=qBzSCsBMsJpbbBx3laVfOSdBSTCBF7XBWl_AZglbsKQ,7171
|
|
14
|
-
pvnet_summation-1.0.1.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
15
|
-
pvnet_summation-1.0.1.dist-info/METADATA,sha256=fIi2uaWV8-ihgZFMGxIoIQSa2-mHCa5u6-UYcP8fipA,3721
|
|
16
|
-
pvnet_summation-1.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
-
pvnet_summation-1.0.1.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
|
|
18
|
-
pvnet_summation-1.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|