PVNet_summation 1.0.2__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PVNet_summation might be problematic. Click here for more details.
- pvnet_summation/data/datamodule.py +72 -23
- pvnet_summation/load_model.py +5 -1
- pvnet_summation/models/base_model.py +14 -3
- pvnet_summation/models/horizon_dense_model.py +171 -0
- pvnet_summation/training/train.py +28 -7
- pvnet_summation/utils.py +51 -6
- {pvnet_summation-1.0.2.dist-info → pvnet_summation-1.1.0.dist-info}/METADATA +3 -3
- pvnet_summation-1.1.0.dist-info/RECORD +19 -0
- pvnet_summation-1.0.2.dist-info/RECORD +0 -18
- {pvnet_summation-1.0.2.dist-info → pvnet_summation-1.1.0.dist-info}/WHEEL +0 -0
- {pvnet_summation-1.0.2.dist-info → pvnet_summation-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {pvnet_summation-1.0.2.dist-info → pvnet_summation-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -8,8 +8,10 @@ import numpy as np
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import torch
|
|
10
10
|
from lightning.pytorch import LightningDataModule
|
|
11
|
-
from ocf_data_sampler.load.gsp import open_gsp
|
|
12
|
-
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
|
|
11
|
+
from ocf_data_sampler.load.gsp import get_gsp_boundaries, open_gsp
|
|
12
|
+
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch
|
|
13
|
+
from ocf_data_sampler.numpy_sample.sun_position import calculate_azimuth_and_elevation
|
|
14
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
13
15
|
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
|
|
14
16
|
from ocf_data_sampler.utils import minutes
|
|
15
17
|
from torch.utils.data import DataLoader, Dataset, Subset, default_collate
|
|
@@ -18,6 +20,56 @@ from typing_extensions import override
|
|
|
18
20
|
SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
|
|
19
21
|
SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
|
|
20
22
|
|
|
23
|
+
def get_gb_centroid_lon_lat() -> tuple[float, float]:
|
|
24
|
+
"""Get the longitude and latitude of the centroid of Great Britain"""
|
|
25
|
+
row = get_gsp_boundaries("20250109").loc[0]
|
|
26
|
+
x_osgb = row.x_osgb.item()
|
|
27
|
+
y_osgb = row.y_osgb.item()
|
|
28
|
+
return osgb_to_lon_lat(x_osgb, y_osgb)
|
|
29
|
+
|
|
30
|
+
LON, LAT = get_gb_centroid_lon_lat()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def construct_sample(
|
|
34
|
+
pvnet_inputs: NumpyBatch,
|
|
35
|
+
valid_times: pd.DatetimeIndex,
|
|
36
|
+
relative_capacities: np.ndarray,
|
|
37
|
+
target: np.ndarray | None,
|
|
38
|
+
last_outturn: float | None = None,
|
|
39
|
+
) -> SumNumpySample:
|
|
40
|
+
"""Construct an input sample for the summation model
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
pvnet_inputs: The PVNet batch for all GSPs
|
|
44
|
+
valid_times: An array of valid-times for the forecast
|
|
45
|
+
relative_capacities: Array of capacities of all GSPs normalised by the total capacity
|
|
46
|
+
target: The target national outturn. This is only needed during training.
|
|
47
|
+
last_outturn: The previous national outturn. This is only needed during training.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
azimuth, elevation = calculate_azimuth_and_elevation(valid_times, LON, LAT)
|
|
51
|
+
|
|
52
|
+
sample = {
|
|
53
|
+
# NumpyBatch object with batch size = num_locations
|
|
54
|
+
"pvnet_inputs": pvnet_inputs,
|
|
55
|
+
# Shape: [time]
|
|
56
|
+
"valid_times": valid_times.values.astype(int),
|
|
57
|
+
# Shape: [num_locations]
|
|
58
|
+
"relative_capacity": relative_capacities,
|
|
59
|
+
# Shape: [time]
|
|
60
|
+
"azimuth": azimuth.astype(np.float32) / 360,
|
|
61
|
+
# Shape: [time]
|
|
62
|
+
"elevation": elevation.astype(np.float32) / 180 + 0.5,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if target is not None:
|
|
66
|
+
# Shape: [time]
|
|
67
|
+
sample["target"] = target
|
|
68
|
+
if last_outturn is not None:
|
|
69
|
+
# Shape: scalar
|
|
70
|
+
sample["last_outturn"] = last_outturn
|
|
71
|
+
return sample
|
|
72
|
+
|
|
21
73
|
|
|
22
74
|
class StreamedDataset(PVNetUKConcurrentDataset):
|
|
23
75
|
"""A torch dataset for creating concurrent PVNet inputs and national targets."""
|
|
@@ -38,7 +90,7 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
38
90
|
super().__init__(config_filename, start_time, end_time, gsp_ids=None)
|
|
39
91
|
|
|
40
92
|
# Load and nornmalise the national GSP data to use as target values
|
|
41
|
-
national_gsp_data = (
|
|
93
|
+
self.national_gsp_data = (
|
|
42
94
|
open_gsp(
|
|
43
95
|
zarr_path=self.config.input_data.gsp.zarr_path,
|
|
44
96
|
boundaries_version=self.config.input_data.gsp.boundaries_version
|
|
@@ -46,8 +98,6 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
46
98
|
.sel(gsp_id=0)
|
|
47
99
|
.compute()
|
|
48
100
|
)
|
|
49
|
-
self.national_gsp_data = national_gsp_data / national_gsp_data.effective_capacity_mwp
|
|
50
|
-
|
|
51
101
|
|
|
52
102
|
def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
|
|
53
103
|
"""Generate a concurrent PVNet sample for given init-time.
|
|
@@ -56,33 +106,32 @@ class StreamedDataset(PVNetUKConcurrentDataset):
|
|
|
56
106
|
t0: init-time for sample
|
|
57
107
|
"""
|
|
58
108
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
109
|
+
# Get the PVNet input batch
|
|
110
|
+
pvnet_inputs: NumpyBatch = super()._get_sample(t0)
|
|
62
111
|
|
|
112
|
+
# Construct an array of valid times for eahc forecast horizon
|
|
63
113
|
valid_times = pd.date_range(
|
|
64
114
|
t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
|
|
65
115
|
t0+minutes(self.config.input_data.gsp.interval_end_minutes),
|
|
66
116
|
freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
|
|
67
117
|
)
|
|
68
118
|
|
|
69
|
-
|
|
119
|
+
# Get the GSP and national capacities
|
|
120
|
+
location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
|
|
70
121
|
total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
|
|
71
|
-
|
|
122
|
+
|
|
123
|
+
# Calculate requited inputs for the sample
|
|
72
124
|
relative_capacities = location_capacities / total_capacity
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Shape: [num_locations]
|
|
84
|
-
"relative_capacity": relative_capacities,
|
|
85
|
-
}
|
|
125
|
+
target = self.national_gsp_data.sel(time_utc=valid_times).values / total_capacity
|
|
126
|
+
last_outturn = self.national_gsp_data.sel(time_utc=t0).values / total_capacity
|
|
127
|
+
|
|
128
|
+
return construct_sample(
|
|
129
|
+
pvnet_inputs=pvnet_inputs,
|
|
130
|
+
valid_times=valid_times,
|
|
131
|
+
relative_capacities=relative_capacities,
|
|
132
|
+
target=target,
|
|
133
|
+
last_outturn=last_outturn,
|
|
134
|
+
)
|
|
86
135
|
|
|
87
136
|
@override
|
|
88
137
|
def __getitem__(self, idx: int) -> SumNumpySample:
|
pvnet_summation/load_model.py
CHANGED
|
@@ -20,10 +20,14 @@ def get_model_from_checkpoints(
|
|
|
20
20
|
) -> tuple[torch.nn.Module, dict, str | None, str | None]:
|
|
21
21
|
"""Load a model from its checkpoint directory
|
|
22
22
|
|
|
23
|
+
Args:
|
|
24
|
+
checkpoint_dir_path: str path to the directory with the model files
|
|
25
|
+
val_best (optional): if True, load the best epoch model; otherwise, load the last
|
|
26
|
+
|
|
23
27
|
Returns:
|
|
24
28
|
tuple:
|
|
25
29
|
model: nn.Module of pretrained model.
|
|
26
|
-
model_config:
|
|
30
|
+
model_config: dict of model config used to train the model.
|
|
27
31
|
datamodule_config: path to datamodule used to create samples e.g train/test split info.
|
|
28
32
|
experiment_configs: path to the full experimental config.
|
|
29
33
|
|
|
@@ -4,6 +4,7 @@ import os
|
|
|
4
4
|
import shutil
|
|
5
5
|
import time
|
|
6
6
|
from importlib.metadata import version
|
|
7
|
+
from math import prod
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
import hydra
|
|
@@ -293,6 +294,12 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
293
294
|
"""
|
|
294
295
|
super().__init__()
|
|
295
296
|
|
|
297
|
+
if (output_quantiles is not None):
|
|
298
|
+
if output_quantiles != sorted(output_quantiles):
|
|
299
|
+
raise ValueError("output_quantiles should be in ascending order")
|
|
300
|
+
if 0.5 not in output_quantiles:
|
|
301
|
+
raise ValueError("Quantiles must include 0.5")
|
|
302
|
+
|
|
296
303
|
self.output_quantiles = output_quantiles
|
|
297
304
|
|
|
298
305
|
self.num_input_locations = num_input_locations
|
|
@@ -309,17 +316,21 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
309
316
|
# Store whether the model should use quantile regression or simply predict the mean
|
|
310
317
|
self.use_quantile_regression = self.output_quantiles is not None
|
|
311
318
|
|
|
312
|
-
#
|
|
319
|
+
# Also store the final output shape
|
|
313
320
|
if self.use_quantile_regression:
|
|
314
|
-
self.
|
|
321
|
+
self.output_shape = (self.forecast_len, len(input_quantiles))
|
|
315
322
|
else:
|
|
316
|
-
self.
|
|
323
|
+
self.output_shape = (self.forecast_len,)
|
|
324
|
+
|
|
325
|
+
# Store the number of output features and that the model should predict for
|
|
326
|
+
self.num_output_features = prod(self.output_shape)
|
|
317
327
|
|
|
318
328
|
# Store the expected input shape
|
|
319
329
|
if input_quantiles is None:
|
|
320
330
|
self.input_shape = (self.num_input_locations, self.forecast_len)
|
|
321
331
|
else:
|
|
322
332
|
self.input_shape = (self.num_input_locations, self.forecast_len, len(input_quantiles))
|
|
333
|
+
|
|
323
334
|
|
|
324
335
|
def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
|
|
325
336
|
"""Convert network prediction into a point prediction.
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Neural network architecture based on dense layers applied independently at each horizon"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from pvnet_summation.data.datamodule import SumTensorBatch
|
|
9
|
+
from pvnet_summation.models.base_model import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HorizonDenseModel(BaseModel):
|
|
13
|
+
"""Neural network architecture based on dense layers applied independently at each horizon.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
output_quantiles: list[float] | None,
|
|
19
|
+
num_input_locations: int,
|
|
20
|
+
input_quantiles: list[float] | None,
|
|
21
|
+
history_minutes: int,
|
|
22
|
+
forecast_minutes: int,
|
|
23
|
+
interval_minutes: int,
|
|
24
|
+
output_network: torch.nn.Module,
|
|
25
|
+
predict_difference_from_sum: bool = False,
|
|
26
|
+
use_horizon_encoding: bool = False,
|
|
27
|
+
use_solar_position: bool = False,
|
|
28
|
+
force_non_crossing: bool = False,
|
|
29
|
+
beta: float = 3,
|
|
30
|
+
):
|
|
31
|
+
"""Neural network architecture based on dense layers applied independently at each horizon.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
|
35
|
+
None the output is a single value.
|
|
36
|
+
num_input_locations: The number of input locations (e.g. number of GSPs)
|
|
37
|
+
input_quantiles: A list of float (0.0, 1.0) quantiles which PVNet predicts for. If set
|
|
38
|
+
to None we assume PVNet predicts a single value
|
|
39
|
+
history_minutes (int): Length of the GSP history period in minutes
|
|
40
|
+
forecast_minutes (int): Length of the GSP forecast period in minutes
|
|
41
|
+
interval_minutes: The interval in minutes between each timestep in the data
|
|
42
|
+
output_network: A partially instantiated pytorch Module class used top predict the
|
|
43
|
+
outturn at each horizon.
|
|
44
|
+
predict_difference_from_sum: Whether to predict the difference from the sum of locations
|
|
45
|
+
else the total is predicted directly
|
|
46
|
+
use_horizon_encoding: Whether to use the forecast horizon as an input feature
|
|
47
|
+
use_solar_position: Whether to use the solar coordinates as input features
|
|
48
|
+
force_non_crossing: If predicting quantile, whether to predict the quantiles other than
|
|
49
|
+
the median by predicting the distance between them and integrating.
|
|
50
|
+
beta: If using force_non_crossing, the beta value to use in the softplus activation
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
super().__init__(
|
|
54
|
+
output_quantiles,
|
|
55
|
+
num_input_locations,
|
|
56
|
+
input_quantiles,
|
|
57
|
+
history_minutes,
|
|
58
|
+
forecast_minutes,
|
|
59
|
+
interval_minutes,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if force_non_crossing:
|
|
63
|
+
assert self.use_quantile_regression
|
|
64
|
+
|
|
65
|
+
self.use_horizon_encoding = use_horizon_encoding
|
|
66
|
+
self.predict_difference_from_sum = predict_difference_from_sum
|
|
67
|
+
self.force_non_crossing = force_non_crossing
|
|
68
|
+
self.beta = beta
|
|
69
|
+
self.use_solar_position = use_solar_position
|
|
70
|
+
|
|
71
|
+
in_features = 1 if self.input_quantiles is None else len(self.input_quantiles)
|
|
72
|
+
in_features = in_features * self.num_input_locations
|
|
73
|
+
|
|
74
|
+
if use_horizon_encoding:
|
|
75
|
+
in_features += 1
|
|
76
|
+
|
|
77
|
+
if use_solar_position:
|
|
78
|
+
in_features += 2
|
|
79
|
+
|
|
80
|
+
out_features = (len(self.output_quantiles) if self.use_quantile_regression else 1)
|
|
81
|
+
|
|
82
|
+
model = output_network(in_features=in_features, out_features=out_features)
|
|
83
|
+
|
|
84
|
+
# Add linear layer if predicting difference from sum
|
|
85
|
+
# - This allows difference to be positive or negative
|
|
86
|
+
# Also add linear layer if we are applying force_non_crossing since a softplus will be used
|
|
87
|
+
if predict_difference_from_sum or force_non_crossing:
|
|
88
|
+
model = nn.Sequential(
|
|
89
|
+
model,
|
|
90
|
+
nn.Linear(out_features, out_features),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.model = model
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def forward(self, x: SumTensorBatch) -> torch.Tensor:
|
|
97
|
+
"""Run model forward"""
|
|
98
|
+
|
|
99
|
+
# x["pvnet_outputs"] has shape [batch, locs, horizon, (quantile)]
|
|
100
|
+
batch_size = x["pvnet_outputs"].shape[0]
|
|
101
|
+
x_in = torch.swapaxes(x["pvnet_outputs"], 1, 2) # -> [batch, horizon, locs, (quantile)]
|
|
102
|
+
x_in = torch.flatten(x_in, start_dim=2) # -> [batch, horizon, locs*(quantile)]
|
|
103
|
+
|
|
104
|
+
if self.use_horizon_encoding:
|
|
105
|
+
horizon_encoding = torch.linspace(
|
|
106
|
+
start=0,
|
|
107
|
+
end=1,
|
|
108
|
+
steps=self.forecast_len,
|
|
109
|
+
device=x_in.device,
|
|
110
|
+
dtype=x_in.dtype,
|
|
111
|
+
)
|
|
112
|
+
horizon_encoding = horizon_encoding.tile((batch_size,1)).unsqueeze(-1)
|
|
113
|
+
x_in = torch.cat([x_in, horizon_encoding], dim=2)
|
|
114
|
+
|
|
115
|
+
if self.use_solar_position:
|
|
116
|
+
x_in = torch.cat(
|
|
117
|
+
[x_in, x["azimuth"].unsqueeze(-1), x["elevation"].unsqueeze(-1)],
|
|
118
|
+
dim=2
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
x_in = torch.flatten(x_in, start_dim=0, end_dim=1) # -> [batch*horizon, features]
|
|
122
|
+
|
|
123
|
+
out = self.model(x_in)
|
|
124
|
+
out = out.view(batch_size, *self.output_shape) # -> [batch, horizon, (quantile)]
|
|
125
|
+
|
|
126
|
+
if self.force_non_crossing:
|
|
127
|
+
|
|
128
|
+
# Get the prediction of the median
|
|
129
|
+
idx = self.output_quantiles.index(0.5)
|
|
130
|
+
if self.predict_difference_from_sum:
|
|
131
|
+
loc_sum = self.sum_of_locations(x).unsqueeze(-1)
|
|
132
|
+
y_median = loc_sum + out[..., idx:idx+1]
|
|
133
|
+
else:
|
|
134
|
+
y_median = out[..., idx:idx+1]
|
|
135
|
+
|
|
136
|
+
# These are the differences between the remaining quantiles
|
|
137
|
+
dy_below = F.softplus(out[..., :idx], beta=self.beta)
|
|
138
|
+
dy_above = F.softplus(out[..., idx+1:], beta=self.beta)
|
|
139
|
+
|
|
140
|
+
# Find the absolute value of the quantile predictions from the differences
|
|
141
|
+
y_below = []
|
|
142
|
+
y = y_median
|
|
143
|
+
for i in range(dy_below.shape[-1]):
|
|
144
|
+
# We detach y to avoid the gradients caused by errors from one quantile
|
|
145
|
+
# prediction flowing back to affect the other quantile predictions.
|
|
146
|
+
# For example if the 0.9 quantile prediction was too low, we don't want the
|
|
147
|
+
# gradient to pull the 0.5 quantile prediction higher to compensate.
|
|
148
|
+
y = y.detach() - dy_below[..., i:i+1]
|
|
149
|
+
y_below.append(y)
|
|
150
|
+
|
|
151
|
+
y_above = []
|
|
152
|
+
y = y_median
|
|
153
|
+
for i in range(dy_above.shape[-1]):
|
|
154
|
+
y = y.detach() + dy_above[..., i:i+1]
|
|
155
|
+
y_above.append(y)
|
|
156
|
+
|
|
157
|
+
# Compile the quantile predictions in the correct order
|
|
158
|
+
out = torch.cat(y_below[::-1] + [y_median,] + y_above, dim=-1)
|
|
159
|
+
|
|
160
|
+
else:
|
|
161
|
+
|
|
162
|
+
if self.predict_difference_from_sum:
|
|
163
|
+
loc_sum = self.sum_of_locations(x)
|
|
164
|
+
|
|
165
|
+
if self.use_quantile_regression:
|
|
166
|
+
loc_sum = loc_sum.unsqueeze(-1)
|
|
167
|
+
|
|
168
|
+
out = loc_sum + out
|
|
169
|
+
|
|
170
|
+
# Use leaky relu as a soft clip to 0
|
|
171
|
+
return F.leaky_relu(out, negative_slope=0.01)
|
|
@@ -7,13 +7,21 @@ import torch
|
|
|
7
7
|
from lightning.pytorch import Callback, Trainer, seed_everything
|
|
8
8
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
9
|
from lightning.pytorch.loggers import Logger, WandbLogger
|
|
10
|
-
from ocf_data_sampler.torch_datasets.
|
|
10
|
+
from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import (
|
|
11
|
+
batch_to_tensor,
|
|
12
|
+
copy_batch_to_device,
|
|
13
|
+
)
|
|
11
14
|
from omegaconf import DictConfig, OmegaConf
|
|
12
15
|
from pvnet.models import BaseModel as PVNetBaseModel
|
|
13
16
|
from tqdm import tqdm
|
|
14
17
|
|
|
15
18
|
from pvnet_summation.data.datamodule import PresavedDataModule, StreamedDataModule
|
|
16
|
-
from pvnet_summation.utils import
|
|
19
|
+
from pvnet_summation.utils import (
|
|
20
|
+
DATAMODULE_CONFIG_NAME,
|
|
21
|
+
FULL_CONFIG_NAME,
|
|
22
|
+
MODEL_CONFIG_NAME,
|
|
23
|
+
create_pvnet_model_config,
|
|
24
|
+
)
|
|
17
25
|
|
|
18
26
|
log = logging.getLogger(__name__)
|
|
19
27
|
|
|
@@ -21,9 +29,8 @@ log = logging.getLogger(__name__)
|
|
|
21
29
|
def resolve_monitor_loss(output_quantiles: list | None) -> str:
|
|
22
30
|
"""Return the desired metric to monitor based on whether quantile regression is being used.
|
|
23
31
|
|
|
24
|
-
|
|
32
|
+
Adds the option to use
|
|
25
33
|
monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}"
|
|
26
|
-
|
|
27
34
|
in early stopping and model checkpoint callbacks so the callbacks config does not need to be
|
|
28
35
|
modified depending on whether quantile regression is being used or not.
|
|
29
36
|
"""
|
|
@@ -86,8 +93,22 @@ def train(config: DictConfig) -> None:
|
|
|
86
93
|
os.makedirs(f"{save_dir}/train")
|
|
87
94
|
os.makedirs(f"{save_dir}/val")
|
|
88
95
|
|
|
96
|
+
pvnet_data_config_path = f"{save_dir}/pvnet_data_config.yaml"
|
|
97
|
+
|
|
98
|
+
data_source_paths = OmegaConf.to_container(
|
|
99
|
+
config.datamodule.data_source_paths,
|
|
100
|
+
resolve=True,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
create_pvnet_model_config(
|
|
104
|
+
save_path=pvnet_data_config_path,
|
|
105
|
+
repo=config.datamodule.pvnet_model.model_id,
|
|
106
|
+
commit=config.datamodule.pvnet_model.revision,
|
|
107
|
+
data_source_paths=data_source_paths,
|
|
108
|
+
)
|
|
109
|
+
|
|
89
110
|
datamodule = StreamedDataModule(
|
|
90
|
-
configuration=
|
|
111
|
+
configuration=pvnet_data_config_path,
|
|
91
112
|
num_workers=config.datamodule.num_workers,
|
|
92
113
|
prefetch_factor=config.datamodule.prefetch_factor,
|
|
93
114
|
train_period=config.datamodule.train_period,
|
|
@@ -107,8 +128,8 @@ def train(config: DictConfig) -> None:
|
|
|
107
128
|
log.info(f"Saving {split} outputs")
|
|
108
129
|
dataloader = dataloader_func(shuffle=True)
|
|
109
130
|
|
|
110
|
-
|
|
111
|
-
|
|
131
|
+
# If max_num_samples set to None use all samples
|
|
132
|
+
max_num_samples = max_num_samples or len(dataloader)
|
|
112
133
|
|
|
113
134
|
for i, sample in tqdm(zip(range(max_num_samples), dataloader), total=max_num_samples):
|
|
114
135
|
# Run PVNet inputs though model
|
pvnet_summation/utils.py
CHANGED
|
@@ -3,8 +3,10 @@ import logging
|
|
|
3
3
|
|
|
4
4
|
import rich.syntax
|
|
5
5
|
import rich.tree
|
|
6
|
+
import yaml
|
|
6
7
|
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from pvnet.models.base_model import BaseModel as PVNetBaseModel
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
10
12
|
|
|
@@ -17,11 +19,10 @@ MODEL_CARD_NAME = "README.md"
|
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def
|
|
21
|
-
"""
|
|
22
|
+
def maybe_apply_debug_mode(config: DictConfig) -> None:
|
|
23
|
+
"""Check if debugging run is requested and force debug-frendly configuration
|
|
22
24
|
|
|
23
|
-
Controlled by main config file
|
|
24
|
-
- forcing debug friendly configuration
|
|
25
|
+
Controlled by main config file
|
|
25
26
|
|
|
26
27
|
Modifies DictConfig in place.
|
|
27
28
|
|
|
@@ -52,7 +53,7 @@ def run_config_utilities(config: DictConfig) -> None:
|
|
|
52
53
|
@rank_zero_only
|
|
53
54
|
def print_config(
|
|
54
55
|
config: DictConfig,
|
|
55
|
-
fields: tuple[str] = (
|
|
56
|
+
fields: tuple[str, ...] = (
|
|
56
57
|
"trainer",
|
|
57
58
|
"model",
|
|
58
59
|
"datamodule",
|
|
@@ -66,7 +67,7 @@ def print_config(
|
|
|
66
67
|
|
|
67
68
|
Args:
|
|
68
69
|
config (DictConfig): Configuration composed by Hydra.
|
|
69
|
-
fields (
|
|
70
|
+
fields (tuple[str, ...], optional): Determines which main fields from config will
|
|
70
71
|
be printed and in what order.
|
|
71
72
|
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
72
73
|
"""
|
|
@@ -85,3 +86,47 @@ def print_config(
|
|
|
85
86
|
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
86
87
|
|
|
87
88
|
rich.print(tree)
|
|
89
|
+
|
|
90
|
+
def populate_config_with_data_data_filepaths(config: dict, data_source_paths: dict) -> dict:
|
|
91
|
+
"""Populate the data source filepaths in the config
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: The data config
|
|
95
|
+
data_source_paths: A dictionary of data paths for the different input sources
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
# Replace the GSP data path
|
|
99
|
+
config["input_data"]["gsp"]["zarr_path"] = data_source_paths["gsp"]
|
|
100
|
+
|
|
101
|
+
# Replace satellite data path if using it
|
|
102
|
+
if "satellite" in config["input_data"]:
|
|
103
|
+
if config["input_data"]["satellite"]["zarr_path"] != "":
|
|
104
|
+
config["input_data"]["satellite"]["zarr_path"] = data_source_paths["satellite"]
|
|
105
|
+
|
|
106
|
+
# NWP is nested so much be treated separately
|
|
107
|
+
if "nwp" in config["input_data"]:
|
|
108
|
+
nwp_config = config["input_data"]["nwp"]
|
|
109
|
+
for nwp_source in nwp_config.keys():
|
|
110
|
+
provider = nwp_config[nwp_source]["provider"]
|
|
111
|
+
assert provider in data_source_paths["nwp"], f"Missing NWP path: {provider}"
|
|
112
|
+
nwp_config[nwp_source]["zarr_path"] = data_source_paths["nwp"][provider]
|
|
113
|
+
|
|
114
|
+
return config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def create_pvnet_model_config(
|
|
118
|
+
save_path: str,
|
|
119
|
+
repo: str,
|
|
120
|
+
commit: str,
|
|
121
|
+
data_source_paths: dict,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Create the data config needed to run the PVNet model"""
|
|
124
|
+
data_config_path = PVNetBaseModel.get_data_config(repo, revision=commit)
|
|
125
|
+
|
|
126
|
+
with open(data_config_path) as file:
|
|
127
|
+
data_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
128
|
+
|
|
129
|
+
data_config = populate_config_with_data_data_filepaths(data_config, data_source_paths)
|
|
130
|
+
|
|
131
|
+
with open(save_path, "w") as file:
|
|
132
|
+
yaml.dump(data_config, file, default_flow_style=False)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet_summation
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: PVNet_summation
|
|
5
5
|
Author-email: James Fulton <info@openclimatefix.org>
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
9
|
Requires-Dist: pvnet>=5.0.0
|
|
10
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
10
|
+
Requires-Dist: ocf-data-sampler>=0.6.0
|
|
11
11
|
Requires-Dist: numpy
|
|
12
12
|
Requires-Dist: pandas
|
|
13
13
|
Requires-Dist: matplotlib
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
|
|
2
|
+
pvnet_summation/load_model.py,sha256=mQJXJ9p8wb25CVsm5UBGb0IL6xGZj-81iIBKHsNdQMY,2515
|
|
3
|
+
pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
|
|
4
|
+
pvnet_summation/utils.py,sha256=JyqzDQjABCtRsdLgxr5j9K9AdmNlQhmYGenj6mKGnFY,4352
|
|
5
|
+
pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
|
|
6
|
+
pvnet_summation/data/datamodule.py,sha256=Pa2iip-ALihhkAVtqDBPJZ93vh4evJwG9L9YCJiRQag,12517
|
|
7
|
+
pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
|
|
8
|
+
pvnet_summation/models/base_model.py,sha256=mxrEq8k6NAVpezLx3ORPM33OrXzRccVD2ErFkPIw8bc,12496
|
|
9
|
+
pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
|
|
10
|
+
pvnet_summation/models/horizon_dense_model.py,sha256=8NfJiO4upQT8ksqwDn1Jkct5-nrbs_EKfKBseVRay1U,7011
|
|
11
|
+
pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
|
|
12
|
+
pvnet_summation/training/lightning_module.py,sha256=t16gcAc4Fmi1g26dhQwQOm4qe2mwnTfEBbOyH_BFZ4o,8695
|
|
13
|
+
pvnet_summation/training/plots.py,sha256=VZHyzI6UvCEd4nmXiJCF1FiVlpDyFHTxX6_rc0vmJrU,2248
|
|
14
|
+
pvnet_summation/training/train.py,sha256=ze4LCr4XvJ18NjiZhR9KslVf_5HoC1xjGIhBcfw8u5E,8000
|
|
15
|
+
pvnet_summation-1.1.0.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
16
|
+
pvnet_summation-1.1.0.dist-info/METADATA,sha256=uy-zlQ8IyRNgM27nYxL207m41MrlHfuCPDzp0474-e8,3720
|
|
17
|
+
pvnet_summation-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
pvnet_summation-1.1.0.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
|
|
19
|
+
pvnet_summation-1.1.0.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
pvnet_summation/__init__.py,sha256=8bjkx2pvF7lZ2W5BiTpHr7iqpkRXc3vW5K1pxJAWaj0,22
|
|
2
|
-
pvnet_summation/load_model.py,sha256=GfreRSaKVTWjV9fnJGNYjp09wrpZwaTunHijdff6cyc,2338
|
|
3
|
-
pvnet_summation/optimizers.py,sha256=kuR3PUnISiAO5bSaKhq_7vqRKZ0gO5cRS4UbjmKgq1c,6472
|
|
4
|
-
pvnet_summation/utils.py,sha256=G7l2iZK8qNWEau27pJYPvGOLSzPaSttFrGwr75yTlPQ,2628
|
|
5
|
-
pvnet_summation/data/__init__.py,sha256=AYJFlJ3KaAQXED0PxuuknI2lKEeFMFLJiJ9b6-H8398,81
|
|
6
|
-
pvnet_summation/data/datamodule.py,sha256=YtuLycH4P4-bEAreqNk9Pbu848CDKX_F-Z1LLF5re3o,10651
|
|
7
|
-
pvnet_summation/models/__init__.py,sha256=v3KMMH_bz9YGUFWsrb5Ndg-d_dgxQPw7yiFahQAag4c,103
|
|
8
|
-
pvnet_summation/models/base_model.py,sha256=qtsbH8WqrRUQdWpBdeLJ3yz3dlhUeLFUKzVvX7uiopo,12074
|
|
9
|
-
pvnet_summation/models/dense_model.py,sha256=vh3Hrm-n7apgVkta_RtQ5mdxb6jiJNFm3ObWukSBgdU,2305
|
|
10
|
-
pvnet_summation/training/__init__.py,sha256=2fbydXPJFk527DUGPlNV0Teaqvu4WNp8hgcODwHJFEw,110
|
|
11
|
-
pvnet_summation/training/lightning_module.py,sha256=t16gcAc4Fmi1g26dhQwQOm4qe2mwnTfEBbOyH_BFZ4o,8695
|
|
12
|
-
pvnet_summation/training/plots.py,sha256=VZHyzI6UvCEd4nmXiJCF1FiVlpDyFHTxX6_rc0vmJrU,2248
|
|
13
|
-
pvnet_summation/training/train.py,sha256=qijjwYdFxbHp-usdSiM2XHetQu1BSU7sL1Z20OOU7TM,7452
|
|
14
|
-
pvnet_summation-1.0.2.dist-info/licenses/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
15
|
-
pvnet_summation-1.0.2.dist-info/METADATA,sha256=Kv6CrUJmyBTf9mOfP8pF0CCBx0i1_3dQ5ap4n_JqXhM,3721
|
|
16
|
-
pvnet_summation-1.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
-
pvnet_summation-1.0.2.dist-info/top_level.txt,sha256=5fWJ75RKtpaHUdLG_-2oDCInXeq4r1aMCxkZp5Wy-LQ,16
|
|
18
|
-
pvnet_summation-1.0.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|