PVNet_summation 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PVNet_summation might be problematic. Click here for more details.

@@ -0,0 +1,185 @@
1
+ """Training"""
2
+ import logging
3
+ import os
4
+
5
+ import hydra
6
+ import torch
7
+ from lightning.pytorch import Callback, Trainer, seed_everything
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+ from lightning.pytorch.loggers import Logger, WandbLogger
10
+ from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor, copy_batch_to_device
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from pvnet.models import BaseModel as PVNetBaseModel
13
+ from tqdm import tqdm
14
+
15
+ from pvnet_summation.data.datamodule import PresavedDataModule, StreamedDataModule
16
+ from pvnet_summation.utils import DATAMODULE_CONFIG_NAME, FULL_CONFIG_NAME, MODEL_CONFIG_NAME
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def resolve_monitor_loss(output_quantiles: list | None) -> str:
22
+ """Return the desired metric to monitor based on whether quantile regression is being used.
23
+
24
+ The adds the option to use something like:
25
+ monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}"
26
+
27
+ in early stopping and model checkpoint callbacks so the callbacks config does not need to be
28
+ modified depending on whether quantile regression is being used or not.
29
+ """
30
+ if output_quantiles is None:
31
+ return "MAE/val"
32
+ else:
33
+ return "quantile_loss/val"
34
+
35
+
36
+ OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss)
37
+
38
+
39
+ def train(config: DictConfig) -> None:
40
+ """Contains training pipeline.
41
+
42
+ Instantiates all PyTorch Lightning objects from config.
43
+
44
+ Args:
45
+ config (DictConfig): Configuration composed by Hydra.
46
+ """
47
+
48
+ # Get the PVNet model
49
+ pvnet_model = PVNetBaseModel.from_pretrained(
50
+ model_id=config.datamodule.pvnet_model.model_id,
51
+ revision=config.datamodule.pvnet_model.revision
52
+ )
53
+
54
+ # Enable adding new keys to config
55
+ OmegaConf.set_struct(config, False)
56
+ # Set summation model parameters to align with the input PVNet model
57
+ config.model.model.history_minutes = pvnet_model.history_minutes
58
+ config.model.model.forecast_minutes = pvnet_model.forecast_minutes
59
+ config.model.model.interval_minutes = pvnet_model.interval_minutes
60
+ config.model.model.num_input_locations = len(pvnet_model.location_id_mapping)
61
+ config.model.model.input_quantiles = pvnet_model.output_quantiles
62
+ OmegaConf.set_struct(config, True)
63
+
64
+ # Set seed for random number generators in pytorch, numpy and python.random
65
+ if "seed" in config:
66
+ seed_everything(config.seed, workers=True)
67
+
68
+ # Compute and save the PVNet predictions before training the summation model
69
+ save_dir = (
70
+ f"{config.sample_save_dir}/{config.datamodule.pvnet_model.model_id}"
71
+ f"/{config.datamodule.pvnet_model.revision}"
72
+ )
73
+
74
+ if os.path.isdir(save_dir):
75
+ log.info(
76
+ f"PVNet output directory already exists: {save_dir}\n"
77
+ "Skipping saving new outputs. The existing saved outputs will be loaded."
78
+ )
79
+ else:
80
+ log.info(f"Saving PVNet outputs to {save_dir}")
81
+
82
+ # Move to device and disable gradients for inference
83
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ pvnet_model.to(device).requires_grad_(False)
85
+
86
+ os.makedirs(f"{save_dir}/train")
87
+ os.makedirs(f"{save_dir}/val")
88
+
89
+ datamodule = StreamedDataModule(
90
+ configuration=config.datamodule.configuration,
91
+ num_workers=config.datamodule.num_workers,
92
+ prefetch_factor=config.datamodule.prefetch_factor,
93
+ train_period=config.datamodule.train_period,
94
+ val_period=config.datamodule.val_period,
95
+ persistent_workers=False,
96
+ )
97
+
98
+ for dataloader_func, max_num_samples, split in [
99
+ (datamodule.train_dataloader, config.datamodule.max_num_train_samples, "train",),
100
+ (datamodule.val_dataloader, config.datamodule.max_num_val_samples, "val"),
101
+ ]:
102
+
103
+ log.info(f"Saving {split} outputs")
104
+ dataloader = dataloader_func(shuffle=True)
105
+
106
+ for i, sample in tqdm(zip(range(max_num_samples), dataloader)):
107
+ # Run PVNet inputs though model
108
+ x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device)
109
+ pvnet_outputs = pvnet_model(x).detach().cpu()
110
+
111
+ # Create version of sample without the PVNet inputs and save
112
+ sample_to_save = {k: v.clone() for k, v in sample.items() if k!="pvnet_inputs"}
113
+
114
+ sample_to_save["pvnet_outputs"] = pvnet_outputs
115
+ torch.save(sample_to_save, f"{save_dir}/{split}/{i:06}.pt")
116
+
117
+ del dataloader
118
+
119
+ datamodule = PresavedDataModule(
120
+ sample_dir=save_dir,
121
+ batch_size=config.datamodule.batch_size,
122
+ num_workers=config.datamodule.num_workers,
123
+ prefetch_factor=config.datamodule.prefetch_factor,
124
+ persistent_workers=config.datamodule.persistent_workers,
125
+ )
126
+
127
+ # Init lightning loggers
128
+ loggers: list[Logger] = []
129
+ if "logger" in config:
130
+ for _, lg_conf in config.logger.items():
131
+ loggers.append(hydra.utils.instantiate(lg_conf))
132
+
133
+ # Init lightning callbacks
134
+ callbacks: list[Callback] = []
135
+ if "callbacks" in config:
136
+ for _, cb_conf in config.callbacks.items():
137
+ callbacks.append(hydra.utils.instantiate(cb_conf))
138
+
139
+ # Align the wandb id with the checkpoint path
140
+ # - only works if wandb logger and model checkpoint used
141
+ # - this makes it easy to push the model to huggingface
142
+ use_wandb_logger = False
143
+ for logger in loggers:
144
+ if isinstance(logger, WandbLogger):
145
+ use_wandb_logger = True
146
+ wandb_logger = logger
147
+ break
148
+
149
+ # Set the output directory based in the wandb-id of the run
150
+ if use_wandb_logger:
151
+ for callback in callbacks:
152
+ if isinstance(callback, ModelCheckpoint):
153
+ # Calling the .experiment property instantiates a wandb run
154
+ wandb_id = wandb_logger.experiment.id
155
+
156
+ # Save the run results to the expected parent folder but with the folder name
157
+ # set by the wandb ID
158
+ save_dir = f"{os.path.dirname(callback.dirpath)}/{wandb_id}"
159
+
160
+ callback.dirpath = save_dir
161
+
162
+ # Save the model config
163
+ os.makedirs(save_dir, exist_ok=True)
164
+ OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}")
165
+
166
+ # Save the datamodule config
167
+ OmegaConf.save(config.datamodule, f"{save_dir}/{DATAMODULE_CONFIG_NAME}")
168
+
169
+ # Save the full hydra config to the output directory and to wandb
170
+ OmegaConf.save(config, f"{save_dir}/{FULL_CONFIG_NAME}")
171
+ wandb_logger.experiment.save(f"{save_dir}/{FULL_CONFIG_NAME}", base_path=save_dir)
172
+
173
+
174
+ # Init lightning model
175
+ model = hydra.utils.instantiate(config.model)
176
+
177
+ trainer: Trainer = hydra.utils.instantiate(
178
+ config.trainer,
179
+ logger=loggers,
180
+ _convert_="partial",
181
+ callbacks=callbacks,
182
+ )
183
+
184
+ # Train the model completely
185
+ trainer.fit(model=model, datamodule=datamodule)
@@ -0,0 +1,87 @@
1
+ """Utils"""
2
+ import logging
3
+
4
+ import rich.syntax
5
+ import rich.tree
6
+ from lightning.pytorch.utilities import rank_zero_only
7
+ from omegaconf import DictConfig, OmegaConf
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
13
+ MODEL_CONFIG_NAME = "model_config.yaml"
14
+ DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
15
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
16
+ MODEL_CARD_NAME = "README.md"
17
+
18
+
19
+
20
+ def run_config_utilities(config: DictConfig) -> None:
21
+ """A couple of optional utilities.
22
+
23
+ Controlled by main config file:
24
+ - forcing debug friendly configuration
25
+
26
+ Modifies DictConfig in place.
27
+
28
+ Args:
29
+ config (DictConfig): Configuration composed by Hydra.
30
+ """
31
+
32
+ # Enable adding new keys to config
33
+ OmegaConf.set_struct(config, False)
34
+
35
+ # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
36
+ if config.trainer.get("fast_dev_run"):
37
+ logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
38
+ # Debuggers don't like GPUs or multiprocessing
39
+ if config.trainer.get("gpus"):
40
+ config.trainer.gpus = 0
41
+ if config.datamodule.get("pin_memory"):
42
+ config.datamodule.pin_memory = False
43
+ if config.datamodule.get("num_workers"):
44
+ config.datamodule.num_workers = 0
45
+ if config.datamodule.get("prefetch_factor"):
46
+ config.datamodule.prefetch_factor = None
47
+
48
+ # Disable adding new keys to config
49
+ OmegaConf.set_struct(config, True)
50
+
51
+
52
+ @rank_zero_only
53
+ def print_config(
54
+ config: DictConfig,
55
+ fields: tuple[str] = (
56
+ "trainer",
57
+ "model",
58
+ "datamodule",
59
+ "callbacks",
60
+ "logger",
61
+ "seed",
62
+ ),
63
+ resolve: bool = True,
64
+ ) -> None:
65
+ """Prints content of DictConfig using Rich library and its tree structure.
66
+
67
+ Args:
68
+ config (DictConfig): Configuration composed by Hydra.
69
+ fields (Sequence[str], optional): Determines which main fields from config will
70
+ be printed and in what order.
71
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
72
+ """
73
+
74
+ style = "dim"
75
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
76
+
77
+ for field in fields:
78
+ branch = tree.add(field, style=style, guide_style=style)
79
+
80
+ config_section = config.get(field)
81
+ branch_content = str(config_section)
82
+ if isinstance(config_section, DictConfig):
83
+ branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
84
+
85
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
86
+
87
+ rich.print(tree)
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.0.0
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.0.0
10
+ Requires-Dist: ocf-data-sampler>=0.2.32
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,18 @@
1
+ pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
2
+ pvnet_summation/load_model.py,sha256=GfreRSaKVTWjV9fnJGNYjp09wrpZwaTunHijdff6cyc,2338
3
+ pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
4
+ pvnet_summation/utils.py,sha256=G7l2iZK8qNWEau27pJYPvGOLSzPaSttFrGwr75yTlPQ,2628
5
+ pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
6
+ pvnet_summation/data/datamodule.py,sha256=dexqqz9CHsH2c7ehgOTnJw5LjlOTNCvNhDZsFOVwy1g,8072
7
+ pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
8
+ pvnet_summation/models/base_model.py,sha256=qtsbH8WqrRUQdWpBdeLJ3yz3dlhUeLFUKzVvX7uiopo,12074
9
+ pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
10
+ pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
11
+ pvnet_summation/training/lightning_module.py,sha256=t16gcAc4Fmi1g26dhQwQOm4qe2mwnTfEBbOyH_BFZ4o,8695
12
+ pvnet_summation/training/plots.py,sha256=VZHyzI6UvCEd4nmXiJCF1FiVlpDyFHTxX6_rc0vmJrU,2248
13
+ pvnet_summation/training/train.py,sha256=qBzSCsBMsJpbbBx3laVfOSdBSTCBF7XBWl_AZglbsKQ,7171
14
+ pvnet_summation-1.0.0.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
15
+ pvnet_summation-1.0.0.dist-info/METADATA,sha256=ihU9WtYJmMGsioyVWW6M7Qv-7RF8h3r4EOwADhZ7W_s,3721
16
+ pvnet_summation-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ pvnet_summation-1.0.0.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
18
+ pvnet_summation-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Open Climate Fix
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ pvnet_summation