PVNet 5.0.0__tar.gz → 5.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {pvnet-5.0.0 → pvnet-5.0.2}/PKG-INFO +1 -1
  2. {pvnet-5.0.0 → pvnet-5.0.2}/PVNet.egg-info/PKG-INFO +1 -1
  3. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/data/base_datamodule.py +20 -5
  4. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/utils.py +2 -0
  5. {pvnet-5.0.0 → pvnet-5.0.2}/LICENSE +0 -0
  6. {pvnet-5.0.0 → pvnet-5.0.2}/PVNet.egg-info/SOURCES.txt +0 -0
  7. {pvnet-5.0.0 → pvnet-5.0.2}/PVNet.egg-info/dependency_links.txt +0 -0
  8. {pvnet-5.0.0 → pvnet-5.0.2}/PVNet.egg-info/requires.txt +0 -0
  9. {pvnet-5.0.0 → pvnet-5.0.2}/PVNet.egg-info/top_level.txt +0 -0
  10. {pvnet-5.0.0 → pvnet-5.0.2}/README.md +0 -0
  11. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/__init__.py +0 -0
  12. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/data/__init__.py +0 -0
  13. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/data/site_datamodule.py +0 -0
  14. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/data/uk_regional_datamodule.py +0 -0
  15. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/load_model.py +0 -0
  16. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/__init__.py +0 -0
  17. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/base_model.py +0 -0
  18. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/ensemble.py +0 -0
  19. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/__init__.py +0 -0
  20. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  21. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  22. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
  23. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  24. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/late_fusion.py +0 -0
  25. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  26. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  27. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  28. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  29. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  30. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  31. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/optimizers.py +0 -0
  32. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/training/__init__.py +0 -0
  33. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/training/lightning_module.py +0 -0
  34. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/training/plots.py +0 -0
  35. {pvnet-5.0.0 → pvnet-5.0.2}/pvnet/training/train.py +0 -0
  36. {pvnet-5.0.0 → pvnet-5.0.2}/pyproject.toml +0 -0
  37. {pvnet-5.0.0 → pvnet-5.0.2}/setup.cfg +0 -0
  38. {pvnet-5.0.0 → pvnet-5.0.2}/tests/test_end2end.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.0
3
+ Version: 5.0.2
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.0
3
+ Version: 5.0.2
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -45,6 +45,7 @@ class BasePresavedDataModule(LightningDataModule):
45
45
  batch_size: int = 16,
46
46
  num_workers: int = 0,
47
47
  prefetch_factor: int | None = None,
48
+ persistent_workers: bool = False,
48
49
  ):
49
50
  """Base Datamodule for loading pre-saved samples
50
51
 
@@ -53,8 +54,9 @@ class BasePresavedDataModule(LightningDataModule):
53
54
  batch_size: Batch size.
54
55
  num_workers: Number of workers to use in multiprocess batch loading.
55
56
  prefetch_factor: Number of data will be prefetched at the end of each worker process.
56
- train_period: Date range filter for train dataloader.
57
- val_period: Date range filter for val dataloader.
57
+ persistent_workers: If True, the data loader will not shut down the worker processes
58
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
59
+ instances alive.
58
60
  """
59
61
  super().__init__()
60
62
 
@@ -71,7 +73,7 @@ class BasePresavedDataModule(LightningDataModule):
71
73
  timeout=0,
72
74
  worker_init_fn=None,
73
75
  prefetch_factor=prefetch_factor,
74
- persistent_workers=False,
76
+ persistent_workers=persistent_workers,
75
77
  )
76
78
 
77
79
  def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
@@ -97,8 +99,11 @@ class BaseStreamedDataModule(LightningDataModule):
97
99
  batch_size: int = 16,
98
100
  num_workers: int = 0,
99
101
  prefetch_factor: int | None = None,
102
+ persistent_workers: bool = False,
100
103
  train_period: list[str | None] = [None, None],
101
104
  val_period: list[str | None] = [None, None],
105
+ seed: int | None = None,
106
+
102
107
  ):
103
108
  """Base Datamodule for streaming samples.
104
109
 
@@ -107,14 +112,19 @@ class BaseStreamedDataModule(LightningDataModule):
107
112
  batch_size: Batch size.
108
113
  num_workers: Number of workers to use in multiprocess batch loading.
109
114
  prefetch_factor: Number of data will be prefetched at the end of each worker process.
115
+ persistent_workers: If True, the data loader will not shut down the worker processes
116
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
117
+ instances alive.
110
118
  train_period: Date range filter for train dataloader.
111
119
  val_period: Date range filter for val dataloader.
120
+ seed: Random seed used in shuffling datasets.
112
121
  """
113
122
  super().__init__()
114
123
 
115
124
  self.configuration = configuration
116
125
  self.train_period = train_period
117
126
  self.val_period = val_period
127
+ self.seed = seed
118
128
 
119
129
  self._common_dataloader_kwargs = dict(
120
130
  batch_size=batch_size,
@@ -126,7 +136,7 @@ class BaseStreamedDataModule(LightningDataModule):
126
136
  timeout=0,
127
137
  worker_init_fn=None,
128
138
  prefetch_factor=prefetch_factor,
129
- persistent_workers=False,
139
+ persistent_workers=persistent_workers,
130
140
  )
131
141
 
132
142
  def setup(self, stage: str | None = None):
@@ -135,11 +145,16 @@ class BaseStreamedDataModule(LightningDataModule):
135
145
  # This logic runs only once at the start of training, therefore the val dataset is only
136
146
  # shuffled once
137
147
  if stage == "fit":
148
+
138
149
  # Prepare the train dataset
139
150
  self.train_dataset = self._get_streamed_samples_dataset(*self.train_period)
140
151
 
141
- # Prepare and pre-shuffle the val dataset
152
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
142
153
  val_dataset = self._get_streamed_samples_dataset(*self.val_period)
154
+
155
+ if self.seed is not None:
156
+ torch.manual_seed(self.seed)
157
+
143
158
  shuffled_indices = torch.randperm(len(val_dataset))
144
159
  self.val_dataset = Subset(val_dataset, shuffled_indices)
145
160
 
@@ -43,6 +43,8 @@ def run_config_utilities(config: DictConfig) -> None:
43
43
  config.datamodule.pin_memory = False
44
44
  if config.datamodule.get("num_workers"):
45
45
  config.datamodule.num_workers = 0
46
+ if config.datamodule.get("prefetch_factor"):
47
+ config.datamodule.prefetch_factor = None
46
48
 
47
49
  # Disable adding new keys to config
48
50
  OmegaConf.set_struct(config, True)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes