PVNet_summation 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. pvnet_summation-1.2.0/LICENSE +21 -0
  2. pvnet_summation-1.2.0/PKG-INFO +100 -0
  3. pvnet_summation-1.2.0/PVNet_summation.egg-info/PKG-INFO +100 -0
  4. pvnet_summation-1.2.0/PVNet_summation.egg-info/SOURCES.txt +23 -0
  5. pvnet_summation-1.2.0/PVNet_summation.egg-info/dependency_links.txt +1 -0
  6. pvnet_summation-1.2.0/PVNet_summation.egg-info/requires.txt +16 -0
  7. pvnet_summation-1.2.0/PVNet_summation.egg-info/top_level.txt +1 -0
  8. pvnet_summation-1.2.0/README.md +74 -0
  9. pvnet_summation-1.2.0/pvnet_summation/__init__.py +1 -0
  10. pvnet_summation-1.2.0/pvnet_summation/data/__init__.py +2 -0
  11. pvnet_summation-1.2.0/pvnet_summation/data/datamodule.py +310 -0
  12. pvnet_summation-1.2.0/pvnet_summation/load_model.py +74 -0
  13. pvnet_summation-1.2.0/pvnet_summation/models/__init__.py +3 -0
  14. pvnet_summation-1.2.0/pvnet_summation/models/base_model.py +356 -0
  15. pvnet_summation-1.2.0/pvnet_summation/models/dense_model.py +75 -0
  16. pvnet_summation-1.2.0/pvnet_summation/models/horizon_dense_model.py +171 -0
  17. pvnet_summation-1.2.0/pvnet_summation/optimizers.py +219 -0
  18. pvnet_summation-1.2.0/pvnet_summation/training/__init__.py +3 -0
  19. pvnet_summation-1.2.0/pvnet_summation/training/lightning_module.py +278 -0
  20. pvnet_summation-1.2.0/pvnet_summation/training/plots.py +91 -0
  21. pvnet_summation-1.2.0/pvnet_summation/training/train.py +216 -0
  22. pvnet_summation-1.2.0/pvnet_summation/utils.py +132 -0
  23. pvnet_summation-1.2.0/pyproject.toml +90 -0
  24. pvnet_summation-1.2.0/setup.cfg +4 -0
  25. pvnet_summation-1.2.0/tests/test_end2end.py +23 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Open Climate Fix
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.2.0
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: <3.14,>=3.11
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.3.0
10
+ Requires-Dist: ocf-data-sampler>=1.0.0
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.2.0
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: <3.14,>=3.11
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.3.0
10
+ Requires-Dist: ocf-data-sampler>=1.0.0
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,23 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ PVNet_summation.egg-info/PKG-INFO
5
+ PVNet_summation.egg-info/SOURCES.txt
6
+ PVNet_summation.egg-info/dependency_links.txt
7
+ PVNet_summation.egg-info/requires.txt
8
+ PVNet_summation.egg-info/top_level.txt
9
+ pvnet_summation/__init__.py
10
+ pvnet_summation/load_model.py
11
+ pvnet_summation/optimizers.py
12
+ pvnet_summation/utils.py
13
+ pvnet_summation/data/__init__.py
14
+ pvnet_summation/data/datamodule.py
15
+ pvnet_summation/models/__init__.py
16
+ pvnet_summation/models/base_model.py
17
+ pvnet_summation/models/dense_model.py
18
+ pvnet_summation/models/horizon_dense_model.py
19
+ pvnet_summation/training/__init__.py
20
+ pvnet_summation/training/lightning_module.py
21
+ pvnet_summation/training/plots.py
22
+ pvnet_summation/training/train.py
23
+ tests/test_end2end.py
@@ -0,0 +1,16 @@
1
+ pvnet>=5.3.0
2
+ ocf-data-sampler>=1.0.0
3
+ numpy
4
+ pandas
5
+ matplotlib
6
+ xarray
7
+ torch>=2.0.0
8
+ lightning
9
+ typer
10
+ wandb
11
+ huggingface-hub
12
+ tqdm
13
+ omegaconf
14
+ hydra-core
15
+ rich
16
+ safetensors
@@ -0,0 +1 @@
1
+ pvnet_summation
@@ -0,0 +1,74 @@
1
+ # PVNet summation
2
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
3
+
4
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
5
+
6
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
7
+
8
+
9
+ ## Setup / Installation
10
+
11
+ ```bash
12
+ git clone https://github.com/openclimatefix/PVNet_summation
13
+ cd PVNet_summation
14
+ pip install .
15
+ ```
16
+
17
+ ### Additional development dependencies
18
+
19
+ ```bash
20
+ pip install ".[dev]"
21
+ ```
22
+
23
+ ## Getting started with running PVNet summation
24
+
25
+ In order to run PVNet summation, we assume that you are already set up with
26
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
27
+
28
+ Before running any code, copy the example configuration to a configs directory:
29
+
30
+ ```
31
+ cp -r configs.example configs
32
+ ```
33
+
34
+ You will be making local amendments to these configs.
35
+
36
+ ### Datasets
37
+
38
+ The datasets required are the same as documented in
39
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
40
+ data for the national sum i.e. GSP ID 0.
41
+
42
+
43
+ ### Training PVNet_summation
44
+
45
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
46
+ configs stored in `configs.example`.
47
+
48
+ Make sure to update the following config files before training your model:
49
+
50
+
51
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
52
+
53
+ 2. In `configs/datamodule/default.yaml`:
54
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
55
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
56
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
57
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
58
+
59
+ 3. In `configs/model/default.yaml`:
60
+ - Update the hyperparameters and structure of the summation model
61
+ 4. In `configs/trainer/default.yaml`:
62
+ - Set `accelerator: 0` if running on a system without a supported GPU
63
+
64
+
65
+ Assuming you have updated the configs, you should now be able to run:
66
+
67
+ ```
68
+ python run.py
69
+ ```
70
+
71
+
72
+ ## Testing
73
+
74
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1 @@
1
+ """PVNet_summation"""
@@ -0,0 +1,2 @@
1
+ """Data module"""
2
+ from .datamodule import PresavedDataModule, StreamedDataModule
@@ -0,0 +1,310 @@
1
+ """Pytorch lightning datamodules for loading pre-saved samples and predictions."""
2
+
3
+ import os
4
+ from glob import glob
5
+ from typing import TypeAlias
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from lightning.pytorch import LightningDataModule
11
+ from ocf_data_sampler.load.generation import open_generation
12
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
13
+ from ocf_data_sampler.numpy_sample.sun_position import calculate_azimuth_and_elevation
14
+ from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetConcurrentDataset
15
+ from ocf_data_sampler.utils import minutes
16
+ from torch.utils.data import DataLoader, Dataset, Subset, default_collate
17
+ from typing_extensions import override
18
+
19
+ SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
20
+ SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
21
+
22
+
23
+ def construct_sample(
24
+ pvnet_inputs: NumpyBatch,
25
+ valid_times: pd.DatetimeIndex,
26
+ relative_capacities: np.ndarray,
27
+ target: np.ndarray | None,
28
+ longitude: float,
29
+ latitude: float,
30
+ last_outturn: float | None = None,
31
+ ) -> SumNumpySample:
32
+ """Construct an input sample for the summation model
33
+
34
+ Args:
35
+ pvnet_inputs: The PVNet batch for all locations
36
+ valid_times: An array of valid-times for the forecast
37
+ relative_capacities: Array of capacities of all locations normalised by the total capacity
38
+ target: The target national outturn. This is only needed during training.
39
+ longitude: The longitude of the national centroid
40
+ latitude: The latitude of the national centroid
41
+ last_outturn: The previous national outturn. This is only needed during training.
42
+
43
+ """
44
+
45
+ azimuth, elevation = calculate_azimuth_and_elevation(valid_times, longitude, latitude)
46
+
47
+ sample = {
48
+ # NumpyBatch object with batch size = num_locations
49
+ "pvnet_inputs": pvnet_inputs,
50
+ # Shape: [time]
51
+ "valid_times": valid_times.values.astype(int),
52
+ # Shape: [num_locations]
53
+ "relative_capacity": relative_capacities,
54
+ # Shape: [time]
55
+ "azimuth": azimuth.astype(np.float32) / 360,
56
+ # Shape: [time]
57
+ "elevation": elevation.astype(np.float32) / 180 + 0.5,
58
+ }
59
+
60
+ if target is not None:
61
+ # Shape: [time]
62
+ sample["target"] = target
63
+ if last_outturn is not None:
64
+ # Shape: scalar
65
+ sample["last_outturn"] = last_outturn
66
+ return sample
67
+
68
+
69
+ class StreamedDataset(PVNetConcurrentDataset):
70
+ """A torch dataset for creating concurrent PVNet inputs and national targets."""
71
+
72
+ def __init__(
73
+ self,
74
+ config_filename: str,
75
+ start_time: str | None = None,
76
+ end_time: str | None = None,
77
+ ) -> None:
78
+ """A torch dataset for creating concurrent PVNet inputs and national targets.
79
+
80
+ Args:
81
+ config_filename: Path to the configuration file
82
+ start_time: Limit the init-times to be after this
83
+ end_time: Limit the init-times to be before this
84
+ """
85
+ super().__init__(config_filename, start_time, end_time)
86
+
87
+ self.national_data = (
88
+ open_generation(
89
+ zarr_path=self.config.input_data.generation.zarr_path,
90
+ )
91
+ .sel(location_id=0)
92
+ .compute()
93
+ )
94
+
95
+ self.longitude = self.national_data.longitude.item()
96
+ self.latitude = self.national_data.latitude.item()
97
+
98
+ def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
99
+ """Generate a concurrent PVNet sample for given init-time.
100
+
101
+ Args:
102
+ t0: init-time for sample
103
+ """
104
+
105
+ # Get the PVNet input batch
106
+ pvnet_inputs: NumpyBatch = super()._get_sample(t0)
107
+
108
+ # Construct an array of valid times for each forecast horizon
109
+ valid_times = pd.date_range(
110
+ t0 + minutes(self.config.input_data.generation.time_resolution_minutes),
111
+ t0 + minutes(self.config.input_data.generation.interval_end_minutes),
112
+ freq=minutes(self.config.input_data.generation.time_resolution_minutes),
113
+ )
114
+
115
+ # Get the region and national capacities
116
+ location_capacities = pvnet_inputs["capacity_mwp"]
117
+ total_capacity = self.national_data.sel(time_utc=t0).capacity_mwp.item()
118
+
119
+ # Calculate requited inputs for the sample
120
+ relative_capacities = location_capacities / total_capacity
121
+ target = self.national_data.sel(time_utc=valid_times).values / total_capacity
122
+ last_outturn = self.national_data.sel(time_utc=t0).values / total_capacity
123
+
124
+ return construct_sample(
125
+ pvnet_inputs=pvnet_inputs,
126
+ valid_times=valid_times,
127
+ relative_capacities=relative_capacities,
128
+ target=target,
129
+ longitude=self.longitude,
130
+ latitude=self.latitude,
131
+ last_outturn=last_outturn,
132
+ )
133
+
134
+ @override
135
+ def __getitem__(self, idx: int) -> SumNumpySample:
136
+ return super().__getitem__(idx)
137
+
138
+ @override
139
+ def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
140
+ return super().get_sample(t0)
141
+
142
+
143
+ class StreamedDataModule(LightningDataModule):
144
+ """Datamodule for training pvnet_summation."""
145
+
146
+ def __init__(
147
+ self,
148
+ configuration: str,
149
+ train_period: list[str | None] = [None, None],
150
+ val_period: list[str | None] = [None, None],
151
+ num_workers: int = 0,
152
+ prefetch_factor: int | None = None,
153
+ persistent_workers: bool = False,
154
+ seed: int | None = None,
155
+ dataset_pickle_dir: str | None = None,
156
+ ):
157
+ """Datamodule for creating concurrent PVNet inputs and national targets.
158
+
159
+ Args:
160
+ configuration: Path to ocf-data-sampler configuration file.
161
+ train_period: Date range filter for train dataloader.
162
+ val_period: Date range filter for val dataloader.
163
+ num_workers: Number of workers to use in multiprocess batch loading.
164
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
165
+ persistent_workers: If True, the data loader will not shut down the worker processes
166
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
167
+ instances alive.
168
+ seed: Random seed used in shuffling datasets.
169
+ dataset_pickle_dir: Directory in which the val and train set will be presaved as
170
+ pickle objects. Setting this speeds up instantiation of multiple workers a lot.
171
+ """
172
+ super().__init__()
173
+ self.configuration = configuration
174
+ self.train_period = train_period
175
+ self.val_period = val_period
176
+ self.seed = seed
177
+ self.dataset_pickle_dir = dataset_pickle_dir
178
+
179
+ self._dataloader_kwargs = dict(
180
+ batch_size=None,
181
+ batch_sampler=None,
182
+ num_workers=num_workers,
183
+ collate_fn=None,
184
+ pin_memory=False,
185
+ drop_last=False,
186
+ timeout=0,
187
+ worker_init_fn=None,
188
+ prefetch_factor=prefetch_factor,
189
+ persistent_workers=persistent_workers,
190
+ multiprocessing_context="spawn" if num_workers > 0 else None,
191
+ )
192
+
193
+ def setup(self, stage: str | None = None):
194
+ """Called once to prepare the datasets."""
195
+
196
+ # This logic runs only once at the start of training, therefore the val dataset is only
197
+ # shuffled once
198
+ if self.dataset_pickle_dir is not None:
199
+ os.makedirs(self.dataset_pickle_dir, exist_ok=True)
200
+
201
+ train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
202
+ val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
203
+
204
+ # For safety, these pickled datasets cannot be overwritten.
205
+ # See: https://github.com/openclimatefix/pvnet/pull/445
206
+ for path in [train_dataset_path, val_dataset_path]:
207
+ if os.path.exists(path):
208
+ raise FileExistsError(
209
+ f"The pickled dataset path '{path}' already exists. Make sure that "
210
+ "this can be safely deleted (i.e. not currently being used by any "
211
+ "training run) and delete it manually. Else change the "
212
+ "`dataset_pickle_dir` to a different directory."
213
+ )
214
+
215
+ # Prepare the train dataset
216
+ self.train_dataset = StreamedDataset(self.configuration, *self.train_period)
217
+
218
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
219
+ val_dataset = StreamedDataset(self.configuration, *self.val_period)
220
+ shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
221
+ self.val_dataset = Subset(val_dataset, shuffled_indices)
222
+
223
+ if self.dataset_pickle_dir is not None:
224
+ self.train_dataset.presave_pickle(train_dataset_path)
225
+ self.train_dataset.presave_pickle(val_dataset_path)
226
+
227
+ def teardown(self, stage: str | None = None) -> None:
228
+ """Clean up the pickled datasets"""
229
+ if self.dataset_pickle_dir is not None:
230
+ for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
231
+ filepath = f"{self.dataset_pickle_dir}/{filename}"
232
+ if os.path.exists(filepath):
233
+ os.remove(filepath)
234
+
235
+ def train_dataloader(self, shuffle: bool = False) -> DataLoader:
236
+ """Construct train dataloader"""
237
+ return DataLoader(self.train_dataset, shuffle=shuffle, **self._dataloader_kwargs)
238
+
239
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
240
+ """Construct val dataloader"""
241
+ return DataLoader(self.val_dataset, shuffle=shuffle, **self._dataloader_kwargs)
242
+
243
+
244
+ class PresavedDataset(Dataset):
245
+ """Dataset for loading pre-saved PVNet predictions from disk"""
246
+
247
+ def __init__(self, sample_dir: str):
248
+ """Dataset for loading pre-saved PVNet predictions from disk.
249
+
250
+ Args:
251
+ sample_dir: The directory containing the saved samples
252
+ """
253
+ self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
254
+
255
+ def __len__(self) -> int:
256
+ return len(self.sample_filepaths)
257
+
258
+ def __getitem__(self, idx: int) -> dict:
259
+ return torch.load(self.sample_filepaths[idx], weights_only=True)
260
+
261
+
262
+ class PresavedDataModule(LightningDataModule):
263
+ """Datamodule for loading pre-saved PVNet predictions."""
264
+
265
+ def __init__(
266
+ self,
267
+ sample_dir: str,
268
+ batch_size: int = 16,
269
+ num_workers: int = 0,
270
+ prefetch_factor: int | None = None,
271
+ persistent_workers: bool = False,
272
+ ):
273
+ """Datamodule for loading pre-saved PVNet predictions.
274
+
275
+ Args:
276
+ sample_dir: Path to the directory of pre-saved samples.
277
+ batch_size: Batch size.
278
+ num_workers: Number of workers to use in multiprocess batch loading.
279
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
280
+ persistent_workers: If True, the data loader will not shut down the worker processes
281
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
282
+ instances alive.
283
+ """
284
+ super().__init__()
285
+ self.sample_dir = sample_dir
286
+
287
+ self._dataloader_kwargs = dict(
288
+ batch_size=batch_size,
289
+ sampler=None,
290
+ batch_sampler=None,
291
+ num_workers=num_workers,
292
+ collate_fn=None if batch_size is None else default_collate,
293
+ pin_memory=False,
294
+ drop_last=False,
295
+ timeout=0,
296
+ worker_init_fn=None,
297
+ prefetch_factor=prefetch_factor,
298
+ persistent_workers=persistent_workers,
299
+ multiprocessing_context="spawn" if num_workers > 0 else None,
300
+ )
301
+
302
+ def train_dataloader(self, shuffle: bool = True) -> DataLoader:
303
+ """Construct train dataloader"""
304
+ dataset = PresavedDataset(f"{self.sample_dir}/train")
305
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
306
+
307
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
308
+ """Construct val dataloader"""
309
+ dataset = PresavedDataset(f"{self.sample_dir}/val")
310
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)