PVNet 5.3.0__py3-none-any.whl → 5.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/models/late_fusion/late_fusion.py +9 -0
- pvnet/models/late_fusion/site_encoders/encoders.py +2 -2
- pvnet/optimizers.py +24 -32
- pvnet/training/lightning_module.py +23 -6
- {pvnet-5.3.0.dist-info → pvnet-5.3.5.dist-info}/METADATA +2 -2
- {pvnet-5.3.0.dist-info → pvnet-5.3.5.dist-info}/RECORD +9 -9
- {pvnet-5.3.0.dist-info → pvnet-5.3.5.dist-info}/WHEEL +0 -0
- {pvnet-5.3.0.dist-info → pvnet-5.3.5.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.3.0.dist-info → pvnet-5.3.5.dist-info}/top_level.txt +0 -0
|
@@ -46,6 +46,7 @@ class LateFusionModel(BaseModel):
|
|
|
46
46
|
include_generation_history: bool = False,
|
|
47
47
|
include_sun: bool = True,
|
|
48
48
|
include_time: bool = False,
|
|
49
|
+
t0_embedding_dim: int = 0,
|
|
49
50
|
location_id_mapping: dict[Any, int] | None = None,
|
|
50
51
|
embedding_dim: int = 16,
|
|
51
52
|
forecast_minutes: int = 30,
|
|
@@ -85,6 +86,8 @@ class LateFusionModel(BaseModel):
|
|
|
85
86
|
include_generation_history: Include generation yield data.
|
|
86
87
|
include_sun: Include sun azimuth and altitude data.
|
|
87
88
|
include_time: Include sine and cosine of dates and times.
|
|
89
|
+
t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
|
|
90
|
+
if set to 0.
|
|
88
91
|
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
|
|
89
92
|
not used if this is not provided.
|
|
90
93
|
embedding_dim: Number of embedding dimensions to use for location ID.
|
|
@@ -119,6 +122,7 @@ class LateFusionModel(BaseModel):
|
|
|
119
122
|
self.include_pv = pv_encoder is not None
|
|
120
123
|
self.include_sun = include_sun
|
|
121
124
|
self.include_time = include_time
|
|
125
|
+
self.t0_embedding_dim = t0_embedding_dim
|
|
122
126
|
self.location_id_mapping = location_id_mapping
|
|
123
127
|
self.embedding_dim = embedding_dim
|
|
124
128
|
self.add_image_embedding_channel = add_image_embedding_channel
|
|
@@ -246,6 +250,8 @@ class LateFusionModel(BaseModel):
|
|
|
246
250
|
# Update num features
|
|
247
251
|
fusion_input_features += 32
|
|
248
252
|
|
|
253
|
+
fusion_input_features += self.t0_embedding_dim
|
|
254
|
+
|
|
249
255
|
if include_generation_history:
|
|
250
256
|
# Update num features
|
|
251
257
|
fusion_input_features += self.history_len + 1
|
|
@@ -321,6 +327,9 @@ class LateFusionModel(BaseModel):
|
|
|
321
327
|
time = self.time_fc1(time)
|
|
322
328
|
modes["time"] = time
|
|
323
329
|
|
|
330
|
+
if self.t0_embedding_dim>0:
|
|
331
|
+
modes["t0_embed"] = x["t0_embedding"]
|
|
332
|
+
|
|
324
333
|
out = self.output_network(modes)
|
|
325
334
|
|
|
326
335
|
if self.use_quantile_regression:
|
|
@@ -158,7 +158,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
158
158
|
super().__init__(sequence_length, num_sites, out_features)
|
|
159
159
|
self.sequence_length = sequence_length
|
|
160
160
|
self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
|
|
161
|
-
self.
|
|
161
|
+
self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
|
|
162
162
|
self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
|
|
163
163
|
self.use_id_in_value = use_id_in_value
|
|
164
164
|
self.key_to_use = key_to_use
|
|
@@ -224,7 +224,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
224
224
|
site_seqs, batch_size = self._encode_inputs(x)
|
|
225
225
|
|
|
226
226
|
# site ID embeddings are the same for each sample
|
|
227
|
-
id_embed = torch.tile(self.
|
|
227
|
+
id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
|
|
228
228
|
# Each concated (site sequence, site ID embedding) is processed with encoder
|
|
229
229
|
x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
|
|
230
230
|
key = self._key_encoder(x_seq_in)
|
pvnet/optimizers.py
CHANGED
|
@@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
|
|
|
65
65
|
"""
|
|
66
66
|
|
|
67
67
|
@abstractmethod
|
|
68
|
-
def __call__(self):
|
|
68
|
+
def __call__(self, model: Module):
|
|
69
69
|
"""Abstract call"""
|
|
70
70
|
pass
|
|
71
71
|
|
|
@@ -129,19 +129,18 @@ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
129
129
|
{"params": decay, "weight_decay": self.weight_decay},
|
|
130
130
|
{"params": no_decay, "weight_decay": 0.0},
|
|
131
131
|
]
|
|
132
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
132
133
|
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
|
|
133
|
-
|
|
134
134
|
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
135
135
|
opt,
|
|
136
136
|
factor=self.factor,
|
|
137
137
|
patience=self.patience,
|
|
138
138
|
threshold=self.threshold,
|
|
139
139
|
)
|
|
140
|
-
|
|
141
|
-
"
|
|
142
|
-
"
|
|
140
|
+
return {
|
|
141
|
+
"optimizer": opt,
|
|
142
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
143
143
|
}
|
|
144
|
-
return [opt], [sch]
|
|
145
144
|
|
|
146
145
|
|
|
147
146
|
class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
@@ -153,15 +152,13 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
153
152
|
patience: int = 3,
|
|
154
153
|
factor: float = 0.5,
|
|
155
154
|
threshold: float = 2e-4,
|
|
156
|
-
step_freq=None,
|
|
157
155
|
**opt_kwargs,
|
|
158
156
|
):
|
|
159
157
|
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
160
|
-
self.
|
|
158
|
+
self.lr = lr
|
|
161
159
|
self.patience = patience
|
|
162
160
|
self.factor = factor
|
|
163
161
|
self.threshold = threshold
|
|
164
|
-
self.step_freq = step_freq
|
|
165
162
|
self.opt_kwargs = opt_kwargs
|
|
166
163
|
|
|
167
164
|
def _call_multi(self, model):
|
|
@@ -169,7 +166,7 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
169
166
|
|
|
170
167
|
group_args = []
|
|
171
168
|
|
|
172
|
-
for key in self.
|
|
169
|
+
for key in self.lr.keys():
|
|
173
170
|
if key == "default":
|
|
174
171
|
continue
|
|
175
172
|
|
|
@@ -178,43 +175,38 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
178
175
|
if param_name.startswith(key):
|
|
179
176
|
submodule_params += [remaining_params.pop(param_name)]
|
|
180
177
|
|
|
181
|
-
group_args += [{"params": submodule_params, "lr": self.
|
|
178
|
+
group_args += [{"params": submodule_params, "lr": self.lr[key]}]
|
|
182
179
|
|
|
183
180
|
remaining_params = [p for k, p in remaining_params.items()]
|
|
184
181
|
group_args += [{"params": remaining_params}]
|
|
185
|
-
|
|
186
|
-
opt = torch.optim.AdamW(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
182
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
183
|
+
opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
|
|
184
|
+
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
185
|
+
opt,
|
|
186
|
+
factor=self.factor,
|
|
187
|
+
patience=self.patience,
|
|
188
|
+
threshold=self.threshold,
|
|
190
189
|
)
|
|
191
|
-
|
|
192
|
-
"
|
|
193
|
-
|
|
194
|
-
factor=self.factor,
|
|
195
|
-
patience=self.patience,
|
|
196
|
-
threshold=self.threshold,
|
|
197
|
-
),
|
|
198
|
-
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
|
190
|
+
return {
|
|
191
|
+
"optimizer": opt,
|
|
192
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
199
193
|
}
|
|
200
194
|
|
|
201
|
-
return [opt], [sch]
|
|
202
195
|
|
|
203
196
|
def __call__(self, model):
|
|
204
197
|
"""Return optimizer"""
|
|
205
|
-
if
|
|
198
|
+
if isinstance(self.lr, dict):
|
|
206
199
|
return self._call_multi(model)
|
|
207
200
|
else:
|
|
208
|
-
|
|
209
|
-
opt = torch.optim.AdamW(model.parameters(), lr=
|
|
201
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
202
|
+
opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
|
|
210
203
|
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
211
204
|
opt,
|
|
212
205
|
factor=self.factor,
|
|
213
206
|
patience=self.patience,
|
|
214
207
|
threshold=self.threshold,
|
|
215
208
|
)
|
|
216
|
-
|
|
217
|
-
"
|
|
218
|
-
"
|
|
209
|
+
return {
|
|
210
|
+
"optimizer": opt,
|
|
211
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
219
212
|
}
|
|
220
|
-
return [opt], [sch]
|
|
@@ -109,7 +109,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
109
109
|
losses = self._calculate_common_losses(y, y_hat)
|
|
110
110
|
losses = {f"{k}/train": v for k, v in losses.items()}
|
|
111
111
|
|
|
112
|
-
self.log_dict(losses, on_step=True, on_epoch=True)
|
|
112
|
+
self.log_dict(losses, on_step=True, on_epoch=True, batch_size=y.size(0))
|
|
113
113
|
|
|
114
114
|
if self.model.use_quantile_regression:
|
|
115
115
|
opt_target = losses["quantile_loss/train"]
|
|
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
256
256
|
|
|
257
257
|
# Calculate the persistance losses - we only need to do this once per training run
|
|
258
258
|
# not every epoch
|
|
259
|
-
if self.current_epoch
|
|
259
|
+
if self.current_epoch==0:
|
|
260
|
+
# Need to find last valid value before forecast
|
|
261
|
+
target_data = batch["generation"]
|
|
262
|
+
history_data = target_data[:, :-(self.model.forecast_len)]
|
|
263
|
+
|
|
264
|
+
# Find where values aren't dropped
|
|
265
|
+
valid_mask = history_data >= 0
|
|
266
|
+
|
|
267
|
+
# Last valid value index for each sample
|
|
268
|
+
flipped_mask = valid_mask.float().flip(dims=[1])
|
|
269
|
+
last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
|
|
270
|
+
last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
|
|
271
|
+
|
|
272
|
+
# Grab those last valid values
|
|
273
|
+
batch_indices = torch.arange(
|
|
274
|
+
history_data.shape[0],
|
|
275
|
+
device=history_data.device
|
|
276
|
+
)
|
|
277
|
+
last_valid_values = history_data[batch_indices, last_valid_indices]
|
|
278
|
+
|
|
260
279
|
y_persist = (
|
|
261
|
-
|
|
262
|
-
.unsqueeze(1)
|
|
263
|
-
.expand(-1, self.model.forecast_len)
|
|
280
|
+
last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
|
|
264
281
|
)
|
|
265
282
|
mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
|
|
266
283
|
self._val_persistence_horizon_maes.append(mae_step_persist)
|
|
@@ -272,7 +289,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
272
289
|
)
|
|
273
290
|
|
|
274
291
|
# Log the metrics
|
|
275
|
-
self.log_dict(losses, on_step=False, on_epoch=True)
|
|
292
|
+
self.log_dict(losses, on_step=False, on_epoch=True, batch_size=y.size(0))
|
|
276
293
|
|
|
277
294
|
def on_validation_epoch_end(self) -> None:
|
|
278
295
|
"""Run on epoch end"""
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.3.
|
|
3
|
+
Version: 5.3.5
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
6
|
Requires-Python: <3.14,>=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
9
|
+
Requires-Dist: ocf-data-sampler>=1.0.9
|
|
10
10
|
Requires-Dist: numpy
|
|
11
11
|
Requires-Dist: pandas
|
|
12
12
|
Requires-Dist: matplotlib
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
|
|
2
2
|
pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
|
|
3
3
|
pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
|
|
4
|
-
pvnet/optimizers.py,sha256=
|
|
4
|
+
pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
|
|
5
5
|
pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
|
|
6
6
|
pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
|
|
7
7
|
pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
|
|
8
8
|
pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
|
|
9
9
|
pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
|
|
10
10
|
pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
|
|
11
|
-
pvnet/models/late_fusion/late_fusion.py,sha256=
|
|
11
|
+
pvnet/models/late_fusion/late_fusion.py,sha256=r05RJvw2-ZQgWJobOGq1g4rlMJQjGM0UzG3syA4T0qo,15617
|
|
12
12
|
pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
|
|
13
13
|
pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
|
|
14
14
|
pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
|
|
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
|
|
|
17
17
|
pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
|
|
18
18
|
pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
|
|
19
19
|
pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
|
|
20
|
-
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=
|
|
20
|
+
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
|
|
21
21
|
pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
22
|
-
pvnet/training/lightning_module.py,sha256=
|
|
22
|
+
pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
|
|
23
23
|
pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
|
|
24
24
|
pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
|
|
25
|
-
pvnet-5.3.
|
|
26
|
-
pvnet-5.3.
|
|
27
|
-
pvnet-5.3.
|
|
28
|
-
pvnet-5.3.
|
|
29
|
-
pvnet-5.3.
|
|
25
|
+
pvnet-5.3.5.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
26
|
+
pvnet-5.3.5.dist-info/METADATA,sha256=rIlZGmFiIzkMpG_5U-6SrsdDW6fIke667JAG79g3KN4,16479
|
|
27
|
+
pvnet-5.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
+
pvnet-5.3.5.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
29
|
+
pvnet-5.3.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|