PVNet_summation 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PVNet_summation might be problematic. Click here for more details.

@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Open Climate Fix
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.0.0
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.0.0
10
+ Requires-Dist: ocf-data-sampler>=0.2.32
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: PVNet_summation
3
+ Version: 1.0.0
4
+ Summary: PVNet_summation
5
+ Author-email: James Fulton <info@openclimatefix.org>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: pvnet>=5.0.0
10
+ Requires-Dist: ocf-data-sampler>=0.2.32
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: lightning
17
+ Requires-Dist: typer
18
+ Requires-Dist: wandb
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: omegaconf
22
+ Requires-Dist: hydra-core
23
+ Requires-Dist: rich
24
+ Requires-Dist: safetensors
25
+ Dynamic: license-file
26
+
27
+ # PVNet summation
28
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
29
+
30
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
31
+
32
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
33
+
34
+
35
+ ## Setup / Installation
36
+
37
+ ```bash
38
+ git clone https://github.com/openclimatefix/PVNet_summation
39
+ cd PVNet_summation
40
+ pip install .
41
+ ```
42
+
43
+ ### Additional development dependencies
44
+
45
+ ```bash
46
+ pip install ".[dev]"
47
+ ```
48
+
49
+ ## Getting started with running PVNet summation
50
+
51
+ In order to run PVNet summation, we assume that you are already set up with
52
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
53
+
54
+ Before running any code, copy the example configuration to a configs directory:
55
+
56
+ ```
57
+ cp -r configs.example configs
58
+ ```
59
+
60
+ You will be making local amendments to these configs.
61
+
62
+ ### Datasets
63
+
64
+ The datasets required are the same as documented in
65
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
66
+ data for the national sum i.e. GSP ID 0.
67
+
68
+
69
+ ### Training PVNet_summation
70
+
71
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
72
+ configs stored in `configs.example`.
73
+
74
+ Make sure to update the following config files before training your model:
75
+
76
+
77
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
78
+
79
+ 2. In `configs/datamodule/default.yaml`:
80
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
81
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
82
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
83
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
84
+
85
+ 3. In `configs/model/default.yaml`:
86
+ - Update the hyperparameters and structure of the summation model
87
+ 4. In `configs/trainer/default.yaml`:
88
+ - Set `accelerator: 0` if running on a system without a supported GPU
89
+
90
+
91
+ Assuming you have updated the configs, you should now be able to run:
92
+
93
+ ```
94
+ python run.py
95
+ ```
96
+
97
+
98
+ ## Testing
99
+
100
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1,22 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ PVNet_summation.egg-info/PKG-INFO
5
+ PVNet_summation.egg-info/SOURCES.txt
6
+ PVNet_summation.egg-info/dependency_links.txt
7
+ PVNet_summation.egg-info/requires.txt
8
+ PVNet_summation.egg-info/top_level.txt
9
+ pvnet_summation/__init__.py
10
+ pvnet_summation/load_model.py
11
+ pvnet_summation/optimizers.py
12
+ pvnet_summation/utils.py
13
+ pvnet_summation/data/__init__.py
14
+ pvnet_summation/data/datamodule.py
15
+ pvnet_summation/models/__init__.py
16
+ pvnet_summation/models/base_model.py
17
+ pvnet_summation/models/dense_model.py
18
+ pvnet_summation/training/__init__.py
19
+ pvnet_summation/training/lightning_module.py
20
+ pvnet_summation/training/plots.py
21
+ pvnet_summation/training/train.py
22
+ tests/test_end2end.py
@@ -0,0 +1,16 @@
1
+ pvnet>=5.0.0
2
+ ocf-data-sampler>=0.2.32
3
+ numpy
4
+ pandas
5
+ matplotlib
6
+ xarray
7
+ torch>=2.0.0
8
+ lightning
9
+ typer
10
+ wandb
11
+ huggingface-hub
12
+ tqdm
13
+ omegaconf
14
+ hydra-core
15
+ rich
16
+ safetensors
@@ -0,0 +1 @@
1
+ pvnet_summation
@@ -0,0 +1,74 @@
1
+ # PVNet summation
2
+ [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
3
+
4
+ This project is used for training a model to sum the GSP predictions of [PVNet](https://github.com/openclimatefix/pvnet) into a national estimate.
5
+
6
+ Using the summation model to sum the GSP predictions rather than doing a simple sum increases the accuracy of the national predictions and can be configured to produce estimates of the uncertainty range of the national estimate. See the [PVNet](https://github.com/openclimatefix/pvnet) repo for more details and our paper.
7
+
8
+
9
+ ## Setup / Installation
10
+
11
+ ```bash
12
+ git clone https://github.com/openclimatefix/PVNet_summation
13
+ cd PVNet_summation
14
+ pip install .
15
+ ```
16
+
17
+ ### Additional development dependencies
18
+
19
+ ```bash
20
+ pip install ".[dev]"
21
+ ```
22
+
23
+ ## Getting started with running PVNet summation
24
+
25
+ In order to run PVNet summation, we assume that you are already set up with
26
+ [PVNet](https://github.com/openclimatefix/pvnet) and have a trained PVNet model already available either locally or pushed to HuggingFace.
27
+
28
+ Before running any code, copy the example configuration to a configs directory:
29
+
30
+ ```
31
+ cp -r configs.example configs
32
+ ```
33
+
34
+ You will be making local amendments to these configs.
35
+
36
+ ### Datasets
37
+
38
+ The datasets required are the same as documented in
39
+ [PVNet](https://github.com/openclimatefix/pvnet). The only addition is that you will need PVLive
40
+ data for the national sum i.e. GSP ID 0.
41
+
42
+
43
+ ### Training PVNet_summation
44
+
45
+ How PVNet_summation is run is determined by the extensive configuration in the config files. The
46
+ configs stored in `configs.example`.
47
+
48
+ Make sure to update the following config files before training your model:
49
+
50
+
51
+ 1. At the very start of training we loop over all of the input samples and make predictions for them using PVNet. These predictions are saved to disk and will be loaded in the training loop for more efficient training. In `configs/config.yaml` update `sample_save_dir` to set where the predictions will be saved to.
52
+
53
+ 2. In `configs/datamodule/default.yaml`:
54
+ - Update `pvnet_model.model_id` and `pvnet_model.revision` to point to the Huggingface commit or local directory where the exported PVNet model is.
55
+ - Update `configuration` to point to a data configuration compatible with the PVNet model whose outputs will be fed into the summation model.
56
+ - Set `train_period` and `val_period` to control the time ranges of the train and val period
57
+ - Optionally set `max_num_train_samples` and `max_num_val_samples` to limit the number of possible train and validation example which will be used.
58
+
59
+ 3. In `configs/model/default.yaml`:
60
+ - Update the hyperparameters and structure of the summation model
61
+ 4. In `configs/trainer/default.yaml`:
62
+ - Set `accelerator: 0` if running on a system without a supported GPU
63
+
64
+
65
+ Assuming you have updated the configs, you should now be able to run:
66
+
67
+ ```
68
+ python run.py
69
+ ```
70
+
71
+
72
+ ## Testing
73
+
74
+ You can use `python -m pytest tests` to run tests
@@ -0,0 +1 @@
1
+ """PVNet_summation"""
@@ -0,0 +1,2 @@
1
+ """Data module"""
2
+ from .datamodule import PresavedDataModule, StreamedDataModule
@@ -0,0 +1,213 @@
1
+ """Pytorch lightning datamodules for loading pre-saved samples and predictions."""
2
+
3
+ from glob import glob
4
+ from typing import TypeAlias
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from lightning.pytorch import LightningDataModule
10
+ from ocf_data_sampler.load.gsp import open_gsp
11
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
12
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
13
+ from ocf_data_sampler.utils import minutes
14
+ from torch.utils.data import DataLoader, Dataset, default_collate
15
+ from typing_extensions import override
16
+
17
+ SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
18
+ SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
19
+
20
+
21
+ class StreamedDataset(PVNetUKConcurrentDataset):
22
+ """A torch dataset for creating concurrent PVNet inputs and national targets."""
23
+
24
+ def __init__(
25
+ self,
26
+ config_filename: str,
27
+ start_time: str | None = None,
28
+ end_time: str | None = None,
29
+ ) -> None:
30
+ """A torch dataset for creating concurrent PVNet inputs and national targets.
31
+
32
+ Args:
33
+ config_filename: Path to the configuration file
34
+ start_time: Limit the init-times to be after this
35
+ end_time: Limit the init-times to be before this
36
+ """
37
+ super().__init__(config_filename, start_time, end_time, gsp_ids=None)
38
+
39
+ # Load and nornmalise the national GSP data to use as target values
40
+ national_gsp_data = (
41
+ open_gsp(
42
+ zarr_path=self.config.input_data.gsp.zarr_path,
43
+ boundaries_version=self.config.input_data.gsp.boundaries_version
44
+ )
45
+ .sel(gsp_id=0)
46
+ .compute()
47
+ )
48
+ self.national_gsp_data = national_gsp_data / national_gsp_data.effective_capacity_mwp
49
+
50
+
51
+ def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
52
+ """Generate a concurrent PVNet sample for given init-time.
53
+
54
+ Args:
55
+ t0: init-time for sample
56
+ """
57
+
58
+ pvnet_inputs: NumpySample = super()._get_sample(t0)
59
+
60
+ location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
61
+
62
+ valid_times = pd.date_range(
63
+ t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
64
+ t0+minutes(self.config.input_data.gsp.interval_end_minutes),
65
+ freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
66
+ )
67
+
68
+ total_outturns = self.national_gsp_data.sel(time_utc=valid_times).values
69
+ total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
70
+
71
+ relative_capacities = location_capacities / total_capacity
72
+
73
+ return {
74
+ # NumpyBatch object with batch size = num_locations
75
+ "pvnet_inputs": pvnet_inputs,
76
+ # Shape: [time]
77
+ "target": total_outturns,
78
+ # Shape: [time]
79
+ "valid_times": valid_times.values.astype(int),
80
+ # Shape:
81
+ "last_outturn": self.national_gsp_data.sel(time_utc=t0).values,
82
+ # Shape: [num_locations]
83
+ "relative_capacity": relative_capacities,
84
+ }
85
+
86
+ @override
87
+ def __getitem__(self, idx: int) -> SumNumpySample:
88
+ return super().__getitem__(idx)
89
+
90
+ @override
91
+ def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
92
+ return super().get_sample(t0)
93
+
94
+
95
+ class StreamedDataModule(LightningDataModule):
96
+ """Datamodule for training pvnet_summation."""
97
+
98
+ def __init__(
99
+ self,
100
+ configuration: str,
101
+ train_period: list[str | None] = [None, None],
102
+ val_period: list[str | None] = [None, None],
103
+ num_workers: int = 0,
104
+ prefetch_factor: int | None = None,
105
+ persistent_workers: bool = False,
106
+ ):
107
+ """Datamodule for creating concurrent PVNet inputs and national targets.
108
+
109
+ Args:
110
+ configuration: Path to ocf-data-sampler configuration file.
111
+ train_period: Date range filter for train dataloader.
112
+ val_period: Date range filter for val dataloader.
113
+ num_workers: Number of workers to use in multiprocess batch loading.
114
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
115
+ persistent_workers: If True, the data loader will not shut down the worker processes
116
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
117
+ instances alive.
118
+ """
119
+ super().__init__()
120
+ self.configuration = configuration
121
+ self.train_period = train_period
122
+ self.val_period = val_period
123
+
124
+ self._dataloader_kwargs = dict(
125
+ batch_size=None,
126
+ batch_sampler=None,
127
+ num_workers=num_workers,
128
+ collate_fn=None,
129
+ pin_memory=False,
130
+ drop_last=False,
131
+ timeout=0,
132
+ worker_init_fn=None,
133
+ prefetch_factor=prefetch_factor,
134
+ persistent_workers=persistent_workers,
135
+ )
136
+
137
+ def train_dataloader(self, shuffle: bool = False) -> DataLoader:
138
+ """Construct train dataloader"""
139
+ dataset = StreamedDataset(self.configuration, *self.train_period)
140
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
141
+
142
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
143
+ """Construct val dataloader"""
144
+ dataset = StreamedDataset(self.configuration, *self.val_period)
145
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
146
+
147
+
148
+ class PresavedDataset(Dataset):
149
+ """Dataset for loading pre-saved PVNet predictions from disk"""
150
+
151
+ def __init__(self, sample_dir: str):
152
+ """"Dataset for loading pre-saved PVNet predictions from disk.
153
+
154
+ Args:
155
+ sample_dir: The directory containing the saved samples
156
+ """
157
+ self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
158
+
159
+ def __len__(self) -> int:
160
+ return len(self.sample_filepaths)
161
+
162
+ def __getitem__(self, idx: int) -> dict:
163
+ return torch.load(self.sample_filepaths[idx], weights_only=True)
164
+
165
+
166
+ class PresavedDataModule(LightningDataModule):
167
+ """Datamodule for loading pre-saved PVNet predictions."""
168
+
169
+ def __init__(
170
+ self,
171
+ sample_dir: str,
172
+ batch_size: int = 16,
173
+ num_workers: int = 0,
174
+ prefetch_factor: int | None = None,
175
+ persistent_workers: bool = False,
176
+ ):
177
+ """Datamodule for loading pre-saved PVNet predictions.
178
+
179
+ Args:
180
+ sample_dir: Path to the directory of pre-saved samples.
181
+ batch_size: Batch size.
182
+ num_workers: Number of workers to use in multiprocess batch loading.
183
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
184
+ persistent_workers: If True, the data loader will not shut down the worker processes
185
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
186
+ instances alive.
187
+ """
188
+ super().__init__()
189
+ self.sample_dir = sample_dir
190
+
191
+ self._dataloader_kwargs = dict(
192
+ batch_size=batch_size,
193
+ sampler=None,
194
+ batch_sampler=None,
195
+ num_workers=num_workers,
196
+ collate_fn=None if batch_size is None else default_collate,
197
+ pin_memory=False,
198
+ drop_last=False,
199
+ timeout=0,
200
+ worker_init_fn=None,
201
+ prefetch_factor=prefetch_factor,
202
+ persistent_workers=persistent_workers,
203
+ )
204
+
205
+ def train_dataloader(self, shuffle: bool = True) -> DataLoader:
206
+ """Construct train dataloader"""
207
+ dataset = PresavedDataset(f"{self.sample_dir}/train")
208
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
209
+
210
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
211
+ """Construct val dataloader"""
212
+ dataset = PresavedDataset(f"{self.sample_dir}/val")
213
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
@@ -0,0 +1,70 @@
1
+ """Load a model from its checkpoint directory"""
2
+
3
+ import glob
4
+ import os
5
+
6
+ import hydra
7
+ import torch
8
+ import yaml
9
+
10
+ from pvnet_summation.utils import (
11
+ DATAMODULE_CONFIG_NAME,
12
+ FULL_CONFIG_NAME,
13
+ MODEL_CONFIG_NAME,
14
+ )
15
+
16
+
17
+ def get_model_from_checkpoints(
18
+ checkpoint_dir_path: str,
19
+ val_best: bool = True,
20
+ ) -> tuple[torch.nn.Module, dict, str | None, str | None]:
21
+ """Load a model from its checkpoint directory
22
+
23
+ Returns:
24
+ tuple:
25
+ model: nn.Module of pretrained model.
26
+ model_config: path to model config used to train the model.
27
+ datamodule_config: path to datamodule used to create samples e.g train/test split info.
28
+ experiment_configs: path to the full experimental config.
29
+
30
+ """
31
+
32
+ # Load lightning training module
33
+ with open(f"{checkpoint_dir_path}/{MODEL_CONFIG_NAME}") as cfg:
34
+ model_config = yaml.load(cfg, Loader=yaml.FullLoader)
35
+
36
+ lightning_module = hydra.utils.instantiate(model_config)
37
+
38
+ if val_best:
39
+ # Only one epoch (best) saved per model
40
+ files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
41
+ if len(files) != 1:
42
+ raise ValueError(
43
+ f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one."
44
+ )
45
+
46
+ checkpoint = torch.load(files[0], map_location="cpu", weights_only=True)
47
+ else:
48
+ checkpoint = torch.load(
49
+ f"{checkpoint_dir_path}/last.ckpt",
50
+ map_location="cpu",
51
+ weights_only=True,
52
+ )
53
+
54
+ lightning_module.load_state_dict(state_dict=checkpoint["state_dict"])
55
+
56
+ # Extract the model from the lightning module
57
+ model = lightning_module.model
58
+ model_config = model_config["model"]
59
+
60
+ # Check for datamodule config
61
+ # This only exists if the model was trained with presaved samples
62
+ datamodule_config = f"{checkpoint_dir_path}/{DATAMODULE_CONFIG_NAME}"
63
+ datamodule_config = datamodule_config if os.path.isfile(datamodule_config) else None
64
+
65
+ # Check for experiment config
66
+ # For backwards compatibility - this might not always exist
67
+ experiment_config = f"{checkpoint_dir_path}/{FULL_CONFIG_NAME}"
68
+ experiment_config = experiment_config if os.path.isfile(experiment_config) else None
69
+
70
+ return model, model_config, datamodule_config, experiment_config
@@ -0,0 +1,3 @@
1
+ """Models for PVNet summation"""
2
+ from .base_model import BaseModel
3
+ from .dense_model import DenseModel