PVNet 5.2.3__py3-none-any.whl → 5.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pvnet/datamodule.py CHANGED
@@ -1,4 +1,4 @@
1
- """ Data module for pytorch lightning """
1
+ """Data module for pytorch lightning"""
2
2
 
3
3
  import os
4
4
 
@@ -6,10 +6,9 @@ import numpy as np
6
6
  from lightning.pytorch import LightningDataModule
7
7
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
8
8
  from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
9
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
10
- from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
9
+ from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset
11
10
  from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor
12
- from torch.utils.data import DataLoader, Dataset, Subset
11
+ from torch.utils.data import DataLoader, Subset
13
12
 
14
13
 
15
14
  def collate_fn(samples: list[NumpySample]) -> TensorBatch:
@@ -17,7 +16,7 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
17
16
  return batch_to_tensor(stack_np_samples_into_batch(samples))
18
17
 
19
18
 
20
- class BaseDataModule(LightningDataModule):
19
+ class PVNetDataModule(LightningDataModule):
21
20
  """Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
22
21
 
23
22
  def __init__(
@@ -40,10 +39,10 @@ class BaseDataModule(LightningDataModule):
40
39
  batch_size: Batch size.
41
40
  num_workers: Number of workers to use in multiprocess batch loading.
42
41
  prefetch_factor: Number of batches loaded in advance by each worker.
43
- persistent_workers: If True, the data loader will not shut down the worker processes
44
- after a dataset has been consumed once. This allows to maintain the workers Dataset
42
+ persistent_workers: If True, the data loader will not shut down the worker processes
43
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
45
44
  instances alive.
46
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
45
+ pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
47
46
  before returning them.
48
47
  train_period: Date range filter for train dataloader.
49
48
  val_period: Date range filter for val dataloader.
@@ -70,7 +69,7 @@ class BaseDataModule(LightningDataModule):
70
69
  worker_init_fn=None,
71
70
  prefetch_factor=prefetch_factor,
72
71
  persistent_workers=persistent_workers,
73
- multiprocessing_context="spawn" if num_workers>0 else None,
72
+ multiprocessing_context="spawn" if num_workers > 0 else None,
74
73
  )
75
74
 
76
75
  def setup(self, stage: str | None = None):
@@ -79,16 +78,15 @@ class BaseDataModule(LightningDataModule):
79
78
  # This logic runs only once at the start of training, therefore the val dataset is only
80
79
  # shuffled once
81
80
  if stage == "fit":
82
-
83
81
  # Prepare the train dataset
84
82
  self.train_dataset = self._get_dataset(*self.train_period)
85
83
 
86
- # Prepare and pre-shuffle the val dataset and set seed for reproducibility
84
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
87
85
  val_dataset = self._get_dataset(*self.val_period)
88
86
 
89
87
  shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
90
88
  self.val_dataset = Subset(val_dataset, shuffled_indices)
91
-
89
+
92
90
  if self.dataset_pickle_dir is not None:
93
91
  os.makedirs(self.dataset_pickle_dir, exist_ok=True)
94
92
  train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
@@ -116,8 +114,8 @@ class BaseDataModule(LightningDataModule):
116
114
  if os.path.exists(filepath):
117
115
  os.remove(filepath)
118
116
 
119
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> Dataset:
120
- raise NotImplementedError
117
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset:
118
+ return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time)
121
119
 
122
120
  def train_dataloader(self) -> DataLoader:
123
121
  """Construct train dataloader"""
@@ -126,17 +124,3 @@ class BaseDataModule(LightningDataModule):
126
124
  def val_dataloader(self) -> DataLoader:
127
125
  """Construct val dataloader"""
128
126
  return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs)
129
-
130
-
131
- class UKRegionalDataModule(BaseDataModule):
132
- """Datamodule for streaming UK regional samples."""
133
-
134
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset:
135
- return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
136
-
137
-
138
- class SitesDataModule(BaseDataModule):
139
- """Datamodule for streaming site samples."""
140
-
141
- def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset:
142
- return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
@@ -1,4 +1,5 @@
1
1
  """Base model for all PVNet submodels"""
2
+
2
3
  import logging
3
4
  import os
4
5
  import shutil
@@ -32,7 +33,7 @@ def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHO
32
33
  """
33
34
  input_config = config["input_data"]
34
35
 
35
- for source in ["gsp", "satellite"]:
36
+ for source in ["generation", "satellite"]:
36
37
  if source in input_config:
37
38
  # If not empty - i.e. if used
38
39
  if input_config[source]["zarr_path"] != "":
@@ -78,8 +79,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
78
79
 
79
80
  # Replace the interval_end_minutes minutes
80
81
  nwp_config["interval_end_minutes"] = (
81
- nwp_config["interval_start_minutes"] +
82
- (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
82
+ nwp_config["interval_start_minutes"]
83
+ + (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
83
84
  * nwp_config["time_resolution_minutes"]
84
85
  )
85
86
 
@@ -96,20 +97,19 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
96
97
 
97
98
  # Replace the interval_end_minutes minutes
98
99
  sat_config["interval_end_minutes"] = (
99
- sat_config["interval_start_minutes"] +
100
- (model.sat_encoder.sequence_length - 1)
101
- * sat_config["time_resolution_minutes"]
100
+ sat_config["interval_start_minutes"]
101
+ + (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"]
102
102
  )
103
103
 
104
104
  if "pv" in input_config:
105
105
  if not model.include_pv:
106
106
  del input_config["pv"]
107
107
 
108
- if "gsp" in input_config:
109
- gsp_config = input_config["gsp"]
108
+ if "generation" in input_config:
109
+ generation_config = input_config["generation"]
110
110
 
111
111
  # Replace the forecast minutes
112
- gsp_config["interval_end_minutes"] = model.forecast_minutes
112
+ generation_config["interval_end_minutes"] = model.forecast_minutes
113
113
 
114
114
  if "solar_position" in input_config:
115
115
  solar_config = input_config["solar_position"]
@@ -138,9 +138,9 @@ def download_from_hf(
138
138
  force_download: Whether to force a new download
139
139
  max_retries: Maximum number of retry attempts
140
140
  wait_time: Wait time (in seconds) before retrying
141
- token:
141
+ token:
142
142
  HF authentication token. If True, the token is read from the HuggingFace config folder.
143
- If a string, it is used as the authentication token.
143
+ If a string, it is used as the authentication token.
144
144
 
145
145
  Returns:
146
146
  The local file path of the downloaded file(s)
@@ -160,7 +160,7 @@ def download_from_hf(
160
160
  return [f"{save_dir}/{f}" for f in filename]
161
161
  else:
162
162
  return f"{save_dir}/{filename}"
163
-
163
+
164
164
  except Exception as e:
165
165
  if attempt == max_retries:
166
166
  raise Exception(
@@ -205,7 +205,7 @@ class HuggingfaceMixin:
205
205
  force_download=force_download,
206
206
  max_retries=5,
207
207
  wait_time=10,
208
- token=token
208
+ token=token,
209
209
  )
210
210
 
211
211
  with open(config_file, "r") as f:
@@ -240,7 +240,7 @@ class HuggingfaceMixin:
240
240
  force_download=force_download,
241
241
  max_retries=5,
242
242
  wait_time=10,
243
- token=token
243
+ token=token,
244
244
  )
245
245
 
246
246
  return data_config_file
@@ -301,7 +301,7 @@ class HuggingfaceMixin:
301
301
  # Save cleaned version of input data configuration file
302
302
  with open(data_config_path) as cfg:
303
303
  config = yaml.load(cfg, Loader=yaml.FullLoader)
304
-
304
+
305
305
  config = fill_config_paths_with_placeholder(config)
306
306
  config = minimize_config_for_model(config, self)
307
307
 
@@ -311,7 +311,7 @@ class HuggingfaceMixin:
311
311
  # Save the datamodule config
312
312
  if datamodule_config_path is not None:
313
313
  shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
314
-
314
+
315
315
  # Save the full experimental config
316
316
  if experiment_config_path is not None:
317
317
  shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
@@ -378,7 +378,6 @@ class HuggingfaceMixin:
378
378
  packages_to_display = ["pvnet", "ocf-data-sampler"]
379
379
  packages_and_versions = {package: version(package) for package in packages_to_display}
380
380
 
381
-
382
381
  package_versions_markdown = ""
383
382
  for package, v in packages_and_versions.items():
384
383
  package_versions_markdown += f" - {package}=={v}\n"
@@ -399,23 +398,19 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
399
398
  history_minutes: int,
400
399
  forecast_minutes: int,
401
400
  output_quantiles: list[float] | None = None,
402
- target_key: str = "gsp",
403
401
  interval_minutes: int = 30,
404
402
  ):
405
403
  """Abtstract base class for PVNet submodels.
406
404
 
407
405
  Args:
408
- history_minutes (int): Length of the GSP history period in minutes
409
- forecast_minutes (int): Length of the GSP forecast period in minutes
406
+ history_minutes (int): Length of the generation history period in minutes
407
+ forecast_minutes (int): Length of the generation forecast period in minutes
410
408
  output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
411
409
  None the output is a single value.
412
- target_key: The key of the target variable in the batch
413
410
  interval_minutes: The interval in minutes between each timestep in the data
414
411
  """
415
412
  super().__init__()
416
413
 
417
- self._target_key = target_key
418
-
419
414
  self.history_minutes = history_minutes
420
415
  self.forecast_minutes = forecast_minutes
421
416
  self.output_quantiles = output_quantiles
pvnet/models/ensemble.py CHANGED
@@ -26,7 +26,6 @@ class Ensemble(BaseModel):
26
26
  output_quantiles = []
27
27
  history_minutes = []
28
28
  forecast_minutes = []
29
- target_key = []
30
29
  interval_minutes = []
31
30
 
32
31
  # Get some model properties from each model
@@ -34,7 +33,6 @@ class Ensemble(BaseModel):
34
33
  output_quantiles.append(model.output_quantiles)
35
34
  history_minutes.append(model.history_minutes)
36
35
  forecast_minutes.append(model.forecast_minutes)
37
- target_key.append(model._target_key)
38
36
  interval_minutes.append(model.interval_minutes)
39
37
 
40
38
  # Check these properties are all the same
@@ -42,7 +40,6 @@ class Ensemble(BaseModel):
42
40
  output_quantiles,
43
41
  history_minutes,
44
42
  forecast_minutes,
45
- target_key,
46
43
  interval_minutes,
47
44
  ]:
48
45
  assert all([p == param_list[0] for p in param_list]), param_list
@@ -51,7 +48,6 @@ class Ensemble(BaseModel):
51
48
  history_minutes=history_minutes[0],
52
49
  forecast_minutes=forecast_minutes[0],
53
50
  output_quantiles=output_quantiles[0],
54
- target_key=target_key[0],
55
51
  interval_minutes=interval_minutes[0],
56
52
  )
57
53
 
@@ -28,8 +28,8 @@ class LateFusionModel(BaseModel):
28
28
  - NWP, if included, is put through a similar encoder.
29
29
  - PV site-level data, if included, is put through an encoder which transforms it from 2D, with
30
30
  time and system-ID dimensions, to become a 1D feature vector.
31
- - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
32
- paramters* are concatenated into a 1D feature vector and passed through another neural
31
+ - The satellite features*, NWP features*, PV site-level features*, location ID embedding*, and
32
+ sun paramters* are concatenated into a 1D feature vector and passed through another neural
33
33
  network to combine them and produce a forecast.
34
34
 
35
35
  * if included
@@ -43,8 +43,7 @@ class LateFusionModel(BaseModel):
43
43
  sat_encoder: AbstractNWPSatelliteEncoder | None = None,
44
44
  pv_encoder: AbstractSitesEncoder | None = None,
45
45
  add_image_embedding_channel: bool = False,
46
- include_gsp_yield_history: bool = True,
47
- include_site_yield_history: bool = False,
46
+ include_generation_history: bool = False,
48
47
  include_sun: bool = True,
49
48
  include_time: bool = False,
50
49
  location_id_mapping: dict[Any, int] | None = None,
@@ -56,7 +55,6 @@ class LateFusionModel(BaseModel):
56
55
  nwp_forecast_minutes: DictConfig | None = None,
57
56
  nwp_history_minutes: DictConfig | None = None,
58
57
  pv_history_minutes: int | None = None,
59
- target_key: str = "gsp",
60
58
  interval_minutes: int = 30,
61
59
  nwp_interval_minutes: DictConfig | None = None,
62
60
  pv_interval_minutes: int = 5,
@@ -83,14 +81,13 @@ class LateFusionModel(BaseModel):
83
81
  pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
84
82
  PV data from 2D into a 1D feature vector.
85
83
  add_image_embedding_channel: Add a channel to the NWP and satellite data with the
86
- embedding of the GSP ID.
87
- include_gsp_yield_history: Include GSP yield data.
88
- include_site_yield_history: Include Site yield data.
84
+ embedding of the location ID.
85
+ include_generation_history: Include generation yield data.
89
86
  include_sun: Include sun azimuth and altitude data.
90
87
  include_time: Include sine and cosine of dates and times.
91
88
  location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
92
89
  not used if this is not provided.
93
- embedding_dim: Number of embedding dimensions to use for GSP ID.
90
+ embedding_dim: Number of embedding dimensions to use for location ID.
94
91
  forecast_minutes: The amount of minutes that should be forecasted.
95
92
  history_minutes: The default amount of historical minutes that are used.
96
93
  sat_history_minutes: Length of recent observations used for satellite inputs. Defaults
@@ -103,7 +100,6 @@ class LateFusionModel(BaseModel):
103
100
  `history_minutes` if not provided.
104
101
  pv_history_minutes: Length of recent site-level PV data used as
105
102
  input. Defaults to `history_minutes` if not provided.
106
- target_key: The key of the target variable in the batch.
107
103
  interval_minutes: The interval between each sample of the target data
108
104
  nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
109
105
  data for each source
@@ -114,12 +110,10 @@ class LateFusionModel(BaseModel):
114
110
  history_minutes=history_minutes,
115
111
  forecast_minutes=forecast_minutes,
116
112
  output_quantiles=output_quantiles,
117
- target_key=target_key,
118
113
  interval_minutes=interval_minutes,
119
114
  )
120
115
 
121
- self.include_gsp_yield_history = include_gsp_yield_history
122
- self.include_site_yield_history = include_site_yield_history
116
+ self.include_generation_history = include_generation_history
123
117
  self.include_sat = sat_encoder is not None
124
118
  self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
125
119
  self.include_pv = pv_encoder is not None
@@ -133,8 +127,7 @@ class LateFusionModel(BaseModel):
133
127
 
134
128
  if self.location_id_mapping is None:
135
129
  logger.warning(
136
- "location_id_mapping` is not provided, defaulting to outdated GSP mapping"
137
- "(0 to 317)"
130
+ "location_id_mapping` is not provided, defaulting to outdated GSP mapping(0 to 317)"
138
131
  )
139
132
 
140
133
  # Note 318 is the 2024 UK GSP count, so this is a temporary fix
@@ -223,8 +216,7 @@ class LateFusionModel(BaseModel):
223
216
 
224
217
  self.pv_encoder = pv_encoder(
225
218
  sequence_length=pv_history_minutes // pv_interval_minutes + 1,
226
- target_key_to_use=self._target_key,
227
- input_key_to_use="site",
219
+ key_to_use="generation",
228
220
  )
229
221
 
230
222
  # Update num features
@@ -238,8 +230,7 @@ class LateFusionModel(BaseModel):
238
230
 
239
231
  if self.include_sun:
240
232
  self.sun_fc1 = nn.Linear(
241
- in_features=2
242
- * (self.forecast_len + self.history_len + 1),
233
+ in_features=2 * (self.forecast_len + self.history_len + 1),
243
234
  out_features=16,
244
235
  )
245
236
 
@@ -248,19 +239,14 @@ class LateFusionModel(BaseModel):
248
239
 
249
240
  if self.include_time:
250
241
  self.time_fc1 = nn.Linear(
251
- in_features=4
252
- * (self.forecast_len + self.history_len + 1),
242
+ in_features=4 * (self.forecast_len + self.history_len + 1),
253
243
  out_features=32,
254
244
  )
255
245
 
256
246
  # Update num features
257
247
  fusion_input_features += 32
258
248
 
259
- if include_gsp_yield_history:
260
- # Update num features
261
- fusion_input_features += self.history_len
262
-
263
- if include_site_yield_history:
249
+ if include_generation_history:
264
250
  # Update num features
265
251
  fusion_input_features += self.history_len + 1
266
252
 
@@ -269,15 +255,14 @@ class LateFusionModel(BaseModel):
269
255
  out_features=self.num_output_features,
270
256
  )
271
257
 
272
-
273
258
  def forward(self, x: TensorBatch) -> torch.Tensor:
274
259
  """Run model forward"""
275
260
 
276
261
  if self.use_id_embedding:
277
- # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
262
+ # eg: x['location_id'] = [1] with location_id_mapping = {1:0}, would give [0]
278
263
  id = torch.tensor(
279
- [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
280
- device=x[f"{self._target_key}_id"].device,
264
+ [self.location_id_mapping[i.item()] for i in x["location_id"]],
265
+ device=x["location_id"].device,
281
266
  dtype=torch.int64,
282
267
  )
283
268
 
@@ -308,32 +293,20 @@ class LateFusionModel(BaseModel):
308
293
  nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
309
294
  modes[f"nwp/{nwp_source}"] = nwp_out
310
295
 
311
- # *********************** Site Data *************************************
312
- # Add site-level yield history
313
- if self.include_site_yield_history:
314
- site_history = x["site"][:, : self.history_len + 1].float()
315
- site_history = site_history.reshape(site_history.shape[0], -1)
316
- modes["site"] = site_history
296
+ # *********************** Generation Data *************************************
297
+ # Add generation yield history
298
+ if self.include_generation_history:
299
+ generation_history = x["generation"][:, : self.history_len + 1].float()
300
+ generation_history = generation_history.reshape(generation_history.shape[0], -1)
301
+ modes["generation"] = generation_history
317
302
 
318
- # Add site-level yield history through PV encoder
303
+ # Add location-level yield history through PV encoder
319
304
  if self.include_pv:
320
- if self._target_key != "site":
321
- modes["site"] = self.pv_encoder(x)
322
- else:
323
- # Target is PV, so only take the history
324
- # Copy batch
325
- x_tmp = x.copy()
326
- x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
327
- modes["site"] = self.pv_encoder(x_tmp)
328
-
329
- # *********************** GSP Data ************************************
330
- # Add gsp yield history
331
- if self.include_gsp_yield_history:
332
- gsp_history = x["gsp"][:, : self.history_len].float()
333
- gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
334
- modes["gsp"] = gsp_history
335
-
336
- # ********************** Embedding of GSP/Site ID ********************
305
+ x_tmp = x.copy()
306
+ x_tmp["generation"] = x_tmp["generation"][:, : self.history_len + 1]
307
+ modes["generation"] = self.pv_encoder(x_tmp)
308
+
309
+ # ********************** Embedding of location ID ********************
337
310
  if self.use_id_embedding:
338
311
  modes["id"] = self.embed(id)
339
312
 
@@ -341,7 +314,7 @@ class LateFusionModel(BaseModel):
341
314
  sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float()
342
315
  sun = self.sun_fc1(sun)
343
316
  modes["sun"] = sun
344
-
317
+
345
318
  if self.include_time:
346
319
  time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]]
347
320
  time = torch.cat(time, dim=1).float()
@@ -1,6 +1,4 @@
1
- """Encoder modules for the site-level PV data.
2
-
3
- """
1
+ """Encoder modules for the site-level PV data."""
4
2
 
5
3
  import einops
6
4
  import torch
@@ -11,6 +9,7 @@ from pvnet.models.late_fusion.linear_networks.networks import ResFCNet
11
9
  from pvnet.models.late_fusion.site_encoders.basic_blocks import AbstractSitesEncoder
12
10
 
13
11
 
12
+ # TODO update this to work with the new sample data format
14
13
  class SimpleLearnedAggregator(AbstractSitesEncoder):
15
14
  """A simple model which learns a different weighted-average across all PV sites for each GSP.
16
15
 
@@ -127,8 +126,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
127
126
  kv_res_block_layers: int = 2,
128
127
  use_id_in_value: bool = False,
129
128
  target_id_dim: int = 318,
130
- target_key_to_use: str = "gsp",
131
- input_key_to_use: str = "site",
129
+ key_to_use: str = "generation",
132
130
  num_channels: int = 1,
133
131
  num_sites_in_inference: int = 1,
134
132
  ):
@@ -149,8 +147,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
149
147
  use_id_in_value: Whether to use a site ID embedding in network used to produce the
150
148
  value for the attention layer.
151
149
  target_id_dim: The number of unique IDs.
152
- target_key_to_use: The key to use for the target in the attention layer.
153
- input_key_to_use: The key to use for the input in the attention layer.
150
+ key_to_use: The key to use in the attention layer.
154
151
  num_channels: Number of channels in the input data
155
152
  num_sites_in_inference: Number of sites to use in inference.
156
153
  This is used to determine the number of sites to use in the
@@ -161,11 +158,10 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
161
158
  super().__init__(sequence_length, num_sites, out_features)
162
159
  self.sequence_length = sequence_length
163
160
  self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
164
- self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
161
+ self.id_embedding = nn.Embedding(num_sites, id_embed_dim)
165
162
  self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
166
163
  self.use_id_in_value = use_id_in_value
167
- self.target_key_to_use = target_key_to_use
168
- self.input_key_to_use = input_key_to_use
164
+ self.key_to_use = key_to_use
169
165
  self.num_channels = num_channels
170
166
  self.num_sites_in_inference = num_sites_in_inference
171
167
 
@@ -206,7 +202,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
206
202
  def _encode_inputs(self, x: TensorBatch) -> tuple[torch.Tensor, int]:
207
203
  # Shape: [batch size, sequence length, number of sites]
208
204
  # Shape: [batch size, station_id, sequence length, channels]
209
- input_data = x[f"{self.input_key_to_use}"]
205
+ input_data = x[f"{self.key_to_use}"]
210
206
  if len(input_data.shape) == 2: # one site per sample
211
207
  input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
212
208
  if len(input_data.shape) == 4: # Has multiple channels
@@ -216,16 +212,11 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
216
212
  input_data = input_data[:, : self.sequence_length]
217
213
  site_seqs = input_data.float()
218
214
  batch_size = site_seqs.shape[0]
219
- site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length]
215
+ site_seqs = site_seqs.swapaxes(1, 2) # [batch size, location ID, sequence length]
220
216
  return site_seqs, batch_size
221
217
 
222
218
  def _encode_query(self, x: TensorBatch) -> torch.Tensor:
223
- if self.target_key_to_use == "gsp":
224
- # GSP seems to have a different structure
225
- ids = x[f"{self.target_key_to_use}_id"]
226
- else:
227
- ids = x[f"{self.input_key_to_use}_id"]
228
- ids = ids.int()
219
+ ids = x["location_id"].int()
229
220
  query = self.target_id_embedding(ids).unsqueeze(1)
230
221
  return query
231
222
 
@@ -233,9 +224,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
233
224
  site_seqs, batch_size = self._encode_inputs(x)
234
225
 
235
226
  # site ID embeddings are the same for each sample
236
- site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
227
+ id_embed = torch.tile(self.id_embedding(self._ids), (batch_size, 1, 1))
237
228
  # Each concated (site sequence, site ID embedding) is processed with encoder
238
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
229
+ x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
239
230
  key = self._key_encoder(x_seq_in)
240
231
 
241
232
  # Reshape to [batch size, site, kdim]
@@ -247,9 +238,9 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
247
238
 
248
239
  if self.use_id_in_value:
249
240
  # site ID embeddings are the same for each sample
250
- site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
241
+ id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
251
242
  # Each concated (site sequence, site ID embedding) is processed with encoder
252
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
243
+ x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
253
244
  else:
254
245
  # Encode each site sequence independently
255
246
  x_seq_in = site_seqs.flatten(0, 1)
@@ -260,9 +251,8 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
260
251
  return value
261
252
 
262
253
  def _attention_forward(
263
- self, x: dict,
264
- average_attn_weights: bool = True
265
- ) -> tuple[torch.Tensor, torch.Tensor:]:
254
+ self, x: dict, average_attn_weights: bool = True
255
+ ) -> tuple[torch.Tensor, torch.Tensor :]:
266
256
  query = self._encode_query(x)
267
257
  key = self._encode_key(x)
268
258
  value = self._encode_value(x)
@@ -45,9 +45,9 @@ class PVNetLightningModule(pl.LightningModule):
45
45
  self.lr = None
46
46
 
47
47
  def transfer_batch_to_device(
48
- self,
49
- batch: TensorBatch,
50
- device: torch.device,
48
+ self,
49
+ batch: TensorBatch,
50
+ device: torch.device,
51
51
  dataloader_idx: int,
52
52
  ) -> dict:
53
53
  """Method to move custom batches to a given device"""
@@ -75,7 +75,7 @@ class PVNetLightningModule(pl.LightningModule):
75
75
  losses = 2 * torch.cat(losses, dim=2)
76
76
 
77
77
  return losses.mean()
78
-
78
+
79
79
  def configure_optimizers(self):
80
80
  """Configure the optimizers using learning rate found with LR finder if used"""
81
81
  if self.lr is not None:
@@ -84,7 +84,7 @@ class PVNetLightningModule(pl.LightningModule):
84
84
  return self._optimizer(self.model)
85
85
 
86
86
  def _calculate_common_losses(
87
- self,
87
+ self,
88
88
  y: torch.Tensor,
89
89
  y_hat: torch.Tensor,
90
90
  ) -> dict[str, torch.Tensor]:
@@ -96,15 +96,15 @@ class PVNetLightningModule(pl.LightningModule):
96
96
  losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
97
97
  y_hat = self.model._quantiles_to_prediction(y_hat)
98
98
 
99
- losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
99
+ losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
100
100
 
101
101
  return losses
102
-
102
+
103
103
  def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
104
104
  """Run training step"""
105
105
  y_hat = self.model(batch)
106
106
 
107
- y = batch[self.model._target_key][:, -self.model.forecast_len :]
107
+ y = batch["generation"][:, -self.model.forecast_len :]
108
108
 
109
109
  losses = self._calculate_common_losses(y, y_hat)
110
110
  losses = {f"{k}/train": v for k, v in losses.items()}
@@ -116,10 +116,10 @@ class PVNetLightningModule(pl.LightningModule):
116
116
  else:
117
117
  opt_target = losses["MAE/train"]
118
118
  return opt_target
119
-
119
+
120
120
  def _calculate_val_losses(
121
- self,
122
- y: torch.Tensor,
121
+ self,
122
+ y: torch.Tensor,
123
123
  y_hat: torch.Tensor,
124
124
  ) -> dict[str, torch.Tensor]:
125
125
  """Calculate additional losses only run in validation"""
@@ -138,28 +138,25 @@ class PVNetLightningModule(pl.LightningModule):
138
138
  return losses
139
139
 
140
140
  def _calculate_step_metrics(
141
- self,
142
- y: torch.Tensor,
143
- y_hat: torch.Tensor,
141
+ self,
142
+ y: torch.Tensor,
143
+ y_hat: torch.Tensor,
144
144
  ) -> tuple[np.array, np.array]:
145
145
  """Calculate the MAE and MSE at each forecast step"""
146
146
 
147
147
  mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
148
148
  mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
149
-
149
+
150
150
  return mae_each_step, mse_each_step
151
-
151
+
152
152
  def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None:
153
153
  """Internally store the validation predictions"""
154
-
155
- target_key = self.model._target_key
156
154
 
157
- y = batch[target_key][:, -self.model.forecast_len :].cpu().numpy()
158
- y_hat = y_hat.cpu().numpy()
159
- ids = batch[f"{target_key}_id"].cpu().numpy()
155
+ y = batch["generation"][:, -self.model.forecast_len :].cpu().numpy()
156
+ y_hat = y_hat.cpu().numpy()
157
+ ids = batch["location_id"].cpu().numpy()
160
158
  init_times_utc = pd.to_datetime(
161
- batch[f"{target_key}_time_utc"][:, self.model.history_len+1]
162
- .cpu().numpy().astype("datetime64[ns]")
159
+ batch["time_utc"][:, self.model.history_len + 1].cpu().numpy().astype("datetime64[ns]")
163
160
  )
164
161
 
165
162
  if self.model.use_quantile_regression:
@@ -170,7 +167,7 @@ class PVNetLightningModule(pl.LightningModule):
170
167
 
171
168
  ds_preds_batch = xr.Dataset(
172
169
  data_vars=dict(
173
- y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
170
+ y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
174
171
  y=(["sample_num", "forecast_step"], y),
175
172
  ),
176
173
  coords=dict(
@@ -186,7 +183,7 @@ class PVNetLightningModule(pl.LightningModule):
186
183
  # Set up stores which we will fill during validation
187
184
  self.all_val_results: list[xr.Dataset] = []
188
185
  self._val_horizon_maes: list[np.array] = []
189
- if self.current_epoch==0:
186
+ if self.current_epoch == 0:
190
187
  self._val_persistence_horizon_maes: list[np.array] = []
191
188
 
192
189
  # Plot some sample forecasts
@@ -197,9 +194,9 @@ class PVNetLightningModule(pl.LightningModule):
197
194
 
198
195
  for plot_num in range(num_figures):
199
196
  idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
200
- idxs = idxs[idxs<len(val_dataset)]
197
+ idxs = idxs[idxs < len(val_dataset)]
201
198
 
202
- if len(idxs)==0:
199
+ if len(idxs) == 0:
203
200
  continue
204
201
 
205
202
  batch = collate_fn([val_dataset[i] for i in idxs])
@@ -207,19 +204,16 @@ class PVNetLightningModule(pl.LightningModule):
207
204
 
208
205
  # Batch validation check only during sanity check phase - use first batch
209
206
  if self.trainer.sanity_checking and plot_num == 0:
210
- validate_batch_against_config(
211
- batch=batch,
212
- model=self.model
213
- )
214
-
207
+ validate_batch_against_config(batch=batch, model=self.model)
208
+
215
209
  with torch.no_grad():
216
210
  y_hat = self.model(batch)
217
-
211
+
218
212
  fig = plot_sample_forecasts(
219
213
  batch,
220
214
  y_hat,
221
215
  quantiles=self.model.output_quantiles,
222
- key_to_plot=self.model._target_key,
216
+ key_to_plot="generation",
223
217
  )
224
218
 
225
219
  plot_name = f"val_forecast_samples/sample_set_{plot_num}"
@@ -238,7 +232,7 @@ class PVNetLightningModule(pl.LightningModule):
238
232
  # Internally store the val predictions
239
233
  self._store_val_predictions(batch, y_hat)
240
234
 
241
- y = batch[self.model._target_key][:, -self.model.forecast_len :]
235
+ y = batch["generation"][:, -self.model.forecast_len :]
242
236
 
243
237
  losses = self._calculate_common_losses(y, y_hat)
244
238
  losses = {f"{k}/val": v for k, v in losses.items()}
@@ -262,21 +256,22 @@ class PVNetLightningModule(pl.LightningModule):
262
256
 
263
257
  # Calculate the persistance losses - we only need to do this once per training run
264
258
  # not every epoch
265
- if self.current_epoch==0:
259
+ if self.current_epoch == 0:
266
260
  y_persist = (
267
- batch[self.model._target_key][:, -(self.model.forecast_len+1)]
268
- .unsqueeze(1).expand(-1, self.model.forecast_len)
261
+ batch["generation"][:, -(self.model.forecast_len + 1)]
262
+ .unsqueeze(1)
263
+ .expand(-1, self.model.forecast_len)
269
264
  )
270
265
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
271
266
  self._val_persistence_horizon_maes.append(mae_step_persist)
272
267
  losses.update(
273
268
  {
274
- "MAE/val_persistence": mae_step_persist.mean(),
275
- "MSE/val_persistence": mse_step_persist.mean()
269
+ "MAE/val_persistence": mae_step_persist.mean(),
270
+ "MSE/val_persistence": mse_step_persist.mean(),
276
271
  }
277
272
  )
278
273
 
279
- # Log the metrics
274
+ # Log the metrics
280
275
  self.log_dict(losses, on_step=False, on_epoch=True)
281
276
 
282
277
  def on_validation_epoch_end(self) -> None:
@@ -289,7 +284,7 @@ class PVNetLightningModule(pl.LightningModule):
289
284
  self._val_horizon_maes = []
290
285
 
291
286
  # We only run this on the first epoch
292
- if self.current_epoch==0:
287
+ if self.current_epoch == 0:
293
288
  val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
294
289
  self._val_persistence_horizon_maes = []
295
290
 
@@ -321,25 +316,25 @@ class PVNetLightningModule(pl.LightningModule):
321
316
  wandb_log_dir = self.logger.experiment.dir
322
317
  filepath = f"{wandb_log_dir}/validation_results.netcdf"
323
318
  ds_val_results.to_netcdf(filepath)
324
-
325
- # Uplodad to wandb
319
+
320
+ # Uplodad to wandb
326
321
  self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now")
327
-
322
+
328
323
  # Create the horizon accuracy curve
329
324
  horizon_mae_plot = wandb_line_plot(
330
- x=np.arange(self.model.forecast_len),
325
+ x=np.arange(self.model.forecast_len),
331
326
  y=val_horizon_maes,
332
327
  xlabel="Horizon step",
333
328
  ylabel="MAE",
334
329
  title="Val horizon loss curve",
335
330
  )
336
-
331
+
337
332
  wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
338
333
 
339
334
  # Create persistence horizon accuracy curve but only on first epoch
340
- if self.current_epoch==0:
335
+ if self.current_epoch == 0:
341
336
  persist_horizon_mae_plot = wandb_line_plot(
342
- x=np.arange(self.model.forecast_len),
337
+ x=np.arange(self.model.forecast_len),
343
338
  y=val_persistence_horizon_maes,
344
339
  xlabel="Horizon step",
345
340
  ylabel="MAE",
pvnet/training/plots.py CHANGED
@@ -32,9 +32,9 @@ def plot_sample_forecasts(
32
32
 
33
33
  y = batch[key_to_plot].cpu().numpy()
34
34
  y_hat = y_hat.cpu().numpy()
35
- ids = batch[f"{key_to_plot}_id"].cpu().numpy().squeeze()
35
+ ids = batch["location_id"].cpu().numpy().squeeze()
36
36
  times_utc = pd.to_datetime(
37
- batch[f"{key_to_plot}_time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
37
+ batch["time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
38
38
  )
39
39
  batch_size = y.shape[0]
40
40
 
pvnet/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utils"""
2
+
2
3
  import logging
3
4
  from typing import TYPE_CHECKING
4
5
 
@@ -17,7 +18,7 @@ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
17
18
  MODEL_CONFIG_NAME = "model_config.yaml"
18
19
  DATA_CONFIG_NAME = "data_config.yaml"
19
20
  DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
20
- FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
22
  MODEL_CARD_NAME = "README.md"
22
23
 
23
24
 
@@ -93,37 +94,41 @@ def print_config(
93
94
 
94
95
 
95
96
  def validate_batch_against_config(
96
- batch: dict,
97
+ batch: dict,
97
98
  model: "BaseModel",
98
99
  ) -> None:
99
100
  """Validates tensor shapes in batch against model configuration."""
100
101
  logger.info("Performing batch shape validation against model config.")
101
-
102
+
102
103
  # NWP validation
103
- if hasattr(model, 'nwp_encoders_dict'):
104
+ if hasattr(model, "nwp_encoders_dict"):
104
105
  if "nwp" not in batch:
105
106
  raise ValueError(
106
107
  "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
107
108
  )
108
-
109
+
109
110
  for source, nwp_data in batch["nwp"].items():
110
111
  if source in model.nwp_encoders_dict:
111
-
112
- enc = model.nwp_encoders_dict[source]
112
+ enc = model.nwp_encoders_dict[source]
113
113
  expected_channels = enc.in_channels
114
114
  if model.add_image_embedding_channel:
115
115
  expected_channels -= 1
116
116
 
117
- expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
118
- expected_channels, enc.image_size_pixels, enc.image_size_pixels)
117
+ expected = (
118
+ nwp_data["nwp"].shape[0],
119
+ enc.sequence_length,
120
+ expected_channels,
121
+ enc.image_size_pixels,
122
+ enc.image_size_pixels,
123
+ )
119
124
  if tuple(nwp_data["nwp"].shape) != expected:
120
- actual_shape = tuple(nwp_data['nwp'].shape)
125
+ actual_shape = tuple(nwp_data["nwp"].shape)
121
126
  raise ValueError(
122
127
  f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
123
128
  )
124
129
 
125
130
  # Satellite validation
126
- if hasattr(model, 'sat_encoder'):
131
+ if hasattr(model, "sat_encoder"):
127
132
  if "satellite_actual" not in batch:
128
133
  raise ValueError(
129
134
  "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
@@ -134,14 +139,19 @@ def validate_batch_against_config(
134
139
  if model.add_image_embedding_channel:
135
140
  expected_channels -= 1
136
141
 
137
- expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
138
- enc.image_size_pixels, enc.image_size_pixels)
142
+ expected = (
143
+ batch["satellite_actual"].shape[0],
144
+ enc.sequence_length,
145
+ expected_channels,
146
+ enc.image_size_pixels,
147
+ enc.image_size_pixels,
148
+ )
139
149
  if tuple(batch["satellite_actual"].shape) != expected:
140
- actual_shape = tuple(batch['satellite_actual'].shape)
150
+ actual_shape = tuple(batch["satellite_actual"].shape)
141
151
  raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
142
152
 
143
- # GSP/Site validation
144
- key = model._target_key
153
+ # generation validation
154
+ key = "generation"
145
155
  if key in batch:
146
156
  total_minutes = model.history_minutes + model.forecast_minutes
147
157
  interval = model.interval_minutes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.2.3
3
+ Version: 5.3.0
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -1,14 +1,14 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
- pvnet/datamodule.py,sha256=sTACPJXPaqojpxf86wldqxxlnFRoPvlvRHkmcGSsmDw,6368
2
+ pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
4
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
5
- pvnet/utils.py,sha256=6hVKQN8F89pJbiC9VSuHCm5yJqzIzs7hLF3ztkBU-TY,5895
5
+ pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
- pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
8
- pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
7
+ pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
+ pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
9
9
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
10
10
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
11
- pvnet/models/late_fusion/late_fusion.py,sha256=7uQPo_OlNXrJOE9nYHTEvwJx2POKg4drJfdnPxwiaJU,16283
11
+ pvnet/models/late_fusion/late_fusion.py,sha256=kQUnyqMykmwc0GdoFhNXYStJPrjr3hFSvUNe8FumVx4,15260
12
12
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
13
13
  pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
14
14
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
17
17
  pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
18
18
  pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
19
19
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
20
- pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
20
+ pvnet/models/late_fusion/site_encoders/encoders.py,sha256=DcTV2LeZ0pSZpGFmsmPEqYIhmPQeYCNi3UM406zHm14,11310
21
21
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
- pvnet/training/lightning_module.py,sha256=KcEbHYBe_Gx0as0-A7bggoMjev-A_i6Y3PHGRaYllTg,12956
23
- pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
22
+ pvnet/training/lightning_module.py,sha256=57sT7bPCU7mJw4EskzOE-JJ9JhWIuAbs40_x5RoBbA8,12705
23
+ pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.2.3.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.2.3.dist-info/METADATA,sha256=xA54YM0qDAlvtNSEQXwLLYF9S6sgo-kejaB0awBb0MA,16479
27
- pvnet-5.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.2.3.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.2.3.dist-info/RECORD,,
25
+ pvnet-5.3.0.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.0.dist-info/METADATA,sha256=b4Ki0jGoNNEd1VopMvR5p-iasCi0ZVtGwA-RfoHRCWw,16479
27
+ pvnet-5.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.3.0.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.0.dist-info/RECORD,,
File without changes