PVNet 5.0.26__py3-none-any.whl → 5.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,14 @@
1
1
  """ Data module for pytorch lightning """
2
2
 
3
3
  import os
4
- from glob import glob
5
4
 
6
5
  import numpy as np
7
6
  from lightning.pytorch import LightningDataModule
8
7
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
9
8
  from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
10
- from ocf_data_sampler.torch_datasets.sample.base import SampleBase, batch_to_tensor
9
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
10
+ from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
11
+ from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor
11
12
  from torch.utils.data import DataLoader, Dataset, Subset
12
13
 
13
14
 
@@ -16,87 +17,8 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
16
17
  return batch_to_tensor(stack_np_samples_into_batch(samples))
17
18
 
18
19
 
19
- class PresavedSamplesDataset(Dataset):
20
- """Dataset of pre-saved samples
21
-
22
- Args:
23
- sample_dir: Path to the directory of pre-saved samples.
24
- sample_class: sample class type to use for save/load/to_numpy
25
- """
26
-
27
- def __init__(self, sample_dir: str, sample_class: SampleBase):
28
- """Initialise PresavedSamplesDataset"""
29
- self.sample_paths = glob(f"{sample_dir}/*")
30
- self.sample_class = sample_class
31
-
32
- def __len__(self) -> int:
33
- return len(self.sample_paths)
34
-
35
- def __getitem__(self, idx) -> NumpySample:
36
- sample = self.sample_class.load(self.sample_paths[idx])
37
- return sample.to_numpy()
38
-
39
-
40
- class BasePresavedDataModule(LightningDataModule):
41
- """Base Datamodule for loading pre-saved samples."""
42
-
43
- def __init__(
44
- self,
45
- sample_dir: str,
46
- batch_size: int = 16,
47
- num_workers: int = 0,
48
- prefetch_factor: int | None = None,
49
- persistent_workers: bool = False,
50
- pin_memory: bool = False,
51
- ):
52
- """Base Datamodule for loading pre-saved samples
53
-
54
- Args:
55
- sample_dir: Path to the directory of pre-saved samples.
56
- batch_size: Batch size.
57
- num_workers: Number of workers to use in multiprocess batch loading.
58
- prefetch_factor: Number of batches loaded in advance by each worker.
59
- persistent_workers: If True, the data loader will not shut down the worker processes
60
- after a dataset has been consumed once. This allows to maintain the workers Dataset
61
- instances alive.
62
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
63
- before returning them.
64
- """
65
- super().__init__()
66
-
67
- self.sample_dir = sample_dir
68
-
69
- self._common_dataloader_kwargs = dict(
70
- batch_size=batch_size,
71
- sampler=None,
72
- batch_sampler=None,
73
- num_workers=num_workers,
74
- collate_fn=collate_fn,
75
- pin_memory=pin_memory,
76
- drop_last=False,
77
- timeout=0,
78
- worker_init_fn=None,
79
- prefetch_factor=prefetch_factor,
80
- persistent_workers=persistent_workers,
81
- multiprocessing_context="spawn" if num_workers>0 else None,
82
- )
83
-
84
- def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
85
- raise NotImplementedError
86
-
87
- def train_dataloader(self) -> DataLoader:
88
- """Construct train dataloader"""
89
- dataset = self._get_premade_samples_dataset("train")
90
- return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
91
-
92
- def val_dataloader(self) -> DataLoader:
93
- """Construct val dataloader"""
94
- dataset = self._get_premade_samples_dataset("val")
95
- return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
96
-
97
-
98
- class BaseStreamedDataModule(LightningDataModule):
99
- """Base Datamodule which streams samples using a sampler for ocf-data-sampler."""
20
+ class BaseDataModule(LightningDataModule):
21
+ """Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
100
22
 
101
23
  def __init__(
102
24
  self,
@@ -159,10 +81,10 @@ class BaseStreamedDataModule(LightningDataModule):
159
81
  if stage == "fit":
160
82
 
161
83
  # Prepare the train dataset
162
- self.train_dataset = self._get_streamed_samples_dataset(*self.train_period)
84
+ self.train_dataset = self._get_dataset(*self.train_period)
163
85
 
164
86
  # Prepare and pre-shuffle the val dataset and set seed for reproducibility
165
- val_dataset = self._get_streamed_samples_dataset(*self.val_period)
87
+ val_dataset = self._get_dataset(*self.val_period)
166
88
 
167
89
  shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
168
90
  self.val_dataset = Subset(val_dataset, shuffled_indices)
@@ -194,11 +116,7 @@ class BaseStreamedDataModule(LightningDataModule):
194
116
  if os.path.exists(filepath):
195
117
  os.remove(filepath)
196
118
 
197
- def _get_streamed_samples_dataset(
198
- self,
199
- start_time: str | None,
200
- end_time: str | None
201
- ) -> Dataset:
119
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> Dataset:
202
120
  raise NotImplementedError
203
121
 
204
122
  def train_dataloader(self) -> DataLoader:
@@ -208,3 +126,17 @@ class BaseStreamedDataModule(LightningDataModule):
208
126
  def val_dataloader(self) -> DataLoader:
209
127
  """Construct val dataloader"""
210
128
  return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs)
129
+
130
+
131
+ class UKRegionalDataModule(BaseDataModule):
132
+ """Datamodule for streaming UK regional samples."""
133
+
134
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset:
135
+ return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
136
+
137
+
138
+ class SitesDataModule(BaseDataModule):
139
+ """Datamodule for streaming site samples."""
140
+
141
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset:
142
+ return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
pvnet/load_model.py CHANGED
@@ -73,6 +73,8 @@ def get_model_from_checkpoints(
73
73
  else:
74
74
  raise FileNotFoundError(f"File {data_config} does not exist")
75
75
 
76
+ # TODO: This should be removed in a future release since no new models will be trained on
77
+ # presaved samples
76
78
  # Check for datamodule config
77
79
  # This only exists if the model was trained with presaved samples
78
80
  datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}"
@@ -11,7 +11,7 @@ import xarray as xr
11
11
  from ocf_data_sampler.numpy_sample.common_types import TensorBatch
12
12
  from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device
13
13
 
14
- from pvnet.data.base_datamodule import collate_fn
14
+ from pvnet.datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
pvnet/training/train.py CHANGED
@@ -15,10 +15,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint
15
15
  from lightning.pytorch.loggers import Logger, WandbLogger
16
16
  from omegaconf import DictConfig, OmegaConf
17
17
 
18
- from pvnet.data.base_datamodule import BasePresavedDataModule
19
18
  from pvnet.utils import (
20
19
  DATA_CONFIG_NAME,
21
- DATAMODULE_CONFIG_NAME,
22
20
  FULL_CONFIG_NAME,
23
21
  MODEL_CONFIG_NAME,
24
22
  )
@@ -102,27 +100,8 @@ def train(config: DictConfig) -> None:
102
100
  os.makedirs(save_dir, exist_ok=True)
103
101
  OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}")
104
102
 
105
- # If using pre-saved samples we need to extract the data config from the directory
106
- # those samples were saved to
107
- if isinstance(datamodule, BasePresavedDataModule):
108
- data_config = f"{config.datamodule.sample_dir}/{DATA_CONFIG_NAME}"
109
-
110
- # We also save the datamodule config used to create the samples to the output
111
- # directory and to wandb
112
- shutil.copyfile(
113
- f"{config.datamodule.sample_dir}/{DATAMODULE_CONFIG_NAME}",
114
- f"{save_dir}/{DATAMODULE_CONFIG_NAME}"
115
- )
116
- wandb_logger.experiment.save(
117
- f"{save_dir}/{DATAMODULE_CONFIG_NAME}",
118
- base_path=save_dir,
119
- )
120
- else:
121
- # If we are streaming batches the data config is defined and we don't need to
122
- # save the datamodule config separately
123
- data_config = config.datamodule.configuration
124
-
125
103
  # Save the data config to the output directory and to wandb
104
+ data_config = config.datamodule.configuration
126
105
  shutil.copyfile(data_config, f"{save_dir}/{DATA_CONFIG_NAME}")
127
106
  wandb_logger.experiment.save(f"{save_dir}/{DATA_CONFIG_NAME}", base_path=save_dir)
128
107
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.26
3
+ Version: 5.1.1
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
142
142
  If you install the local version of `ocf-data-sampler` that is more recent than the version
143
143
  specified in `PVNet` it is not guarenteed to function properly with this library.
144
144
 
145
- ## Streaming samples (no pre-save)
146
-
147
- PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
148
-
149
- Make sure you have copied example configs (as already stated above):
150
- cp -r configs.example configs
151
145
 
152
146
  ### Set up and config example for streaming
153
147
 
@@ -1,11 +1,8 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
- pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
2
+ pvnet/datamodule.py,sha256=1qWYoRaxIvewnQJZ3GYz3zveQ7BT92XAT9vHqpLci0I,6356
3
+ pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
3
4
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
5
  pvnet/utils.py,sha256=iHG0wlN_cITKXpR16w54fs2R68wkX7vy3a_6f4SuanY,5414
5
- pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
6
- pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
7
- pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
- pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
10
7
  pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
11
8
  pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
@@ -22,11 +19,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
19
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
20
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
21
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=RFp0kOmRW0cYU__i7SRCJUIvlp57s5cZ_H3hy2JnBxE,12954
22
+ pvnet/training/lightning_module.py,sha256=ByWv4m1QSUUYq2DNCqWCWqEBcXF5gEQ1g__8J6lOseo,12944
26
23
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
- pvnet/training/train.py,sha256=1tDA34ianCRfilS0yEeIpR9nsQWaJiiZfTD2qRUmgEc,5214
28
- pvnet-5.0.26.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.26.dist-info/METADATA,sha256=OcYyfKyx5U6tWv2Tq_L9yvV1ilsih4jtBwjsM0weN2Y,16707
30
- pvnet-5.0.26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.26.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.26.dist-info/RECORD,,
24
+ pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
+ pvnet-5.1.1.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.1.1.dist-info/METADATA,sha256=52p2kuR_y7YhoLdLDio_P2ljaaIu14vcQYth-jY00_c,16474
27
+ pvnet-5.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.1.1.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.1.1.dist-info/RECORD,,
pvnet/data/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- """Data parts"""
2
- from .site_datamodule import SitePresavedDataModule, SiteStreamedDataModule
3
- from .uk_regional_datamodule import UKRegionalPresavedDataModule, UKRegionalStreamedDataModule
@@ -1,29 +0,0 @@
1
- """ Data module for pytorch lightning """
2
-
3
- from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
4
- from ocf_data_sampler.torch_datasets.sample.site import SiteSample
5
- from torch.utils.data import Dataset
6
-
7
- from pvnet.data.base_datamodule import (
8
- BasePresavedDataModule,
9
- BaseStreamedDataModule,
10
- PresavedSamplesDataset,
11
- )
12
-
13
-
14
- class SitePresavedDataModule(BasePresavedDataModule):
15
- """Datamodule for loading pre-saved samples."""
16
-
17
- def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
18
- return PresavedSamplesDataset(f"{self.sample_dir}/{subdir}", SiteSample)
19
-
20
-
21
- class SiteStreamedDataModule(BaseStreamedDataModule):
22
- """Datamodule which streams samples using sampler for ocf-data-sampler."""
23
-
24
- def _get_streamed_samples_dataset(
25
- self,
26
- start_time: str | None,
27
- end_time: str | None
28
- ) -> Dataset:
29
- return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
@@ -1,29 +0,0 @@
1
- """ Data module for pytorch lightning """
2
-
3
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
4
- from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample
5
- from torch.utils.data import Dataset
6
-
7
- from pvnet.data.base_datamodule import (
8
- BasePresavedDataModule,
9
- BaseStreamedDataModule,
10
- PresavedSamplesDataset,
11
- )
12
-
13
-
14
- class UKRegionalPresavedDataModule(BasePresavedDataModule):
15
- """Datamodule for loading pre-saved samples."""
16
-
17
- def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
18
- return PresavedSamplesDataset(f"{self.sample_dir}/{subdir}", UKRegionalSample)
19
-
20
-
21
- class UKRegionalStreamedDataModule(BaseStreamedDataModule):
22
- """Datamodule which streams samples using sampler for ocf-data-sampler."""
23
-
24
- def _get_streamed_samples_dataset(
25
- self,
26
- start_time: str | None,
27
- end_time: str | None
28
- ) -> Dataset:
29
- return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
File without changes