PVNet 5.0.14__py3-none-any.whl → 5.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  """Base model for all PVNet submodels"""
2
- import copy
3
2
  import logging
4
3
  import os
5
4
  import shutil
@@ -12,9 +11,7 @@ import torch
12
11
  import yaml
13
12
  from huggingface_hub import ModelCard, ModelCardData, snapshot_download
14
13
  from huggingface_hub.hf_api import HfApi
15
- from ocf_data_sampler.numpy_sample.common_types import TensorBatch
16
14
  from safetensors.torch import load_file, save_file
17
- from torchvision.transforms.functional import center_crop
18
15
 
19
16
  from pvnet.utils import (
20
17
  DATA_CONFIG_NAME,
@@ -437,69 +434,6 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
437
434
  else:
438
435
  self.num_output_features = self.forecast_len
439
436
 
440
- def _adapt_batch(self, batch: TensorBatch) -> TensorBatch:
441
- """Slice batches into appropriate shapes for model.
442
-
443
- Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
444
- We make some specific assumptions about the original batch and the derived sliced batch:
445
- - We are only limiting the future projections. I.e. we are never shrinking the batch from
446
- the left hand side of the time axis, only slicing it from the right
447
- - We are only shrinking the spatial crop of the satellite and NWP data
448
-
449
- """
450
- # Create a copy of the batch to avoid modifying the original
451
- new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
452
-
453
- if "gsp" in new_batch.keys():
454
- # Slice off the end of the GSP data
455
- gsp_len = self.forecast_len + self.history_len + 1
456
- new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
457
- new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
458
-
459
- if "site" in new_batch.keys():
460
- # Slice off the end of the site data
461
- site_len = self.forecast_len + self.history_len + 1
462
- new_batch["site"] = new_batch["site"][:, :site_len]
463
-
464
- # Slice all site related datetime coordinates and features
465
- site_time_keys = [
466
- "site_time_utc",
467
- "site_date_sin",
468
- "site_date_cos",
469
- "site_time_sin",
470
- "site_time_cos",
471
- ]
472
-
473
- for key in site_time_keys:
474
- if key in new_batch.keys():
475
- new_batch[key] = new_batch[key][:, :site_len]
476
-
477
- if self.include_sat:
478
- # Slice off the end of the satellite data and spatially crop
479
- # Shape: batch_size, seq_length, channel, height, width
480
- new_batch["satellite_actual"] = center_crop(
481
- new_batch["satellite_actual"][:, : self.sat_sequence_len],
482
- output_size=self.sat_encoder.image_size_pixels,
483
- )
484
-
485
- if self.include_nwp:
486
- # Slice off the end of the NWP data and spatially crop
487
- for nwp_source in self.nwp_encoders_dict:
488
- # shape: batch_size, seq_len, n_chans, height, width
489
- new_batch["nwp"][nwp_source]["nwp"] = center_crop(
490
- new_batch["nwp"][nwp_source]["nwp"],
491
- output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
492
- )[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
493
-
494
- if self.include_sun:
495
- sun_len = self.forecast_len + self.history_len + 1
496
- # Slice off end of solar coords
497
- for s in ["solar_azimuth", "solar_elevation"]:
498
- if s in new_batch.keys():
499
- new_batch[s] = new_batch[s][:, :sun_len]
500
-
501
- return new_batch
502
-
503
437
  def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
504
438
  """
505
439
  Convert network prediction into a point prediction.
@@ -61,7 +61,6 @@ class LateFusionModel(BaseModel):
61
61
  nwp_interval_minutes: DictConfig | None = None,
62
62
  pv_interval_minutes: int = 5,
63
63
  sat_interval_minutes: int = 5,
64
- adapt_batches: bool = False,
65
64
  ):
66
65
  """Neural network which combines information from different sources.
67
66
 
@@ -110,9 +109,6 @@ class LateFusionModel(BaseModel):
110
109
  data for each source
111
110
  pv_interval_minutes: The interval between each sample of the PV data
112
111
  sat_interval_minutes: The interval between each sample of the satellite data
113
- adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
114
- the model to use. This allows us to overprepare batches and slice from them for the
115
- data we need for a model run.
116
112
  """
117
113
  super().__init__(
118
114
  history_minutes=history_minutes,
@@ -134,7 +130,6 @@ class LateFusionModel(BaseModel):
134
130
  self.add_image_embedding_channel = add_image_embedding_channel
135
131
  self.interval_minutes = interval_minutes
136
132
  self.min_sat_delay_minutes = min_sat_delay_minutes
137
- self.adapt_batches = adapt_batches
138
133
 
139
134
  if self.location_id_mapping is None:
140
135
  logger.warning(
@@ -272,9 +267,6 @@ class LateFusionModel(BaseModel):
272
267
  def forward(self, x: TensorBatch) -> torch.Tensor:
273
268
  """Run model forward"""
274
269
 
275
- if self.adapt_batches:
276
- x = self._adapt_batch(x)
277
-
278
270
  if self.use_id_embedding:
279
271
  # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
280
272
  id = torch.tensor(
@@ -105,10 +105,6 @@ class PVNetLightningModule(pl.LightningModule):
105
105
  """Run training step"""
106
106
  y_hat = self.model(batch)
107
107
 
108
- # Batch may be adapted in the model forward method, would need adapting here too
109
- if self.model.adapt_batches:
110
- batch = self.model._adapt_batch(batch)
111
-
112
108
  y = batch[self.model._target_key][:, -self.model.forecast_len :]
113
109
 
114
110
  losses = self._calculate_common_losses(y, y_hat)
@@ -212,10 +208,6 @@ class PVNetLightningModule(pl.LightningModule):
212
208
  with torch.no_grad():
213
209
  y_hat = self.model(batch)
214
210
 
215
- # Batch may be adapted in the model forward method, would need adapting here too
216
- if self.model.adapt_batches:
217
- batch = self.model._adapt_batch(batch)
218
-
219
211
  fig = plot_sample_forecasts(
220
212
  batch,
221
213
  y_hat,
@@ -235,9 +227,6 @@ class PVNetLightningModule(pl.LightningModule):
235
227
  """Run validation step"""
236
228
 
237
229
  y_hat = self.model(batch)
238
- # Batch may be adapted in the model forward method, would need adapting here too
239
- if self.model.adapt_batches:
240
- batch = self.model._adapt_batch(batch)
241
230
 
242
231
  # Internally store the val predictions
243
232
  self._store_val_predictions(batch, y_hat)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.14
3
+ Version: 5.0.15
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -14,7 +14,6 @@ Requires-Dist: xarray
14
14
  Requires-Dist: h5netcdf
15
15
  Requires-Dist: torch>=2.0.0
16
16
  Requires-Dist: lightning
17
- Requires-Dist: torchvision
18
17
  Requires-Dist: typer
19
18
  Requires-Dist: sqlalchemy
20
19
  Requires-Dist: fsspec[s3]
@@ -7,11 +7,11 @@ pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0
7
7
  pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
8
  pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
9
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
10
- pvnet/models/base_model.py,sha256=sdVWUuJivQYmfK0CF2wzv8Hm-C50lJ--zttMScRYirY,19203
10
+ pvnet/models/base_model.py,sha256=4Sxal_8hq4F6irr2uG7GbtYZAjaOzUxmLBnO1U7fce0,16181
11
11
  pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
12
12
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
13
13
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
14
- pvnet/models/late_fusion/late_fusion.py,sha256=Lz4qrhIwVz7hrIR2sOzOoWAL2rKWWmUREIAtQs0zI8c,16131
14
+ pvnet/models/late_fusion/late_fusion.py,sha256=VpP1aY646iO-JBlYkyoBlj0Z-gzsqaHnMgpNzjTqAIo,15735
15
15
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
16
16
  pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=N8CZCQUwydmBip10AHY5hSkwrSP12B_Mrvm-5XVcz1s,3005
17
17
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=i8POpAsdNFwdnx64Hn7Zm9q9bby8qMPn7G7miwhsxGk,6645
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
22
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
23
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
24
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=GVqdi5ALFo9-_WRYeyMMj2qH_k4gPxQ2sG6FhL_wRFE,13242
25
+ pvnet/training/lightning_module.py,sha256=UvtpijxlRSDKrwH979FSQBeCLLMReYe-S-guWe1upl4,12685
26
26
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
27
  pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
28
- pvnet-5.0.14.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.14.dist-info/METADATA,sha256=FJMwH-1nLPAkE6zsgtdQeofXdWoJn3qzEpUU9ltWn_s,18044
30
- pvnet-5.0.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.14.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.14.dist-info/RECORD,,
28
+ pvnet-5.0.15.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
+ pvnet-5.0.15.dist-info/METADATA,sha256=uiHZcCcXv50Cl82sg_5XU1D0adASBW8C9DRH6pvamaA,18017
30
+ pvnet-5.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ pvnet-5.0.15.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
+ pvnet-5.0.15.dist-info/RECORD,,
File without changes