PVNet 5.3.2__tar.gz → 5.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pvnet-5.3.2 → pvnet-5.3.4}/PKG-INFO +1 -1
- {pvnet-5.3.2 → pvnet-5.3.4}/PVNet.egg-info/PKG-INFO +1 -1
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/optimizers.py +24 -32
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/training/lightning_module.py +23 -6
- {pvnet-5.3.2 → pvnet-5.3.4}/LICENSE +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/PVNet.egg-info/SOURCES.txt +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/PVNet.egg-info/dependency_links.txt +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/PVNet.egg-info/requires.txt +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/PVNet.egg-info/top_level.txt +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/README.md +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/datamodule.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/load_model.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/base_model.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/ensemble.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/basic_blocks.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/late_fusion.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/training/__init__.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/training/plots.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/training/train.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pvnet/utils.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/pyproject.toml +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/setup.cfg +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/tests/test_datamodule.py +0 -0
- {pvnet-5.3.2 → pvnet-5.3.4}/tests/test_end2end.py +0 -0
|
@@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
|
|
|
65
65
|
"""
|
|
66
66
|
|
|
67
67
|
@abstractmethod
|
|
68
|
-
def __call__(self):
|
|
68
|
+
def __call__(self, model: Module):
|
|
69
69
|
"""Abstract call"""
|
|
70
70
|
pass
|
|
71
71
|
|
|
@@ -129,19 +129,18 @@ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
129
129
|
{"params": decay, "weight_decay": self.weight_decay},
|
|
130
130
|
{"params": no_decay, "weight_decay": 0.0},
|
|
131
131
|
]
|
|
132
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
132
133
|
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
|
|
133
|
-
|
|
134
134
|
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
135
135
|
opt,
|
|
136
136
|
factor=self.factor,
|
|
137
137
|
patience=self.patience,
|
|
138
138
|
threshold=self.threshold,
|
|
139
139
|
)
|
|
140
|
-
|
|
141
|
-
"
|
|
142
|
-
"
|
|
140
|
+
return {
|
|
141
|
+
"optimizer": opt,
|
|
142
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
143
143
|
}
|
|
144
|
-
return [opt], [sch]
|
|
145
144
|
|
|
146
145
|
|
|
147
146
|
class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
@@ -153,15 +152,13 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
153
152
|
patience: int = 3,
|
|
154
153
|
factor: float = 0.5,
|
|
155
154
|
threshold: float = 2e-4,
|
|
156
|
-
step_freq=None,
|
|
157
155
|
**opt_kwargs,
|
|
158
156
|
):
|
|
159
157
|
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
160
|
-
self.
|
|
158
|
+
self.lr = lr
|
|
161
159
|
self.patience = patience
|
|
162
160
|
self.factor = factor
|
|
163
161
|
self.threshold = threshold
|
|
164
|
-
self.step_freq = step_freq
|
|
165
162
|
self.opt_kwargs = opt_kwargs
|
|
166
163
|
|
|
167
164
|
def _call_multi(self, model):
|
|
@@ -169,7 +166,7 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
169
166
|
|
|
170
167
|
group_args = []
|
|
171
168
|
|
|
172
|
-
for key in self.
|
|
169
|
+
for key in self.lr.keys():
|
|
173
170
|
if key == "default":
|
|
174
171
|
continue
|
|
175
172
|
|
|
@@ -178,43 +175,38 @@ class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
|
178
175
|
if param_name.startswith(key):
|
|
179
176
|
submodule_params += [remaining_params.pop(param_name)]
|
|
180
177
|
|
|
181
|
-
group_args += [{"params": submodule_params, "lr": self.
|
|
178
|
+
group_args += [{"params": submodule_params, "lr": self.lr[key]}]
|
|
182
179
|
|
|
183
180
|
remaining_params = [p for k, p in remaining_params.items()]
|
|
184
181
|
group_args += [{"params": remaining_params}]
|
|
185
|
-
|
|
186
|
-
opt = torch.optim.AdamW(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
182
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
183
|
+
opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
|
|
184
|
+
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
185
|
+
opt,
|
|
186
|
+
factor=self.factor,
|
|
187
|
+
patience=self.patience,
|
|
188
|
+
threshold=self.threshold,
|
|
190
189
|
)
|
|
191
|
-
|
|
192
|
-
"
|
|
193
|
-
|
|
194
|
-
factor=self.factor,
|
|
195
|
-
patience=self.patience,
|
|
196
|
-
threshold=self.threshold,
|
|
197
|
-
),
|
|
198
|
-
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
|
190
|
+
return {
|
|
191
|
+
"optimizer": opt,
|
|
192
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
199
193
|
}
|
|
200
194
|
|
|
201
|
-
return [opt], [sch]
|
|
202
195
|
|
|
203
196
|
def __call__(self, model):
|
|
204
197
|
"""Return optimizer"""
|
|
205
|
-
if
|
|
198
|
+
if isinstance(self.lr, dict):
|
|
206
199
|
return self._call_multi(model)
|
|
207
200
|
else:
|
|
208
|
-
|
|
209
|
-
opt = torch.optim.AdamW(model.parameters(), lr=
|
|
201
|
+
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
|
|
202
|
+
opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
|
|
210
203
|
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
211
204
|
opt,
|
|
212
205
|
factor=self.factor,
|
|
213
206
|
patience=self.patience,
|
|
214
207
|
threshold=self.threshold,
|
|
215
208
|
)
|
|
216
|
-
|
|
217
|
-
"
|
|
218
|
-
"
|
|
209
|
+
return {
|
|
210
|
+
"optimizer": opt,
|
|
211
|
+
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
|
|
219
212
|
}
|
|
220
|
-
return [opt], [sch]
|
|
@@ -109,7 +109,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
109
109
|
losses = self._calculate_common_losses(y, y_hat)
|
|
110
110
|
losses = {f"{k}/train": v for k, v in losses.items()}
|
|
111
111
|
|
|
112
|
-
self.log_dict(losses, on_step=True, on_epoch=True)
|
|
112
|
+
self.log_dict(losses, on_step=True, on_epoch=True, batch_size=y.size(0))
|
|
113
113
|
|
|
114
114
|
if self.model.use_quantile_regression:
|
|
115
115
|
opt_target = losses["quantile_loss/train"]
|
|
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
256
256
|
|
|
257
257
|
# Calculate the persistance losses - we only need to do this once per training run
|
|
258
258
|
# not every epoch
|
|
259
|
-
if self.current_epoch
|
|
259
|
+
if self.current_epoch==0:
|
|
260
|
+
# Need to find last valid value before forecast
|
|
261
|
+
target_data = batch["generation"]
|
|
262
|
+
history_data = target_data[:, :-(self.model.forecast_len)]
|
|
263
|
+
|
|
264
|
+
# Find where values aren't dropped
|
|
265
|
+
valid_mask = history_data >= 0
|
|
266
|
+
|
|
267
|
+
# Last valid value index for each sample
|
|
268
|
+
flipped_mask = valid_mask.float().flip(dims=[1])
|
|
269
|
+
last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
|
|
270
|
+
last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
|
|
271
|
+
|
|
272
|
+
# Grab those last valid values
|
|
273
|
+
batch_indices = torch.arange(
|
|
274
|
+
history_data.shape[0],
|
|
275
|
+
device=history_data.device
|
|
276
|
+
)
|
|
277
|
+
last_valid_values = history_data[batch_indices, last_valid_indices]
|
|
278
|
+
|
|
260
279
|
y_persist = (
|
|
261
|
-
|
|
262
|
-
.unsqueeze(1)
|
|
263
|
-
.expand(-1, self.model.forecast_len)
|
|
280
|
+
last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
|
|
264
281
|
)
|
|
265
282
|
mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
|
|
266
283
|
self._val_persistence_horizon_maes.append(mae_step_persist)
|
|
@@ -272,7 +289,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
272
289
|
)
|
|
273
290
|
|
|
274
291
|
# Log the metrics
|
|
275
|
-
self.log_dict(losses, on_step=False, on_epoch=True)
|
|
292
|
+
self.log_dict(losses, on_step=False, on_epoch=True, batch_size=y.size(0))
|
|
276
293
|
|
|
277
294
|
def on_validation_epoch_end(self) -> None:
|
|
278
295
|
"""Run on epoch end"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|