PVNet 5.2.3__py3-none-any.whl → 5.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/datamodule.py +12 -28
- pvnet/models/base_model.py +18 -23
- pvnet/models/ensemble.py +0 -4
- pvnet/models/late_fusion/late_fusion.py +28 -55
- pvnet/models/late_fusion/site_encoders/encoders.py +14 -24
- pvnet/training/lightning_module.py +44 -49
- pvnet/training/plots.py +2 -2
- pvnet/utils.py +26 -16
- {pvnet-5.2.3.dist-info → pvnet-5.3.1.dist-info}/METADATA +1 -1
- {pvnet-5.2.3.dist-info → pvnet-5.3.1.dist-info}/RECORD +13 -13
- {pvnet-5.2.3.dist-info → pvnet-5.3.1.dist-info}/WHEEL +0 -0
- {pvnet-5.2.3.dist-info → pvnet-5.3.1.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.2.3.dist-info → pvnet-5.3.1.dist-info}/top_level.txt +0 -0
pvnet/datamodule.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Data module for pytorch lightning"""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
@@ -6,10 +6,9 @@ import numpy as np
|
|
|
6
6
|
from lightning.pytorch import LightningDataModule
|
|
7
7
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
8
8
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
|
|
9
|
-
from ocf_data_sampler.torch_datasets.
|
|
10
|
-
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
|
|
9
|
+
from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset
|
|
11
10
|
from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor
|
|
12
|
-
from torch.utils.data import DataLoader,
|
|
11
|
+
from torch.utils.data import DataLoader, Subset
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def collate_fn(samples: list[NumpySample]) -> TensorBatch:
|
|
@@ -17,7 +16,7 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
|
|
|
17
16
|
return batch_to_tensor(stack_np_samples_into_batch(samples))
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class
|
|
19
|
+
class PVNetDataModule(LightningDataModule):
|
|
21
20
|
"""Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
|
|
22
21
|
|
|
23
22
|
def __init__(
|
|
@@ -40,10 +39,10 @@ class BaseDataModule(LightningDataModule):
|
|
|
40
39
|
batch_size: Batch size.
|
|
41
40
|
num_workers: Number of workers to use in multiprocess batch loading.
|
|
42
41
|
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
43
|
-
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
44
|
-
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
42
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
43
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
45
44
|
instances alive.
|
|
46
|
-
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
45
|
+
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
47
46
|
before returning them.
|
|
48
47
|
train_period: Date range filter for train dataloader.
|
|
49
48
|
val_period: Date range filter for val dataloader.
|
|
@@ -70,7 +69,7 @@ class BaseDataModule(LightningDataModule):
|
|
|
70
69
|
worker_init_fn=None,
|
|
71
70
|
prefetch_factor=prefetch_factor,
|
|
72
71
|
persistent_workers=persistent_workers,
|
|
73
|
-
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
72
|
+
multiprocessing_context="spawn" if num_workers > 0 else None,
|
|
74
73
|
)
|
|
75
74
|
|
|
76
75
|
def setup(self, stage: str | None = None):
|
|
@@ -79,16 +78,15 @@ class BaseDataModule(LightningDataModule):
|
|
|
79
78
|
# This logic runs only once at the start of training, therefore the val dataset is only
|
|
80
79
|
# shuffled once
|
|
81
80
|
if stage == "fit":
|
|
82
|
-
|
|
83
81
|
# Prepare the train dataset
|
|
84
82
|
self.train_dataset = self._get_dataset(*self.train_period)
|
|
85
83
|
|
|
86
|
-
#
|
|
84
|
+
# Prepare and pre-shuffle the val dataset and set seed for reproducibility
|
|
87
85
|
val_dataset = self._get_dataset(*self.val_period)
|
|
88
86
|
|
|
89
87
|
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
90
88
|
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
91
|
-
|
|
89
|
+
|
|
92
90
|
if self.dataset_pickle_dir is not None:
|
|
93
91
|
os.makedirs(self.dataset_pickle_dir, exist_ok=True)
|
|
94
92
|
train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
|
|
@@ -116,8 +114,8 @@ class BaseDataModule(LightningDataModule):
|
|
|
116
114
|
if os.path.exists(filepath):
|
|
117
115
|
os.remove(filepath)
|
|
118
116
|
|
|
119
|
-
def _get_dataset(self, start_time: str | None, end_time: str | None) ->
|
|
120
|
-
|
|
117
|
+
def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset:
|
|
118
|
+
return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
121
119
|
|
|
122
120
|
def train_dataloader(self) -> DataLoader:
|
|
123
121
|
"""Construct train dataloader"""
|
|
@@ -126,17 +124,3 @@ class BaseDataModule(LightningDataModule):
|
|
|
126
124
|
def val_dataloader(self) -> DataLoader:
|
|
127
125
|
"""Construct val dataloader"""
|
|
128
126
|
return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class UKRegionalDataModule(BaseDataModule):
|
|
132
|
-
"""Datamodule for streaming UK regional samples."""
|
|
133
|
-
|
|
134
|
-
def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset:
|
|
135
|
-
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class SitesDataModule(BaseDataModule):
|
|
139
|
-
"""Datamodule for streaming site samples."""
|
|
140
|
-
|
|
141
|
-
def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset:
|
|
142
|
-
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
|
pvnet/models/base_model.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Base model for all PVNet submodels"""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import shutil
|
|
@@ -32,7 +33,7 @@ def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHO
|
|
|
32
33
|
"""
|
|
33
34
|
input_config = config["input_data"]
|
|
34
35
|
|
|
35
|
-
for source in ["
|
|
36
|
+
for source in ["generation", "satellite"]:
|
|
36
37
|
if source in input_config:
|
|
37
38
|
# If not empty - i.e. if used
|
|
38
39
|
if input_config[source]["zarr_path"] != "":
|
|
@@ -78,8 +79,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
|
|
|
78
79
|
|
|
79
80
|
# Replace the interval_end_minutes minutes
|
|
80
81
|
nwp_config["interval_end_minutes"] = (
|
|
81
|
-
nwp_config["interval_start_minutes"]
|
|
82
|
-
(model.nwp_encoders_dict[nwp_source].sequence_length - 1)
|
|
82
|
+
nwp_config["interval_start_minutes"]
|
|
83
|
+
+ (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
|
|
83
84
|
* nwp_config["time_resolution_minutes"]
|
|
84
85
|
)
|
|
85
86
|
|
|
@@ -96,20 +97,19 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
|
|
|
96
97
|
|
|
97
98
|
# Replace the interval_end_minutes minutes
|
|
98
99
|
sat_config["interval_end_minutes"] = (
|
|
99
|
-
sat_config["interval_start_minutes"]
|
|
100
|
-
(model.sat_encoder.sequence_length - 1)
|
|
101
|
-
* sat_config["time_resolution_minutes"]
|
|
100
|
+
sat_config["interval_start_minutes"]
|
|
101
|
+
+ (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"]
|
|
102
102
|
)
|
|
103
103
|
|
|
104
104
|
if "pv" in input_config:
|
|
105
105
|
if not model.include_pv:
|
|
106
106
|
del input_config["pv"]
|
|
107
107
|
|
|
108
|
-
if "
|
|
109
|
-
|
|
108
|
+
if "generation" in input_config:
|
|
109
|
+
generation_config = input_config["generation"]
|
|
110
110
|
|
|
111
111
|
# Replace the forecast minutes
|
|
112
|
-
|
|
112
|
+
generation_config["interval_end_minutes"] = model.forecast_minutes
|
|
113
113
|
|
|
114
114
|
if "solar_position" in input_config:
|
|
115
115
|
solar_config = input_config["solar_position"]
|
|
@@ -138,9 +138,9 @@ def download_from_hf(
|
|
|
138
138
|
force_download: Whether to force a new download
|
|
139
139
|
max_retries: Maximum number of retry attempts
|
|
140
140
|
wait_time: Wait time (in seconds) before retrying
|
|
141
|
-
token:
|
|
141
|
+
token:
|
|
142
142
|
HF authentication token. If True, the token is read from the HuggingFace config folder.
|
|
143
|
-
If a string, it is used as the authentication token.
|
|
143
|
+
If a string, it is used as the authentication token.
|
|
144
144
|
|
|
145
145
|
Returns:
|
|
146
146
|
The local file path of the downloaded file(s)
|
|
@@ -160,7 +160,7 @@ def download_from_hf(
|
|
|
160
160
|
return [f"{save_dir}/{f}" for f in filename]
|
|
161
161
|
else:
|
|
162
162
|
return f"{save_dir}/{filename}"
|
|
163
|
-
|
|
163
|
+
|
|
164
164
|
except Exception as e:
|
|
165
165
|
if attempt == max_retries:
|
|
166
166
|
raise Exception(
|
|
@@ -205,7 +205,7 @@ class HuggingfaceMixin:
|
|
|
205
205
|
force_download=force_download,
|
|
206
206
|
max_retries=5,
|
|
207
207
|
wait_time=10,
|
|
208
|
-
token=token
|
|
208
|
+
token=token,
|
|
209
209
|
)
|
|
210
210
|
|
|
211
211
|
with open(config_file, "r") as f:
|
|
@@ -240,7 +240,7 @@ class HuggingfaceMixin:
|
|
|
240
240
|
force_download=force_download,
|
|
241
241
|
max_retries=5,
|
|
242
242
|
wait_time=10,
|
|
243
|
-
token=token
|
|
243
|
+
token=token,
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
return data_config_file
|
|
@@ -301,7 +301,7 @@ class HuggingfaceMixin:
|
|
|
301
301
|
# Save cleaned version of input data configuration file
|
|
302
302
|
with open(data_config_path) as cfg:
|
|
303
303
|
config = yaml.load(cfg, Loader=yaml.FullLoader)
|
|
304
|
-
|
|
304
|
+
|
|
305
305
|
config = fill_config_paths_with_placeholder(config)
|
|
306
306
|
config = minimize_config_for_model(config, self)
|
|
307
307
|
|
|
@@ -311,7 +311,7 @@ class HuggingfaceMixin:
|
|
|
311
311
|
# Save the datamodule config
|
|
312
312
|
if datamodule_config_path is not None:
|
|
313
313
|
shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
|
|
314
|
-
|
|
314
|
+
|
|
315
315
|
# Save the full experimental config
|
|
316
316
|
if experiment_config_path is not None:
|
|
317
317
|
shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
|
|
@@ -378,7 +378,6 @@ class HuggingfaceMixin:
|
|
|
378
378
|
packages_to_display = ["pvnet", "ocf-data-sampler"]
|
|
379
379
|
packages_and_versions = {package: version(package) for package in packages_to_display}
|
|
380
380
|
|
|
381
|
-
|
|
382
381
|
package_versions_markdown = ""
|
|
383
382
|
for package, v in packages_and_versions.items():
|
|
384
383
|
package_versions_markdown += f" - {package}=={v}\n"
|
|
@@ -399,23 +398,19 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
399
398
|
history_minutes: int,
|
|
400
399
|
forecast_minutes: int,
|
|
401
400
|
output_quantiles: list[float] | None = None,
|
|
402
|
-
target_key: str = "gsp",
|
|
403
401
|
interval_minutes: int = 30,
|
|
404
402
|
):
|
|
405
403
|
"""Abtstract base class for PVNet submodels.
|
|
406
404
|
|
|
407
405
|
Args:
|
|
408
|
-
history_minutes (int): Length of the
|
|
409
|
-
forecast_minutes (int): Length of the
|
|
406
|
+
history_minutes (int): Length of the generation history period in minutes
|
|
407
|
+
forecast_minutes (int): Length of the generation forecast period in minutes
|
|
410
408
|
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
|
411
409
|
None the output is a single value.
|
|
412
|
-
target_key: The key of the target variable in the batch
|
|
413
410
|
interval_minutes: The interval in minutes between each timestep in the data
|
|
414
411
|
"""
|
|
415
412
|
super().__init__()
|
|
416
413
|
|
|
417
|
-
self._target_key = target_key
|
|
418
|
-
|
|
419
414
|
self.history_minutes = history_minutes
|
|
420
415
|
self.forecast_minutes = forecast_minutes
|
|
421
416
|
self.output_quantiles = output_quantiles
|
pvnet/models/ensemble.py
CHANGED
|
@@ -26,7 +26,6 @@ class Ensemble(BaseModel):
|
|
|
26
26
|
output_quantiles = []
|
|
27
27
|
history_minutes = []
|
|
28
28
|
forecast_minutes = []
|
|
29
|
-
target_key = []
|
|
30
29
|
interval_minutes = []
|
|
31
30
|
|
|
32
31
|
# Get some model properties from each model
|
|
@@ -34,7 +33,6 @@ class Ensemble(BaseModel):
|
|
|
34
33
|
output_quantiles.append(model.output_quantiles)
|
|
35
34
|
history_minutes.append(model.history_minutes)
|
|
36
35
|
forecast_minutes.append(model.forecast_minutes)
|
|
37
|
-
target_key.append(model._target_key)
|
|
38
36
|
interval_minutes.append(model.interval_minutes)
|
|
39
37
|
|
|
40
38
|
# Check these properties are all the same
|
|
@@ -42,7 +40,6 @@ class Ensemble(BaseModel):
|
|
|
42
40
|
output_quantiles,
|
|
43
41
|
history_minutes,
|
|
44
42
|
forecast_minutes,
|
|
45
|
-
target_key,
|
|
46
43
|
interval_minutes,
|
|
47
44
|
]:
|
|
48
45
|
assert all([p == param_list[0] for p in param_list]), param_list
|
|
@@ -51,7 +48,6 @@ class Ensemble(BaseModel):
|
|
|
51
48
|
history_minutes=history_minutes[0],
|
|
52
49
|
forecast_minutes=forecast_minutes[0],
|
|
53
50
|
output_quantiles=output_quantiles[0],
|
|
54
|
-
target_key=target_key[0],
|
|
55
51
|
interval_minutes=interval_minutes[0],
|
|
56
52
|
)
|
|
57
53
|
|
|
@@ -28,8 +28,8 @@ class LateFusionModel(BaseModel):
|
|
|
28
28
|
- NWP, if included, is put through a similar encoder.
|
|
29
29
|
- PV site-level data, if included, is put through an encoder which transforms it from 2D, with
|
|
30
30
|
time and system-ID dimensions, to become a 1D feature vector.
|
|
31
|
-
- The satellite features*, NWP features*, PV site-level features*,
|
|
32
|
-
paramters* are concatenated into a 1D feature vector and passed through another neural
|
|
31
|
+
- The satellite features*, NWP features*, PV site-level features*, location ID embedding*, and
|
|
32
|
+
sun paramters* are concatenated into a 1D feature vector and passed through another neural
|
|
33
33
|
network to combine them and produce a forecast.
|
|
34
34
|
|
|
35
35
|
* if included
|
|
@@ -43,8 +43,7 @@ class LateFusionModel(BaseModel):
|
|
|
43
43
|
sat_encoder: AbstractNWPSatelliteEncoder | None = None,
|
|
44
44
|
pv_encoder: AbstractSitesEncoder | None = None,
|
|
45
45
|
add_image_embedding_channel: bool = False,
|
|
46
|
-
|
|
47
|
-
include_site_yield_history: bool = False,
|
|
46
|
+
include_generation_history: bool = False,
|
|
48
47
|
include_sun: bool = True,
|
|
49
48
|
include_time: bool = False,
|
|
50
49
|
location_id_mapping: dict[Any, int] | None = None,
|
|
@@ -56,7 +55,6 @@ class LateFusionModel(BaseModel):
|
|
|
56
55
|
nwp_forecast_minutes: DictConfig | None = None,
|
|
57
56
|
nwp_history_minutes: DictConfig | None = None,
|
|
58
57
|
pv_history_minutes: int | None = None,
|
|
59
|
-
target_key: str = "gsp",
|
|
60
58
|
interval_minutes: int = 30,
|
|
61
59
|
nwp_interval_minutes: DictConfig | None = None,
|
|
62
60
|
pv_interval_minutes: int = 5,
|
|
@@ -83,14 +81,13 @@ class LateFusionModel(BaseModel):
|
|
|
83
81
|
pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
|
|
84
82
|
PV data from 2D into a 1D feature vector.
|
|
85
83
|
add_image_embedding_channel: Add a channel to the NWP and satellite data with the
|
|
86
|
-
embedding of the
|
|
87
|
-
|
|
88
|
-
include_site_yield_history: Include Site yield data.
|
|
84
|
+
embedding of the location ID.
|
|
85
|
+
include_generation_history: Include generation yield data.
|
|
89
86
|
include_sun: Include sun azimuth and altitude data.
|
|
90
87
|
include_time: Include sine and cosine of dates and times.
|
|
91
88
|
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
|
|
92
89
|
not used if this is not provided.
|
|
93
|
-
embedding_dim: Number of embedding dimensions to use for
|
|
90
|
+
embedding_dim: Number of embedding dimensions to use for location ID.
|
|
94
91
|
forecast_minutes: The amount of minutes that should be forecasted.
|
|
95
92
|
history_minutes: The default amount of historical minutes that are used.
|
|
96
93
|
sat_history_minutes: Length of recent observations used for satellite inputs. Defaults
|
|
@@ -103,7 +100,6 @@ class LateFusionModel(BaseModel):
|
|
|
103
100
|
`history_minutes` if not provided.
|
|
104
101
|
pv_history_minutes: Length of recent site-level PV data used as
|
|
105
102
|
input. Defaults to `history_minutes` if not provided.
|
|
106
|
-
target_key: The key of the target variable in the batch.
|
|
107
103
|
interval_minutes: The interval between each sample of the target data
|
|
108
104
|
nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
|
|
109
105
|
data for each source
|
|
@@ -114,12 +110,10 @@ class LateFusionModel(BaseModel):
|
|
|
114
110
|
history_minutes=history_minutes,
|
|
115
111
|
forecast_minutes=forecast_minutes,
|
|
116
112
|
output_quantiles=output_quantiles,
|
|
117
|
-
target_key=target_key,
|
|
118
113
|
interval_minutes=interval_minutes,
|
|
119
114
|
)
|
|
120
115
|
|
|
121
|
-
self.
|
|
122
|
-
self.include_site_yield_history = include_site_yield_history
|
|
116
|
+
self.include_generation_history = include_generation_history
|
|
123
117
|
self.include_sat = sat_encoder is not None
|
|
124
118
|
self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
|
|
125
119
|
self.include_pv = pv_encoder is not None
|
|
@@ -133,8 +127,7 @@ class LateFusionModel(BaseModel):
|
|
|
133
127
|
|
|
134
128
|
if self.location_id_mapping is None:
|
|
135
129
|
logger.warning(
|
|
136
|
-
"location_id_mapping` is not provided, defaulting to outdated GSP mapping"
|
|
137
|
-
"(0 to 317)"
|
|
130
|
+
"location_id_mapping` is not provided, defaulting to outdated GSP mapping(0 to 317)"
|
|
138
131
|
)
|
|
139
132
|
|
|
140
133
|
# Note 318 is the 2024 UK GSP count, so this is a temporary fix
|
|
@@ -223,8 +216,7 @@ class LateFusionModel(BaseModel):
|
|
|
223
216
|
|
|
224
217
|
self.pv_encoder = pv_encoder(
|
|
225
218
|
sequence_length=pv_history_minutes // pv_interval_minutes + 1,
|
|
226
|
-
|
|
227
|
-
input_key_to_use="site",
|
|
219
|
+
key_to_use="generation",
|
|
228
220
|
)
|
|
229
221
|
|
|
230
222
|
# Update num features
|
|
@@ -238,8 +230,7 @@ class LateFusionModel(BaseModel):
|
|
|
238
230
|
|
|
239
231
|
if self.include_sun:
|
|
240
232
|
self.sun_fc1 = nn.Linear(
|
|
241
|
-
in_features=2
|
|
242
|
-
* (self.forecast_len + self.history_len + 1),
|
|
233
|
+
in_features=2 * (self.forecast_len + self.history_len + 1),
|
|
243
234
|
out_features=16,
|
|
244
235
|
)
|
|
245
236
|
|
|
@@ -248,19 +239,14 @@ class LateFusionModel(BaseModel):
|
|
|
248
239
|
|
|
249
240
|
if self.include_time:
|
|
250
241
|
self.time_fc1 = nn.Linear(
|
|
251
|
-
in_features=4
|
|
252
|
-
* (self.forecast_len + self.history_len + 1),
|
|
242
|
+
in_features=4 * (self.forecast_len + self.history_len + 1),
|
|
253
243
|
out_features=32,
|
|
254
244
|
)
|
|
255
245
|
|
|
256
246
|
# Update num features
|
|
257
247
|
fusion_input_features += 32
|
|
258
248
|
|
|
259
|
-
if
|
|
260
|
-
# Update num features
|
|
261
|
-
fusion_input_features += self.history_len
|
|
262
|
-
|
|
263
|
-
if include_site_yield_history:
|
|
249
|
+
if include_generation_history:
|
|
264
250
|
# Update num features
|
|
265
251
|
fusion_input_features += self.history_len + 1
|
|
266
252
|
|
|
@@ -269,15 +255,14 @@ class LateFusionModel(BaseModel):
|
|
|
269
255
|
out_features=self.num_output_features,
|
|
270
256
|
)
|
|
271
257
|
|
|
272
|
-
|
|
273
258
|
def forward(self, x: TensorBatch) -> torch.Tensor:
|
|
274
259
|
"""Run model forward"""
|
|
275
260
|
|
|
276
261
|
if self.use_id_embedding:
|
|
277
|
-
# eg: x['
|
|
262
|
+
# eg: x['location_id'] = [1] with location_id_mapping = {1:0}, would give [0]
|
|
278
263
|
id = torch.tensor(
|
|
279
|
-
[self.location_id_mapping[i.item()] for i in x[
|
|
280
|
-
device=x[
|
|
264
|
+
[self.location_id_mapping[i.item()] for i in x["location_id"]],
|
|
265
|
+
device=x["location_id"].device,
|
|
281
266
|
dtype=torch.int64,
|
|
282
267
|
)
|
|
283
268
|
|
|
@@ -308,32 +293,20 @@ class LateFusionModel(BaseModel):
|
|
|
308
293
|
nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
|
|
309
294
|
modes[f"nwp/{nwp_source}"] = nwp_out
|
|
310
295
|
|
|
311
|
-
# ***********************
|
|
312
|
-
# Add
|
|
313
|
-
if self.
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
modes["
|
|
296
|
+
# *********************** Generation Data *************************************
|
|
297
|
+
# Add generation yield history
|
|
298
|
+
if self.include_generation_history:
|
|
299
|
+
generation_history = x["generation"][:, : self.history_len + 1].float()
|
|
300
|
+
generation_history = generation_history.reshape(generation_history.shape[0], -1)
|
|
301
|
+
modes["generation"] = generation_history
|
|
317
302
|
|
|
318
|
-
# Add
|
|
303
|
+
# Add location-level yield history through PV encoder
|
|
319
304
|
if self.include_pv:
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
x_tmp = x.copy()
|
|
326
|
-
x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
|
|
327
|
-
modes["site"] = self.pv_encoder(x_tmp)
|
|
328
|
-
|
|
329
|
-
# *********************** GSP Data ************************************
|
|
330
|
-
# Add gsp yield history
|
|
331
|
-
if self.include_gsp_yield_history:
|
|
332
|
-
gsp_history = x["gsp"][:, : self.history_len].float()
|
|
333
|
-
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
|
|
334
|
-
modes["gsp"] = gsp_history
|
|
335
|
-
|
|
336
|
-
# ********************** Embedding of GSP/Site ID ********************
|
|
305
|
+
x_tmp = x.copy()
|
|
306
|
+
x_tmp["generation"] = x_tmp["generation"][:, : self.history_len + 1]
|
|
307
|
+
modes["generation"] = self.pv_encoder(x_tmp)
|
|
308
|
+
|
|
309
|
+
# ********************** Embedding of location ID ********************
|
|
337
310
|
if self.use_id_embedding:
|
|
338
311
|
modes["id"] = self.embed(id)
|
|
339
312
|
|
|
@@ -341,7 +314,7 @@ class LateFusionModel(BaseModel):
|
|
|
341
314
|
sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float()
|
|
342
315
|
sun = self.sun_fc1(sun)
|
|
343
316
|
modes["sun"] = sun
|
|
344
|
-
|
|
317
|
+
|
|
345
318
|
if self.include_time:
|
|
346
319
|
time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]]
|
|
347
320
|
time = torch.cat(time, dim=1).float()
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
"""Encoder modules for the site-level PV data.
|
|
2
|
-
|
|
3
|
-
"""
|
|
1
|
+
"""Encoder modules for the site-level PV data."""
|
|
4
2
|
|
|
5
3
|
import einops
|
|
6
4
|
import torch
|
|
@@ -11,6 +9,7 @@ from pvnet.models.late_fusion.linear_networks.networks import ResFCNet
|
|
|
11
9
|
from pvnet.models.late_fusion.site_encoders.basic_blocks import AbstractSitesEncoder
|
|
12
10
|
|
|
13
11
|
|
|
12
|
+
# TODO update this to work with the new sample data format
|
|
14
13
|
class SimpleLearnedAggregator(AbstractSitesEncoder):
|
|
15
14
|
"""A simple model which learns a different weighted-average across all PV sites for each GSP.
|
|
16
15
|
|
|
@@ -127,8 +126,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
127
126
|
kv_res_block_layers: int = 2,
|
|
128
127
|
use_id_in_value: bool = False,
|
|
129
128
|
target_id_dim: int = 318,
|
|
130
|
-
|
|
131
|
-
input_key_to_use: str = "site",
|
|
129
|
+
key_to_use: str = "generation",
|
|
132
130
|
num_channels: int = 1,
|
|
133
131
|
num_sites_in_inference: int = 1,
|
|
134
132
|
):
|
|
@@ -149,8 +147,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
149
147
|
use_id_in_value: Whether to use a site ID embedding in network used to produce the
|
|
150
148
|
value for the attention layer.
|
|
151
149
|
target_id_dim: The number of unique IDs.
|
|
152
|
-
|
|
153
|
-
input_key_to_use: The key to use for the input in the attention layer.
|
|
150
|
+
key_to_use: The key to use in the attention layer.
|
|
154
151
|
num_channels: Number of channels in the input data
|
|
155
152
|
num_sites_in_inference: Number of sites to use in inference.
|
|
156
153
|
This is used to determine the number of sites to use in the
|
|
@@ -164,8 +161,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
164
161
|
self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
|
|
165
162
|
self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
|
|
166
163
|
self.use_id_in_value = use_id_in_value
|
|
167
|
-
self.
|
|
168
|
-
self.input_key_to_use = input_key_to_use
|
|
164
|
+
self.key_to_use = key_to_use
|
|
169
165
|
self.num_channels = num_channels
|
|
170
166
|
self.num_sites_in_inference = num_sites_in_inference
|
|
171
167
|
|
|
@@ -206,7 +202,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
206
202
|
def _encode_inputs(self, x: TensorBatch) -> tuple[torch.Tensor, int]:
|
|
207
203
|
# Shape: [batch size, sequence length, number of sites]
|
|
208
204
|
# Shape: [batch size, station_id, sequence length, channels]
|
|
209
|
-
input_data = x[f"{self.
|
|
205
|
+
input_data = x[f"{self.key_to_use}"]
|
|
210
206
|
if len(input_data.shape) == 2: # one site per sample
|
|
211
207
|
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
|
|
212
208
|
if len(input_data.shape) == 4: # Has multiple channels
|
|
@@ -216,16 +212,11 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
216
212
|
input_data = input_data[:, : self.sequence_length]
|
|
217
213
|
site_seqs = input_data.float()
|
|
218
214
|
batch_size = site_seqs.shape[0]
|
|
219
|
-
site_seqs = site_seqs.swapaxes(1, 2) # [batch size,
|
|
215
|
+
site_seqs = site_seqs.swapaxes(1, 2) # [batch size, location ID, sequence length]
|
|
220
216
|
return site_seqs, batch_size
|
|
221
217
|
|
|
222
218
|
def _encode_query(self, x: TensorBatch) -> torch.Tensor:
|
|
223
|
-
|
|
224
|
-
# GSP seems to have a different structure
|
|
225
|
-
ids = x[f"{self.target_key_to_use}_id"]
|
|
226
|
-
else:
|
|
227
|
-
ids = x[f"{self.input_key_to_use}_id"]
|
|
228
|
-
ids = ids.int()
|
|
219
|
+
ids = x["location_id"].int()
|
|
229
220
|
query = self.target_id_embedding(ids).unsqueeze(1)
|
|
230
221
|
return query
|
|
231
222
|
|
|
@@ -233,9 +224,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
233
224
|
site_seqs, batch_size = self._encode_inputs(x)
|
|
234
225
|
|
|
235
226
|
# site ID embeddings are the same for each sample
|
|
236
|
-
|
|
227
|
+
id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
|
|
237
228
|
# Each concated (site sequence, site ID embedding) is processed with encoder
|
|
238
|
-
x_seq_in = torch.cat((site_seqs,
|
|
229
|
+
x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
|
|
239
230
|
key = self._key_encoder(x_seq_in)
|
|
240
231
|
|
|
241
232
|
# Reshape to [batch size, site, kdim]
|
|
@@ -247,9 +238,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
247
238
|
|
|
248
239
|
if self.use_id_in_value:
|
|
249
240
|
# site ID embeddings are the same for each sample
|
|
250
|
-
|
|
241
|
+
id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
|
|
251
242
|
# Each concated (site sequence, site ID embedding) is processed with encoder
|
|
252
|
-
x_seq_in = torch.cat((site_seqs,
|
|
243
|
+
x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
|
|
253
244
|
else:
|
|
254
245
|
# Encode each site sequence independently
|
|
255
246
|
x_seq_in = site_seqs.flatten(0, 1)
|
|
@@ -260,9 +251,8 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
260
251
|
return value
|
|
261
252
|
|
|
262
253
|
def _attention_forward(
|
|
263
|
-
self, x: dict,
|
|
264
|
-
|
|
265
|
-
) -> tuple[torch.Tensor, torch.Tensor:]:
|
|
254
|
+
self, x: dict, average_attn_weights: bool = True
|
|
255
|
+
) -> tuple[torch.Tensor, torch.Tensor :]:
|
|
266
256
|
query = self._encode_query(x)
|
|
267
257
|
key = self._encode_key(x)
|
|
268
258
|
value = self._encode_value(x)
|
|
@@ -45,9 +45,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
45
45
|
self.lr = None
|
|
46
46
|
|
|
47
47
|
def transfer_batch_to_device(
|
|
48
|
-
self,
|
|
49
|
-
batch: TensorBatch,
|
|
50
|
-
device: torch.device,
|
|
48
|
+
self,
|
|
49
|
+
batch: TensorBatch,
|
|
50
|
+
device: torch.device,
|
|
51
51
|
dataloader_idx: int,
|
|
52
52
|
) -> dict:
|
|
53
53
|
"""Method to move custom batches to a given device"""
|
|
@@ -75,7 +75,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
75
75
|
losses = 2 * torch.cat(losses, dim=2)
|
|
76
76
|
|
|
77
77
|
return losses.mean()
|
|
78
|
-
|
|
78
|
+
|
|
79
79
|
def configure_optimizers(self):
|
|
80
80
|
"""Configure the optimizers using learning rate found with LR finder if used"""
|
|
81
81
|
if self.lr is not None:
|
|
@@ -84,7 +84,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
84
84
|
return self._optimizer(self.model)
|
|
85
85
|
|
|
86
86
|
def _calculate_common_losses(
|
|
87
|
-
self,
|
|
87
|
+
self,
|
|
88
88
|
y: torch.Tensor,
|
|
89
89
|
y_hat: torch.Tensor,
|
|
90
90
|
) -> dict[str, torch.Tensor]:
|
|
@@ -96,15 +96,15 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
96
96
|
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
|
|
97
97
|
y_hat = self.model._quantiles_to_prediction(y_hat)
|
|
98
98
|
|
|
99
|
-
losses.update({"MSE":
|
|
99
|
+
losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
|
|
100
100
|
|
|
101
101
|
return losses
|
|
102
|
-
|
|
102
|
+
|
|
103
103
|
def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
|
|
104
104
|
"""Run training step"""
|
|
105
105
|
y_hat = self.model(batch)
|
|
106
106
|
|
|
107
|
-
y = batch[
|
|
107
|
+
y = batch["generation"][:, -self.model.forecast_len :]
|
|
108
108
|
|
|
109
109
|
losses = self._calculate_common_losses(y, y_hat)
|
|
110
110
|
losses = {f"{k}/train": v for k, v in losses.items()}
|
|
@@ -116,10 +116,10 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
116
116
|
else:
|
|
117
117
|
opt_target = losses["MAE/train"]
|
|
118
118
|
return opt_target
|
|
119
|
-
|
|
119
|
+
|
|
120
120
|
def _calculate_val_losses(
|
|
121
|
-
self,
|
|
122
|
-
y: torch.Tensor,
|
|
121
|
+
self,
|
|
122
|
+
y: torch.Tensor,
|
|
123
123
|
y_hat: torch.Tensor,
|
|
124
124
|
) -> dict[str, torch.Tensor]:
|
|
125
125
|
"""Calculate additional losses only run in validation"""
|
|
@@ -138,28 +138,25 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
138
138
|
return losses
|
|
139
139
|
|
|
140
140
|
def _calculate_step_metrics(
|
|
141
|
-
self,
|
|
142
|
-
y: torch.Tensor,
|
|
143
|
-
y_hat: torch.Tensor,
|
|
141
|
+
self,
|
|
142
|
+
y: torch.Tensor,
|
|
143
|
+
y_hat: torch.Tensor,
|
|
144
144
|
) -> tuple[np.array, np.array]:
|
|
145
145
|
"""Calculate the MAE and MSE at each forecast step"""
|
|
146
146
|
|
|
147
147
|
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
|
|
148
148
|
mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
|
|
149
|
-
|
|
149
|
+
|
|
150
150
|
return mae_each_step, mse_each_step
|
|
151
|
-
|
|
151
|
+
|
|
152
152
|
def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None:
|
|
153
153
|
"""Internally store the validation predictions"""
|
|
154
|
-
|
|
155
|
-
target_key = self.model._target_key
|
|
156
154
|
|
|
157
|
-
y = batch[
|
|
158
|
-
y_hat = y_hat.cpu().numpy()
|
|
159
|
-
ids = batch[
|
|
155
|
+
y = batch["generation"][:, -self.model.forecast_len :].cpu().numpy()
|
|
156
|
+
y_hat = y_hat.cpu().numpy()
|
|
157
|
+
ids = batch["location_id"].cpu().numpy()
|
|
160
158
|
init_times_utc = pd.to_datetime(
|
|
161
|
-
batch[
|
|
162
|
-
.cpu().numpy().astype("datetime64[ns]")
|
|
159
|
+
batch["time_utc"][:, self.model.history_len + 1].cpu().numpy().astype("datetime64[ns]")
|
|
163
160
|
)
|
|
164
161
|
|
|
165
162
|
if self.model.use_quantile_regression:
|
|
@@ -170,7 +167,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
170
167
|
|
|
171
168
|
ds_preds_batch = xr.Dataset(
|
|
172
169
|
data_vars=dict(
|
|
173
|
-
y_hat=(["sample_num", "forecast_step",
|
|
170
|
+
y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
|
|
174
171
|
y=(["sample_num", "forecast_step"], y),
|
|
175
172
|
),
|
|
176
173
|
coords=dict(
|
|
@@ -186,7 +183,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
186
183
|
# Set up stores which we will fill during validation
|
|
187
184
|
self.all_val_results: list[xr.Dataset] = []
|
|
188
185
|
self._val_horizon_maes: list[np.array] = []
|
|
189
|
-
if self.current_epoch==0:
|
|
186
|
+
if self.current_epoch == 0:
|
|
190
187
|
self._val_persistence_horizon_maes: list[np.array] = []
|
|
191
188
|
|
|
192
189
|
# Plot some sample forecasts
|
|
@@ -197,9 +194,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
197
194
|
|
|
198
195
|
for plot_num in range(num_figures):
|
|
199
196
|
idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
|
|
200
|
-
idxs = idxs[idxs<len(val_dataset)]
|
|
197
|
+
idxs = idxs[idxs < len(val_dataset)]
|
|
201
198
|
|
|
202
|
-
if len(idxs)==0:
|
|
199
|
+
if len(idxs) == 0:
|
|
203
200
|
continue
|
|
204
201
|
|
|
205
202
|
batch = collate_fn([val_dataset[i] for i in idxs])
|
|
@@ -207,19 +204,16 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
207
204
|
|
|
208
205
|
# Batch validation check only during sanity check phase - use first batch
|
|
209
206
|
if self.trainer.sanity_checking and plot_num == 0:
|
|
210
|
-
validate_batch_against_config(
|
|
211
|
-
|
|
212
|
-
model=self.model
|
|
213
|
-
)
|
|
214
|
-
|
|
207
|
+
validate_batch_against_config(batch=batch, model=self.model)
|
|
208
|
+
|
|
215
209
|
with torch.no_grad():
|
|
216
210
|
y_hat = self.model(batch)
|
|
217
|
-
|
|
211
|
+
|
|
218
212
|
fig = plot_sample_forecasts(
|
|
219
213
|
batch,
|
|
220
214
|
y_hat,
|
|
221
215
|
quantiles=self.model.output_quantiles,
|
|
222
|
-
key_to_plot=
|
|
216
|
+
key_to_plot="generation",
|
|
223
217
|
)
|
|
224
218
|
|
|
225
219
|
plot_name = f"val_forecast_samples/sample_set_{plot_num}"
|
|
@@ -238,7 +232,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
238
232
|
# Internally store the val predictions
|
|
239
233
|
self._store_val_predictions(batch, y_hat)
|
|
240
234
|
|
|
241
|
-
y = batch[
|
|
235
|
+
y = batch["generation"][:, -self.model.forecast_len :]
|
|
242
236
|
|
|
243
237
|
losses = self._calculate_common_losses(y, y_hat)
|
|
244
238
|
losses = {f"{k}/val": v for k, v in losses.items()}
|
|
@@ -262,21 +256,22 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
262
256
|
|
|
263
257
|
# Calculate the persistance losses - we only need to do this once per training run
|
|
264
258
|
# not every epoch
|
|
265
|
-
if self.current_epoch==0:
|
|
259
|
+
if self.current_epoch == 0:
|
|
266
260
|
y_persist = (
|
|
267
|
-
batch[
|
|
268
|
-
.unsqueeze(1)
|
|
261
|
+
batch["generation"][:, -(self.model.forecast_len + 1)]
|
|
262
|
+
.unsqueeze(1)
|
|
263
|
+
.expand(-1, self.model.forecast_len)
|
|
269
264
|
)
|
|
270
265
|
mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
|
|
271
266
|
self._val_persistence_horizon_maes.append(mae_step_persist)
|
|
272
267
|
losses.update(
|
|
273
268
|
{
|
|
274
|
-
"MAE/val_persistence": mae_step_persist.mean(),
|
|
275
|
-
"MSE/val_persistence": mse_step_persist.mean()
|
|
269
|
+
"MAE/val_persistence": mae_step_persist.mean(),
|
|
270
|
+
"MSE/val_persistence": mse_step_persist.mean(),
|
|
276
271
|
}
|
|
277
272
|
)
|
|
278
273
|
|
|
279
|
-
#
|
|
274
|
+
# Log the metrics
|
|
280
275
|
self.log_dict(losses, on_step=False, on_epoch=True)
|
|
281
276
|
|
|
282
277
|
def on_validation_epoch_end(self) -> None:
|
|
@@ -289,7 +284,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
289
284
|
self._val_horizon_maes = []
|
|
290
285
|
|
|
291
286
|
# We only run this on the first epoch
|
|
292
|
-
if self.current_epoch==0:
|
|
287
|
+
if self.current_epoch == 0:
|
|
293
288
|
val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
|
|
294
289
|
self._val_persistence_horizon_maes = []
|
|
295
290
|
|
|
@@ -321,25 +316,25 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
321
316
|
wandb_log_dir = self.logger.experiment.dir
|
|
322
317
|
filepath = f"{wandb_log_dir}/validation_results.netcdf"
|
|
323
318
|
ds_val_results.to_netcdf(filepath)
|
|
324
|
-
|
|
325
|
-
#
|
|
319
|
+
|
|
320
|
+
# Uplodad to wandb
|
|
326
321
|
self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now")
|
|
327
|
-
|
|
322
|
+
|
|
328
323
|
# Create the horizon accuracy curve
|
|
329
324
|
horizon_mae_plot = wandb_line_plot(
|
|
330
|
-
x=np.arange(self.model.forecast_len),
|
|
325
|
+
x=np.arange(self.model.forecast_len),
|
|
331
326
|
y=val_horizon_maes,
|
|
332
327
|
xlabel="Horizon step",
|
|
333
328
|
ylabel="MAE",
|
|
334
329
|
title="Val horizon loss curve",
|
|
335
330
|
)
|
|
336
|
-
|
|
331
|
+
|
|
337
332
|
wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
|
|
338
333
|
|
|
339
334
|
# Create persistence horizon accuracy curve but only on first epoch
|
|
340
|
-
if self.current_epoch==0:
|
|
335
|
+
if self.current_epoch == 0:
|
|
341
336
|
persist_horizon_mae_plot = wandb_line_plot(
|
|
342
|
-
x=np.arange(self.model.forecast_len),
|
|
337
|
+
x=np.arange(self.model.forecast_len),
|
|
343
338
|
y=val_persistence_horizon_maes,
|
|
344
339
|
xlabel="Horizon step",
|
|
345
340
|
ylabel="MAE",
|
pvnet/training/plots.py
CHANGED
|
@@ -32,9 +32,9 @@ def plot_sample_forecasts(
|
|
|
32
32
|
|
|
33
33
|
y = batch[key_to_plot].cpu().numpy()
|
|
34
34
|
y_hat = y_hat.cpu().numpy()
|
|
35
|
-
ids = batch[
|
|
35
|
+
ids = batch["location_id"].cpu().numpy().squeeze()
|
|
36
36
|
times_utc = pd.to_datetime(
|
|
37
|
-
batch[
|
|
37
|
+
batch["time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
|
|
38
38
|
)
|
|
39
39
|
batch_size = y.shape[0]
|
|
40
40
|
|
pvnet/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Utils"""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
@@ -17,7 +18,7 @@ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
|
|
|
17
18
|
MODEL_CONFIG_NAME = "model_config.yaml"
|
|
18
19
|
DATA_CONFIG_NAME = "data_config.yaml"
|
|
19
20
|
DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
|
|
20
|
-
FULL_CONFIG_NAME =
|
|
21
|
+
FULL_CONFIG_NAME = "full_experiment_config.yaml"
|
|
21
22
|
MODEL_CARD_NAME = "README.md"
|
|
22
23
|
|
|
23
24
|
|
|
@@ -93,37 +94,41 @@ def print_config(
|
|
|
93
94
|
|
|
94
95
|
|
|
95
96
|
def validate_batch_against_config(
|
|
96
|
-
batch: dict,
|
|
97
|
+
batch: dict,
|
|
97
98
|
model: "BaseModel",
|
|
98
99
|
) -> None:
|
|
99
100
|
"""Validates tensor shapes in batch against model configuration."""
|
|
100
101
|
logger.info("Performing batch shape validation against model config.")
|
|
101
|
-
|
|
102
|
+
|
|
102
103
|
# NWP validation
|
|
103
|
-
if hasattr(model,
|
|
104
|
+
if hasattr(model, "nwp_encoders_dict"):
|
|
104
105
|
if "nwp" not in batch:
|
|
105
106
|
raise ValueError(
|
|
106
107
|
"Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
|
|
107
108
|
)
|
|
108
|
-
|
|
109
|
+
|
|
109
110
|
for source, nwp_data in batch["nwp"].items():
|
|
110
111
|
if source in model.nwp_encoders_dict:
|
|
111
|
-
|
|
112
|
-
enc = model.nwp_encoders_dict[source]
|
|
112
|
+
enc = model.nwp_encoders_dict[source]
|
|
113
113
|
expected_channels = enc.in_channels
|
|
114
114
|
if model.add_image_embedding_channel:
|
|
115
115
|
expected_channels -= 1
|
|
116
116
|
|
|
117
|
-
expected = (
|
|
118
|
-
|
|
117
|
+
expected = (
|
|
118
|
+
nwp_data["nwp"].shape[0],
|
|
119
|
+
enc.sequence_length,
|
|
120
|
+
expected_channels,
|
|
121
|
+
enc.image_size_pixels,
|
|
122
|
+
enc.image_size_pixels,
|
|
123
|
+
)
|
|
119
124
|
if tuple(nwp_data["nwp"].shape) != expected:
|
|
120
|
-
actual_shape = tuple(nwp_data[
|
|
125
|
+
actual_shape = tuple(nwp_data["nwp"].shape)
|
|
121
126
|
raise ValueError(
|
|
122
127
|
f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
|
|
123
128
|
)
|
|
124
129
|
|
|
125
130
|
# Satellite validation
|
|
126
|
-
if hasattr(model,
|
|
131
|
+
if hasattr(model, "sat_encoder"):
|
|
127
132
|
if "satellite_actual" not in batch:
|
|
128
133
|
raise ValueError(
|
|
129
134
|
"Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
|
|
@@ -134,14 +139,19 @@ def validate_batch_against_config(
|
|
|
134
139
|
if model.add_image_embedding_channel:
|
|
135
140
|
expected_channels -= 1
|
|
136
141
|
|
|
137
|
-
expected = (
|
|
138
|
-
|
|
142
|
+
expected = (
|
|
143
|
+
batch["satellite_actual"].shape[0],
|
|
144
|
+
enc.sequence_length,
|
|
145
|
+
expected_channels,
|
|
146
|
+
enc.image_size_pixels,
|
|
147
|
+
enc.image_size_pixels,
|
|
148
|
+
)
|
|
139
149
|
if tuple(batch["satellite_actual"].shape) != expected:
|
|
140
|
-
actual_shape = tuple(batch[
|
|
150
|
+
actual_shape = tuple(batch["satellite_actual"].shape)
|
|
141
151
|
raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
|
|
142
152
|
|
|
143
|
-
#
|
|
144
|
-
key =
|
|
153
|
+
# generation validation
|
|
154
|
+
key = "generation"
|
|
145
155
|
if key in batch:
|
|
146
156
|
total_minutes = model.history_minutes + model.forecast_minutes
|
|
147
157
|
interval = model.interval_minutes
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
|
|
2
|
-
pvnet/datamodule.py,sha256=
|
|
2
|
+
pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
|
|
3
3
|
pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
|
|
4
4
|
pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
|
|
5
|
-
pvnet/utils.py,sha256=
|
|
5
|
+
pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
|
|
6
6
|
pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
|
|
7
|
-
pvnet/models/base_model.py,sha256=
|
|
8
|
-
pvnet/models/ensemble.py,sha256=
|
|
7
|
+
pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
|
|
8
|
+
pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
|
|
9
9
|
pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
|
|
10
10
|
pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
|
|
11
|
-
pvnet/models/late_fusion/late_fusion.py,sha256=
|
|
11
|
+
pvnet/models/late_fusion/late_fusion.py,sha256=kQUnyqMykmwc0GdoFhNXYStJPrjr3hFSvUNe8FumVx4,15260
|
|
12
12
|
pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
|
|
13
13
|
pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
|
|
14
14
|
pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
|
|
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
|
|
|
17
17
|
pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
|
|
18
18
|
pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
|
|
19
19
|
pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
|
|
20
|
-
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=
|
|
20
|
+
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
|
|
21
21
|
pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
22
|
-
pvnet/training/lightning_module.py,sha256=
|
|
23
|
-
pvnet/training/plots.py,sha256=
|
|
22
|
+
pvnet/training/lightning_module.py,sha256=57sT7bPCU7mJw4EskzOE-JJ9JhWIuAbs40_x5RoBbA8,12705
|
|
23
|
+
pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
|
|
24
24
|
pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
|
|
25
|
-
pvnet-5.
|
|
26
|
-
pvnet-5.
|
|
27
|
-
pvnet-5.
|
|
28
|
-
pvnet-5.
|
|
29
|
-
pvnet-5.
|
|
25
|
+
pvnet-5.3.1.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
26
|
+
pvnet-5.3.1.dist-info/METADATA,sha256=LMfxIQEjnBwoJQktBq3DOEKYgcUUxaMD6k3s6vOBWiU,16479
|
|
27
|
+
pvnet-5.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
+
pvnet-5.3.1.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
29
|
+
pvnet-5.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|