PVNet_summation 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet_summation/__init__.py +1 -0
- pvnet_summation/data/__init__.py +2 -0
- pvnet_summation/data/datamodule.py +312 -0
- pvnet_summation/load_model.py +74 -0
- pvnet_summation/models/__init__.py +3 -0
- pvnet_summation/models/base_model.py +356 -0
- pvnet_summation/models/dense_model.py +75 -0
- pvnet_summation/models/horizon_dense_model.py +171 -0
- pvnet_summation/optimizers.py +219 -0
- pvnet_summation/training/__init__.py +3 -0
- pvnet_summation/training/lightning_module.py +278 -0
- pvnet_summation/training/plots.py +91 -0
- pvnet_summation/training/train.py +216 -0
- pvnet_summation/utils.py +132 -0
- pvnet_summation-1.1.2.dist-info/METADATA +100 -0
- pvnet_summation-1.1.2.dist-info/RECORD +19 -0
- pvnet_summation-1.1.2.dist-info/WHEEL +5 -0
- pvnet_summation-1.1.2.dist-info/licenses/LICENSE +21 -0
- pvnet_summation-1.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""PVNet_summation"""
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""Pytorch lightning datamodules for loading pre-saved samples and predictions."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from glob import glob
|
|
5
|
+
from typing import TypeAlias
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from lightning.pytorch import LightningDataModule
|
|
11
|
+
from ocf_data_sampler.load.gsp import get_gsp_boundaries, open_gsp
|
|
12
|
+
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
|
|
13
|
+
from ocf_data_sampler.numpy_sample.sun_position import calculate_azimuth_and_elevation
|
|
14
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
15
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
|
|
16
|
+
from ocf_data_sampler.utils import minutes
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset, Subset, default_collate
|
|
18
|
+
from typing_extensions import override
|
|
19
|
+
|
|
20
|
+
SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
|
|
21
|
+
SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
|
|
22
|
+
|
|
23
|
+
def get_gb_centroid_lon_lat() -> tuple[float, float]:
|
|
24
|
+
"""Get the longitude and latitude of the centroid of Great Britain"""
|
|
25
|
+
row = get_gsp_boundaries("20250109").loc[0]
|
|
26
|
+
x_osgb = row.x_osgb.item()
|
|
27
|
+
y_osgb = row.y_osgb.item()
|
|
28
|
+
return osgb_to_lon_lat(x_osgb, y_osgb)
|
|
29
|
+
|
|
30
|
+
LON, LAT = get_gb_centroid_lon_lat()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def construct_sample(
|
|
34
|
+
pvnet_inputs: NumpyBatch,
|
|
35
|
+
valid_times: pd.DatetimeIndex,
|
|
36
|
+
relative_capacities: np.ndarray,
|
|
37
|
+
target: np.ndarray | None,
|
|
38
|
+
last_outturn: float | None = None,
|
|
39
|
+
) -> SumNumpySample:
|
|
40
|
+
"""Construct an input sample for the summation model
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
pvnet_inputs: The PVNet batch for all GSPs
|
|
44
|
+
valid_times: An array of valid-times for the forecast
|
|
45
|
+
relative_capacities: Array of capacities of all GSPs normalised by the total capacity
|
|
46
|
+
target: The target national outturn. This is only needed during training.
|
|
47
|
+
last_outturn: The previous national outturn. This is only needed during training.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
azimuth, elevation = calculate_azimuth_and_elevation(valid_times, LON, LAT)
|
|
51
|
+
|
|
52
|
+
sample = {
|
|
53
|
+
# NumpyBatch object with batch size = num_locations
|
|
54
|
+
"pvnet_inputs": pvnet_inputs,
|
|
55
|
+
# Shape: [time]
|
|
56
|
+
"valid_times": valid_times.values.astype(int),
|
|
57
|
+
# Shape: [num_locations]
|
|
58
|
+
"relative_capacity": relative_capacities,
|
|
59
|
+
# Shape: [time]
|
|
60
|
+
"azimuth": azimuth.astype(np.float32) / 360,
|
|
61
|
+
# Shape: [time]
|
|
62
|
+
"elevation": elevation.astype(np.float32) / 180 + 0.5,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if target is not None:
|
|
66
|
+
# Shape: [time]
|
|
67
|
+
sample["target"] = target
|
|
68
|
+
if last_outturn is not None:
|
|
69
|
+
# Shape: scalar
|
|
70
|
+
sample["last_outturn"] = last_outturn
|
|
71
|
+
return sample
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class StreamedDataset(PVNetUKConcurrentDataset):
|
|
75
|
+
"""A torch dataset for creating concurrent PVNet inputs and national targets."""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
config_filename: str,
|
|
80
|
+
start_time: str | None = None,
|
|
81
|
+
end_time: str | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""A torch dataset for creating concurrent PVNet inputs and national targets.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
config_filename: Path to the configuration file
|
|
87
|
+
start_time: Limit the init-times to be after this
|
|
88
|
+
end_time: Limit the init-times to be before this
|
|
89
|
+
"""
|
|
90
|
+
super().__init__(config_filename, start_time, end_time, gsp_ids=None)
|
|
91
|
+
|
|
92
|
+
# Load and nornmalise the national GSP data to use as target values
|
|
93
|
+
self.national_gsp_data = (
|
|
94
|
+
open_gsp(
|
|
95
|
+
zarr_path=self.config.input_data.gsp.zarr_path,
|
|
96
|
+
boundaries_version=self.config.input_data.gsp.boundaries_version
|
|
97
|
+
)
|
|
98
|
+
.sel(gsp_id=0)
|
|
99
|
+
.compute()
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
103
|
+
"""Generate a concurrent PVNet sample for given init-time.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
t0: init-time for sample
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# Get the PVNet input batch
|
|
110
|
+
pvnet_inputs: NumpyBatch = super()._get_sample(t0)
|
|
111
|
+
|
|
112
|
+
# Construct an array of valid times for eahc forecast horizon
|
|
113
|
+
valid_times = pd.date_range(
|
|
114
|
+
t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
|
|
115
|
+
t0+minutes(self.config.input_data.gsp.interval_end_minutes),
|
|
116
|
+
freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Get the GSP and national capacities
|
|
120
|
+
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
121
|
+
total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
|
|
122
|
+
|
|
123
|
+
# Calculate requited inputs for the sample
|
|
124
|
+
relative_capacities = location_capacities / total_capacity
|
|
125
|
+
target = self.national_gsp_data.sel(time_utc=valid_times).values / total_capacity
|
|
126
|
+
last_outturn = self.national_gsp_data.sel(time_utc=t0).values / total_capacity
|
|
127
|
+
|
|
128
|
+
return construct_sample(
|
|
129
|
+
pvnet_inputs=pvnet_inputs,
|
|
130
|
+
valid_times=valid_times,
|
|
131
|
+
relative_capacities=relative_capacities,
|
|
132
|
+
target=target,
|
|
133
|
+
last_outturn=last_outturn,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@override
|
|
137
|
+
def __getitem__(self, idx: int) -> SumNumpySample:
|
|
138
|
+
return super().__getitem__(idx)
|
|
139
|
+
|
|
140
|
+
@override
|
|
141
|
+
def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
142
|
+
return super().get_sample(t0)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class StreamedDataModule(LightningDataModule):
|
|
146
|
+
"""Datamodule for training pvnet_summation."""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
configuration: str,
|
|
151
|
+
train_period: list[str | None] = [None, None],
|
|
152
|
+
val_period: list[str | None] = [None, None],
|
|
153
|
+
num_workers: int = 0,
|
|
154
|
+
prefetch_factor: int | None = None,
|
|
155
|
+
persistent_workers: bool = False,
|
|
156
|
+
seed: int | None = None,
|
|
157
|
+
dataset_pickle_dir: str | None = None,
|
|
158
|
+
):
|
|
159
|
+
"""Datamodule for creating concurrent PVNet inputs and national targets.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
configuration: Path to ocf-data-sampler configuration file.
|
|
163
|
+
train_period: Date range filter for train dataloader.
|
|
164
|
+
val_period: Date range filter for val dataloader.
|
|
165
|
+
num_workers: Number of workers to use in multiprocess batch loading.
|
|
166
|
+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
|
167
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
168
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
169
|
+
instances alive.
|
|
170
|
+
seed: Random seed used in shuffling datasets.
|
|
171
|
+
dataset_pickle_dir: Directory in which the val and train set will be presaved as
|
|
172
|
+
pickle objects. Setting this speeds up instantiation of multiple workers a lot.
|
|
173
|
+
"""
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.configuration = configuration
|
|
176
|
+
self.train_period = train_period
|
|
177
|
+
self.val_period = val_period
|
|
178
|
+
self.seed = seed
|
|
179
|
+
self.dataset_pickle_dir = dataset_pickle_dir
|
|
180
|
+
|
|
181
|
+
self._dataloader_kwargs = dict(
|
|
182
|
+
batch_size=None,
|
|
183
|
+
batch_sampler=None,
|
|
184
|
+
num_workers=num_workers,
|
|
185
|
+
collate_fn=None,
|
|
186
|
+
pin_memory=False,
|
|
187
|
+
drop_last=False,
|
|
188
|
+
timeout=0,
|
|
189
|
+
worker_init_fn=None,
|
|
190
|
+
prefetch_factor=prefetch_factor,
|
|
191
|
+
persistent_workers=persistent_workers,
|
|
192
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def setup(self, stage: str | None = None):
|
|
196
|
+
"""Called once to prepare the datasets."""
|
|
197
|
+
|
|
198
|
+
# This logic runs only once at the start of training, therefore the val dataset is only
|
|
199
|
+
# shuffled once
|
|
200
|
+
if self.dataset_pickle_dir is not None:
|
|
201
|
+
os.makedirs(self.dataset_pickle_dir, exist_ok=True)
|
|
202
|
+
|
|
203
|
+
train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
|
|
204
|
+
val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
|
|
205
|
+
|
|
206
|
+
# For safety, these pickled datasets cannot be overwritten.
|
|
207
|
+
# See: https://github.com/openclimatefix/pvnet/pull/445
|
|
208
|
+
for path in [train_dataset_path, val_dataset_path]:
|
|
209
|
+
if os.path.exists(path):
|
|
210
|
+
raise FileExistsError(
|
|
211
|
+
f"The pickled dataset path '{path}' already exists. Make sure that "
|
|
212
|
+
"this can be safely deleted (i.e. not currently being used by any "
|
|
213
|
+
"training run) and delete it manually. Else change the "
|
|
214
|
+
"`dataset_pickle_dir` to a different directory."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Prepare the train dataset
|
|
218
|
+
self.train_dataset = StreamedDataset(self.configuration, *self.train_period)
|
|
219
|
+
|
|
220
|
+
# Prepare and pre-shuffle the val dataset and set seed for reproducibility
|
|
221
|
+
val_dataset = StreamedDataset(self.configuration, *self.val_period)
|
|
222
|
+
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
223
|
+
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
224
|
+
|
|
225
|
+
if self.dataset_pickle_dir is not None:
|
|
226
|
+
self.train_dataset.presave_pickle(train_dataset_path)
|
|
227
|
+
self.train_dataset.presave_pickle(val_dataset_path)
|
|
228
|
+
|
|
229
|
+
def teardown(self, stage: str | None = None) -> None:
|
|
230
|
+
"""Clean up the pickled datasets"""
|
|
231
|
+
if self.dataset_pickle_dir is not None:
|
|
232
|
+
for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
|
|
233
|
+
filepath = f"{self.dataset_pickle_dir}/{filename}"
|
|
234
|
+
if os.path.exists(filepath):
|
|
235
|
+
os.remove(filepath)
|
|
236
|
+
|
|
237
|
+
def train_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
238
|
+
"""Construct train dataloader"""
|
|
239
|
+
return DataLoader(self.train_dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
240
|
+
|
|
241
|
+
def val_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
242
|
+
"""Construct val dataloader"""
|
|
243
|
+
return DataLoader(self.val_dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class PresavedDataset(Dataset):
|
|
247
|
+
"""Dataset for loading pre-saved PVNet predictions from disk"""
|
|
248
|
+
|
|
249
|
+
def __init__(self, sample_dir: str):
|
|
250
|
+
""""Dataset for loading pre-saved PVNet predictions from disk.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
sample_dir: The directory containing the saved samples
|
|
254
|
+
"""
|
|
255
|
+
self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
|
|
256
|
+
|
|
257
|
+
def __len__(self) -> int:
|
|
258
|
+
return len(self.sample_filepaths)
|
|
259
|
+
|
|
260
|
+
def __getitem__(self, idx: int) -> dict:
|
|
261
|
+
return torch.load(self.sample_filepaths[idx], weights_only=True)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class PresavedDataModule(LightningDataModule):
|
|
265
|
+
"""Datamodule for loading pre-saved PVNet predictions."""
|
|
266
|
+
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
sample_dir: str,
|
|
270
|
+
batch_size: int = 16,
|
|
271
|
+
num_workers: int = 0,
|
|
272
|
+
prefetch_factor: int | None = None,
|
|
273
|
+
persistent_workers: bool = False,
|
|
274
|
+
):
|
|
275
|
+
"""Datamodule for loading pre-saved PVNet predictions.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
sample_dir: Path to the directory of pre-saved samples.
|
|
279
|
+
batch_size: Batch size.
|
|
280
|
+
num_workers: Number of workers to use in multiprocess batch loading.
|
|
281
|
+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
|
282
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
283
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
284
|
+
instances alive.
|
|
285
|
+
"""
|
|
286
|
+
super().__init__()
|
|
287
|
+
self.sample_dir = sample_dir
|
|
288
|
+
|
|
289
|
+
self._dataloader_kwargs = dict(
|
|
290
|
+
batch_size=batch_size,
|
|
291
|
+
sampler=None,
|
|
292
|
+
batch_sampler=None,
|
|
293
|
+
num_workers=num_workers,
|
|
294
|
+
collate_fn=None if batch_size is None else default_collate,
|
|
295
|
+
pin_memory=False,
|
|
296
|
+
drop_last=False,
|
|
297
|
+
timeout=0,
|
|
298
|
+
worker_init_fn=None,
|
|
299
|
+
prefetch_factor=prefetch_factor,
|
|
300
|
+
persistent_workers=persistent_workers,
|
|
301
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def train_dataloader(self, shuffle: bool = True) -> DataLoader:
|
|
305
|
+
"""Construct train dataloader"""
|
|
306
|
+
dataset = PresavedDataset(f"{self.sample_dir}/train")
|
|
307
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
308
|
+
|
|
309
|
+
def val_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
310
|
+
"""Construct val dataloader"""
|
|
311
|
+
dataset = PresavedDataset(f"{self.sample_dir}/val")
|
|
312
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Load a model from its checkpoint directory"""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import torch
|
|
8
|
+
import yaml
|
|
9
|
+
|
|
10
|
+
from pvnet_summation.utils import (
|
|
11
|
+
DATAMODULE_CONFIG_NAME,
|
|
12
|
+
FULL_CONFIG_NAME,
|
|
13
|
+
MODEL_CONFIG_NAME,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_model_from_checkpoints(
|
|
18
|
+
checkpoint_dir_path: str,
|
|
19
|
+
val_best: bool = True,
|
|
20
|
+
) -> tuple[torch.nn.Module, dict, str | None, str | None]:
|
|
21
|
+
"""Load a model from its checkpoint directory
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
checkpoint_dir_path: str path to the directory with the model files
|
|
25
|
+
val_best (optional): if True, load the best epoch model; otherwise, load the last
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
tuple:
|
|
29
|
+
model: nn.Module of pretrained model.
|
|
30
|
+
model_config: dict of model config used to train the model.
|
|
31
|
+
datamodule_config: path to datamodule used to create samples e.g train/test split info.
|
|
32
|
+
experiment_configs: path to the full experimental config.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# Load lightning training module
|
|
37
|
+
with open(f"{checkpoint_dir_path}/{MODEL_CONFIG_NAME}") as cfg:
|
|
38
|
+
model_config = yaml.load(cfg, Loader=yaml.FullLoader)
|
|
39
|
+
|
|
40
|
+
lightning_module = hydra.utils.instantiate(model_config)
|
|
41
|
+
|
|
42
|
+
if val_best:
|
|
43
|
+
# Only one epoch (best) saved per model
|
|
44
|
+
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
|
|
45
|
+
if len(files) != 1:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
checkpoint = torch.load(files[0], map_location="cpu", weights_only=True)
|
|
51
|
+
else:
|
|
52
|
+
checkpoint = torch.load(
|
|
53
|
+
f"{checkpoint_dir_path}/last.ckpt",
|
|
54
|
+
map_location="cpu",
|
|
55
|
+
weights_only=True,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
lightning_module.load_state_dict(state_dict=checkpoint["state_dict"])
|
|
59
|
+
|
|
60
|
+
# Extract the model from the lightning module
|
|
61
|
+
model = lightning_module.model
|
|
62
|
+
model_config = model_config["model"]
|
|
63
|
+
|
|
64
|
+
# Check for datamodule config
|
|
65
|
+
# This only exists if the model was trained with presaved samples
|
|
66
|
+
datamodule_config = f"{checkpoint_dir_path}/{DATAMODULE_CONFIG_NAME}"
|
|
67
|
+
datamodule_config = datamodule_config if os.path.isfile(datamodule_config) else None
|
|
68
|
+
|
|
69
|
+
# Check for experiment config
|
|
70
|
+
# For backwards compatibility - this might not always exist
|
|
71
|
+
experiment_config = f"{checkpoint_dir_path}/{FULL_CONFIG_NAME}"
|
|
72
|
+
experiment_config = experiment_config if os.path.isfile(experiment_config) else None
|
|
73
|
+
|
|
74
|
+
return model, model_config, datamodule_config, experiment_config
|