PVNet_summation 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PVNet_summation might be problematic. Click here for more details.
- pvnet_summation/__init__.py +1 -0
- pvnet_summation/data/__init__.py +2 -0
- pvnet_summation/data/datamodule.py +213 -0
- pvnet_summation/load_model.py +70 -0
- pvnet_summation/models/__init__.py +3 -0
- pvnet_summation/models/base_model.py +345 -0
- pvnet_summation/models/dense_model.py +75 -0
- pvnet_summation/optimizers.py +219 -0
- pvnet_summation/training/__init__.py +3 -0
- pvnet_summation/training/lightning_module.py +247 -0
- pvnet_summation/training/plots.py +80 -0
- pvnet_summation/training/train.py +185 -0
- pvnet_summation/utils.py +87 -0
- pvnet_summation-1.0.0.dist-info/METADATA +100 -0
- pvnet_summation-1.0.0.dist-info/RECORD +18 -0
- pvnet_summation-1.0.0.dist-info/WHEEL +5 -0
- pvnet_summation-1.0.0.dist-info/licenses/LICENSE +21 -0
- pvnet_summation-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""PVNet_summation"""
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Pytorch lightning datamodules for loading pre-saved samples and predictions."""
|
|
2
|
+
|
|
3
|
+
from glob import glob
|
|
4
|
+
from typing import TypeAlias
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from lightning.pytorch import LightningDataModule
|
|
10
|
+
from ocf_data_sampler.load.gsp import open_gsp
|
|
11
|
+
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
|
|
12
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
|
|
13
|
+
from ocf_data_sampler.utils import minutes
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset, default_collate
|
|
15
|
+
from typing_extensions import override
|
|
16
|
+
|
|
17
|
+
SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
|
|
18
|
+
SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StreamedDataset(PVNetUKConcurrentDataset):
|
|
22
|
+
"""A torch dataset for creating concurrent PVNet inputs and national targets."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
config_filename: str,
|
|
27
|
+
start_time: str | None = None,
|
|
28
|
+
end_time: str | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""A torch dataset for creating concurrent PVNet inputs and national targets.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config_filename: Path to the configuration file
|
|
34
|
+
start_time: Limit the init-times to be after this
|
|
35
|
+
end_time: Limit the init-times to be before this
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(config_filename, start_time, end_time, gsp_ids=None)
|
|
38
|
+
|
|
39
|
+
# Load and nornmalise the national GSP data to use as target values
|
|
40
|
+
national_gsp_data = (
|
|
41
|
+
open_gsp(
|
|
42
|
+
zarr_path=self.config.input_data.gsp.zarr_path,
|
|
43
|
+
boundaries_version=self.config.input_data.gsp.boundaries_version
|
|
44
|
+
)
|
|
45
|
+
.sel(gsp_id=0)
|
|
46
|
+
.compute()
|
|
47
|
+
)
|
|
48
|
+
self.national_gsp_data = national_gsp_data / national_gsp_data.effective_capacity_mwp
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
52
|
+
"""Generate a concurrent PVNet sample for given init-time.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
t0: init-time for sample
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
pvnet_inputs: NumpySample = super()._get_sample(t0)
|
|
59
|
+
|
|
60
|
+
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
61
|
+
|
|
62
|
+
valid_times = pd.date_range(
|
|
63
|
+
t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
|
|
64
|
+
t0+minutes(self.config.input_data.gsp.interval_end_minutes),
|
|
65
|
+
freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
total_outturns = self.national_gsp_data.sel(time_utc=valid_times).values
|
|
69
|
+
total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
|
|
70
|
+
|
|
71
|
+
relative_capacities = location_capacities / total_capacity
|
|
72
|
+
|
|
73
|
+
return {
|
|
74
|
+
# NumpyBatch object with batch size = num_locations
|
|
75
|
+
"pvnet_inputs": pvnet_inputs,
|
|
76
|
+
# Shape: [time]
|
|
77
|
+
"target": total_outturns,
|
|
78
|
+
# Shape: [time]
|
|
79
|
+
"valid_times": valid_times.values.astype(int),
|
|
80
|
+
# Shape:
|
|
81
|
+
"last_outturn": self.national_gsp_data.sel(time_utc=t0).values,
|
|
82
|
+
# Shape: [num_locations]
|
|
83
|
+
"relative_capacity": relative_capacities,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def __getitem__(self, idx: int) -> SumNumpySample:
|
|
88
|
+
return super().__getitem__(idx)
|
|
89
|
+
|
|
90
|
+
@override
|
|
91
|
+
def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
92
|
+
return super().get_sample(t0)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class StreamedDataModule(LightningDataModule):
|
|
96
|
+
"""Datamodule for training pvnet_summation."""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
configuration: str,
|
|
101
|
+
train_period: list[str | None] = [None, None],
|
|
102
|
+
val_period: list[str | None] = [None, None],
|
|
103
|
+
num_workers: int = 0,
|
|
104
|
+
prefetch_factor: int | None = None,
|
|
105
|
+
persistent_workers: bool = False,
|
|
106
|
+
):
|
|
107
|
+
"""Datamodule for creating concurrent PVNet inputs and national targets.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
configuration: Path to ocf-data-sampler configuration file.
|
|
111
|
+
train_period: Date range filter for train dataloader.
|
|
112
|
+
val_period: Date range filter for val dataloader.
|
|
113
|
+
num_workers: Number of workers to use in multiprocess batch loading.
|
|
114
|
+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
|
115
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
116
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
117
|
+
instances alive.
|
|
118
|
+
"""
|
|
119
|
+
super().__init__()
|
|
120
|
+
self.configuration = configuration
|
|
121
|
+
self.train_period = train_period
|
|
122
|
+
self.val_period = val_period
|
|
123
|
+
|
|
124
|
+
self._dataloader_kwargs = dict(
|
|
125
|
+
batch_size=None,
|
|
126
|
+
batch_sampler=None,
|
|
127
|
+
num_workers=num_workers,
|
|
128
|
+
collate_fn=None,
|
|
129
|
+
pin_memory=False,
|
|
130
|
+
drop_last=False,
|
|
131
|
+
timeout=0,
|
|
132
|
+
worker_init_fn=None,
|
|
133
|
+
prefetch_factor=prefetch_factor,
|
|
134
|
+
persistent_workers=persistent_workers,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def train_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
138
|
+
"""Construct train dataloader"""
|
|
139
|
+
dataset = StreamedDataset(self.configuration, *self.train_period)
|
|
140
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
141
|
+
|
|
142
|
+
def val_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
143
|
+
"""Construct val dataloader"""
|
|
144
|
+
dataset = StreamedDataset(self.configuration, *self.val_period)
|
|
145
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class PresavedDataset(Dataset):
|
|
149
|
+
"""Dataset for loading pre-saved PVNet predictions from disk"""
|
|
150
|
+
|
|
151
|
+
def __init__(self, sample_dir: str):
|
|
152
|
+
""""Dataset for loading pre-saved PVNet predictions from disk.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
sample_dir: The directory containing the saved samples
|
|
156
|
+
"""
|
|
157
|
+
self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
|
|
158
|
+
|
|
159
|
+
def __len__(self) -> int:
|
|
160
|
+
return len(self.sample_filepaths)
|
|
161
|
+
|
|
162
|
+
def __getitem__(self, idx: int) -> dict:
|
|
163
|
+
return torch.load(self.sample_filepaths[idx], weights_only=True)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class PresavedDataModule(LightningDataModule):
|
|
167
|
+
"""Datamodule for loading pre-saved PVNet predictions."""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
sample_dir: str,
|
|
172
|
+
batch_size: int = 16,
|
|
173
|
+
num_workers: int = 0,
|
|
174
|
+
prefetch_factor: int | None = None,
|
|
175
|
+
persistent_workers: bool = False,
|
|
176
|
+
):
|
|
177
|
+
"""Datamodule for loading pre-saved PVNet predictions.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
sample_dir: Path to the directory of pre-saved samples.
|
|
181
|
+
batch_size: Batch size.
|
|
182
|
+
num_workers: Number of workers to use in multiprocess batch loading.
|
|
183
|
+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
|
184
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
185
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
186
|
+
instances alive.
|
|
187
|
+
"""
|
|
188
|
+
super().__init__()
|
|
189
|
+
self.sample_dir = sample_dir
|
|
190
|
+
|
|
191
|
+
self._dataloader_kwargs = dict(
|
|
192
|
+
batch_size=batch_size,
|
|
193
|
+
sampler=None,
|
|
194
|
+
batch_sampler=None,
|
|
195
|
+
num_workers=num_workers,
|
|
196
|
+
collate_fn=None if batch_size is None else default_collate,
|
|
197
|
+
pin_memory=False,
|
|
198
|
+
drop_last=False,
|
|
199
|
+
timeout=0,
|
|
200
|
+
worker_init_fn=None,
|
|
201
|
+
prefetch_factor=prefetch_factor,
|
|
202
|
+
persistent_workers=persistent_workers,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def train_dataloader(self, shuffle: bool = True) -> DataLoader:
|
|
206
|
+
"""Construct train dataloader"""
|
|
207
|
+
dataset = PresavedDataset(f"{self.sample_dir}/train")
|
|
208
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
209
|
+
|
|
210
|
+
def val_dataloader(self, shuffle: bool = False) -> DataLoader:
|
|
211
|
+
"""Construct val dataloader"""
|
|
212
|
+
dataset = PresavedDataset(f"{self.sample_dir}/val")
|
|
213
|
+
return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Load a model from its checkpoint directory"""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import torch
|
|
8
|
+
import yaml
|
|
9
|
+
|
|
10
|
+
from pvnet_summation.utils import (
|
|
11
|
+
DATAMODULE_CONFIG_NAME,
|
|
12
|
+
FULL_CONFIG_NAME,
|
|
13
|
+
MODEL_CONFIG_NAME,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_model_from_checkpoints(
|
|
18
|
+
checkpoint_dir_path: str,
|
|
19
|
+
val_best: bool = True,
|
|
20
|
+
) -> tuple[torch.nn.Module, dict, str | None, str | None]:
|
|
21
|
+
"""Load a model from its checkpoint directory
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
tuple:
|
|
25
|
+
model: nn.Module of pretrained model.
|
|
26
|
+
model_config: path to model config used to train the model.
|
|
27
|
+
datamodule_config: path to datamodule used to create samples e.g train/test split info.
|
|
28
|
+
experiment_configs: path to the full experimental config.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Load lightning training module
|
|
33
|
+
with open(f"{checkpoint_dir_path}/{MODEL_CONFIG_NAME}") as cfg:
|
|
34
|
+
model_config = yaml.load(cfg, Loader=yaml.FullLoader)
|
|
35
|
+
|
|
36
|
+
lightning_module = hydra.utils.instantiate(model_config)
|
|
37
|
+
|
|
38
|
+
if val_best:
|
|
39
|
+
# Only one epoch (best) saved per model
|
|
40
|
+
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
|
|
41
|
+
if len(files) != 1:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
checkpoint = torch.load(files[0], map_location="cpu", weights_only=True)
|
|
47
|
+
else:
|
|
48
|
+
checkpoint = torch.load(
|
|
49
|
+
f"{checkpoint_dir_path}/last.ckpt",
|
|
50
|
+
map_location="cpu",
|
|
51
|
+
weights_only=True,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
lightning_module.load_state_dict(state_dict=checkpoint["state_dict"])
|
|
55
|
+
|
|
56
|
+
# Extract the model from the lightning module
|
|
57
|
+
model = lightning_module.model
|
|
58
|
+
model_config = model_config["model"]
|
|
59
|
+
|
|
60
|
+
# Check for datamodule config
|
|
61
|
+
# This only exists if the model was trained with presaved samples
|
|
62
|
+
datamodule_config = f"{checkpoint_dir_path}/{DATAMODULE_CONFIG_NAME}"
|
|
63
|
+
datamodule_config = datamodule_config if os.path.isfile(datamodule_config) else None
|
|
64
|
+
|
|
65
|
+
# Check for experiment config
|
|
66
|
+
# For backwards compatibility - this might not always exist
|
|
67
|
+
experiment_config = f"{checkpoint_dir_path}/{FULL_CONFIG_NAME}"
|
|
68
|
+
experiment_config = experiment_config if os.path.isfile(experiment_config) else None
|
|
69
|
+
|
|
70
|
+
return model, model_config, datamodule_config, experiment_config
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""Base model for all PVNet submodels"""
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
import time
|
|
6
|
+
from importlib.metadata import version
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import hydra
|
|
10
|
+
import torch
|
|
11
|
+
import yaml
|
|
12
|
+
from huggingface_hub import ModelCard, ModelCardData, snapshot_download
|
|
13
|
+
from huggingface_hub.hf_api import HfApi
|
|
14
|
+
from safetensors.torch import load_file, save_file
|
|
15
|
+
|
|
16
|
+
from pvnet_summation.data.datamodule import SumTensorBatch
|
|
17
|
+
from pvnet_summation.utils import (
|
|
18
|
+
DATAMODULE_CONFIG_NAME,
|
|
19
|
+
FULL_CONFIG_NAME,
|
|
20
|
+
MODEL_CARD_NAME,
|
|
21
|
+
MODEL_CONFIG_NAME,
|
|
22
|
+
PYTORCH_WEIGHTS_NAME,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def santize_datamodule(config: dict) -> dict:
|
|
27
|
+
"""Create new datamodule config which only keeps the details required for inference"""
|
|
28
|
+
return {"pvnet_model": config["pvnet_model"]}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def download_from_hf(
|
|
32
|
+
repo_id: str,
|
|
33
|
+
filename: str | list[str],
|
|
34
|
+
revision: str,
|
|
35
|
+
cache_dir: str | None,
|
|
36
|
+
force_download: bool,
|
|
37
|
+
max_retries: int = 5,
|
|
38
|
+
wait_time: int = 10,
|
|
39
|
+
) -> str | list[str]:
|
|
40
|
+
"""Tries to download one or more files from HuggingFace up to max_retries times.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
repo_id: HuggingFace repo ID
|
|
44
|
+
filename: Name of the file(s) to download
|
|
45
|
+
revision: Specific model revision
|
|
46
|
+
cache_dir: Cache directory
|
|
47
|
+
force_download: Whether to force a new download
|
|
48
|
+
max_retries: Maximum number of retry attempts
|
|
49
|
+
wait_time: Wait time (in seconds) before retrying
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The local file path of the downloaded file(s)
|
|
53
|
+
"""
|
|
54
|
+
for attempt in range(1, max_retries + 1):
|
|
55
|
+
try:
|
|
56
|
+
save_dir = snapshot_download(
|
|
57
|
+
repo_id=repo_id,
|
|
58
|
+
allow_patterns=filename,
|
|
59
|
+
revision=revision,
|
|
60
|
+
cache_dir=cache_dir,
|
|
61
|
+
force_download=force_download,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if isinstance(filename, list):
|
|
65
|
+
return [f"{save_dir}/{f}" for f in filename]
|
|
66
|
+
else:
|
|
67
|
+
return f"{save_dir}/{filename}"
|
|
68
|
+
|
|
69
|
+
except Exception as e:
|
|
70
|
+
if attempt == max_retries:
|
|
71
|
+
raise Exception(
|
|
72
|
+
f"Failed to download {filename} from {repo_id} after {max_retries} attempts."
|
|
73
|
+
) from e
|
|
74
|
+
logging.warning(
|
|
75
|
+
(
|
|
76
|
+
f"Attempt {attempt}/{max_retries} failed to download {filename} "
|
|
77
|
+
f"from {repo_id}. Retrying in {wait_time} seconds..."
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
time.sleep(wait_time)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class HuggingfaceMixin:
|
|
84
|
+
"""Mixin for saving and loading model to and from huggingface"""
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_pretrained(
|
|
88
|
+
cls,
|
|
89
|
+
model_id: str,
|
|
90
|
+
revision: str,
|
|
91
|
+
cache_dir: str | None = None,
|
|
92
|
+
force_download: bool = False,
|
|
93
|
+
strict: bool = True,
|
|
94
|
+
) -> "BaseModel":
|
|
95
|
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
|
96
|
+
|
|
97
|
+
if os.path.isdir(model_id):
|
|
98
|
+
print("Loading model from local directory")
|
|
99
|
+
model_file = f"{model_id}/{PYTORCH_WEIGHTS_NAME}"
|
|
100
|
+
config_file = f"{model_id}/{MODEL_CONFIG_NAME}"
|
|
101
|
+
else:
|
|
102
|
+
print("Loading model from huggingface repo")
|
|
103
|
+
|
|
104
|
+
model_file, config_file = download_from_hf(
|
|
105
|
+
repo_id=model_id,
|
|
106
|
+
filename=[PYTORCH_WEIGHTS_NAME, MODEL_CONFIG_NAME],
|
|
107
|
+
revision=revision,
|
|
108
|
+
cache_dir=cache_dir,
|
|
109
|
+
force_download=force_download,
|
|
110
|
+
max_retries=5,
|
|
111
|
+
wait_time=10,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with open(config_file, "r") as f:
|
|
115
|
+
model = hydra.utils.instantiate(yaml.safe_load(f))
|
|
116
|
+
|
|
117
|
+
state_dict = load_file(model_file)
|
|
118
|
+
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
|
119
|
+
model.eval() # type: ignore
|
|
120
|
+
|
|
121
|
+
return model
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def get_datamodule_config(
|
|
125
|
+
cls,
|
|
126
|
+
model_id: str,
|
|
127
|
+
revision: str,
|
|
128
|
+
cache_dir: str | None = None,
|
|
129
|
+
force_download: bool = False,
|
|
130
|
+
) -> str:
|
|
131
|
+
"""Load data config file."""
|
|
132
|
+
if os.path.isdir(model_id):
|
|
133
|
+
print("Loading datamodule config from local directory")
|
|
134
|
+
datamodule_config_file = os.path.join(model_id, DATAMODULE_CONFIG_NAME)
|
|
135
|
+
else:
|
|
136
|
+
print("Loading datamodule config from huggingface repo")
|
|
137
|
+
datamodule_config_file = download_from_hf(
|
|
138
|
+
repo_id=model_id,
|
|
139
|
+
filename=DATAMODULE_CONFIG_NAME,
|
|
140
|
+
revision=revision,
|
|
141
|
+
cache_dir=cache_dir,
|
|
142
|
+
force_download=force_download,
|
|
143
|
+
max_retries=5,
|
|
144
|
+
wait_time=10,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return datamodule_config_file
|
|
148
|
+
|
|
149
|
+
def _save_model_weights(self, save_directory: str) -> None:
|
|
150
|
+
"""Save weights from a Pytorch model to a local directory."""
|
|
151
|
+
save_file(self.state_dict(), f"{save_directory}/{PYTORCH_WEIGHTS_NAME}")
|
|
152
|
+
|
|
153
|
+
def save_pretrained(
|
|
154
|
+
self,
|
|
155
|
+
save_directory: str,
|
|
156
|
+
model_config: dict,
|
|
157
|
+
wandb_repo: str,
|
|
158
|
+
wandb_id: str,
|
|
159
|
+
card_template_path: str,
|
|
160
|
+
datamodule_config_path,
|
|
161
|
+
experiment_config_path: str | None = None,
|
|
162
|
+
hf_repo_id: str | None = None,
|
|
163
|
+
push_to_hub: bool = False,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Save weights in local directory or upload to huggingface hub.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
save_directory:
|
|
169
|
+
Path to directory in which the model weights and configuration will be saved.
|
|
170
|
+
model_config (`dict`):
|
|
171
|
+
Model configuration specified as a key/value dictionary.
|
|
172
|
+
wandb_repo: Identifier of the repo on wandb.
|
|
173
|
+
wandb_id: Identifier of the model on wandb.
|
|
174
|
+
datamodule_config_path:
|
|
175
|
+
The path to the datamodule config.
|
|
176
|
+
card_template_path: Path to the HuggingFace model card template. Defaults to card in
|
|
177
|
+
PVNet library if set to None.
|
|
178
|
+
experiment_config_path:
|
|
179
|
+
The path to the full experimental config.
|
|
180
|
+
hf_repo_id:
|
|
181
|
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
|
|
182
|
+
the folder name if not provided.
|
|
183
|
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
184
|
+
Whether or not to push your model to the HuggingFace Hub after saving it.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
save_directory = Path(save_directory)
|
|
188
|
+
save_directory.mkdir(parents=True, exist_ok=True)
|
|
189
|
+
|
|
190
|
+
# Save model weights/files
|
|
191
|
+
self._save_model_weights(save_directory)
|
|
192
|
+
|
|
193
|
+
# Save the model config
|
|
194
|
+
if isinstance(model_config, dict):
|
|
195
|
+
with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile:
|
|
196
|
+
yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False)
|
|
197
|
+
|
|
198
|
+
# Sanitize and save the datamodule config
|
|
199
|
+
with open(datamodule_config_path) as cfg:
|
|
200
|
+
datamodule_config = yaml.load(cfg, Loader=yaml.FullLoader)
|
|
201
|
+
|
|
202
|
+
datamodule_config = santize_datamodule(datamodule_config)
|
|
203
|
+
|
|
204
|
+
with open(save_directory / DATAMODULE_CONFIG_NAME, "w") as outfile:
|
|
205
|
+
yaml.dump(datamodule_config, outfile, sort_keys=False, default_flow_style=False)
|
|
206
|
+
|
|
207
|
+
# Save the full experimental config
|
|
208
|
+
if experiment_config_path is not None:
|
|
209
|
+
shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
|
|
210
|
+
|
|
211
|
+
card = self.create_hugging_face_model_card(card_template_path, wandb_repo, wandb_id)
|
|
212
|
+
|
|
213
|
+
(save_directory / MODEL_CARD_NAME).write_text(str(card))
|
|
214
|
+
|
|
215
|
+
if push_to_hub:
|
|
216
|
+
api = HfApi()
|
|
217
|
+
|
|
218
|
+
api.upload_folder(
|
|
219
|
+
repo_id=hf_repo_id,
|
|
220
|
+
folder_path=save_directory,
|
|
221
|
+
repo_type="model",
|
|
222
|
+
commit_message=f"Upload model - {wandb_id}",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Print the most recent commit hash
|
|
226
|
+
c = api.list_repo_commits(repo_id=hf_repo_id, repo_type="model")[0]
|
|
227
|
+
|
|
228
|
+
message = (
|
|
229
|
+
f"The latest commit is now: \n"
|
|
230
|
+
f" date: {c.created_at} \n"
|
|
231
|
+
f" commit hash: {c.commit_id}\n"
|
|
232
|
+
f" by: {c.authors}\n"
|
|
233
|
+
f" title: {c.title}\n"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
print(message)
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def create_hugging_face_model_card(
|
|
240
|
+
card_template_path: str,
|
|
241
|
+
wandb_repo: str,
|
|
242
|
+
wandb_id: str,
|
|
243
|
+
) -> ModelCard:
|
|
244
|
+
"""
|
|
245
|
+
Creates Hugging Face model card
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
card_template_path: Path to the HuggingFace model card template
|
|
249
|
+
wandb_repo: Identifier of the repo on wandb.
|
|
250
|
+
wandb_id: Identifier of the model on wandb.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
card: ModelCard - Hugging Face model card object
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
# Creating and saving model card.
|
|
257
|
+
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
|
|
258
|
+
|
|
259
|
+
link = f"https://wandb.ai/{wandb_repo}/runs/{wandb_id}"
|
|
260
|
+
wandb_link = f" - [{link}]({link})\n"
|
|
261
|
+
|
|
262
|
+
# Find package versions for OCF packages
|
|
263
|
+
packages_to_display = ["pvnet_summation", "ocf-data-sampler"]
|
|
264
|
+
packages_and_versions = {package: version(package) for package in packages_to_display}
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
package_versions_markdown = ""
|
|
268
|
+
for package, v in packages_and_versions.items():
|
|
269
|
+
package_versions_markdown += f" - {package}=={v}\n"
|
|
270
|
+
|
|
271
|
+
return ModelCard.from_template(
|
|
272
|
+
card_data,
|
|
273
|
+
template_path=card_template_path,
|
|
274
|
+
wandb_link=wandb_link,
|
|
275
|
+
package_versions=package_versions_markdown,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
280
|
+
"""Abstract base class for PVNet-summation submodels"""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
output_quantiles: list[float] | None,
|
|
285
|
+
num_input_locations: int,
|
|
286
|
+
input_quantiles: list[float] | None,
|
|
287
|
+
history_minutes: int,
|
|
288
|
+
forecast_minutes: int,
|
|
289
|
+
interval_minutes: int,
|
|
290
|
+
):
|
|
291
|
+
"""Abtstract base class for PVNet-summation submodels.
|
|
292
|
+
|
|
293
|
+
"""
|
|
294
|
+
super().__init__()
|
|
295
|
+
|
|
296
|
+
self.output_quantiles = output_quantiles
|
|
297
|
+
|
|
298
|
+
self.num_input_locations = num_input_locations
|
|
299
|
+
self.input_quantiles = input_quantiles
|
|
300
|
+
|
|
301
|
+
self.history_minutes = history_minutes
|
|
302
|
+
self.forecast_minutes = forecast_minutes
|
|
303
|
+
self.interval_minutes = interval_minutes
|
|
304
|
+
|
|
305
|
+
# Number of timestemps for 30 minutely data
|
|
306
|
+
self.history_len = history_minutes // interval_minutes
|
|
307
|
+
self.forecast_len = (forecast_minutes) // interval_minutes
|
|
308
|
+
|
|
309
|
+
# Store whether the model should use quantile regression or simply predict the mean
|
|
310
|
+
self.use_quantile_regression = self.output_quantiles is not None
|
|
311
|
+
|
|
312
|
+
# Store the number of ouput features that the model should predict for
|
|
313
|
+
if self.use_quantile_regression:
|
|
314
|
+
self.num_output_features = self.forecast_len * len(self.output_quantiles)
|
|
315
|
+
else:
|
|
316
|
+
self.num_output_features = self.forecast_len
|
|
317
|
+
|
|
318
|
+
# Store the expected input shape
|
|
319
|
+
if input_quantiles is None:
|
|
320
|
+
self.input_shape = (self.num_input_locations, self.forecast_len)
|
|
321
|
+
else:
|
|
322
|
+
self.input_shape = (self.num_input_locations, self.forecast_len, len(input_quantiles))
|
|
323
|
+
|
|
324
|
+
def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
"""Convert network prediction into a point prediction.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
y_quantiles: Quantile prediction of network
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
torch.Tensor: Point prediction
|
|
332
|
+
"""
|
|
333
|
+
# y_quantiles Shape: [batch_size, seq_length, num_quantiles]
|
|
334
|
+
idx = self.output_quantiles.index(0.5)
|
|
335
|
+
return y_quantiles[..., idx]
|
|
336
|
+
|
|
337
|
+
def sum_of_locations(self, x: SumTensorBatch) -> torch.Tensor:
|
|
338
|
+
"""Compute the sum of the location-level predictions"""
|
|
339
|
+
if self.input_quantiles is None:
|
|
340
|
+
y_hat = x["pvnet_outputs"]
|
|
341
|
+
else:
|
|
342
|
+
idx = self.input_quantiles.index(0.5)
|
|
343
|
+
y_hat = x["pvnet_outputs"][..., idx]
|
|
344
|
+
|
|
345
|
+
return (y_hat * x["relative_capacity"].unsqueeze(-1)).sum(dim=1)
|