PVNet 5.2.3__py3-none-any.whl → 5.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pvnet/datamodule.py CHANGED
@@ -1,4 +1,4 @@
1
- """ Data module for pytorch lightning """
1
+ """Data module for pytorch lightning"""
2
2
 
3
3
  import os
4
4
 
@@ -6,10 +6,9 @@ import numpy as np
6
6
  from lightning.pytorch import LightningDataModule
7
7
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
8
8
  from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
9
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
10
- from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
9
+ from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset
11
10
  from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor
12
- from torch.utils.data import DataLoader, Dataset, Subset
11
+ from torch.utils.data import DataLoader, Subset
13
12
 
14
13
 
15
14
  def collate_fn(samples: list[NumpySample]) -> TensorBatch:
@@ -17,7 +16,7 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
17
16
  return batch_to_tensor(stack_np_samples_into_batch(samples))
18
17
 
19
18
 
20
- class BaseDataModule(LightningDataModule):
19
+ class PVNetDataModule(LightningDataModule):
21
20
  """Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
22
21
 
23
22
  def __init__(
@@ -40,10 +39,10 @@ class BaseDataModule(LightningDataModule):
40
39
  batch_size: Batch size.
41
40
  num_workers: Number of workers to use in multiprocess batch loading.
42
41
  prefetch_factor: Number of batches loaded in advance by each worker.
43
- persistent_workers: If True, the data loader will not shut down the worker processes
44
- after a dataset has been consumed once. This allows to maintain the workers Dataset
42
+ persistent_workers: If True, the data loader will not shut down the worker processes
43
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
45
44
  instances alive.
46
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
45
+ pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
47
46
  before returning them.
48
47
  train_period: Date range filter for train dataloader.
49
48
  val_period: Date range filter for val dataloader.
@@ -70,7 +69,7 @@ class BaseDataModule(LightningDataModule):
70
69
  worker_init_fn=None,
71
70
  prefetch_factor=prefetch_factor,
72
71
  persistent_workers=persistent_workers,
73
- multiprocessing_context="spawn" if num_workers>0 else None,
72
+ multiprocessing_context="spawn" if num_workers > 0 else None,
74
73
  )
75
74
 
76
75
  def setup(self, stage: str | None = None):
@@ -79,16 +78,15 @@ class BaseDataModule(LightningDataModule):
79
78
  # This logic runs only once at the start of training, therefore the val dataset is only
80
79
  # shuffled once
81
80
  if stage == "fit":
82
-
83
81
  # Prepare the train dataset
84
82
  self.train_dataset = self._get_dataset(*self.train_period)
85
83
 
86
- # Prepare and pre-shuffle the val dataset and set seed for reproducibility
84
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
87
85
  val_dataset = self._get_dataset(*self.val_period)
88
86
 
89
87
  shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
90
88
  self.val_dataset = Subset(val_dataset, shuffled_indices)
91
-
89
+
92
90
  if self.dataset_pickle_dir is not None:
93
91
  os.makedirs(self.dataset_pickle_dir, exist_ok=True)
94
92
  train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
@@ -116,8 +114,8 @@ class BaseDataModule(LightningDataModule):
116
114
  if os.path.exists(filepath):
117
115
  os.remove(filepath)
118
116
 
119
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> Dataset:
120
- raise NotImplementedError
117
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset:
118
+ return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time)
121
119
 
122
120
  def train_dataloader(self) -> DataLoader:
123
121
  """Construct train dataloader"""
@@ -126,17 +124,3 @@ class BaseDataModule(LightningDataModule):
126
124
  def val_dataloader(self) -> DataLoader:
127
125
  """Construct val dataloader"""
128
126
  return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs)
129
-
130
-
131
- class UKRegionalDataModule(BaseDataModule):
132
- """Datamodule for streaming UK regional samples."""
133
-
134
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset:
135
- return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
136
-
137
-
138
- class SitesDataModule(BaseDataModule):
139
- """Datamodule for streaming site samples."""
140
-
141
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset:
142
- return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
@@ -1,4 +1,5 @@
1
1
  """Base model for all PVNet submodels"""
2
+
2
3
  import logging
3
4
  import os
4
5
  import shutil
@@ -32,7 +33,7 @@ def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHO
32
33
  """
33
34
  input_config = config["input_data"]
34
35
 
35
- for source in ["gsp", "satellite"]:
36
+ for source in ["generation", "satellite"]:
36
37
  if source in input_config:
37
38
  # If not empty - i.e. if used
38
39
  if input_config[source]["zarr_path"] != "":
@@ -78,8 +79,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
78
79
 
79
80
  # Replace the interval_end_minutes minutes
80
81
  nwp_config["interval_end_minutes"] = (
81
- nwp_config["interval_start_minutes"] +
82
- (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
82
+ nwp_config["interval_start_minutes"]
83
+ + (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
83
84
  * nwp_config["time_resolution_minutes"]
84
85
  )
85
86
 
@@ -96,20 +97,19 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
96
97
 
97
98
  # Replace the interval_end_minutes minutes
98
99
  sat_config["interval_end_minutes"] = (
99
- sat_config["interval_start_minutes"] +
100
- (model.sat_encoder.sequence_length - 1)
101
- * sat_config["time_resolution_minutes"]
100
+ sat_config["interval_start_minutes"]
101
+ + (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"]
102
102
  )
103
103
 
104
104
  if "pv" in input_config:
105
105
  if not model.include_pv:
106
106
  del input_config["pv"]
107
107
 
108
- if "gsp" in input_config:
109
- gsp_config = input_config["gsp"]
108
+ if "generation" in input_config:
109
+ generation_config = input_config["generation"]
110
110
 
111
111
  # Replace the forecast minutes
112
- gsp_config["interval_end_minutes"] = model.forecast_minutes
112
+ generation_config["interval_end_minutes"] = model.forecast_minutes
113
113
 
114
114
  if "solar_position" in input_config:
115
115
  solar_config = input_config["solar_position"]
@@ -138,9 +138,9 @@ def download_from_hf(
138
138
  force_download: Whether to force a new download
139
139
  max_retries: Maximum number of retry attempts
140
140
  wait_time: Wait time (in seconds) before retrying
141
- token:
141
+ token:
142
142
  HF authentication token. If True, the token is read from the HuggingFace config folder.
143
- If a string, it is used as the authentication token.
143
+ If a string, it is used as the authentication token.
144
144
 
145
145
  Returns:
146
146
  The local file path of the downloaded file(s)
@@ -160,7 +160,7 @@ def download_from_hf(
160
160
  return [f"{save_dir}/{f}" for f in filename]
161
161
  else:
162
162
  return f"{save_dir}/{filename}"
163
-
163
+
164
164
  except Exception as e:
165
165
  if attempt == max_retries:
166
166
  raise Exception(
@@ -205,7 +205,7 @@ class HuggingfaceMixin:
205
205
  force_download=force_download,
206
206
  max_retries=5,
207
207
  wait_time=10,
208
- token=token
208
+ token=token,
209
209
  )
210
210
 
211
211
  with open(config_file, "r") as f:
@@ -240,7 +240,7 @@ class HuggingfaceMixin:
240
240
  force_download=force_download,
241
241
  max_retries=5,
242
242
  wait_time=10,
243
- token=token
243
+ token=token,
244
244
  )
245
245
 
246
246
  return data_config_file
@@ -301,7 +301,7 @@ class HuggingfaceMixin:
301
301
  # Save cleaned version of input data configuration file
302
302
  with open(data_config_path) as cfg:
303
303
  config = yaml.load(cfg, Loader=yaml.FullLoader)
304
-
304
+
305
305
  config = fill_config_paths_with_placeholder(config)
306
306
  config = minimize_config_for_model(config, self)
307
307
 
@@ -311,7 +311,7 @@ class HuggingfaceMixin:
311
311
  # Save the datamodule config
312
312
  if datamodule_config_path is not None:
313
313
  shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
314
-
314
+
315
315
  # Save the full experimental config
316
316
  if experiment_config_path is not None:
317
317
  shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
@@ -378,7 +378,6 @@ class HuggingfaceMixin:
378
378
  packages_to_display = ["pvnet", "ocf-data-sampler"]
379
379
  packages_and_versions = {package: version(package) for package in packages_to_display}
380
380
 
381
-
382
381
  package_versions_markdown = ""
383
382
  for package, v in packages_and_versions.items():
384
383
  package_versions_markdown += f" - {package}=={v}\n"
@@ -399,23 +398,19 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
399
398
  history_minutes: int,
400
399
  forecast_minutes: int,
401
400
  output_quantiles: list[float] | None = None,
402
- target_key: str = "gsp",
403
401
  interval_minutes: int = 30,
404
402
  ):
405
403
  """Abtstract base class for PVNet submodels.
406
404
 
407
405
  Args:
408
- history_minutes (int): Length of the GSP history period in minutes
409
- forecast_minutes (int): Length of the GSP forecast period in minutes
406
+ history_minutes (int): Length of the generation history period in minutes
407
+ forecast_minutes (int): Length of the generation forecast period in minutes
410
408
  output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
411
409
  None the output is a single value.
412
- target_key: The key of the target variable in the batch
413
410
  interval_minutes: The interval in minutes between each timestep in the data
414
411
  """
415
412
  super().__init__()
416
413
 
417
- self._target_key = target_key
418
-
419
414
  self.history_minutes = history_minutes
420
415
  self.forecast_minutes = forecast_minutes
421
416
  self.output_quantiles = output_quantiles
pvnet/models/ensemble.py CHANGED
@@ -26,7 +26,6 @@ class Ensemble(BaseModel):
26
26
  output_quantiles = []
27
27
  history_minutes = []
28
28
  forecast_minutes = []
29
- target_key = []
30
29
  interval_minutes = []
31
30
 
32
31
  # Get some model properties from each model
@@ -34,7 +33,6 @@ class Ensemble(BaseModel):
34
33
  output_quantiles.append(model.output_quantiles)
35
34
  history_minutes.append(model.history_minutes)
36
35
  forecast_minutes.append(model.forecast_minutes)
37
- target_key.append(model._target_key)
38
36
  interval_minutes.append(model.interval_minutes)
39
37
 
40
38
  # Check these properties are all the same
@@ -42,7 +40,6 @@ class Ensemble(BaseModel):
42
40
  output_quantiles,
43
41
  history_minutes,
44
42
  forecast_minutes,
45
- target_key,
46
43
  interval_minutes,
47
44
  ]:
48
45
  assert all([p == param_list[0] for p in param_list]), param_list
@@ -51,7 +48,6 @@ class Ensemble(BaseModel):
51
48
  history_minutes=history_minutes[0],
52
49
  forecast_minutes=forecast_minutes[0],
53
50
  output_quantiles=output_quantiles[0],
54
- target_key=target_key[0],
55
51
  interval_minutes=interval_minutes[0],
56
52
  )
57
53
 
@@ -28,8 +28,8 @@ class LateFusionModel(BaseModel):
28
28
  - NWP, if included, is put through a similar encoder.
29
29
  - PV site-level data, if included, is put through an encoder which transforms it from 2D, with
30
30
  time and system-ID dimensions, to become a 1D feature vector.
31
- - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
32
- paramters* are concatenated into a 1D feature vector and passed through another neural
31
+ - The satellite features*, NWP features*, PV site-level features*, location ID embedding*, and
32
+ sun paramters* are concatenated into a 1D feature vector and passed through another neural
33
33
  network to combine them and produce a forecast.
34
34
 
35
35
  * if included
@@ -43,10 +43,10 @@ class LateFusionModel(BaseModel):
43
43
  sat_encoder: AbstractNWPSatelliteEncoder | None = None,
44
44
  pv_encoder: AbstractSitesEncoder | None = None,
45
45
  add_image_embedding_channel: bool = False,
46
- include_gsp_yield_history: bool = True,
47
- include_site_yield_history: bool = False,
46
+ include_generation_history: bool = False,
48
47
  include_sun: bool = True,
49
48
  include_time: bool = False,
49
+ t0_embedding_dim: int = 0,
50
50
  location_id_mapping: dict[Any, int] | None = None,
51
51
  embedding_dim: int = 16,
52
52
  forecast_minutes: int = 30,
@@ -56,7 +56,6 @@ class LateFusionModel(BaseModel):
56
56
  nwp_forecast_minutes: DictConfig | None = None,
57
57
  nwp_history_minutes: DictConfig | None = None,
58
58
  pv_history_minutes: int | None = None,
59
- target_key: str = "gsp",
60
59
  interval_minutes: int = 30,
61
60
  nwp_interval_minutes: DictConfig | None = None,
62
61
  pv_interval_minutes: int = 5,
@@ -83,14 +82,15 @@ class LateFusionModel(BaseModel):
83
82
  pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
84
83
  PV data from 2D into a 1D feature vector.
85
84
  add_image_embedding_channel: Add a channel to the NWP and satellite data with the
86
- embedding of the GSP ID.
87
- include_gsp_yield_history: Include GSP yield data.
88
- include_site_yield_history: Include Site yield data.
85
+ embedding of the location ID.
86
+ include_generation_history: Include generation yield data.
89
87
  include_sun: Include sun azimuth and altitude data.
90
88
  include_time: Include sine and cosine of dates and times.
89
+ t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90
+ if set to 0.
91
91
  location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
92
92
  not used if this is not provided.
93
- embedding_dim: Number of embedding dimensions to use for GSP ID.
93
+ embedding_dim: Number of embedding dimensions to use for location ID.
94
94
  forecast_minutes: The amount of minutes that should be forecasted.
95
95
  history_minutes: The default amount of historical minutes that are used.
96
96
  sat_history_minutes: Length of recent observations used for satellite inputs. Defaults
@@ -103,7 +103,6 @@ class LateFusionModel(BaseModel):
103
103
  `history_minutes` if not provided.
104
104
  pv_history_minutes: Length of recent site-level PV data used as
105
105
  input. Defaults to `history_minutes` if not provided.
106
- target_key: The key of the target variable in the batch.
107
106
  interval_minutes: The interval between each sample of the target data
108
107
  nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
109
108
  data for each source
@@ -114,17 +113,16 @@ class LateFusionModel(BaseModel):
114
113
  history_minutes=history_minutes,
115
114
  forecast_minutes=forecast_minutes,
116
115
  output_quantiles=output_quantiles,
117
- target_key=target_key,
118
116
  interval_minutes=interval_minutes,
119
117
  )
120
118
 
121
- self.include_gsp_yield_history = include_gsp_yield_history
122
- self.include_site_yield_history = include_site_yield_history
119
+ self.include_generation_history = include_generation_history
123
120
  self.include_sat = sat_encoder is not None
124
121
  self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
125
122
  self.include_pv = pv_encoder is not None
126
123
  self.include_sun = include_sun
127
124
  self.include_time = include_time
125
+ self.t0_embedding_dim = t0_embedding_dim
128
126
  self.location_id_mapping = location_id_mapping
129
127
  self.embedding_dim = embedding_dim
130
128
  self.add_image_embedding_channel = add_image_embedding_channel
@@ -133,8 +131,7 @@ class LateFusionModel(BaseModel):
133
131
 
134
132
  if self.location_id_mapping is None:
135
133
  logger.warning(
136
- "location_id_mapping` is not provided, defaulting to outdated GSP mapping"
137
- "(0 to 317)"
134
+ "location_id_mapping` is not provided, defaulting to outdated GSP mapping(0 to 317)"
138
135
  )
139
136
 
140
137
  # Note 318 is the 2024 UK GSP count, so this is a temporary fix
@@ -223,8 +220,7 @@ class LateFusionModel(BaseModel):
223
220
 
224
221
  self.pv_encoder = pv_encoder(
225
222
  sequence_length=pv_history_minutes // pv_interval_minutes + 1,
226
- target_key_to_use=self._target_key,
227
- input_key_to_use="site",
223
+ key_to_use="generation",
228
224
  )
229
225
 
230
226
  # Update num features
@@ -238,8 +234,7 @@ class LateFusionModel(BaseModel):
238
234
 
239
235
  if self.include_sun:
240
236
  self.sun_fc1 = nn.Linear(
241
- in_features=2
242
- * (self.forecast_len + self.history_len + 1),
237
+ in_features=2 * (self.forecast_len + self.history_len + 1),
243
238
  out_features=16,
244
239
  )
245
240
 
@@ -248,19 +243,16 @@ class LateFusionModel(BaseModel):
248
243
 
249
244
  if self.include_time:
250
245
  self.time_fc1 = nn.Linear(
251
- in_features=4
252
- * (self.forecast_len + self.history_len + 1),
246
+ in_features=4 * (self.forecast_len + self.history_len + 1),
253
247
  out_features=32,
254
248
  )
255
249
 
256
250
  # Update num features
257
251
  fusion_input_features += 32
258
252
 
259
- if include_gsp_yield_history:
260
- # Update num features
261
- fusion_input_features += self.history_len
253
+ fusion_input_features += self.t0_embedding_dim
262
254
 
263
- if include_site_yield_history:
255
+ if include_generation_history:
264
256
  # Update num features
265
257
  fusion_input_features += self.history_len + 1
266
258
 
@@ -269,15 +261,14 @@ class LateFusionModel(BaseModel):
269
261
  out_features=self.num_output_features,
270
262
  )
271
263
 
272
-
273
264
  def forward(self, x: TensorBatch) -> torch.Tensor:
274
265
  """Run model forward"""
275
266
 
276
267
  if self.use_id_embedding:
277
- # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
268
+ # eg: x['location_id'] = [1] with location_id_mapping = {1:0}, would give [0]
278
269
  id = torch.tensor(
279
- [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
280
- device=x[f"{self._target_key}_id"].device,
270
+ [self.location_id_mapping[i.item()] for i in x["location_id"]],
271
+ device=x["location_id"].device,
281
272
  dtype=torch.int64,
282
273
  )
283
274
 
@@ -308,32 +299,20 @@ class LateFusionModel(BaseModel):
308
299
  nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
309
300
  modes[f"nwp/{nwp_source}"] = nwp_out
310
301
 
311
- # *********************** Site Data *************************************
312
- # Add site-level yield history
313
- if self.include_site_yield_history:
314
- site_history = x["site"][:, : self.history_len + 1].float()
315
- site_history = site_history.reshape(site_history.shape[0], -1)
316
- modes["site"] = site_history
302
+ # *********************** Generation Data *************************************
303
+ # Add generation yield history
304
+ if self.include_generation_history:
305
+ generation_history = x["generation"][:, : self.history_len + 1].float()
306
+ generation_history = generation_history.reshape(generation_history.shape[0], -1)
307
+ modes["generation"] = generation_history
317
308
 
318
- # Add site-level yield history through PV encoder
309
+ # Add location-level yield history through PV encoder
319
310
  if self.include_pv:
320
- if self._target_key != "site":
321
- modes["site"] = self.pv_encoder(x)
322
- else:
323
- # Target is PV, so only take the history
324
- # Copy batch
325
- x_tmp = x.copy()
326
- x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
327
- modes["site"] = self.pv_encoder(x_tmp)
328
-
329
- # *********************** GSP Data ************************************
330
- # Add gsp yield history
331
- if self.include_gsp_yield_history:
332
- gsp_history = x["gsp"][:, : self.history_len].float()
333
- gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
334
- modes["gsp"] = gsp_history
335
-
336
- # ********************** Embedding of GSP/Site ID ********************
311
+ x_tmp = x.copy()
312
+ x_tmp["generation"] = x_tmp["generation"][:, : self.history_len + 1]
313
+ modes["generation"] = self.pv_encoder(x_tmp)
314
+
315
+ # ********************** Embedding of location ID ********************
337
316
  if self.use_id_embedding:
338
317
  modes["id"] = self.embed(id)
339
318
 
@@ -341,13 +320,16 @@ class LateFusionModel(BaseModel):
341
320
  sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float()
342
321
  sun = self.sun_fc1(sun)
343
322
  modes["sun"] = sun
344
-
323
+
345
324
  if self.include_time:
346
325
  time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]]
347
326
  time = torch.cat(time, dim=1).float()
348
327
  time = self.time_fc1(time)
349
328
  modes["time"] = time
350
329
 
330
+ if self.t0_embedding_dim>0:
331
+ modes["t0_embed"] = x["t0_embedding"]
332
+
351
333
  out = self.output_network(modes)
352
334
 
353
335
  if self.use_quantile_regression:
@@ -1,6 +1,4 @@
1
- """Encoder modules for the site-level PV data.
2
-
3
- """
1
+ """Encoder modules for the site-level PV data."""
4
2
 
5
3
  import einops
6
4
  import torch
@@ -11,6 +9,7 @@ from pvnet.models.late_fusion.linear_networks.networks import ResFCNet
11
9
  from pvnet.models.late_fusion.site_encoders.basic_blocks import AbstractSitesEncoder
12
10
 
13
11
 
12
+ # TODO update this to work with the new sample data format
14
13
  class SimpleLearnedAggregator(AbstractSitesEncoder):
15
14
  """A simple model which learns a different weighted-average across all PV sites for each GSP.
16
15
 
@@ -127,8 +126,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
127
126
  kv_res_block_layers: int = 2,
128
127
  use_id_in_value: bool = False,
129
128
  target_id_dim: int = 318,
130
- target_key_to_use: str = "gsp",
131
- input_key_to_use: str = "site",
129
+ key_to_use: str = "generation",
132
130
  num_channels: int = 1,
133
131
  num_sites_in_inference: int = 1,
134
132
  ):
@@ -149,8 +147,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
149
147
  use_id_in_value: Whether to use a site ID embedding in network used to produce the
150
148
  value for the attention layer.
151
149
  target_id_dim: The number of unique IDs.
152
- target_key_to_use: The key to use for the target in the attention layer.
153
- input_key_to_use: The key to use for the input in the attention layer.
150
+ key_to_use: The key to use in the attention layer.
154
151
  num_channels: Number of channels in the input data
155
152
  num_sites_in_inference: Number of sites to use in inference.
156
153
  This is used to determine the number of sites to use in the
@@ -164,8 +161,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
164
161
  self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
165
162
  self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
166
163
  self.use_id_in_value = use_id_in_value
167
- self.target_key_to_use = target_key_to_use
168
- self.input_key_to_use = input_key_to_use
164
+ self.key_to_use = key_to_use
169
165
  self.num_channels = num_channels
170
166
  self.num_sites_in_inference = num_sites_in_inference
171
167
 
@@ -206,7 +202,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
206
202
  def _encode_inputs(self, x: TensorBatch) -> tuple[torch.Tensor, int]:
207
203
  # Shape: [batch size, sequence length, number of sites]
208
204
  # Shape: [batch size, station_id, sequence length, channels]
209
- input_data = x[f"{self.input_key_to_use}"]
205
+ input_data = x[f"{self.key_to_use}"]
210
206
  if len(input_data.shape) == 2: # one site per sample
211
207
  input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
212
208
  if len(input_data.shape) == 4: # Has multiple channels
@@ -216,16 +212,11 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
216
212
  input_data = input_data[:, : self.sequence_length]
217
213
  site_seqs = input_data.float()
218
214
  batch_size = site_seqs.shape[0]
219
- site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length]
215
+ site_seqs = site_seqs.swapaxes(1, 2) # [batch size, location ID, sequence length]
220
216
  return site_seqs, batch_size
221
217
 
222
218
  def _encode_query(self, x: TensorBatch) -> torch.Tensor:
223
- if self.target_key_to_use == "gsp":
224
- # GSP seems to have a different structure
225
- ids = x[f"{self.target_key_to_use}_id"]
226
- else:
227
- ids = x[f"{self.input_key_to_use}_id"]
228
- ids = ids.int()
219
+ ids = x["location_id"].int()
229
220
  query = self.target_id_embedding(ids).unsqueeze(1)
230
221
  return query
231
222
 
@@ -233,9 +224,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
233
224
  site_seqs, batch_size = self._encode_inputs(x)
234
225
 
235
226
  # site ID embeddings are the same for each sample
236
- site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
227
+ id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
237
228
  # Each concated (site sequence, site ID embedding) is processed with encoder
238
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
229
+ x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
239
230
  key = self._key_encoder(x_seq_in)
240
231
 
241
232
  # Reshape to [batch size, site, kdim]
@@ -247,9 +238,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
247
238
 
248
239
  if self.use_id_in_value:
249
240
  # site ID embeddings are the same for each sample
250
- site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
241
+ id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
251
242
  # Each concated (site sequence, site ID embedding) is processed with encoder
252
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
243
+ x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
253
244
  else:
254
245
  # Encode each site sequence independently
255
246
  x_seq_in = site_seqs.flatten(0, 1)
@@ -260,9 +251,8 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
260
251
  return value
261
252
 
262
253
  def _attention_forward(
263
- self, x: dict,
264
- average_attn_weights: bool = True
265
- ) -> tuple[torch.Tensor, torch.Tensor:]:
254
+ self, x: dict, average_attn_weights: bool = True
255
+ ) -> tuple[torch.Tensor, torch.Tensor :]:
266
256
  query = self._encode_query(x)
267
257
  key = self._encode_key(x)
268
258
  value = self._encode_value(x)
pvnet/optimizers.py CHANGED
@@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
65
65
  """
66
66
 
67
67
  @abstractmethod
68
- def __call__(self):
68
+ def __call__(self, model: Module):
69
69
  """Abstract call"""
70
70
  pass
71
71
 
@@ -129,19 +129,18 @@ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
129
129
  {"params": decay, "weight_decay": self.weight_decay},
130
130
  {"params": no_decay, "weight_decay": 0.0},
131
131
  ]
132
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
132
133
  opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
133
-
134
134
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
135
135
  opt,
136
136
  factor=self.factor,
137
137
  patience=self.patience,
138
138
  threshold=self.threshold,
139
139
  )
140
- sch = {
141
- "scheduler": sch,
142
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
140
+ return {
141
+ "optimizer": opt,
142
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
143
143
  }
144
- return [opt], [sch]
145
144
 
146
145
 
147
146
  class AdamWReduceLROnPlateau(AbstractOptimizer):
@@ -153,15 +152,13 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
153
152
  patience: int = 3,
154
153
  factor: float = 0.5,
155
154
  threshold: float = 2e-4,
156
- step_freq=None,
157
155
  **opt_kwargs,
158
156
  ):
159
157
  """AdamW optimizer and reduce on plateau scheduler"""
160
- self._lr = lr
158
+ self.lr = lr
161
159
  self.patience = patience
162
160
  self.factor = factor
163
161
  self.threshold = threshold
164
- self.step_freq = step_freq
165
162
  self.opt_kwargs = opt_kwargs
166
163
 
167
164
  def _call_multi(self, model):
@@ -169,7 +166,7 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
169
166
 
170
167
  group_args = []
171
168
 
172
- for key in self._lr.keys():
169
+ for key in self.lr.keys():
173
170
  if key == "default":
174
171
  continue
175
172
 
@@ -178,43 +175,38 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
178
175
  if param_name.startswith(key):
179
176
  submodule_params += [remaining_params.pop(param_name)]
180
177
 
181
- group_args += [{"params": submodule_params, "lr": self._lr[key]}]
178
+ group_args += [{"params": submodule_params, "lr": self.lr[key]}]
182
179
 
183
180
  remaining_params = [p for k, p in remaining_params.items()]
184
181
  group_args += [{"params": remaining_params}]
185
-
186
- opt = torch.optim.AdamW(
187
- group_args,
188
- lr=self._lr["default"] if model.lr is None else model.lr,
189
- **self.opt_kwargs,
182
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
183
+ opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
184
+ sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
185
+ opt,
186
+ factor=self.factor,
187
+ patience=self.patience,
188
+ threshold=self.threshold,
190
189
  )
191
- sch = {
192
- "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
193
- opt,
194
- factor=self.factor,
195
- patience=self.patience,
196
- threshold=self.threshold,
197
- ),
198
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
190
+ return {
191
+ "optimizer": opt,
192
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
199
193
  }
200
194
 
201
- return [opt], [sch]
202
195
 
203
196
  def __call__(self, model):
204
197
  """Return optimizer"""
205
- if not isinstance(self._lr, float):
198
+ if isinstance(self.lr, dict):
206
199
  return self._call_multi(model)
207
200
  else:
208
- default_lr = self._lr if model.lr is None else model.lr
209
- opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
201
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
202
+ opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
210
203
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
211
204
  opt,
212
205
  factor=self.factor,
213
206
  patience=self.patience,
214
207
  threshold=self.threshold,
215
208
  )
216
- sch = {
217
- "scheduler": sch,
218
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
209
+ return {
210
+ "optimizer": opt,
211
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
219
212
  }
220
- return [opt], [sch]
@@ -45,9 +45,9 @@ class PVNetLightningModule(pl.LightningModule):
45
45
  self.lr = None
46
46
 
47
47
  def transfer_batch_to_device(
48
- self,
49
- batch: TensorBatch,
50
- device: torch.device,
48
+ self,
49
+ batch: TensorBatch,
50
+ device: torch.device,
51
51
  dataloader_idx: int,
52
52
  ) -> dict:
53
53
  """Method to move custom batches to a given device"""
@@ -75,7 +75,7 @@ class PVNetLightningModule(pl.LightningModule):
75
75
  losses = 2 * torch.cat(losses, dim=2)
76
76
 
77
77
  return losses.mean()
78
-
78
+
79
79
  def configure_optimizers(self):
80
80
  """Configure the optimizers using learning rate found with LR finder if used"""
81
81
  if self.lr is not None:
@@ -84,7 +84,7 @@ class PVNetLightningModule(pl.LightningModule):
84
84
  return self._optimizer(self.model)
85
85
 
86
86
  def _calculate_common_losses(
87
- self,
87
+ self,
88
88
  y: torch.Tensor,
89
89
  y_hat: torch.Tensor,
90
90
  ) -> dict[str, torch.Tensor]:
@@ -96,30 +96,30 @@ class PVNetLightningModule(pl.LightningModule):
96
96
  losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
97
97
  y_hat = self.model._quantiles_to_prediction(y_hat)
98
98
 
99
- losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
99
+ losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
100
100
 
101
101
  return losses
102
-
102
+
103
103
  def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
104
104
  """Run training step"""
105
105
  y_hat = self.model(batch)
106
106
 
107
- y = batch[self.model._target_key][:, -self.model.forecast_len :]
107
+ y = batch["generation"][:, -self.model.forecast_len :]
108
108
 
109
109
  losses = self._calculate_common_losses(y, y_hat)
110
110
  losses = {f"{k}/train": v for k, v in losses.items()}
111
111
 
112
- self.log_dict(losses, on_step=True, on_epoch=True)
112
+ self.log_dict(losses, on_step=True, on_epoch=True, batch_size=y.size(0))
113
113
 
114
114
  if self.model.use_quantile_regression:
115
115
  opt_target = losses["quantile_loss/train"]
116
116
  else:
117
117
  opt_target = losses["MAE/train"]
118
118
  return opt_target
119
-
119
+
120
120
  def _calculate_val_losses(
121
- self,
122
- y: torch.Tensor,
121
+ self,
122
+ y: torch.Tensor,
123
123
  y_hat: torch.Tensor,
124
124
  ) -> dict[str, torch.Tensor]:
125
125
  """Calculate additional losses only run in validation"""
@@ -138,28 +138,25 @@ class PVNetLightningModule(pl.LightningModule):
138
138
  return losses
139
139
 
140
140
  def _calculate_step_metrics(
141
- self,
142
- y: torch.Tensor,
143
- y_hat: torch.Tensor,
141
+ self,
142
+ y: torch.Tensor,
143
+ y_hat: torch.Tensor,
144
144
  ) -> tuple[np.array, np.array]:
145
145
  """Calculate the MAE and MSE at each forecast step"""
146
146
 
147
147
  mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
148
148
  mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
149
-
149
+
150
150
  return mae_each_step, mse_each_step
151
-
151
+
152
152
  def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None:
153
153
  """Internally store the validation predictions"""
154
-
155
- target_key = self.model._target_key
156
154
 
157
- y = batch[target_key][:, -self.model.forecast_len :].cpu().numpy()
158
- y_hat = y_hat.cpu().numpy()
159
- ids = batch[f"{target_key}_id"].cpu().numpy()
155
+ y = batch["generation"][:, -self.model.forecast_len :].cpu().numpy()
156
+ y_hat = y_hat.cpu().numpy()
157
+ ids = batch["location_id"].cpu().numpy()
160
158
  init_times_utc = pd.to_datetime(
161
- batch[f"{target_key}_time_utc"][:, self.model.history_len+1]
162
- .cpu().numpy().astype("datetime64[ns]")
159
+ batch["time_utc"][:, self.model.history_len + 1].cpu().numpy().astype("datetime64[ns]")
163
160
  )
164
161
 
165
162
  if self.model.use_quantile_regression:
@@ -170,7 +167,7 @@ class PVNetLightningModule(pl.LightningModule):
170
167
 
171
168
  ds_preds_batch = xr.Dataset(
172
169
  data_vars=dict(
173
- y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
170
+ y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
174
171
  y=(["sample_num", "forecast_step"], y),
175
172
  ),
176
173
  coords=dict(
@@ -186,7 +183,7 @@ class PVNetLightningModule(pl.LightningModule):
186
183
  # Set up stores which we will fill during validation
187
184
  self.all_val_results: list[xr.Dataset] = []
188
185
  self._val_horizon_maes: list[np.array] = []
189
- if self.current_epoch==0:
186
+ if self.current_epoch == 0:
190
187
  self._val_persistence_horizon_maes: list[np.array] = []
191
188
 
192
189
  # Plot some sample forecasts
@@ -197,9 +194,9 @@ class PVNetLightningModule(pl.LightningModule):
197
194
 
198
195
  for plot_num in range(num_figures):
199
196
  idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
200
- idxs = idxs[idxs<len(val_dataset)]
197
+ idxs = idxs[idxs < len(val_dataset)]
201
198
 
202
- if len(idxs)==0:
199
+ if len(idxs) == 0:
203
200
  continue
204
201
 
205
202
  batch = collate_fn([val_dataset[i] for i in idxs])
@@ -207,19 +204,16 @@ class PVNetLightningModule(pl.LightningModule):
207
204
 
208
205
  # Batch validation check only during sanity check phase - use first batch
209
206
  if self.trainer.sanity_checking and plot_num == 0:
210
- validate_batch_against_config(
211
- batch=batch,
212
- model=self.model
213
- )
214
-
207
+ validate_batch_against_config(batch=batch, model=self.model)
208
+
215
209
  with torch.no_grad():
216
210
  y_hat = self.model(batch)
217
-
211
+
218
212
  fig = plot_sample_forecasts(
219
213
  batch,
220
214
  y_hat,
221
215
  quantiles=self.model.output_quantiles,
222
- key_to_plot=self.model._target_key,
216
+ key_to_plot="generation",
223
217
  )
224
218
 
225
219
  plot_name = f"val_forecast_samples/sample_set_{plot_num}"
@@ -238,7 +232,7 @@ class PVNetLightningModule(pl.LightningModule):
238
232
  # Internally store the val predictions
239
233
  self._store_val_predictions(batch, y_hat)
240
234
 
241
- y = batch[self.model._target_key][:, -self.model.forecast_len :]
235
+ y = batch["generation"][:, -self.model.forecast_len :]
242
236
 
243
237
  losses = self._calculate_common_losses(y, y_hat)
244
238
  losses = {f"{k}/val": v for k, v in losses.items()}
@@ -263,21 +257,39 @@ class PVNetLightningModule(pl.LightningModule):
263
257
  # Calculate the persistance losses - we only need to do this once per training run
264
258
  # not every epoch
265
259
  if self.current_epoch==0:
260
+ # Need to find last valid value before forecast
261
+ target_data = batch["generation"]
262
+ history_data = target_data[:, :-(self.model.forecast_len)]
263
+
264
+ # Find where values aren't dropped
265
+ valid_mask = history_data >= 0
266
+
267
+ # Last valid value index for each sample
268
+ flipped_mask = valid_mask.float().flip(dims=[1])
269
+ last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
270
+ last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
271
+
272
+ # Grab those last valid values
273
+ batch_indices = torch.arange(
274
+ history_data.shape[0],
275
+ device=history_data.device
276
+ )
277
+ last_valid_values = history_data[batch_indices, last_valid_indices]
278
+
266
279
  y_persist = (
267
- batch[self.model._target_key][:, -(self.model.forecast_len+1)]
268
- .unsqueeze(1).expand(-1, self.model.forecast_len)
280
+ last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
269
281
  )
270
282
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
271
283
  self._val_persistence_horizon_maes.append(mae_step_persist)
272
284
  losses.update(
273
285
  {
274
- "MAE/val_persistence": mae_step_persist.mean(),
275
- "MSE/val_persistence": mse_step_persist.mean()
286
+ "MAE/val_persistence": mae_step_persist.mean(),
287
+ "MSE/val_persistence": mse_step_persist.mean(),
276
288
  }
277
289
  )
278
290
 
279
- # Log the metrics
280
- self.log_dict(losses, on_step=False, on_epoch=True)
291
+ # Log the metrics
292
+ self.log_dict(losses, on_step=False, on_epoch=True, batch_size=y.size(0))
281
293
 
282
294
  def on_validation_epoch_end(self) -> None:
283
295
  """Run on epoch end"""
@@ -289,7 +301,7 @@ class PVNetLightningModule(pl.LightningModule):
289
301
  self._val_horizon_maes = []
290
302
 
291
303
  # We only run this on the first epoch
292
- if self.current_epoch==0:
304
+ if self.current_epoch == 0:
293
305
  val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
294
306
  self._val_persistence_horizon_maes = []
295
307
 
@@ -321,25 +333,25 @@ class PVNetLightningModule(pl.LightningModule):
321
333
  wandb_log_dir = self.logger.experiment.dir
322
334
  filepath = f"{wandb_log_dir}/validation_results.netcdf"
323
335
  ds_val_results.to_netcdf(filepath)
324
-
325
- # Uplodad to wandb
336
+
337
+ # Uplodad to wandb
326
338
  self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now")
327
-
339
+
328
340
  # Create the horizon accuracy curve
329
341
  horizon_mae_plot = wandb_line_plot(
330
- x=np.arange(self.model.forecast_len),
342
+ x=np.arange(self.model.forecast_len),
331
343
  y=val_horizon_maes,
332
344
  xlabel="Horizon step",
333
345
  ylabel="MAE",
334
346
  title="Val horizon loss curve",
335
347
  )
336
-
348
+
337
349
  wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
338
350
 
339
351
  # Create persistence horizon accuracy curve but only on first epoch
340
- if self.current_epoch==0:
352
+ if self.current_epoch == 0:
341
353
  persist_horizon_mae_plot = wandb_line_plot(
342
- x=np.arange(self.model.forecast_len),
354
+ x=np.arange(self.model.forecast_len),
343
355
  y=val_persistence_horizon_maes,
344
356
  xlabel="Horizon step",
345
357
  ylabel="MAE",
pvnet/training/plots.py CHANGED
@@ -32,9 +32,9 @@ def plot_sample_forecasts(
32
32
 
33
33
  y = batch[key_to_plot].cpu().numpy()
34
34
  y_hat = y_hat.cpu().numpy()
35
- ids = batch[f"{key_to_plot}_id"].cpu().numpy().squeeze()
35
+ ids = batch["location_id"].cpu().numpy().squeeze()
36
36
  times_utc = pd.to_datetime(
37
- batch[f"{key_to_plot}_time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
37
+ batch["time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
38
38
  )
39
39
  batch_size = y.shape[0]
40
40
 
pvnet/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utils"""
2
+
2
3
  import logging
3
4
  from typing import TYPE_CHECKING
4
5
 
@@ -17,7 +18,7 @@ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
17
18
  MODEL_CONFIG_NAME = "model_config.yaml"
18
19
  DATA_CONFIG_NAME = "data_config.yaml"
19
20
  DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
20
- FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
22
  MODEL_CARD_NAME = "README.md"
22
23
 
23
24
 
@@ -93,37 +94,41 @@ def print_config(
93
94
 
94
95
 
95
96
  def validate_batch_against_config(
96
- batch: dict,
97
+ batch: dict,
97
98
  model: "BaseModel",
98
99
  ) -> None:
99
100
  """Validates tensor shapes in batch against model configuration."""
100
101
  logger.info("Performing batch shape validation against model config.")
101
-
102
+
102
103
  # NWP validation
103
- if hasattr(model, 'nwp_encoders_dict'):
104
+ if hasattr(model, "nwp_encoders_dict"):
104
105
  if "nwp" not in batch:
105
106
  raise ValueError(
106
107
  "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
107
108
  )
108
-
109
+
109
110
  for source, nwp_data in batch["nwp"].items():
110
111
  if source in model.nwp_encoders_dict:
111
-
112
- enc = model.nwp_encoders_dict[source]
112
+ enc = model.nwp_encoders_dict[source]
113
113
  expected_channels = enc.in_channels
114
114
  if model.add_image_embedding_channel:
115
115
  expected_channels -= 1
116
116
 
117
- expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
118
- expected_channels, enc.image_size_pixels, enc.image_size_pixels)
117
+ expected = (
118
+ nwp_data["nwp"].shape[0],
119
+ enc.sequence_length,
120
+ expected_channels,
121
+ enc.image_size_pixels,
122
+ enc.image_size_pixels,
123
+ )
119
124
  if tuple(nwp_data["nwp"].shape) != expected:
120
- actual_shape = tuple(nwp_data['nwp'].shape)
125
+ actual_shape = tuple(nwp_data["nwp"].shape)
121
126
  raise ValueError(
122
127
  f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
123
128
  )
124
129
 
125
130
  # Satellite validation
126
- if hasattr(model, 'sat_encoder'):
131
+ if hasattr(model, "sat_encoder"):
127
132
  if "satellite_actual" not in batch:
128
133
  raise ValueError(
129
134
  "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
@@ -134,14 +139,19 @@ def validate_batch_against_config(
134
139
  if model.add_image_embedding_channel:
135
140
  expected_channels -= 1
136
141
 
137
- expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
138
- enc.image_size_pixels, enc.image_size_pixels)
142
+ expected = (
143
+ batch["satellite_actual"].shape[0],
144
+ enc.sequence_length,
145
+ expected_channels,
146
+ enc.image_size_pixels,
147
+ enc.image_size_pixels,
148
+ )
139
149
  if tuple(batch["satellite_actual"].shape) != expected:
140
- actual_shape = tuple(batch['satellite_actual'].shape)
150
+ actual_shape = tuple(batch["satellite_actual"].shape)
141
151
  raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
142
152
 
143
- # GSP/Site validation
144
- key = model._target_key
153
+ # generation validation
154
+ key = "generation"
145
155
  if key in batch:
146
156
  total_minutes = model.history_minutes + model.forecast_minutes
147
157
  interval = model.interval_minutes
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.2.3
3
+ Version: 5.3.5
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.6.0
9
+ Requires-Dist: ocf-data-sampler>=1.0.9
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -1,14 +1,14 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
- pvnet/datamodule.py,sha256=sTACPJXPaqojpxf86wldqxxlnFRoPvlvRHkmcGSsmDw,6368
2
+ pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
- pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
5
- pvnet/utils.py,sha256=6hVKQN8F89pJbiC9VSuHCm5yJqzIzs7hLF3ztkBU-TY,5895
4
+ pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
5
+ pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
- pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
8
- pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
7
+ pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
+ pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
9
9
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
10
10
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
11
- pvnet/models/late_fusion/late_fusion.py,sha256=7uQPo_OlNXrJOE9nYHTEvwJx2POKg4drJfdnPxwiaJU,16283
11
+ pvnet/models/late_fusion/late_fusion.py,sha256=r05RJvw2-ZQgWJobOGq1g4rlMJQjGM0UzG3syA4T0qo,15617
12
12
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
13
13
  pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
14
14
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
17
17
  pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
18
18
  pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
19
19
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
20
- pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
20
+ pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
21
21
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
- pvnet/training/lightning_module.py,sha256=KcEbHYBe_Gx0as0-A7bggoMjev-A_i6Y3PHGRaYllTg,12956
23
- pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
22
+ pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
23
+ pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.2.3.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.2.3.dist-info/METADATA,sha256=xA54YM0qDAlvtNSEQXwLLYF9S6sgo-kejaB0awBb0MA,16479
27
- pvnet-5.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.2.3.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.2.3.dist-info/RECORD,,
25
+ pvnet-5.3.5.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.5.dist-info/METADATA,sha256=rIlZGmFiIzkMpG_5U-6SrsdDW6fIke667JAG79g3KN4,16479
27
+ pvnet-5.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.3.5.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.5.dist-info/RECORD,,
File without changes