PVNet 5.0.26__tar.gz → 5.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pvnet-5.0.26 → pvnet-5.1.1}/PKG-INFO +1 -7
- {pvnet-5.0.26 → pvnet-5.1.1}/PVNet.egg-info/PKG-INFO +1 -7
- {pvnet-5.0.26 → pvnet-5.1.1}/PVNet.egg-info/SOURCES.txt +2 -4
- {pvnet-5.0.26 → pvnet-5.1.1}/README.md +0 -6
- pvnet-5.0.26/pvnet/data/base_datamodule.py → pvnet-5.1.1/pvnet/datamodule.py +22 -90
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/load_model.py +2 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/training/lightning_module.py +1 -1
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/training/train.py +1 -22
- pvnet-5.1.1/tests/test_datamodule.py +15 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/tests/test_end2end.py +2 -2
- pvnet-5.0.26/pvnet/data/__init__.py +0 -3
- pvnet-5.0.26/pvnet/data/site_datamodule.py +0 -29
- pvnet-5.0.26/pvnet/data/uk_regional_datamodule.py +0 -29
- {pvnet-5.0.26 → pvnet-5.1.1}/LICENSE +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/PVNet.egg-info/dependency_links.txt +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/PVNet.egg-info/requires.txt +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/PVNet.egg-info/top_level.txt +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/base_model.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/ensemble.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/basic_blocks.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/late_fusion.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/optimizers.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/training/__init__.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/training/plots.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pvnet/utils.py +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/pyproject.toml +0 -0
- {pvnet-5.0.26 → pvnet-5.1.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.
|
|
3
|
+
Version: 5.1.1
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
6
|
Requires-Python: >=3.11
|
|
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
142
142
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
143
143
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
144
144
|
|
|
145
|
-
## Streaming samples (no pre-save)
|
|
146
|
-
|
|
147
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
148
|
-
|
|
149
|
-
Make sure you have copied example configs (as already stated above):
|
|
150
|
-
cp -r configs.example configs
|
|
151
145
|
|
|
152
146
|
### Set up and config example for streaming
|
|
153
147
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.
|
|
3
|
+
Version: 5.1.1
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
6
|
Requires-Python: >=3.11
|
|
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
142
142
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
143
143
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
144
144
|
|
|
145
|
-
## Streaming samples (no pre-save)
|
|
146
|
-
|
|
147
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
148
|
-
|
|
149
|
-
Make sure you have copied example configs (as already stated above):
|
|
150
|
-
cp -r configs.example configs
|
|
151
145
|
|
|
152
146
|
### Set up and config example for streaming
|
|
153
147
|
|
|
@@ -7,13 +7,10 @@ PVNet.egg-info/dependency_links.txt
|
|
|
7
7
|
PVNet.egg-info/requires.txt
|
|
8
8
|
PVNet.egg-info/top_level.txt
|
|
9
9
|
pvnet/__init__.py
|
|
10
|
+
pvnet/datamodule.py
|
|
10
11
|
pvnet/load_model.py
|
|
11
12
|
pvnet/optimizers.py
|
|
12
13
|
pvnet/utils.py
|
|
13
|
-
pvnet/data/__init__.py
|
|
14
|
-
pvnet/data/base_datamodule.py
|
|
15
|
-
pvnet/data/site_datamodule.py
|
|
16
|
-
pvnet/data/uk_regional_datamodule.py
|
|
17
14
|
pvnet/models/__init__.py
|
|
18
15
|
pvnet/models/base_model.py
|
|
19
16
|
pvnet/models/ensemble.py
|
|
@@ -33,4 +30,5 @@ pvnet/training/__init__.py
|
|
|
33
30
|
pvnet/training/lightning_module.py
|
|
34
31
|
pvnet/training/plots.py
|
|
35
32
|
pvnet/training/train.py
|
|
33
|
+
tests/test_datamodule.py
|
|
36
34
|
tests/test_end2end.py
|
|
@@ -113,12 +113,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
113
113
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
114
114
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
115
115
|
|
|
116
|
-
## Streaming samples (no pre-save)
|
|
117
|
-
|
|
118
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
119
|
-
|
|
120
|
-
Make sure you have copied example configs (as already stated above):
|
|
121
|
-
cp -r configs.example configs
|
|
122
116
|
|
|
123
117
|
### Set up and config example for streaming
|
|
124
118
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
""" Data module for pytorch lightning """
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from glob import glob
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
from lightning.pytorch import LightningDataModule
|
|
8
7
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
9
8
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
|
|
10
|
-
from ocf_data_sampler.torch_datasets.
|
|
9
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
|
|
10
|
+
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
|
|
11
|
+
from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor
|
|
11
12
|
from torch.utils.data import DataLoader, Dataset, Subset
|
|
12
13
|
|
|
13
14
|
|
|
@@ -16,87 +17,8 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
|
|
|
16
17
|
return batch_to_tensor(stack_np_samples_into_batch(samples))
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
class
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
sample_dir: Path to the directory of pre-saved samples.
|
|
24
|
-
sample_class: sample class type to use for save/load/to_numpy
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, sample_dir: str, sample_class: SampleBase):
|
|
28
|
-
"""Initialise PresavedSamplesDataset"""
|
|
29
|
-
self.sample_paths = glob(f"{sample_dir}/*")
|
|
30
|
-
self.sample_class = sample_class
|
|
31
|
-
|
|
32
|
-
def __len__(self) -> int:
|
|
33
|
-
return len(self.sample_paths)
|
|
34
|
-
|
|
35
|
-
def __getitem__(self, idx) -> NumpySample:
|
|
36
|
-
sample = self.sample_class.load(self.sample_paths[idx])
|
|
37
|
-
return sample.to_numpy()
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class BasePresavedDataModule(LightningDataModule):
|
|
41
|
-
"""Base Datamodule for loading pre-saved samples."""
|
|
42
|
-
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
sample_dir: str,
|
|
46
|
-
batch_size: int = 16,
|
|
47
|
-
num_workers: int = 0,
|
|
48
|
-
prefetch_factor: int | None = None,
|
|
49
|
-
persistent_workers: bool = False,
|
|
50
|
-
pin_memory: bool = False,
|
|
51
|
-
):
|
|
52
|
-
"""Base Datamodule for loading pre-saved samples
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
sample_dir: Path to the directory of pre-saved samples.
|
|
56
|
-
batch_size: Batch size.
|
|
57
|
-
num_workers: Number of workers to use in multiprocess batch loading.
|
|
58
|
-
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
59
|
-
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
60
|
-
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
61
|
-
instances alive.
|
|
62
|
-
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
63
|
-
before returning them.
|
|
64
|
-
"""
|
|
65
|
-
super().__init__()
|
|
66
|
-
|
|
67
|
-
self.sample_dir = sample_dir
|
|
68
|
-
|
|
69
|
-
self._common_dataloader_kwargs = dict(
|
|
70
|
-
batch_size=batch_size,
|
|
71
|
-
sampler=None,
|
|
72
|
-
batch_sampler=None,
|
|
73
|
-
num_workers=num_workers,
|
|
74
|
-
collate_fn=collate_fn,
|
|
75
|
-
pin_memory=pin_memory,
|
|
76
|
-
drop_last=False,
|
|
77
|
-
timeout=0,
|
|
78
|
-
worker_init_fn=None,
|
|
79
|
-
prefetch_factor=prefetch_factor,
|
|
80
|
-
persistent_workers=persistent_workers,
|
|
81
|
-
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
|
|
85
|
-
raise NotImplementedError
|
|
86
|
-
|
|
87
|
-
def train_dataloader(self) -> DataLoader:
|
|
88
|
-
"""Construct train dataloader"""
|
|
89
|
-
dataset = self._get_premade_samples_dataset("train")
|
|
90
|
-
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
|
|
91
|
-
|
|
92
|
-
def val_dataloader(self) -> DataLoader:
|
|
93
|
-
"""Construct val dataloader"""
|
|
94
|
-
dataset = self._get_premade_samples_dataset("val")
|
|
95
|
-
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class BaseStreamedDataModule(LightningDataModule):
|
|
99
|
-
"""Base Datamodule which streams samples using a sampler for ocf-data-sampler."""
|
|
20
|
+
class BaseDataModule(LightningDataModule):
|
|
21
|
+
"""Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
|
|
100
22
|
|
|
101
23
|
def __init__(
|
|
102
24
|
self,
|
|
@@ -159,10 +81,10 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
159
81
|
if stage == "fit":
|
|
160
82
|
|
|
161
83
|
# Prepare the train dataset
|
|
162
|
-
self.train_dataset = self.
|
|
84
|
+
self.train_dataset = self._get_dataset(*self.train_period)
|
|
163
85
|
|
|
164
86
|
# Prepare and pre-shuffle the val dataset and set seed for reproducibility
|
|
165
|
-
val_dataset = self.
|
|
87
|
+
val_dataset = self._get_dataset(*self.val_period)
|
|
166
88
|
|
|
167
89
|
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
168
90
|
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
@@ -194,11 +116,7 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
194
116
|
if os.path.exists(filepath):
|
|
195
117
|
os.remove(filepath)
|
|
196
118
|
|
|
197
|
-
def
|
|
198
|
-
self,
|
|
199
|
-
start_time: str | None,
|
|
200
|
-
end_time: str | None
|
|
201
|
-
) -> Dataset:
|
|
119
|
+
def _get_dataset(self, start_time: str | None, end_time: str | None) -> Dataset:
|
|
202
120
|
raise NotImplementedError
|
|
203
121
|
|
|
204
122
|
def train_dataloader(self) -> DataLoader:
|
|
@@ -208,3 +126,17 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
208
126
|
def val_dataloader(self) -> DataLoader:
|
|
209
127
|
"""Construct val dataloader"""
|
|
210
128
|
return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class UKRegionalDataModule(BaseDataModule):
|
|
132
|
+
"""Datamodule for streaming UK regional samples."""
|
|
133
|
+
|
|
134
|
+
def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset:
|
|
135
|
+
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class SitesDataModule(BaseDataModule):
|
|
139
|
+
"""Datamodule for streaming site samples."""
|
|
140
|
+
|
|
141
|
+
def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset:
|
|
142
|
+
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
@@ -73,6 +73,8 @@ def get_model_from_checkpoints(
|
|
|
73
73
|
else:
|
|
74
74
|
raise FileNotFoundError(f"File {data_config} does not exist")
|
|
75
75
|
|
|
76
|
+
# TODO: This should be removed in a future release since no new models will be trained on
|
|
77
|
+
# presaved samples
|
|
76
78
|
# Check for datamodule config
|
|
77
79
|
# This only exists if the model was trained with presaved samples
|
|
78
80
|
datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}"
|
|
@@ -11,7 +11,7 @@ import xarray as xr
|
|
|
11
11
|
from ocf_data_sampler.numpy_sample.common_types import TensorBatch
|
|
12
12
|
from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device
|
|
13
13
|
|
|
14
|
-
from pvnet.
|
|
14
|
+
from pvnet.datamodule import collate_fn
|
|
15
15
|
from pvnet.models.base_model import BaseModel
|
|
16
16
|
from pvnet.optimizers import AbstractOptimizer
|
|
17
17
|
from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
|
|
@@ -15,10 +15,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
15
15
|
from lightning.pytorch.loggers import Logger, WandbLogger
|
|
16
16
|
from omegaconf import DictConfig, OmegaConf
|
|
17
17
|
|
|
18
|
-
from pvnet.data.base_datamodule import BasePresavedDataModule
|
|
19
18
|
from pvnet.utils import (
|
|
20
19
|
DATA_CONFIG_NAME,
|
|
21
|
-
DATAMODULE_CONFIG_NAME,
|
|
22
20
|
FULL_CONFIG_NAME,
|
|
23
21
|
MODEL_CONFIG_NAME,
|
|
24
22
|
)
|
|
@@ -102,27 +100,8 @@ def train(config: DictConfig) -> None:
|
|
|
102
100
|
os.makedirs(save_dir, exist_ok=True)
|
|
103
101
|
OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}")
|
|
104
102
|
|
|
105
|
-
# If using pre-saved samples we need to extract the data config from the directory
|
|
106
|
-
# those samples were saved to
|
|
107
|
-
if isinstance(datamodule, BasePresavedDataModule):
|
|
108
|
-
data_config = f"{config.datamodule.sample_dir}/{DATA_CONFIG_NAME}"
|
|
109
|
-
|
|
110
|
-
# We also save the datamodule config used to create the samples to the output
|
|
111
|
-
# directory and to wandb
|
|
112
|
-
shutil.copyfile(
|
|
113
|
-
f"{config.datamodule.sample_dir}/{DATAMODULE_CONFIG_NAME}",
|
|
114
|
-
f"{save_dir}/{DATAMODULE_CONFIG_NAME}"
|
|
115
|
-
)
|
|
116
|
-
wandb_logger.experiment.save(
|
|
117
|
-
f"{save_dir}/{DATAMODULE_CONFIG_NAME}",
|
|
118
|
-
base_path=save_dir,
|
|
119
|
-
)
|
|
120
|
-
else:
|
|
121
|
-
# If we are streaming batches the data config is defined and we don't need to
|
|
122
|
-
# save the datamodule config separately
|
|
123
|
-
data_config = config.datamodule.configuration
|
|
124
|
-
|
|
125
103
|
# Save the data config to the output directory and to wandb
|
|
104
|
+
data_config = config.datamodule.configuration
|
|
126
105
|
shutil.copyfile(data_config, f"{save_dir}/{DATA_CONFIG_NAME}")
|
|
127
106
|
wandb_logger.experiment.save(f"{save_dir}/{DATA_CONFIG_NAME}", base_path=save_dir)
|
|
128
107
|
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from pvnet.datamodule import SitesDataModule
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_sites_data_module(site_data_config_path):
|
|
6
|
+
"""Test SitesDataModule initialization"""
|
|
7
|
+
|
|
8
|
+
_ = SitesDataModule(
|
|
9
|
+
configuration=site_data_config_path,
|
|
10
|
+
batch_size=2,
|
|
11
|
+
num_workers=0,
|
|
12
|
+
prefetch_factor=None,
|
|
13
|
+
train_period=[None, None],
|
|
14
|
+
val_period=[None, None],
|
|
15
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import lightning
|
|
2
2
|
|
|
3
|
-
from pvnet.
|
|
3
|
+
from pvnet.datamodule import UKRegionalDataModule
|
|
4
4
|
from pvnet.optimizers import EmbAdamWReduceLROnPlateau
|
|
5
5
|
from pvnet.training.lightning_module import PVNetLightningModule
|
|
6
6
|
|
|
@@ -8,7 +8,7 @@ from pvnet.training.lightning_module import PVNetLightningModule
|
|
|
8
8
|
def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_model):
|
|
9
9
|
"""Test end-to-end training."""
|
|
10
10
|
|
|
11
|
-
datamodule =
|
|
11
|
+
datamodule = UKRegionalDataModule(
|
|
12
12
|
configuration=uk_data_config_path,
|
|
13
13
|
batch_size=2,
|
|
14
14
|
num_workers=2,
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
""" Data module for pytorch lightning """
|
|
2
|
-
|
|
3
|
-
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
|
|
4
|
-
from ocf_data_sampler.torch_datasets.sample.site import SiteSample
|
|
5
|
-
from torch.utils.data import Dataset
|
|
6
|
-
|
|
7
|
-
from pvnet.data.base_datamodule import (
|
|
8
|
-
BasePresavedDataModule,
|
|
9
|
-
BaseStreamedDataModule,
|
|
10
|
-
PresavedSamplesDataset,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class SitePresavedDataModule(BasePresavedDataModule):
|
|
15
|
-
"""Datamodule for loading pre-saved samples."""
|
|
16
|
-
|
|
17
|
-
def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
|
|
18
|
-
return PresavedSamplesDataset(f"{self.sample_dir}/{subdir}", SiteSample)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class SiteStreamedDataModule(BaseStreamedDataModule):
|
|
22
|
-
"""Datamodule which streams samples using sampler for ocf-data-sampler."""
|
|
23
|
-
|
|
24
|
-
def _get_streamed_samples_dataset(
|
|
25
|
-
self,
|
|
26
|
-
start_time: str | None,
|
|
27
|
-
end_time: str | None
|
|
28
|
-
) -> Dataset:
|
|
29
|
-
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
""" Data module for pytorch lightning """
|
|
2
|
-
|
|
3
|
-
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
|
|
4
|
-
from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample
|
|
5
|
-
from torch.utils.data import Dataset
|
|
6
|
-
|
|
7
|
-
from pvnet.data.base_datamodule import (
|
|
8
|
-
BasePresavedDataModule,
|
|
9
|
-
BaseStreamedDataModule,
|
|
10
|
-
PresavedSamplesDataset,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class UKRegionalPresavedDataModule(BasePresavedDataModule):
|
|
15
|
-
"""Datamodule for loading pre-saved samples."""
|
|
16
|
-
|
|
17
|
-
def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
|
|
18
|
-
return PresavedSamplesDataset(f"{self.sample_dir}/{subdir}", UKRegionalSample)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class UKRegionalStreamedDataModule(BaseStreamedDataModule):
|
|
22
|
-
"""Datamodule which streams samples using sampler for ocf-data-sampler."""
|
|
23
|
-
|
|
24
|
-
def _get_streamed_samples_dataset(
|
|
25
|
-
self,
|
|
26
|
-
start_time: str | None,
|
|
27
|
-
end_time: str | None
|
|
28
|
-
) -> Dataset:
|
|
29
|
-
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|