PVNet 5.3.14__tar.gz → 5.3.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pvnet-5.3.14 → pvnet-5.3.16}/PKG-INFO +1 -1
- {pvnet-5.3.14 → pvnet-5.3.16}/PVNet.egg-info/PKG-INFO +1 -1
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/training/lightning_module.py +22 -14
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/training/plots.py +14 -9
- {pvnet-5.3.14 → pvnet-5.3.16}/LICENSE +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/PVNet.egg-info/SOURCES.txt +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/PVNet.egg-info/dependency_links.txt +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/PVNet.egg-info/requires.txt +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/PVNet.egg-info/top_level.txt +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/README.md +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/datamodule.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/load_model.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/base_model.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/ensemble.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/basic_blocks.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/late_fusion.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/optimizers.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/training/__init__.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/training/train.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pvnet/utils.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/pyproject.toml +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/setup.cfg +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/tests/test_datamodule.py +0 -0
- {pvnet-5.3.14 → pvnet-5.3.16}/tests/test_end2end.py +0 -0
|
@@ -206,23 +206,25 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
206
206
|
if self.trainer.sanity_checking and plot_num == 0:
|
|
207
207
|
validate_batch_against_config(batch=batch, model=self.model)
|
|
208
208
|
|
|
209
|
-
|
|
209
|
+
# Save example forecast plots via logger
|
|
210
|
+
if self.logger:
|
|
210
211
|
y_hat = self.model(batch)
|
|
211
212
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
213
|
+
fig = plot_sample_forecasts(
|
|
214
|
+
batch,
|
|
215
|
+
y_hat,
|
|
216
|
+
quantiles=self.model.output_quantiles,
|
|
217
|
+
key_to_plot="generation",
|
|
218
|
+
)
|
|
218
219
|
|
|
219
|
-
|
|
220
|
+
plot_name = f"val_forecast_samples/sample_set_{plot_num}"
|
|
220
221
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
222
|
+
self.logger.experiment.log(
|
|
223
|
+
{plot_name: wandb.Image(fig)},
|
|
224
|
+
step=self.trainer.global_step,
|
|
225
|
+
)
|
|
224
226
|
|
|
225
|
-
|
|
227
|
+
plt.close(fig)
|
|
226
228
|
|
|
227
229
|
def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
|
|
228
230
|
"""Run validation step"""
|
|
@@ -346,7 +348,10 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
346
348
|
title="Val horizon loss curve",
|
|
347
349
|
)
|
|
348
350
|
|
|
349
|
-
wandb.log(
|
|
351
|
+
wandb.log(
|
|
352
|
+
{"val_horizon_mae_plot": horizon_mae_plot},
|
|
353
|
+
step=self.trainer.global_step,
|
|
354
|
+
)
|
|
350
355
|
|
|
351
356
|
# Create persistence horizon accuracy curve but only on first epoch
|
|
352
357
|
if self.current_epoch == 0:
|
|
@@ -357,4 +362,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
357
362
|
ylabel="MAE",
|
|
358
363
|
title="Val persistence horizon loss curve",
|
|
359
364
|
)
|
|
360
|
-
wandb.log(
|
|
365
|
+
wandb.log(
|
|
366
|
+
{"persistence_val_horizon_mae_plot": persist_horizon_mae_plot},
|
|
367
|
+
step=self.trainer.global_step,
|
|
368
|
+
)
|
|
@@ -31,7 +31,8 @@ def plot_sample_forecasts(
|
|
|
31
31
|
"""Plot a batch of data and the forecast from that batch"""
|
|
32
32
|
|
|
33
33
|
y = batch[key_to_plot].cpu().numpy()
|
|
34
|
-
y_hat = y_hat.cpu().numpy()
|
|
34
|
+
y_hat = y_hat.cpu().numpy()
|
|
35
|
+
forecast_length = y_hat.shape[1]
|
|
35
36
|
ids = batch["location_id"].cpu().numpy().squeeze()
|
|
36
37
|
times_utc = pd.to_datetime(
|
|
37
38
|
batch["time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]")
|
|
@@ -42,30 +43,34 @@ def plot_sample_forecasts(
|
|
|
42
43
|
|
|
43
44
|
for i, ax in enumerate(axes.ravel()[:batch_size]):
|
|
44
45
|
|
|
45
|
-
|
|
46
|
+
# Crop to the forecast window only
|
|
47
|
+
forecast_times = times_utc[i][-forecast_length:]
|
|
48
|
+
y_no_history = y[i][-forecast_length:]
|
|
49
|
+
|
|
50
|
+
ax.plot(forecast_times, y_no_history, marker=".", color="k", label=r"$y$")
|
|
46
51
|
|
|
47
52
|
if quantiles is None:
|
|
48
53
|
ax.plot(
|
|
49
|
-
|
|
50
|
-
y_hat[i],
|
|
51
|
-
marker=".",
|
|
52
|
-
color="r",
|
|
54
|
+
forecast_times,
|
|
55
|
+
y_hat[i],
|
|
56
|
+
marker=".",
|
|
57
|
+
color="r",
|
|
53
58
|
label=r"$\hat{y}$",
|
|
54
59
|
)
|
|
55
60
|
else:
|
|
56
61
|
cm = pylab.get_cmap("twilight")
|
|
57
62
|
for nq, q in enumerate(quantiles):
|
|
58
63
|
ax.plot(
|
|
59
|
-
|
|
64
|
+
forecast_times,
|
|
60
65
|
y_hat[i, :, nq],
|
|
61
66
|
color=cm(q),
|
|
62
67
|
label=r"$\hat{y}$" + f"({q})",
|
|
63
68
|
alpha=0.7,
|
|
64
69
|
)
|
|
65
70
|
|
|
66
|
-
ax.set_title(f"ID: {ids[i]} | {
|
|
71
|
+
ax.set_title(f"ID: {ids[i]} | {forecast_times[0].date()}", fontsize="small")
|
|
67
72
|
|
|
68
|
-
xticks = [t for t in
|
|
73
|
+
xticks = [t for t in forecast_times if t.minute == 0][::2]
|
|
69
74
|
ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90)
|
|
70
75
|
ax.grid()
|
|
71
76
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|