PVNet 5.3.1__tar.gz → 5.3.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pvnet-5.3.1 → pvnet-5.3.6}/PKG-INFO +2 -2
  2. {pvnet-5.3.1 → pvnet-5.3.6}/PVNet.egg-info/PKG-INFO +2 -2
  3. {pvnet-5.3.1 → pvnet-5.3.6}/PVNet.egg-info/requires.txt +1 -1
  4. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/late_fusion.py +9 -0
  5. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/optimizers.py +24 -32
  6. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/training/lightning_module.py +23 -6
  7. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/utils.py +36 -38
  8. {pvnet-5.3.1 → pvnet-5.3.6}/pyproject.toml +1 -1
  9. {pvnet-5.3.1 → pvnet-5.3.6}/LICENSE +0 -0
  10. {pvnet-5.3.1 → pvnet-5.3.6}/PVNet.egg-info/SOURCES.txt +0 -0
  11. {pvnet-5.3.1 → pvnet-5.3.6}/PVNet.egg-info/dependency_links.txt +0 -0
  12. {pvnet-5.3.1 → pvnet-5.3.6}/PVNet.egg-info/top_level.txt +0 -0
  13. {pvnet-5.3.1 → pvnet-5.3.6}/README.md +0 -0
  14. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/__init__.py +0 -0
  15. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/datamodule.py +0 -0
  16. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/load_model.py +0 -0
  17. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/__init__.py +0 -0
  18. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/base_model.py +0 -0
  19. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/ensemble.py +0 -0
  20. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/__init__.py +0 -0
  21. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  22. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  23. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
  24. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  25. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  26. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  27. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  28. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  29. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  30. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  31. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/training/__init__.py +0 -0
  32. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/training/plots.py +0 -0
  33. {pvnet-5.3.1 → pvnet-5.3.6}/pvnet/training/train.py +0 -0
  34. {pvnet-5.3.1 → pvnet-5.3.6}/setup.cfg +0 -0
  35. {pvnet-5.3.1 → pvnet-5.3.6}/tests/test_datamodule.py +0 -0
  36. {pvnet-5.3.1 → pvnet-5.3.6}/tests/test_end2end.py +0 -0
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.1
3
+ Version: 5.3.6
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.6.0
9
+ Requires-Dist: ocf-data-sampler>=1.0.9
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.1
3
+ Version: 5.3.6
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.6.0
9
+ Requires-Dist: ocf-data-sampler>=1.0.9
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -1,4 +1,4 @@
1
- ocf-data-sampler>=0.6.0
1
+ ocf-data-sampler>=1.0.9
2
2
  numpy
3
3
  pandas
4
4
  matplotlib
@@ -46,6 +46,7 @@ class LateFusionModel(BaseModel):
46
46
  include_generation_history: bool = False,
47
47
  include_sun: bool = True,
48
48
  include_time: bool = False,
49
+ t0_embedding_dim: int = 0,
49
50
  location_id_mapping: dict[Any, int] | None = None,
50
51
  embedding_dim: int = 16,
51
52
  forecast_minutes: int = 30,
@@ -85,6 +86,8 @@ class LateFusionModel(BaseModel):
85
86
  include_generation_history: Include generation yield data.
86
87
  include_sun: Include sun azimuth and altitude data.
87
88
  include_time: Include sine and cosine of dates and times.
89
+ t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90
+ if set to 0.
88
91
  location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
89
92
  not used if this is not provided.
90
93
  embedding_dim: Number of embedding dimensions to use for location ID.
@@ -119,6 +122,7 @@ class LateFusionModel(BaseModel):
119
122
  self.include_pv = pv_encoder is not None
120
123
  self.include_sun = include_sun
121
124
  self.include_time = include_time
125
+ self.t0_embedding_dim = t0_embedding_dim
122
126
  self.location_id_mapping = location_id_mapping
123
127
  self.embedding_dim = embedding_dim
124
128
  self.add_image_embedding_channel = add_image_embedding_channel
@@ -246,6 +250,8 @@ class LateFusionModel(BaseModel):
246
250
  # Update num features
247
251
  fusion_input_features += 32
248
252
 
253
+ fusion_input_features += self.t0_embedding_dim
254
+
249
255
  if include_generation_history:
250
256
  # Update num features
251
257
  fusion_input_features += self.history_len + 1
@@ -321,6 +327,9 @@ class LateFusionModel(BaseModel):
321
327
  time = self.time_fc1(time)
322
328
  modes["time"] = time
323
329
 
330
+ if self.t0_embedding_dim>0:
331
+ modes["t0_embed"] = x["t0_embedding"]
332
+
324
333
  out = self.output_network(modes)
325
334
 
326
335
  if self.use_quantile_regression:
@@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
65
65
  """
66
66
 
67
67
  @abstractmethod
68
- def __call__(self):
68
+ def __call__(self, model: Module):
69
69
  """Abstract call"""
70
70
  pass
71
71
 
@@ -129,19 +129,18 @@ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
129
129
  {"params": decay, "weight_decay": self.weight_decay},
130
130
  {"params": no_decay, "weight_decay": 0.0},
131
131
  ]
132
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
132
133
  opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
133
-
134
134
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
135
135
  opt,
136
136
  factor=self.factor,
137
137
  patience=self.patience,
138
138
  threshold=self.threshold,
139
139
  )
140
- sch = {
141
- "scheduler": sch,
142
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
140
+ return {
141
+ "optimizer": opt,
142
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
143
143
  }
144
- return [opt], [sch]
145
144
 
146
145
 
147
146
  class AdamWReduceLROnPlateau(AbstractOptimizer):
@@ -153,15 +152,13 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
153
152
  patience: int = 3,
154
153
  factor: float = 0.5,
155
154
  threshold: float = 2e-4,
156
- step_freq=None,
157
155
  **opt_kwargs,
158
156
  ):
159
157
  """AdamW optimizer and reduce on plateau scheduler"""
160
- self._lr = lr
158
+ self.lr = lr
161
159
  self.patience = patience
162
160
  self.factor = factor
163
161
  self.threshold = threshold
164
- self.step_freq = step_freq
165
162
  self.opt_kwargs = opt_kwargs
166
163
 
167
164
  def _call_multi(self, model):
@@ -169,7 +166,7 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
169
166
 
170
167
  group_args = []
171
168
 
172
- for key in self._lr.keys():
169
+ for key in self.lr.keys():
173
170
  if key == "default":
174
171
  continue
175
172
 
@@ -178,43 +175,38 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
178
175
  if param_name.startswith(key):
179
176
  submodule_params += [remaining_params.pop(param_name)]
180
177
 
181
- group_args += [{"params": submodule_params, "lr": self._lr[key]}]
178
+ group_args += [{"params": submodule_params, "lr": self.lr[key]}]
182
179
 
183
180
  remaining_params = [p for k, p in remaining_params.items()]
184
181
  group_args += [{"params": remaining_params}]
185
-
186
- opt = torch.optim.AdamW(
187
- group_args,
188
- lr=self._lr["default"] if model.lr is None else model.lr,
189
- **self.opt_kwargs,
182
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
183
+ opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
184
+ sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
185
+ opt,
186
+ factor=self.factor,
187
+ patience=self.patience,
188
+ threshold=self.threshold,
190
189
  )
191
- sch = {
192
- "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
193
- opt,
194
- factor=self.factor,
195
- patience=self.patience,
196
- threshold=self.threshold,
197
- ),
198
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
190
+ return {
191
+ "optimizer": opt,
192
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
199
193
  }
200
194
 
201
- return [opt], [sch]
202
195
 
203
196
  def __call__(self, model):
204
197
  """Return optimizer"""
205
- if not isinstance(self._lr, float):
198
+ if isinstance(self.lr, dict):
206
199
  return self._call_multi(model)
207
200
  else:
208
- default_lr = self._lr if model.lr is None else model.lr
209
- opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
201
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
202
+ opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
210
203
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
211
204
  opt,
212
205
  factor=self.factor,
213
206
  patience=self.patience,
214
207
  threshold=self.threshold,
215
208
  )
216
- sch = {
217
- "scheduler": sch,
218
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
209
+ return {
210
+ "optimizer": opt,
211
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
219
212
  }
220
- return [opt], [sch]
@@ -109,7 +109,7 @@ class PVNetLightningModule(pl.LightningModule):
109
109
  losses = self._calculate_common_losses(y, y_hat)
110
110
  losses = {f"{k}/train": v for k, v in losses.items()}
111
111
 
112
- self.log_dict(losses, on_step=True, on_epoch=True)
112
+ self.log_dict(losses, on_step=True, on_epoch=True, batch_size=y.size(0))
113
113
 
114
114
  if self.model.use_quantile_regression:
115
115
  opt_target = losses["quantile_loss/train"]
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
256
256
 
257
257
  # Calculate the persistance losses - we only need to do this once per training run
258
258
  # not every epoch
259
- if self.current_epoch == 0:
259
+ if self.current_epoch==0:
260
+ # Need to find last valid value before forecast
261
+ target_data = batch["generation"]
262
+ history_data = target_data[:, :-(self.model.forecast_len)]
263
+
264
+ # Find where values aren't dropped
265
+ valid_mask = history_data >= 0
266
+
267
+ # Last valid value index for each sample
268
+ flipped_mask = valid_mask.float().flip(dims=[1])
269
+ last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
270
+ last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
271
+
272
+ # Grab those last valid values
273
+ batch_indices = torch.arange(
274
+ history_data.shape[0],
275
+ device=history_data.device
276
+ )
277
+ last_valid_values = history_data[batch_indices, last_valid_indices]
278
+
260
279
  y_persist = (
261
- batch["generation"][:, -(self.model.forecast_len + 1)]
262
- .unsqueeze(1)
263
- .expand(-1, self.model.forecast_len)
280
+ last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
264
281
  )
265
282
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
266
283
  self._val_persistence_horizon_maes.append(mae_step_persist)
@@ -272,7 +289,7 @@ class PVNetLightningModule(pl.LightningModule):
272
289
  )
273
290
 
274
291
  # Log the metrics
275
- self.log_dict(losses, on_step=False, on_epoch=True)
292
+ self.log_dict(losses, on_step=False, on_epoch=True, batch_size=y.size(0))
276
293
 
277
294
  def on_validation_epoch_end(self) -> None:
278
295
  """Run on epoch end"""
@@ -101,66 +101,64 @@ def validate_batch_against_config(
101
101
  logger.info("Performing batch shape validation against model config.")
102
102
 
103
103
  # NWP validation
104
- if hasattr(model, "nwp_encoders_dict"):
104
+ if model.include_nwp:
105
105
  if "nwp" not in batch:
106
- raise ValueError(
107
- "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
- )
106
+ raise ValueError("Model uses NWP data but 'nwp' missing from batch.")
109
107
 
110
- for source, nwp_data in batch["nwp"].items():
111
- if source in model.nwp_encoders_dict:
112
- enc = model.nwp_encoders_dict[source]
113
- expected_channels = enc.in_channels
114
- if model.add_image_embedding_channel:
115
- expected_channels -= 1
116
-
117
- expected = (
118
- nwp_data["nwp"].shape[0],
119
- enc.sequence_length,
120
- expected_channels,
121
- enc.image_size_pixels,
122
- enc.image_size_pixels,
108
+ for source in model.nwp_encoders_dict:
109
+ if source not in batch["nwp"]:
110
+ raise ValueError(
111
+ f"Model uses NWP source '{source}' but it is missing from batch['nwp']."
112
+ )
113
+
114
+ enc = model.nwp_encoders_dict[source]
115
+ expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
116
+
117
+ expected_shape = (
118
+ batch["nwp"][source]["nwp"].shape[0],
119
+ enc.sequence_length,
120
+ expected_channels,
121
+ enc.image_size_pixels,
122
+ enc.image_size_pixels,
123
+ )
124
+ actual_shape = tuple(batch["nwp"][source]["nwp"].shape)
125
+ if actual_shape != expected_shape:
126
+ raise ValueError(
127
+ f"NWP.{source} shape mismatch: expected {expected_shape}, got {actual_shape}"
123
128
  )
124
- if tuple(nwp_data["nwp"].shape) != expected:
125
- actual_shape = tuple(nwp_data["nwp"].shape)
126
- raise ValueError(
127
- f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
128
- )
129
129
 
130
130
  # Satellite validation
131
- if hasattr(model, "sat_encoder"):
131
+ if model.include_sat:
132
132
  if "satellite_actual" not in batch:
133
133
  raise ValueError(
134
- "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
134
+ "Model uses satellite data but 'satellite_actual' missing from batch."
135
135
  )
136
136
 
137
137
  enc = model.sat_encoder
138
- expected_channels = enc.in_channels
139
- if model.add_image_embedding_channel:
140
- expected_channels -= 1
138
+ expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
141
139
 
142
- expected = (
140
+ expected_shape = (
143
141
  batch["satellite_actual"].shape[0],
144
142
  enc.sequence_length,
145
143
  expected_channels,
146
144
  enc.image_size_pixels,
147
145
  enc.image_size_pixels,
148
146
  )
149
- if tuple(batch["satellite_actual"].shape) != expected:
150
- actual_shape = tuple(batch["satellite_actual"].shape)
151
- raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
147
+ actual_shape = tuple(batch["satellite_actual"].shape)
148
+ if actual_shape != expected_shape:
149
+ raise ValueError(
150
+ f"Satellite shape mismatch: expected {expected_shape}, got {actual_shape}"
151
+ )
152
152
 
153
- # generation validation
154
153
  key = "generation"
155
154
  if key in batch:
156
155
  total_minutes = model.history_minutes + model.forecast_minutes
157
- interval = model.interval_minutes
158
- expected_len = total_minutes // interval + 1
159
- expected = (batch[key].shape[0], expected_len)
160
- if tuple(batch[key].shape) != expected:
161
- actual_shape = tuple(batch[key].shape)
156
+ expected_len = total_minutes // model.interval_minutes + 1
157
+ expected_shape = (batch[key].shape[0], expected_len)
158
+ actual_shape = tuple(batch[key].shape)
159
+ if actual_shape != expected_shape:
162
160
  raise ValueError(
163
- f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
161
+ f"Generation data shape mismatch: expected {expected_shape}, got {actual_shape}"
164
162
  )
165
163
 
166
164
  logger.info("Batch shape validation successful!")
@@ -12,7 +12,7 @@ readme = {file="README.md", content-type="text/markdown"}
12
12
  requires-python = ">=3.11,<3.14"
13
13
 
14
14
  dependencies = [
15
- "ocf-data-sampler>=0.6.0",
15
+ "ocf-data-sampler>=1.0.9",
16
16
  "numpy",
17
17
  "pandas",
18
18
  "matplotlib",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes