PVNet_summation 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """PVNet_summation"""
@@ -0,0 +1,2 @@
1
+ """Data module"""
2
+ from .datamodule import PresavedDataModule, StreamedDataModule
@@ -0,0 +1,312 @@
1
+ """Pytorch lightning datamodules for loading pre-saved samples and predictions."""
2
+
3
+ import os
4
+ from glob import glob
5
+ from typing import TypeAlias
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from lightning.pytorch import LightningDataModule
11
+ from ocf_data_sampler.load.gsp import get_gsp_boundaries, open_gsp
12
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
13
+ from ocf_data_sampler.numpy_sample.sun_position import calculate_azimuth_and_elevation
14
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
15
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
16
+ from ocf_data_sampler.utils import minutes
17
+ from torch.utils.data import DataLoader, Dataset, Subset, default_collate
18
+ from typing_extensions import override
19
+
20
+ SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
21
+ SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
22
+
23
+ def get_gb_centroid_lon_lat() -> tuple[float, float]:
24
+ """Get the longitude and latitude of the centroid of Great Britain"""
25
+ row = get_gsp_boundaries("20250109").loc[0]
26
+ x_osgb = row.x_osgb.item()
27
+ y_osgb = row.y_osgb.item()
28
+ return osgb_to_lon_lat(x_osgb, y_osgb)
29
+
30
+ LON, LAT = get_gb_centroid_lon_lat()
31
+
32
+
33
+ def construct_sample(
34
+ pvnet_inputs: NumpyBatch,
35
+ valid_times: pd.DatetimeIndex,
36
+ relative_capacities: np.ndarray,
37
+ target: np.ndarray | None,
38
+ last_outturn: float | None = None,
39
+ ) -> SumNumpySample:
40
+ """Construct an input sample for the summation model
41
+
42
+ Args:
43
+ pvnet_inputs: The PVNet batch for all GSPs
44
+ valid_times: An array of valid-times for the forecast
45
+ relative_capacities: Array of capacities of all GSPs normalised by the total capacity
46
+ target: The target national outturn. This is only needed during training.
47
+ last_outturn: The previous national outturn. This is only needed during training.
48
+ """
49
+
50
+ azimuth, elevation = calculate_azimuth_and_elevation(valid_times, LON, LAT)
51
+
52
+ sample = {
53
+ # NumpyBatch object with batch size = num_locations
54
+ "pvnet_inputs": pvnet_inputs,
55
+ # Shape: [time]
56
+ "valid_times": valid_times.values.astype(int),
57
+ # Shape: [num_locations]
58
+ "relative_capacity": relative_capacities,
59
+ # Shape: [time]
60
+ "azimuth": azimuth.astype(np.float32) / 360,
61
+ # Shape: [time]
62
+ "elevation": elevation.astype(np.float32) / 180 + 0.5,
63
+ }
64
+
65
+ if target is not None:
66
+ # Shape: [time]
67
+ sample["target"] = target
68
+ if last_outturn is not None:
69
+ # Shape: scalar
70
+ sample["last_outturn"] = last_outturn
71
+ return sample
72
+
73
+
74
+ class StreamedDataset(PVNetUKConcurrentDataset):
75
+ """A torch dataset for creating concurrent PVNet inputs and national targets."""
76
+
77
+ def __init__(
78
+ self,
79
+ config_filename: str,
80
+ start_time: str | None = None,
81
+ end_time: str | None = None,
82
+ ) -> None:
83
+ """A torch dataset for creating concurrent PVNet inputs and national targets.
84
+
85
+ Args:
86
+ config_filename: Path to the configuration file
87
+ start_time: Limit the init-times to be after this
88
+ end_time: Limit the init-times to be before this
89
+ """
90
+ super().__init__(config_filename, start_time, end_time, gsp_ids=None)
91
+
92
+ # Load and nornmalise the national GSP data to use as target values
93
+ self.national_gsp_data = (
94
+ open_gsp(
95
+ zarr_path=self.config.input_data.gsp.zarr_path,
96
+ boundaries_version=self.config.input_data.gsp.boundaries_version
97
+ )
98
+ .sel(gsp_id=0)
99
+ .compute()
100
+ )
101
+
102
+ def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
103
+ """Generate a concurrent PVNet sample for given init-time.
104
+
105
+ Args:
106
+ t0: init-time for sample
107
+ """
108
+
109
+ # Get the PVNet input batch
110
+ pvnet_inputs: NumpyBatch = super()._get_sample(t0)
111
+
112
+ # Construct an array of valid times for eahc forecast horizon
113
+ valid_times = pd.date_range(
114
+ t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
115
+ t0+minutes(self.config.input_data.gsp.interval_end_minutes),
116
+ freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
117
+ )
118
+
119
+ # Get the GSP and national capacities
120
+ location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
121
+ total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
122
+
123
+ # Calculate requited inputs for the sample
124
+ relative_capacities = location_capacities / total_capacity
125
+ target = self.national_gsp_data.sel(time_utc=valid_times).values / total_capacity
126
+ last_outturn = self.national_gsp_data.sel(time_utc=t0).values / total_capacity
127
+
128
+ return construct_sample(
129
+ pvnet_inputs=pvnet_inputs,
130
+ valid_times=valid_times,
131
+ relative_capacities=relative_capacities,
132
+ target=target,
133
+ last_outturn=last_outturn,
134
+ )
135
+
136
+ @override
137
+ def __getitem__(self, idx: int) -> SumNumpySample:
138
+ return super().__getitem__(idx)
139
+
140
+ @override
141
+ def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
142
+ return super().get_sample(t0)
143
+
144
+
145
+ class StreamedDataModule(LightningDataModule):
146
+ """Datamodule for training pvnet_summation."""
147
+
148
+ def __init__(
149
+ self,
150
+ configuration: str,
151
+ train_period: list[str | None] = [None, None],
152
+ val_period: list[str | None] = [None, None],
153
+ num_workers: int = 0,
154
+ prefetch_factor: int | None = None,
155
+ persistent_workers: bool = False,
156
+ seed: int | None = None,
157
+ dataset_pickle_dir: str | None = None,
158
+ ):
159
+ """Datamodule for creating concurrent PVNet inputs and national targets.
160
+
161
+ Args:
162
+ configuration: Path to ocf-data-sampler configuration file.
163
+ train_period: Date range filter for train dataloader.
164
+ val_period: Date range filter for val dataloader.
165
+ num_workers: Number of workers to use in multiprocess batch loading.
166
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
167
+ persistent_workers: If True, the data loader will not shut down the worker processes
168
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
169
+ instances alive.
170
+ seed: Random seed used in shuffling datasets.
171
+ dataset_pickle_dir: Directory in which the val and train set will be presaved as
172
+ pickle objects. Setting this speeds up instantiation of multiple workers a lot.
173
+ """
174
+ super().__init__()
175
+ self.configuration = configuration
176
+ self.train_period = train_period
177
+ self.val_period = val_period
178
+ self.seed = seed
179
+ self.dataset_pickle_dir = dataset_pickle_dir
180
+
181
+ self._dataloader_kwargs = dict(
182
+ batch_size=None,
183
+ batch_sampler=None,
184
+ num_workers=num_workers,
185
+ collate_fn=None,
186
+ pin_memory=False,
187
+ drop_last=False,
188
+ timeout=0,
189
+ worker_init_fn=None,
190
+ prefetch_factor=prefetch_factor,
191
+ persistent_workers=persistent_workers,
192
+ multiprocessing_context="spawn" if num_workers>0 else None,
193
+ )
194
+
195
+ def setup(self, stage: str | None = None):
196
+ """Called once to prepare the datasets."""
197
+
198
+ # This logic runs only once at the start of training, therefore the val dataset is only
199
+ # shuffled once
200
+ if self.dataset_pickle_dir is not None:
201
+ os.makedirs(self.dataset_pickle_dir, exist_ok=True)
202
+
203
+ train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
204
+ val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
205
+
206
+ # For safety, these pickled datasets cannot be overwritten.
207
+ # See: https://github.com/openclimatefix/pvnet/pull/445
208
+ for path in [train_dataset_path, val_dataset_path]:
209
+ if os.path.exists(path):
210
+ raise FileExistsError(
211
+ f"The pickled dataset path '{path}' already exists. Make sure that "
212
+ "this can be safely deleted (i.e. not currently being used by any "
213
+ "training run) and delete it manually. Else change the "
214
+ "`dataset_pickle_dir` to a different directory."
215
+ )
216
+
217
+ # Prepare the train dataset
218
+ self.train_dataset = StreamedDataset(self.configuration, *self.train_period)
219
+
220
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
221
+ val_dataset = StreamedDataset(self.configuration, *self.val_period)
222
+ shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
223
+ self.val_dataset = Subset(val_dataset, shuffled_indices)
224
+
225
+ if self.dataset_pickle_dir is not None:
226
+ self.train_dataset.presave_pickle(train_dataset_path)
227
+ self.train_dataset.presave_pickle(val_dataset_path)
228
+
229
+ def teardown(self, stage: str | None = None) -> None:
230
+ """Clean up the pickled datasets"""
231
+ if self.dataset_pickle_dir is not None:
232
+ for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
233
+ filepath = f"{self.dataset_pickle_dir}/{filename}"
234
+ if os.path.exists(filepath):
235
+ os.remove(filepath)
236
+
237
+ def train_dataloader(self, shuffle: bool = False) -> DataLoader:
238
+ """Construct train dataloader"""
239
+ return DataLoader(self.train_dataset, shuffle=shuffle, **self._dataloader_kwargs)
240
+
241
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
242
+ """Construct val dataloader"""
243
+ return DataLoader(self.val_dataset, shuffle=shuffle, **self._dataloader_kwargs)
244
+
245
+
246
+ class PresavedDataset(Dataset):
247
+ """Dataset for loading pre-saved PVNet predictions from disk"""
248
+
249
+ def __init__(self, sample_dir: str):
250
+ """"Dataset for loading pre-saved PVNet predictions from disk.
251
+
252
+ Args:
253
+ sample_dir: The directory containing the saved samples
254
+ """
255
+ self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
256
+
257
+ def __len__(self) -> int:
258
+ return len(self.sample_filepaths)
259
+
260
+ def __getitem__(self, idx: int) -> dict:
261
+ return torch.load(self.sample_filepaths[idx], weights_only=True)
262
+
263
+
264
+ class PresavedDataModule(LightningDataModule):
265
+ """Datamodule for loading pre-saved PVNet predictions."""
266
+
267
+ def __init__(
268
+ self,
269
+ sample_dir: str,
270
+ batch_size: int = 16,
271
+ num_workers: int = 0,
272
+ prefetch_factor: int | None = None,
273
+ persistent_workers: bool = False,
274
+ ):
275
+ """Datamodule for loading pre-saved PVNet predictions.
276
+
277
+ Args:
278
+ sample_dir: Path to the directory of pre-saved samples.
279
+ batch_size: Batch size.
280
+ num_workers: Number of workers to use in multiprocess batch loading.
281
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
282
+ persistent_workers: If True, the data loader will not shut down the worker processes
283
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
284
+ instances alive.
285
+ """
286
+ super().__init__()
287
+ self.sample_dir = sample_dir
288
+
289
+ self._dataloader_kwargs = dict(
290
+ batch_size=batch_size,
291
+ sampler=None,
292
+ batch_sampler=None,
293
+ num_workers=num_workers,
294
+ collate_fn=None if batch_size is None else default_collate,
295
+ pin_memory=False,
296
+ drop_last=False,
297
+ timeout=0,
298
+ worker_init_fn=None,
299
+ prefetch_factor=prefetch_factor,
300
+ persistent_workers=persistent_workers,
301
+ multiprocessing_context="spawn" if num_workers>0 else None,
302
+ )
303
+
304
+ def train_dataloader(self, shuffle: bool = True) -> DataLoader:
305
+ """Construct train dataloader"""
306
+ dataset = PresavedDataset(f"{self.sample_dir}/train")
307
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
308
+
309
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
310
+ """Construct val dataloader"""
311
+ dataset = PresavedDataset(f"{self.sample_dir}/val")
312
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
@@ -0,0 +1,74 @@
1
+ """Load a model from its checkpoint directory"""
2
+
3
+ import glob
4
+ import os
5
+
6
+ import hydra
7
+ import torch
8
+ import yaml
9
+
10
+ from pvnet_summation.utils import (
11
+ DATAMODULE_CONFIG_NAME,
12
+ FULL_CONFIG_NAME,
13
+ MODEL_CONFIG_NAME,
14
+ )
15
+
16
+
17
+ def get_model_from_checkpoints(
18
+ checkpoint_dir_path: str,
19
+ val_best: bool = True,
20
+ ) -> tuple[torch.nn.Module, dict, str | None, str | None]:
21
+ """Load a model from its checkpoint directory
22
+
23
+ Args:
24
+ checkpoint_dir_path: str path to the directory with the model files
25
+ val_best (optional): if True, load the best epoch model; otherwise, load the last
26
+
27
+ Returns:
28
+ tuple:
29
+ model: nn.Module of pretrained model.
30
+ model_config: dict of model config used to train the model.
31
+ datamodule_config: path to datamodule used to create samples e.g train/test split info.
32
+ experiment_configs: path to the full experimental config.
33
+
34
+ """
35
+
36
+ # Load lightning training module
37
+ with open(f"{checkpoint_dir_path}/{MODEL_CONFIG_NAME}") as cfg:
38
+ model_config = yaml.load(cfg, Loader=yaml.FullLoader)
39
+
40
+ lightning_module = hydra.utils.instantiate(model_config)
41
+
42
+ if val_best:
43
+ # Only one epoch (best) saved per model
44
+ files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
45
+ if len(files) != 1:
46
+ raise ValueError(
47
+ f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one."
48
+ )
49
+
50
+ checkpoint = torch.load(files[0], map_location="cpu", weights_only=True)
51
+ else:
52
+ checkpoint = torch.load(
53
+ f"{checkpoint_dir_path}/last.ckpt",
54
+ map_location="cpu",
55
+ weights_only=True,
56
+ )
57
+
58
+ lightning_module.load_state_dict(state_dict=checkpoint["state_dict"])
59
+
60
+ # Extract the model from the lightning module
61
+ model = lightning_module.model
62
+ model_config = model_config["model"]
63
+
64
+ # Check for datamodule config
65
+ # This only exists if the model was trained with presaved samples
66
+ datamodule_config = f"{checkpoint_dir_path}/{DATAMODULE_CONFIG_NAME}"
67
+ datamodule_config = datamodule_config if os.path.isfile(datamodule_config) else None
68
+
69
+ # Check for experiment config
70
+ # For backwards compatibility - this might not always exist
71
+ experiment_config = f"{checkpoint_dir_path}/{FULL_CONFIG_NAME}"
72
+ experiment_config = experiment_config if os.path.isfile(experiment_config) else None
73
+
74
+ return model, model_config, datamodule_config, experiment_config
@@ -0,0 +1,3 @@
1
+ """Models for PVNet summation"""
2
+ from .base_model import BaseModel
3
+ from .dense_model import DenseModel