PVNet 5.0.1__tar.gz → 5.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {pvnet-5.0.1 → pvnet-5.0.3}/PKG-INFO +1 -1
  2. {pvnet-5.0.1 → pvnet-5.0.3}/PVNet.egg-info/PKG-INFO +1 -1
  3. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/data/base_datamodule.py +19 -5
  4. {pvnet-5.0.1 → pvnet-5.0.3}/LICENSE +0 -0
  5. {pvnet-5.0.1 → pvnet-5.0.3}/PVNet.egg-info/SOURCES.txt +0 -0
  6. {pvnet-5.0.1 → pvnet-5.0.3}/PVNet.egg-info/dependency_links.txt +0 -0
  7. {pvnet-5.0.1 → pvnet-5.0.3}/PVNet.egg-info/requires.txt +0 -0
  8. {pvnet-5.0.1 → pvnet-5.0.3}/PVNet.egg-info/top_level.txt +0 -0
  9. {pvnet-5.0.1 → pvnet-5.0.3}/README.md +0 -0
  10. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/__init__.py +0 -0
  11. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/data/__init__.py +0 -0
  12. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/data/site_datamodule.py +0 -0
  13. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/data/uk_regional_datamodule.py +0 -0
  14. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/load_model.py +0 -0
  15. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/__init__.py +0 -0
  16. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/base_model.py +0 -0
  17. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/ensemble.py +0 -0
  18. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/__init__.py +0 -0
  19. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  20. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  21. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
  22. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  23. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/late_fusion.py +0 -0
  24. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  25. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  26. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  27. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  28. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  29. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  30. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/optimizers.py +0 -0
  31. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/training/__init__.py +0 -0
  32. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/training/lightning_module.py +0 -0
  33. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/training/plots.py +0 -0
  34. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/training/train.py +0 -0
  35. {pvnet-5.0.1 → pvnet-5.0.3}/pvnet/utils.py +0 -0
  36. {pvnet-5.0.1 → pvnet-5.0.3}/pyproject.toml +0 -0
  37. {pvnet-5.0.1 → pvnet-5.0.3}/setup.cfg +0 -0
  38. {pvnet-5.0.1 → pvnet-5.0.3}/tests/test_end2end.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.1
3
+ Version: 5.0.3
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.1
3
+ Version: 5.0.3
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.10
@@ -46,6 +46,7 @@ class BasePresavedDataModule(LightningDataModule):
46
46
  num_workers: int = 0,
47
47
  prefetch_factor: int | None = None,
48
48
  persistent_workers: bool = False,
49
+ pin_memory: bool = False,
49
50
  ):
50
51
  """Base Datamodule for loading pre-saved samples
51
52
 
@@ -53,10 +54,12 @@ class BasePresavedDataModule(LightningDataModule):
53
54
  sample_dir: Path to the directory of pre-saved samples.
54
55
  batch_size: Batch size.
55
56
  num_workers: Number of workers to use in multiprocess batch loading.
56
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
57
+ prefetch_factor: Number of batches loaded in advance by each worker.
57
58
  persistent_workers: If True, the data loader will not shut down the worker processes
58
59
  after a dataset has been consumed once. This allows to maintain the workers Dataset
59
60
  instances alive.
61
+ pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
62
+ before returning them.
60
63
  """
61
64
  super().__init__()
62
65
 
@@ -68,7 +71,7 @@ class BasePresavedDataModule(LightningDataModule):
68
71
  batch_sampler=None,
69
72
  num_workers=num_workers,
70
73
  collate_fn=collate_fn,
71
- pin_memory=False,
74
+ pin_memory=pin_memory,
72
75
  drop_last=False,
73
76
  timeout=0,
74
77
  worker_init_fn=None,
@@ -100,8 +103,10 @@ class BaseStreamedDataModule(LightningDataModule):
100
103
  num_workers: int = 0,
101
104
  prefetch_factor: int | None = None,
102
105
  persistent_workers: bool = False,
106
+ pin_memory: bool = False,
103
107
  train_period: list[str | None] = [None, None],
104
108
  val_period: list[str | None] = [None, None],
109
+ seed: int | None = None,
105
110
 
106
111
  ):
107
112
  """Base Datamodule for streaming samples.
@@ -110,25 +115,29 @@ class BaseStreamedDataModule(LightningDataModule):
110
115
  configuration: Path to ocf-data-sampler configuration file.
111
116
  batch_size: Batch size.
112
117
  num_workers: Number of workers to use in multiprocess batch loading.
113
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
118
+ prefetch_factor: Number of batches loaded in advance by each worker.
114
119
  persistent_workers: If True, the data loader will not shut down the worker processes
115
120
  after a dataset has been consumed once. This allows to maintain the workers Dataset
116
121
  instances alive.
122
+ pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
123
+ before returning them.
117
124
  train_period: Date range filter for train dataloader.
118
125
  val_period: Date range filter for val dataloader.
126
+ seed: Random seed used in shuffling datasets.
119
127
  """
120
128
  super().__init__()
121
129
 
122
130
  self.configuration = configuration
123
131
  self.train_period = train_period
124
132
  self.val_period = val_period
133
+ self.seed = seed
125
134
 
126
135
  self._common_dataloader_kwargs = dict(
127
136
  batch_size=batch_size,
128
137
  batch_sampler=None,
129
138
  num_workers=num_workers,
130
139
  collate_fn=collate_fn,
131
- pin_memory=False,
140
+ pin_memory=pin_memory,
132
141
  drop_last=False,
133
142
  timeout=0,
134
143
  worker_init_fn=None,
@@ -142,11 +151,16 @@ class BaseStreamedDataModule(LightningDataModule):
142
151
  # This logic runs only once at the start of training, therefore the val dataset is only
143
152
  # shuffled once
144
153
  if stage == "fit":
154
+
145
155
  # Prepare the train dataset
146
156
  self.train_dataset = self._get_streamed_samples_dataset(*self.train_period)
147
157
 
148
- # Prepare and pre-shuffle the val dataset
158
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
149
159
  val_dataset = self._get_streamed_samples_dataset(*self.val_period)
160
+
161
+ if self.seed is not None:
162
+ torch.manual_seed(self.seed)
163
+
150
164
  shuffled_indices = torch.randperm(len(val_dataset))
151
165
  self.val_dataset = Subset(val_dataset, shuffled_indices)
152
166
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes