PVNet_summation 1.0.1__tar.gz → 1.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PVNet_summation might be problematic. Click here for more details.

Files changed (24) hide show
  1. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PKG-INFO +1 -1
  2. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PVNet_summation.egg-info/PKG-INFO +1 -1
  3. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/data/datamodule.py +55 -5
  4. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/training/train.py +12 -2
  5. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/LICENSE +0 -0
  6. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PVNet_summation.egg-info/SOURCES.txt +0 -0
  7. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PVNet_summation.egg-info/dependency_links.txt +0 -0
  8. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PVNet_summation.egg-info/requires.txt +0 -0
  9. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/PVNet_summation.egg-info/top_level.txt +0 -0
  10. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/README.md +0 -0
  11. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/__init__.py +0 -0
  12. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/data/__init__.py +0 -0
  13. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/load_model.py +0 -0
  14. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/models/__init__.py +0 -0
  15. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/models/base_model.py +0 -0
  16. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/models/dense_model.py +0 -0
  17. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/optimizers.py +0 -0
  18. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/training/__init__.py +0 -0
  19. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/training/lightning_module.py +0 -0
  20. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/training/plots.py +0 -0
  21. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pvnet_summation/utils.py +0 -0
  22. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/pyproject.toml +0 -0
  23. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/setup.cfg +0 -0
  24. {pvnet_summation-1.0.1 → pvnet_summation-1.0.2}/tests/test_end2end.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet_summation
3
- Version: 1.0.1
3
+ Version: 1.0.2
4
4
  Summary: PVNet_summation
5
5
  Author-email: James Fulton <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet_summation
3
- Version: 1.0.1
3
+ Version: 1.0.2
4
4
  Summary: PVNet_summation
5
5
  Author-email: James Fulton <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -1,5 +1,6 @@
1
1
  """Pytorch lightning datamodules for loading pre-saved samples and predictions."""
2
2
 
3
+ import os
3
4
  from glob import glob
4
5
  from typing import TypeAlias
5
6
 
@@ -11,7 +12,7 @@ from ocf_data_sampler.load.gsp import open_gsp
11
12
  from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
12
13
  from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
13
14
  from ocf_data_sampler.utils import minutes
14
- from torch.utils.data import DataLoader, Dataset, default_collate
15
+ from torch.utils.data import DataLoader, Dataset, Subset, default_collate
15
16
  from typing_extensions import override
16
17
 
17
18
  SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
@@ -103,6 +104,8 @@ class StreamedDataModule(LightningDataModule):
103
104
  num_workers: int = 0,
104
105
  prefetch_factor: int | None = None,
105
106
  persistent_workers: bool = False,
107
+ seed: int | None = None,
108
+ dataset_pickle_dir: str | None = None,
106
109
  ):
107
110
  """Datamodule for creating concurrent PVNet inputs and national targets.
108
111
 
@@ -115,11 +118,16 @@ class StreamedDataModule(LightningDataModule):
115
118
  persistent_workers: If True, the data loader will not shut down the worker processes
116
119
  after a dataset has been consumed once. This allows to maintain the workers Dataset
117
120
  instances alive.
121
+ seed: Random seed used in shuffling datasets.
122
+ dataset_pickle_dir: Directory in which the val and train set will be presaved as
123
+ pickle objects. Setting this speeds up instantiation of multiple workers a lot.
118
124
  """
119
125
  super().__init__()
120
126
  self.configuration = configuration
121
127
  self.train_period = train_period
122
128
  self.val_period = val_period
129
+ self.seed = seed
130
+ self.dataset_pickle_dir = dataset_pickle_dir
123
131
 
124
132
  self._dataloader_kwargs = dict(
125
133
  batch_size=None,
@@ -132,17 +140,58 @@ class StreamedDataModule(LightningDataModule):
132
140
  worker_init_fn=None,
133
141
  prefetch_factor=prefetch_factor,
134
142
  persistent_workers=persistent_workers,
143
+ multiprocessing_context="spawn" if num_workers>0 else None,
135
144
  )
136
145
 
146
+ def setup(self, stage: str | None = None):
147
+ """Called once to prepare the datasets."""
148
+
149
+ # This logic runs only once at the start of training, therefore the val dataset is only
150
+ # shuffled once
151
+ if self.dataset_pickle_dir is not None:
152
+ os.makedirs(self.dataset_pickle_dir, exist_ok=True)
153
+
154
+ train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
155
+ val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
156
+
157
+ # For safety, these pickled datasets cannot be overwritten.
158
+ # See: https://github.com/openclimatefix/pvnet/pull/445
159
+ for path in [train_dataset_path, val_dataset_path]:
160
+ if os.path.exists(path):
161
+ raise FileExistsError(
162
+ f"The pickled dataset path '{path}' already exists. Make sure that "
163
+ "this can be safely deleted (i.e. not currently being used by any "
164
+ "training run) and delete it manually. Else change the "
165
+ "`dataset_pickle_dir` to a different directory."
166
+ )
167
+
168
+ # Prepare the train dataset
169
+ self.train_dataset = StreamedDataset(self.configuration, *self.train_period)
170
+
171
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
172
+ val_dataset = StreamedDataset(self.configuration, *self.val_period)
173
+ shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
174
+ self.val_dataset = Subset(val_dataset, shuffled_indices)
175
+
176
+ if self.dataset_pickle_dir is not None:
177
+ self.train_dataset.presave_pickle(train_dataset_path)
178
+ self.train_dataset.presave_pickle(val_dataset_path)
179
+
180
+ def teardown(self, stage: str | None = None) -> None:
181
+ """Clean up the pickled datasets"""
182
+ if self.dataset_pickle_dir is not None:
183
+ for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
184
+ filepath = f"{self.dataset_pickle_dir}/{filename}"
185
+ if os.path.exists(filepath):
186
+ os.remove(filepath)
187
+
137
188
  def train_dataloader(self, shuffle: bool = False) -> DataLoader:
138
189
  """Construct train dataloader"""
139
- dataset = StreamedDataset(self.configuration, *self.train_period)
140
- return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
190
+ return DataLoader(self.train_dataset, shuffle=shuffle, **self._dataloader_kwargs)
141
191
 
142
192
  def val_dataloader(self, shuffle: bool = False) -> DataLoader:
143
193
  """Construct val dataloader"""
144
- dataset = StreamedDataset(self.configuration, *self.val_period)
145
- return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
194
+ return DataLoader(self.val_dataset, shuffle=shuffle, **self._dataloader_kwargs)
146
195
 
147
196
 
148
197
  class PresavedDataset(Dataset):
@@ -200,6 +249,7 @@ class PresavedDataModule(LightningDataModule):
200
249
  worker_init_fn=None,
201
250
  prefetch_factor=prefetch_factor,
202
251
  persistent_workers=persistent_workers,
252
+ multiprocessing_context="spawn" if num_workers>0 else None,
203
253
  )
204
254
 
205
255
  def train_dataloader(self, shuffle: bool = True) -> DataLoader:
@@ -93,8 +93,12 @@ def train(config: DictConfig) -> None:
93
93
  train_period=config.datamodule.train_period,
94
94
  val_period=config.datamodule.val_period,
95
95
  persistent_workers=False,
96
+ seed=config.datamodule.seed,
97
+ dataset_pickle_dir=config.datamodule.dataset_pickle_dir,
96
98
  )
97
99
 
100
+ datamodule.setup()
101
+
98
102
  for dataloader_func, max_num_samples, split in [
99
103
  (datamodule.train_dataloader, config.datamodule.max_num_train_samples, "train",),
100
104
  (datamodule.val_dataloader, config.datamodule.max_num_val_samples, "val"),
@@ -103,7 +107,10 @@ def train(config: DictConfig) -> None:
103
107
  log.info(f"Saving {split} outputs")
104
108
  dataloader = dataloader_func(shuffle=True)
105
109
 
106
- for i, sample in tqdm(zip(range(max_num_samples), dataloader)):
110
+ if max_num_samples is None:
111
+ max_num_samples=len(dataloader)
112
+
113
+ for i, sample in tqdm(zip(range(max_num_samples), dataloader), total=max_num_samples):
107
114
  # Run PVNet inputs though model
108
115
  x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device)
109
116
  pvnet_outputs = pvnet_model(x).detach().cpu()
@@ -116,6 +123,9 @@ def train(config: DictConfig) -> None:
116
123
 
117
124
  del dataloader
118
125
 
126
+ datamodule.teardown()
127
+
128
+
119
129
  datamodule = PresavedDataModule(
120
130
  sample_dir=save_dir,
121
131
  batch_size=config.datamodule.batch_size,
@@ -182,4 +192,4 @@ def train(config: DictConfig) -> None:
182
192
  )
183
193
 
184
194
  # Train the model completely
185
- trainer.fit(model=model, datamodule=datamodule)
195
+ trainer.fit(model=model, datamodule=datamodule)
File without changes