PVNet_summation 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,216 @@
1
+ """Training"""
2
+ import logging
3
+ import os
4
+
5
+ import hydra
6
+ import torch
7
+ from lightning.pytorch import Callback, Trainer, seed_everything
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+ from lightning.pytorch.loggers import Logger, WandbLogger
10
+ from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import (
11
+ batch_to_tensor,
12
+ copy_batch_to_device,
13
+ )
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from pvnet.models import BaseModel as PVNetBaseModel
16
+ from tqdm import tqdm
17
+
18
+ from pvnet_summation.data.datamodule import PresavedDataModule, StreamedDataModule
19
+ from pvnet_summation.utils import (
20
+ DATAMODULE_CONFIG_NAME,
21
+ FULL_CONFIG_NAME,
22
+ MODEL_CONFIG_NAME,
23
+ create_pvnet_model_config,
24
+ )
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+
29
+ def resolve_monitor_loss(output_quantiles: list | None) -> str:
30
+ """Return the desired metric to monitor based on whether quantile regression is being used.
31
+
32
+ Adds the option to use
33
+ monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}"
34
+ in early stopping and model checkpoint callbacks so the callbacks config does not need to be
35
+ modified depending on whether quantile regression is being used or not.
36
+ """
37
+ if output_quantiles is None:
38
+ return "MAE/val"
39
+ else:
40
+ return "quantile_loss/val"
41
+
42
+
43
+ OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss)
44
+
45
+
46
+ def train(config: DictConfig) -> None:
47
+ """Contains training pipeline.
48
+
49
+ Instantiates all PyTorch Lightning objects from config.
50
+
51
+ Args:
52
+ config (DictConfig): Configuration composed by Hydra.
53
+ """
54
+
55
+ # Get the PVNet model
56
+ pvnet_model = PVNetBaseModel.from_pretrained(
57
+ model_id=config.datamodule.pvnet_model.model_id,
58
+ revision=config.datamodule.pvnet_model.revision
59
+ )
60
+
61
+ # Enable adding new keys to config
62
+ OmegaConf.set_struct(config, False)
63
+ # Set summation model parameters to align with the input PVNet model
64
+ config.model.model.history_minutes = pvnet_model.history_minutes
65
+ config.model.model.forecast_minutes = pvnet_model.forecast_minutes
66
+ config.model.model.interval_minutes = pvnet_model.interval_minutes
67
+ config.model.model.num_input_locations = len(pvnet_model.location_id_mapping)
68
+ config.model.model.input_quantiles = pvnet_model.output_quantiles
69
+ OmegaConf.set_struct(config, True)
70
+
71
+ # Set seed for random number generators in pytorch, numpy and python.random
72
+ if "seed" in config:
73
+ seed_everything(config.seed, workers=True)
74
+
75
+ # Compute and save the PVNet predictions before training the summation model
76
+ save_dir = (
77
+ f"{config.sample_save_dir}/{config.datamodule.pvnet_model.model_id}"
78
+ f"/{config.datamodule.pvnet_model.revision}"
79
+ )
80
+
81
+ if os.path.isdir(save_dir):
82
+ log.info(
83
+ f"PVNet output directory already exists: {save_dir}\n"
84
+ "Skipping saving new outputs. The existing saved outputs will be loaded."
85
+ )
86
+ else:
87
+ log.info(f"Saving PVNet outputs to {save_dir}")
88
+
89
+ # Move to device and disable gradients for inference
90
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
+ pvnet_model.to(device).requires_grad_(False)
92
+
93
+ os.makedirs(f"{save_dir}/train")
94
+ os.makedirs(f"{save_dir}/val")
95
+
96
+ pvnet_data_config_path = f"{save_dir}/pvnet_data_config.yaml"
97
+
98
+ data_source_paths = OmegaConf.to_container(
99
+ config.datamodule.data_source_paths,
100
+ resolve=True,
101
+ )
102
+
103
+ create_pvnet_model_config(
104
+ save_path=pvnet_data_config_path,
105
+ repo=config.datamodule.pvnet_model.model_id,
106
+ commit=config.datamodule.pvnet_model.revision,
107
+ data_source_paths=data_source_paths,
108
+ )
109
+
110
+ datamodule = StreamedDataModule(
111
+ configuration=pvnet_data_config_path,
112
+ num_workers=config.datamodule.num_workers,
113
+ prefetch_factor=config.datamodule.prefetch_factor,
114
+ train_period=config.datamodule.train_period,
115
+ val_period=config.datamodule.val_period,
116
+ persistent_workers=False,
117
+ seed=config.datamodule.seed,
118
+ dataset_pickle_dir=config.datamodule.dataset_pickle_dir,
119
+ )
120
+
121
+ datamodule.setup()
122
+
123
+ for dataloader_func, max_num_samples, split in [
124
+ (datamodule.train_dataloader, config.datamodule.max_num_train_samples, "train",),
125
+ (datamodule.val_dataloader, config.datamodule.max_num_val_samples, "val"),
126
+ ]:
127
+
128
+ log.info(f"Saving {split} outputs")
129
+ dataloader = dataloader_func(shuffle=True)
130
+
131
+ # If max_num_samples set to None use all samples
132
+ max_num_samples = max_num_samples or len(dataloader)
133
+
134
+ for i, sample in tqdm(zip(range(max_num_samples), dataloader), total=max_num_samples):
135
+ # Run PVNet inputs though model
136
+ x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device)
137
+ pvnet_outputs = pvnet_model(x).detach().cpu()
138
+
139
+ # Create version of sample without the PVNet inputs and save
140
+ sample_to_save = {k: v.clone() for k, v in sample.items() if k!="pvnet_inputs"}
141
+
142
+ sample_to_save["pvnet_outputs"] = pvnet_outputs
143
+ torch.save(sample_to_save, f"{save_dir}/{split}/{i:06}.pt")
144
+
145
+ del dataloader
146
+
147
+ datamodule.teardown()
148
+
149
+
150
+ datamodule = PresavedDataModule(
151
+ sample_dir=save_dir,
152
+ batch_size=config.datamodule.batch_size,
153
+ num_workers=config.datamodule.num_workers,
154
+ prefetch_factor=config.datamodule.prefetch_factor,
155
+ persistent_workers=config.datamodule.persistent_workers,
156
+ )
157
+
158
+ # Init lightning loggers
159
+ loggers: list[Logger] = []
160
+ if "logger" in config:
161
+ for _, lg_conf in config.logger.items():
162
+ loggers.append(hydra.utils.instantiate(lg_conf))
163
+
164
+ # Init lightning callbacks
165
+ callbacks: list[Callback] = []
166
+ if "callbacks" in config:
167
+ for _, cb_conf in config.callbacks.items():
168
+ callbacks.append(hydra.utils.instantiate(cb_conf))
169
+
170
+ # Align the wandb id with the checkpoint path
171
+ # - only works if wandb logger and model checkpoint used
172
+ # - this makes it easy to push the model to huggingface
173
+ use_wandb_logger = False
174
+ for logger in loggers:
175
+ if isinstance(logger, WandbLogger):
176
+ use_wandb_logger = True
177
+ wandb_logger = logger
178
+ break
179
+
180
+ # Set the output directory based in the wandb-id of the run
181
+ if use_wandb_logger:
182
+ for callback in callbacks:
183
+ if isinstance(callback, ModelCheckpoint):
184
+ # Calling the .experiment property instantiates a wandb run
185
+ wandb_id = wandb_logger.experiment.id
186
+
187
+ # Save the run results to the expected parent folder but with the folder name
188
+ # set by the wandb ID
189
+ save_dir = f"{os.path.dirname(callback.dirpath)}/{wandb_id}"
190
+
191
+ callback.dirpath = save_dir
192
+
193
+ # Save the model config
194
+ os.makedirs(save_dir, exist_ok=True)
195
+ OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}")
196
+
197
+ # Save the datamodule config
198
+ OmegaConf.save(config.datamodule, f"{save_dir}/{DATAMODULE_CONFIG_NAME}")
199
+
200
+ # Save the full hydra config to the output directory and to wandb
201
+ OmegaConf.save(config, f"{save_dir}/{FULL_CONFIG_NAME}")
202
+ wandb_logger.experiment.save(f"{save_dir}/{FULL_CONFIG_NAME}", base_path=save_dir)
203
+
204
+
205
+ # Init lightning model
206
+ model = hydra.utils.instantiate(config.model)
207
+
208
+ trainer: Trainer = hydra.utils.instantiate(
209
+ config.trainer,
210
+ logger=loggers,
211
+ _convert_="partial",
212
+ callbacks=callbacks,
213
+ )
214
+
215
+ # Train the model completely
216
+ trainer.fit(model=model, datamodule=datamodule)
@@ -0,0 +1,132 @@
1
+ """Utils"""
2
+ import logging
3
+
4
+ import rich.syntax
5
+ import rich.tree
6
+ import yaml
7
+ from lightning.pytorch.utilities import rank_zero_only
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from pvnet.models.base_model import BaseModel as PVNetBaseModel
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
15
+ MODEL_CONFIG_NAME = "model_config.yaml"
16
+ DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
17
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
18
+ MODEL_CARD_NAME = "README.md"
19
+
20
+
21
+
22
+ def maybe_apply_debug_mode(config: DictConfig) -> None:
23
+ """Check if debugging run is requested and force debug-frendly configuration
24
+
25
+ Controlled by main config file
26
+
27
+ Modifies DictConfig in place.
28
+
29
+ Args:
30
+ config (DictConfig): Configuration composed by Hydra.
31
+ """
32
+
33
+ # Enable adding new keys to config
34
+ OmegaConf.set_struct(config, False)
35
+
36
+ # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
37
+ if config.trainer.get("fast_dev_run"):
38
+ logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
39
+ # Debuggers don't like GPUs or multiprocessing
40
+ if config.trainer.get("gpus"):
41
+ config.trainer.gpus = 0
42
+ if config.datamodule.get("pin_memory"):
43
+ config.datamodule.pin_memory = False
44
+ if config.datamodule.get("num_workers"):
45
+ config.datamodule.num_workers = 0
46
+ if config.datamodule.get("prefetch_factor"):
47
+ config.datamodule.prefetch_factor = None
48
+
49
+ # Disable adding new keys to config
50
+ OmegaConf.set_struct(config, True)
51
+
52
+
53
+ @rank_zero_only
54
+ def print_config(
55
+ config: DictConfig,
56
+ fields: tuple[str, ...] = (
57
+ "trainer",
58
+ "model",
59
+ "datamodule",
60
+ "callbacks",
61
+ "logger",
62
+ "seed",
63
+ ),
64
+ resolve: bool = True,
65
+ ) -> None:
66
+ """Prints content of DictConfig using Rich library and its tree structure.
67
+
68
+ Args:
69
+ config (DictConfig): Configuration composed by Hydra.
70
+ fields (tuple[str, ...], optional): Determines which main fields from config will
71
+ be printed and in what order.
72
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
73
+ """
74
+
75
+ style = "dim"
76
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
77
+
78
+ for field in fields:
79
+ branch = tree.add(field, style=style, guide_style=style)
80
+
81
+ config_section = config.get(field)
82
+ branch_content = str(config_section)
83
+ if isinstance(config_section, DictConfig):
84
+ branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
85
+
86
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
+
88
+ rich.print(tree)
89
+
90
+ def populate_config_with_data_data_filepaths(config: dict, data_source_paths: dict) -> dict:
91
+ """Populate the data source filepaths in the config
92
+
93
+ Args:
94
+ config: The data config
95
+ data_source_paths: A dictionary of data paths for the different input sources
96
+ """
97
+
98
+ # Replace the GSP data path
99
+ config["input_data"]["gsp"]["zarr_path"] = data_source_paths["gsp"]
100
+
101
+ # Replace satellite data path if using it
102
+ if "satellite" in config["input_data"]:
103
+ if config["input_data"]["satellite"]["zarr_path"] != "":
104
+ config["input_data"]["satellite"]["zarr_path"] = data_source_paths["satellite"]
105
+
106
+ # NWP is nested so much be treated separately
107
+ if "nwp" in config["input_data"]:
108
+ nwp_config = config["input_data"]["nwp"]
109
+ for nwp_source in nwp_config.keys():
110
+ provider = nwp_config[nwp_source]["provider"]
111
+ assert provider in data_source_paths["nwp"], f"Missing NWP path: {provider}"
112
+ nwp_config[nwp_source]["zarr_path"] = data_source_paths["nwp"][provider]
113
+
114
+ return config
115
+
116
+
117
+ def create_pvnet_model_config(
118
+ save_path: str,
119
+ repo: str,
120
+ commit: str,
121
+ data_source_paths: dict,
122
+ ) -> None:
123
+ """Create the data config needed to run the PVNet model"""
124
+ data_config_path = PVNetBaseModel.get_data_config(repo, revision=commit)
125
+
126
+ with open(data_config_path) as file:
127
+ data_config = yaml.load(file, Loader=yaml.FullLoader)
128
+
129
+ data_config = populate_config_with_data_data_filepaths(data_config, data_source_paths)
130
+
131
+ with open(save_path, "w") as file:
132
+ yaml.dump(data_config, file, default_flow_style=False)
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.1.2
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: <3.14,>=3.11
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.0.0
10
+ Requires-Dist: ocf-data-sampler>=0.6.0
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,19 @@
1
+ pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
2
+ pvnet_summation/load_model.py,sha256=mQJXJ9p8wb25CVsm5UBGb0IL6xGZj-81iIBKHsNdQMY,2515
3
+ pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
4
+ pvnet_summation/utils.py,sha256=JyqzDQjABCtRsdLgxr5j9K9AdmNlQhmYGenj6mKGnFY,4352
5
+ pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
6
+ pvnet_summation/data/datamodule.py,sha256=Pa2iip-ALihhkAVtqDBPJZ93vh4evJwG9L9YCJiRQag,12517
7
+ pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
8
+ pvnet_summation/models/base_model.py,sha256=mxrEq8k6NAVpezLx3ORPM33OrXzRccVD2ErFkPIw8bc,12496
9
+ pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
10
+ pvnet_summation/models/horizon_dense_model.py,sha256=8NfJiO4upQT8ksqwDn1Jkct5-nrbs_EKfKBseVRay1U,7011
11
+ pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
12
+ pvnet_summation/training/lightning_module.py,sha256=IMwayobtjA69Blz8v6dxhG31-GgovB9kBqUZJ5A5qRA,9926
13
+ pvnet_summation/training/plots.py,sha256=wjiNh1bH6FQa9rf4Y9Xtp1jyks1bzGJG2-8936I_Dk0,2475
14
+ pvnet_summation/training/train.py,sha256=ze4LCr4XvJ18NjiZhR9KslVf_5HoC1xjGIhBcfw8u5E,8000
15
+ pvnet_summation-1.1.2.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
16
+ pvnet_summation-1.1.2.dist-info/METADATA,sha256=zTBNEYtw5n-s_kAloLYcQPka1Ql_9uMbw5zRIYadeiM,3726
17
+ pvnet_summation-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ pvnet_summation-1.1.2.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
19
+ pvnet_summation-1.1.2.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Open Climate Fix
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ pvnet_summation