PVNet 5.3.0__py3-none-any.whl → 5.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,6 +46,7 @@ class LateFusionModel(BaseModel):
46
46
  include_generation_history: bool = False,
47
47
  include_sun: bool = True,
48
48
  include_time: bool = False,
49
+ t0_embedding_dim: int = 0,
49
50
  location_id_mapping: dict[Any, int] | None = None,
50
51
  embedding_dim: int = 16,
51
52
  forecast_minutes: int = 30,
@@ -85,6 +86,8 @@ class LateFusionModel(BaseModel):
85
86
  include_generation_history: Include generation yield data.
86
87
  include_sun: Include sun azimuth and altitude data.
87
88
  include_time: Include sine and cosine of dates and times.
89
+ t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90
+ if set to 0.
88
91
  location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
89
92
  not used if this is not provided.
90
93
  embedding_dim: Number of embedding dimensions to use for location ID.
@@ -119,6 +122,7 @@ class LateFusionModel(BaseModel):
119
122
  self.include_pv = pv_encoder is not None
120
123
  self.include_sun = include_sun
121
124
  self.include_time = include_time
125
+ self.t0_embedding_dim = t0_embedding_dim
122
126
  self.location_id_mapping = location_id_mapping
123
127
  self.embedding_dim = embedding_dim
124
128
  self.add_image_embedding_channel = add_image_embedding_channel
@@ -246,6 +250,8 @@ class LateFusionModel(BaseModel):
246
250
  # Update num features
247
251
  fusion_input_features += 32
248
252
 
253
+ fusion_input_features += self.t0_embedding_dim
254
+
249
255
  if include_generation_history:
250
256
  # Update num features
251
257
  fusion_input_features += self.history_len + 1
@@ -321,6 +327,9 @@ class LateFusionModel(BaseModel):
321
327
  time = self.time_fc1(time)
322
328
  modes["time"] = time
323
329
 
330
+ if self.t0_embedding_dim>0:
331
+ modes["t0_embed"] = x["t0_embedding"]
332
+
324
333
  out = self.output_network(modes)
325
334
 
326
335
  if self.use_quantile_regression:
@@ -158,7 +158,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
158
158
  super().__init__(sequence_length, num_sites, out_features)
159
159
  self.sequence_length = sequence_length
160
160
  self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
161
- self.id_embedding = nn.Embedding(num_sites, id_embed_dim)
161
+ self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
162
162
  self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
163
163
  self.use_id_in_value = use_id_in_value
164
164
  self.key_to_use = key_to_use
@@ -224,7 +224,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
224
224
  site_seqs, batch_size = self._encode_inputs(x)
225
225
 
226
226
  # site ID embeddings are the same for each sample
227
- id_embed = torch.tile(self.id_embedding(self._ids), (batch_size, 1, 1))
227
+ id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
228
228
  # Each concated (site sequence, site ID embedding) is processed with encoder
229
229
  x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
230
230
  key = self._key_encoder(x_seq_in)
pvnet/optimizers.py CHANGED
@@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
65
65
  """
66
66
 
67
67
  @abstractmethod
68
- def __call__(self):
68
+ def __call__(self, model: Module):
69
69
  """Abstract call"""
70
70
  pass
71
71
 
@@ -129,19 +129,18 @@ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
129
129
  {"params": decay, "weight_decay": self.weight_decay},
130
130
  {"params": no_decay, "weight_decay": 0.0},
131
131
  ]
132
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
132
133
  opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
133
-
134
134
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
135
135
  opt,
136
136
  factor=self.factor,
137
137
  patience=self.patience,
138
138
  threshold=self.threshold,
139
139
  )
140
- sch = {
141
- "scheduler": sch,
142
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
140
+ return {
141
+ "optimizer": opt,
142
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
143
143
  }
144
- return [opt], [sch]
145
144
 
146
145
 
147
146
  class AdamWReduceLROnPlateau(AbstractOptimizer):
@@ -153,15 +152,13 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
153
152
  patience: int = 3,
154
153
  factor: float = 0.5,
155
154
  threshold: float = 2e-4,
156
- step_freq=None,
157
155
  **opt_kwargs,
158
156
  ):
159
157
  """AdamW optimizer and reduce on plateau scheduler"""
160
- self._lr = lr
158
+ self.lr = lr
161
159
  self.patience = patience
162
160
  self.factor = factor
163
161
  self.threshold = threshold
164
- self.step_freq = step_freq
165
162
  self.opt_kwargs = opt_kwargs
166
163
 
167
164
  def _call_multi(self, model):
@@ -169,7 +166,7 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
169
166
 
170
167
  group_args = []
171
168
 
172
- for key in self._lr.keys():
169
+ for key in self.lr.keys():
173
170
  if key == "default":
174
171
  continue
175
172
 
@@ -178,43 +175,38 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
178
175
  if param_name.startswith(key):
179
176
  submodule_params += [remaining_params.pop(param_name)]
180
177
 
181
- group_args += [{"params": submodule_params, "lr": self._lr[key]}]
178
+ group_args += [{"params": submodule_params, "lr": self.lr[key]}]
182
179
 
183
180
  remaining_params = [p for k, p in remaining_params.items()]
184
181
  group_args += [{"params": remaining_params}]
185
-
186
- opt = torch.optim.AdamW(
187
- group_args,
188
- lr=self._lr["default"] if model.lr is None else model.lr,
189
- **self.opt_kwargs,
182
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
183
+ opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
184
+ sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
185
+ opt,
186
+ factor=self.factor,
187
+ patience=self.patience,
188
+ threshold=self.threshold,
190
189
  )
191
- sch = {
192
- "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
193
- opt,
194
- factor=self.factor,
195
- patience=self.patience,
196
- threshold=self.threshold,
197
- ),
198
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
190
+ return {
191
+ "optimizer": opt,
192
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
199
193
  }
200
194
 
201
- return [opt], [sch]
202
195
 
203
196
  def __call__(self, model):
204
197
  """Return optimizer"""
205
- if not isinstance(self._lr, float):
198
+ if isinstance(self.lr, dict):
206
199
  return self._call_multi(model)
207
200
  else:
208
- default_lr = self._lr if model.lr is None else model.lr
209
- opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
201
+ monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
202
+ opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
210
203
  sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
211
204
  opt,
212
205
  factor=self.factor,
213
206
  patience=self.patience,
214
207
  threshold=self.threshold,
215
208
  )
216
- sch = {
217
- "scheduler": sch,
218
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
209
+ return {
210
+ "optimizer": opt,
211
+ "lr_scheduler": {"scheduler": sch, "monitor": monitor},
219
212
  }
220
- return [opt], [sch]
@@ -109,7 +109,7 @@ class PVNetLightningModule(pl.LightningModule):
109
109
  losses = self._calculate_common_losses(y, y_hat)
110
110
  losses = {f"{k}/train": v for k, v in losses.items()}
111
111
 
112
- self.log_dict(losses, on_step=True, on_epoch=True)
112
+ self.log_dict(losses, on_step=True, on_epoch=True, batch_size=y.size(0))
113
113
 
114
114
  if self.model.use_quantile_regression:
115
115
  opt_target = losses["quantile_loss/train"]
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
256
256
 
257
257
  # Calculate the persistance losses - we only need to do this once per training run
258
258
  # not every epoch
259
- if self.current_epoch == 0:
259
+ if self.current_epoch==0:
260
+ # Need to find last valid value before forecast
261
+ target_data = batch["generation"]
262
+ history_data = target_data[:, :-(self.model.forecast_len)]
263
+
264
+ # Find where values aren't dropped
265
+ valid_mask = history_data >= 0
266
+
267
+ # Last valid value index for each sample
268
+ flipped_mask = valid_mask.float().flip(dims=[1])
269
+ last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
270
+ last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
271
+
272
+ # Grab those last valid values
273
+ batch_indices = torch.arange(
274
+ history_data.shape[0],
275
+ device=history_data.device
276
+ )
277
+ last_valid_values = history_data[batch_indices, last_valid_indices]
278
+
260
279
  y_persist = (
261
- batch["generation"][:, -(self.model.forecast_len + 1)]
262
- .unsqueeze(1)
263
- .expand(-1, self.model.forecast_len)
280
+ last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
264
281
  )
265
282
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
266
283
  self._val_persistence_horizon_maes.append(mae_step_persist)
@@ -272,7 +289,7 @@ class PVNetLightningModule(pl.LightningModule):
272
289
  )
273
290
 
274
291
  # Log the metrics
275
- self.log_dict(losses, on_step=False, on_epoch=True)
292
+ self.log_dict(losses, on_step=False, on_epoch=True, batch_size=y.size(0))
276
293
 
277
294
  def on_validation_epoch_end(self) -> None:
278
295
  """Run on epoch end"""
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.0
3
+ Version: 5.3.5
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.6.0
9
+ Requires-Dist: ocf-data-sampler>=1.0.9
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -1,14 +1,14 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
- pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
+ pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
5
5
  pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
7
  pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
8
  pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
9
9
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
10
10
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
11
- pvnet/models/late_fusion/late_fusion.py,sha256=kQUnyqMykmwc0GdoFhNXYStJPrjr3hFSvUNe8FumVx4,15260
11
+ pvnet/models/late_fusion/late_fusion.py,sha256=r05RJvw2-ZQgWJobOGq1g4rlMJQjGM0UzG3syA4T0qo,15617
12
12
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
13
13
  pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
14
14
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
17
17
  pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
18
18
  pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
19
19
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
20
- pvnet/models/late_fusion/site_encoders/encoders.py,sha256=DcTV2LeZ0pSZpGFmsmPEqYIhmPQeYCNi3UM406zHm14,11310
20
+ pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
21
21
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
- pvnet/training/lightning_module.py,sha256=57sT7bPCU7mJw4EskzOE-JJ9JhWIuAbs40_x5RoBbA8,12705
22
+ pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
23
23
  pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.3.0.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.3.0.dist-info/METADATA,sha256=b4Ki0jGoNNEd1VopMvR5p-iasCi0ZVtGwA-RfoHRCWw,16479
27
- pvnet-5.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.3.0.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.3.0.dist-info/RECORD,,
25
+ pvnet-5.3.5.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.5.dist-info/METADATA,sha256=rIlZGmFiIzkMpG_5U-6SrsdDW6fIke667JAG79g3KN4,16479
27
+ pvnet-5.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.3.5.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.5.dist-info/RECORD,,
File without changes