ocf-data-sampler 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/__init__.py +1 -1
- ocf_data_sampler/config/load.py +6 -17
- ocf_data_sampler/config/model.py +10 -20
- ocf_data_sampler/config/save.py +9 -62
- ocf_data_sampler/load/__init__.py +5 -1
- ocf_data_sampler/load/gsp.py +10 -6
- ocf_data_sampler/load/load_dataset.py +15 -17
- ocf_data_sampler/load/nwp/nwp.py +3 -4
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -17
- ocf_data_sampler/load/nwp/providers/ukv.py +1 -9
- ocf_data_sampler/load/nwp/providers/utils.py +1 -5
- ocf_data_sampler/load/satellite.py +4 -8
- ocf_data_sampler/load/site.py +20 -13
- ocf_data_sampler/numpy_sample/collate.py +3 -4
- ocf_data_sampler/numpy_sample/datetime_features.py +14 -22
- ocf_data_sampler/select/find_contiguous_time_periods.py +2 -2
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -2
- ocf_data_sampler/torch_datasets/datasets/site.py +1 -1
- {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.6.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.6.dist-info}/RECORD +29 -28
- tests/config/test_config.py +1 -47
- tests/config/test_load.py +7 -0
- tests/config/test_save.py +21 -30
- tests/load/test_load_sites.py +1 -1
- tests/numpy_sample/test_datetime_features.py +0 -10
- tests/torch_datasets/test_site.py +3 -3
- {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.6.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.6.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.6.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""Configuration model"""
|
|
2
2
|
|
|
3
|
-
from ocf_data_sampler.config.model import Configuration
|
|
3
|
+
from ocf_data_sampler.config.model import Configuration, InputData
|
|
4
4
|
from ocf_data_sampler.config.save import save_yaml_configuration
|
|
5
5
|
from ocf_data_sampler.config.load import load_yaml_configuration
|
ocf_data_sampler/config/load.py
CHANGED
|
@@ -1,33 +1,22 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
Example:
|
|
4
|
-
|
|
5
|
-
from ocf_data_sampler.config import load_yaml_configuration
|
|
6
|
-
configuration = load_yaml_configuration(filename)
|
|
7
|
-
"""
|
|
1
|
+
"""Load configuration from a yaml file"""
|
|
8
2
|
|
|
9
3
|
import fsspec
|
|
10
|
-
from pathy import Pathy
|
|
11
4
|
from pyaml_env import parse_config
|
|
12
|
-
|
|
13
5
|
from ocf_data_sampler.config import Configuration
|
|
14
6
|
|
|
15
7
|
|
|
16
|
-
def load_yaml_configuration(filename: str
|
|
8
|
+
def load_yaml_configuration(filename: str) -> Configuration:
|
|
17
9
|
"""
|
|
18
10
|
Load a yaml file which has a configuration in it
|
|
19
11
|
|
|
20
12
|
Args:
|
|
21
|
-
filename: the file name that you want to load. Will load from local, AWS, or GCP
|
|
13
|
+
filename: the yaml file name that you want to load. Will load from local, AWS, or GCP
|
|
22
14
|
depending on the protocol suffix (e.g. 's3://bucket/config.yaml').
|
|
23
15
|
|
|
24
|
-
Returns:pydantic class
|
|
16
|
+
Returns: pydantic class
|
|
25
17
|
|
|
26
18
|
"""
|
|
27
|
-
# load the file to a dictionary
|
|
28
19
|
with fsspec.open(filename, mode="r") as stream:
|
|
29
20
|
configuration = parse_config(data=stream)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
configuration = Configuration(**configuration)
|
|
33
|
-
return configuration
|
|
21
|
+
|
|
22
|
+
return Configuration(**configuration)
|
ocf_data_sampler/config/model.py
CHANGED
|
@@ -1,16 +1,10 @@
|
|
|
1
1
|
"""Configuration model for the dataset.
|
|
2
2
|
|
|
3
|
-
All paths must include the protocol prefix. For local files,
|
|
4
|
-
it's sufficient to just start with a '/'. For aws, start with 's3://',
|
|
5
|
-
for gcp start with 'gs://'.
|
|
6
3
|
|
|
7
|
-
|
|
4
|
+
Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// to read from alternative filesystems.
|
|
8
5
|
|
|
9
|
-
from ocf_data_sampler.config import Configuration
|
|
10
|
-
config = Configuration(**config_dict)
|
|
11
6
|
"""
|
|
12
7
|
|
|
13
|
-
import logging
|
|
14
8
|
from typing import Dict, List, Optional
|
|
15
9
|
from typing_extensions import Self
|
|
16
10
|
|
|
@@ -18,10 +12,6 @@ from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInf
|
|
|
18
12
|
|
|
19
13
|
from ocf_data_sampler.constants import NWP_PROVIDERS
|
|
20
14
|
|
|
21
|
-
logger = logging.getLogger(__name__)
|
|
22
|
-
|
|
23
|
-
providers = ["pvoutput.org", "solar_sheffield_passiv"]
|
|
24
|
-
|
|
25
15
|
|
|
26
16
|
class Base(BaseModel):
|
|
27
17
|
"""Pydantic Base model where no extras can be added"""
|
|
@@ -79,8 +69,6 @@ class TimeWindowMixin(Base):
|
|
|
79
69
|
return v
|
|
80
70
|
|
|
81
71
|
|
|
82
|
-
|
|
83
|
-
# noinspection PyMethodParameters
|
|
84
72
|
class DropoutMixin(Base):
|
|
85
73
|
"""Mixin class, to add dropout minutes"""
|
|
86
74
|
|
|
@@ -137,7 +125,8 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
137
125
|
|
|
138
126
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
139
127
|
...,
|
|
140
|
-
description="
|
|
128
|
+
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
129
|
+
"to read from alternative filesystems.",
|
|
141
130
|
)
|
|
142
131
|
|
|
143
132
|
channels: list[str] = Field(
|
|
@@ -145,13 +134,13 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
145
134
|
)
|
|
146
135
|
|
|
147
136
|
|
|
148
|
-
# noinspection PyMethodParameters
|
|
149
137
|
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
150
138
|
"""NWP configuration model"""
|
|
151
139
|
|
|
152
140
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
153
141
|
...,
|
|
154
|
-
description="
|
|
142
|
+
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
143
|
+
"to read from alternative filesystems.",
|
|
155
144
|
)
|
|
156
145
|
|
|
157
146
|
channels: list[str] = Field(
|
|
@@ -175,7 +164,6 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
175
164
|
"""Validate 'provider'"""
|
|
176
165
|
if v.lower() not in NWP_PROVIDERS:
|
|
177
166
|
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
|
|
178
|
-
logger.warning(message)
|
|
179
167
|
raise Exception(message)
|
|
180
168
|
return v
|
|
181
169
|
|
|
@@ -209,7 +197,11 @@ class MultiNWP(RootModel):
|
|
|
209
197
|
class GSP(TimeWindowMixin, DropoutMixin):
|
|
210
198
|
"""GSP configuration model"""
|
|
211
199
|
|
|
212
|
-
zarr_path: str = Field(
|
|
200
|
+
zarr_path: str = Field(
|
|
201
|
+
...,
|
|
202
|
+
description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
|
|
203
|
+
"to read from alternative filesystems.",
|
|
204
|
+
)
|
|
213
205
|
|
|
214
206
|
|
|
215
207
|
class Site(TimeWindowMixin, DropoutMixin):
|
|
@@ -228,8 +220,6 @@ class Site(TimeWindowMixin, DropoutMixin):
|
|
|
228
220
|
# TODO validate the csv for metadata
|
|
229
221
|
|
|
230
222
|
|
|
231
|
-
|
|
232
|
-
# noinspection PyPep8Naming
|
|
233
223
|
class InputData(Base):
|
|
234
224
|
"""Input data model"""
|
|
235
225
|
|
ocf_data_sampler/config/save.py
CHANGED
|
@@ -2,25 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
This module provides functionality to save configuration objects to YAML files,
|
|
4
4
|
supporting local and cloud storage locations.
|
|
5
|
-
|
|
6
|
-
Example:
|
|
7
|
-
from ocf_data_sampler.config import save_yaml_configuration
|
|
8
|
-
saved_path = save_yaml_configuration(config, "config.yaml")
|
|
9
5
|
"""
|
|
10
6
|
|
|
11
7
|
import json
|
|
12
|
-
from pathlib import Path
|
|
13
|
-
from typing import Union
|
|
14
|
-
|
|
15
8
|
import fsspec
|
|
16
9
|
import yaml
|
|
10
|
+
import os
|
|
17
11
|
|
|
18
12
|
from ocf_data_sampler.config import Configuration
|
|
19
13
|
|
|
20
|
-
def save_yaml_configuration(
|
|
21
|
-
configuration: Configuration,
|
|
22
|
-
filename: Union[str, Path],
|
|
23
|
-
) -> Path:
|
|
14
|
+
def save_yaml_configuration(configuration: Configuration, filename: str) -> None:
|
|
24
15
|
"""Save a configuration object to a YAML file.
|
|
25
16
|
|
|
26
17
|
Args:
|
|
@@ -28,57 +19,13 @@ def save_yaml_configuration(
|
|
|
28
19
|
filename: Destination path for the YAML file. Can be a local path or
|
|
29
20
|
cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
|
|
30
21
|
absolute paths are recommended.
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
Path: The path where the configuration was saved
|
|
34
|
-
|
|
35
|
-
Raises:
|
|
36
|
-
ValueError: If filename is None, directory doesn't exist, or if writing to the specified path fails
|
|
37
|
-
TypeError: If the configuration cannot be serialized
|
|
38
22
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
try:
|
|
43
|
-
# Convert to absolute path if it's a relative path
|
|
44
|
-
if isinstance(filename, (str, Path)) and not any(
|
|
45
|
-
str(filename).startswith(prefix) for prefix in ('gs://', 's3://', '/')
|
|
46
|
-
):
|
|
47
|
-
filename = Path.cwd() / filename
|
|
48
|
-
|
|
49
|
-
filepath = Path(filename)
|
|
50
|
-
|
|
51
|
-
# For local paths, check if parent directory exists before attempting to create
|
|
52
|
-
if filepath.is_absolute():
|
|
53
|
-
if not filepath.parent.exists():
|
|
54
|
-
raise ValueError("Directory does not exist")
|
|
55
|
-
|
|
56
|
-
# Only try to create directory if it's in a writable location
|
|
57
|
-
try:
|
|
58
|
-
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
59
|
-
except PermissionError:
|
|
60
|
-
raise ValueError(f"Permission denied when accessing directory {filepath.parent}")
|
|
61
|
-
|
|
62
|
-
# Serialize configuration to JSON-compatible dictionary
|
|
63
|
-
config_dict = json.loads(configuration.model_dump_json())
|
|
64
|
-
|
|
65
|
-
# Write to file directly for local paths
|
|
66
|
-
if filepath.is_absolute():
|
|
67
|
-
try:
|
|
68
|
-
with open(filepath, 'w') as f:
|
|
69
|
-
yaml.safe_dump(config_dict, f, default_flow_style=False)
|
|
70
|
-
except PermissionError:
|
|
71
|
-
raise ValueError(f"Permission denied when writing to {filename}")
|
|
72
|
-
else:
|
|
73
|
-
# Use fsspec for cloud storage
|
|
74
|
-
with fsspec.open(str(filepath), mode='w') as yaml_file:
|
|
75
|
-
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
|
23
|
+
|
|
24
|
+
if os.path.exists(filename):
|
|
25
|
+
raise FileExistsError(f"File already exists: {filename}")
|
|
76
26
|
|
|
77
|
-
|
|
27
|
+
# Serialize configuration to JSON-compatible dictionary
|
|
28
|
+
config_dict = json.loads(configuration.model_dump_json())
|
|
78
29
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
except (IOError, OSError) as e:
|
|
82
|
-
if "Permission denied" in str(e):
|
|
83
|
-
raise ValueError(f"Permission denied when writing to {filename}") from e
|
|
84
|
-
raise ValueError(f"Failed to write configuration to {filename}: {str(e)}") from e
|
|
30
|
+
with fsspec.open(filename, mode='w') as yaml_file:
|
|
31
|
+
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
ocf_data_sampler/load/gsp.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
1
|
import pkg_resources
|
|
3
2
|
|
|
4
3
|
import pandas as pd
|
|
5
4
|
import xarray as xr
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def open_gsp(zarr_path: str
|
|
7
|
+
def open_gsp(zarr_path: str) -> xr.DataArray:
|
|
8
|
+
"""Open the GSP data
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
zarr_path: Path to the GSP zarr data
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
xr.DataArray: The opened GSP data
|
|
15
|
+
"""
|
|
9
16
|
|
|
10
|
-
# Load GSP generation xr.Dataset
|
|
11
17
|
ds = xr.open_zarr(zarr_path)
|
|
12
18
|
|
|
13
|
-
# Rename to standard time name
|
|
14
19
|
ds = ds.rename({"datetime_gmt": "time_utc"})
|
|
15
20
|
|
|
16
21
|
# Load UK GSP locations
|
|
@@ -19,13 +24,12 @@ def open_gsp(zarr_path: str | Path) -> xr.DataArray:
|
|
|
19
24
|
index_col="gsp_id",
|
|
20
25
|
)
|
|
21
26
|
|
|
22
|
-
# Add coordinates
|
|
27
|
+
# Add locations and capacities as coordinates for each GSP and datetime
|
|
23
28
|
ds = ds.assign_coords(
|
|
24
29
|
x_osgb=(df_gsp_loc.x_osgb.to_xarray()),
|
|
25
30
|
y_osgb=(df_gsp_loc.y_osgb.to_xarray()),
|
|
26
31
|
nominal_capacity_mwp=ds.installedcapacity_mwp,
|
|
27
32
|
effective_capacity_mwp=ds.capacity_mwp,
|
|
28
|
-
|
|
29
33
|
)
|
|
30
34
|
|
|
31
35
|
return ds.generation_mw
|
|
@@ -1,36 +1,31 @@
|
|
|
1
1
|
""" Loads all data sources """
|
|
2
2
|
import xarray as xr
|
|
3
3
|
|
|
4
|
-
from ocf_data_sampler.config import
|
|
5
|
-
from ocf_data_sampler.load
|
|
6
|
-
from ocf_data_sampler.load.nwp import open_nwp
|
|
7
|
-
from ocf_data_sampler.load.satellite import open_sat_data
|
|
8
|
-
from ocf_data_sampler.load.site import open_site
|
|
4
|
+
from ocf_data_sampler.config import InputData
|
|
5
|
+
from ocf_data_sampler.load import open_nwp, open_gsp, open_sat_data, open_site
|
|
9
6
|
|
|
10
7
|
|
|
11
|
-
def get_dataset_dict(
|
|
8
|
+
def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] | xr.DataArray]:
|
|
12
9
|
"""Construct dictionary of all of the input data sources
|
|
13
10
|
|
|
14
11
|
Args:
|
|
15
|
-
|
|
12
|
+
input_config: InputData configuration object
|
|
16
13
|
"""
|
|
17
14
|
|
|
18
|
-
in_config = config.input_data
|
|
19
|
-
|
|
20
15
|
datasets_dict = {}
|
|
21
16
|
|
|
22
17
|
# Load GSP data unless the path is None
|
|
23
|
-
if
|
|
24
|
-
da_gsp = open_gsp(zarr_path=
|
|
18
|
+
if input_config.gsp and input_config.gsp.zarr_path:
|
|
19
|
+
da_gsp = open_gsp(zarr_path=input_config.gsp.zarr_path).compute()
|
|
25
20
|
|
|
26
21
|
# Remove national GSP
|
|
27
22
|
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
|
|
28
23
|
|
|
29
24
|
# Load NWP data if in config
|
|
30
|
-
if
|
|
25
|
+
if input_config.nwp:
|
|
31
26
|
|
|
32
27
|
datasets_dict["nwp"] = {}
|
|
33
|
-
for nwp_source, nwp_config in
|
|
28
|
+
for nwp_source, nwp_config in input_config.nwp.items():
|
|
34
29
|
|
|
35
30
|
da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
|
|
36
31
|
|
|
@@ -39,8 +34,8 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
|
|
|
39
34
|
datasets_dict["nwp"][nwp_source] = da_nwp
|
|
40
35
|
|
|
41
36
|
# Load satellite data if in config
|
|
42
|
-
if
|
|
43
|
-
sat_config =
|
|
37
|
+
if input_config.satellite:
|
|
38
|
+
sat_config = input_config.satellite
|
|
44
39
|
|
|
45
40
|
da_sat = open_sat_data(sat_config.zarr_path)
|
|
46
41
|
|
|
@@ -48,8 +43,11 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
|
|
|
48
43
|
|
|
49
44
|
datasets_dict["sat"] = da_sat
|
|
50
45
|
|
|
51
|
-
if
|
|
52
|
-
da_sites = open_site(
|
|
46
|
+
if input_config.site:
|
|
47
|
+
da_sites = open_site(
|
|
48
|
+
generation_file_path=input_config.site.file_path,
|
|
49
|
+
metadata_file_path=input_config.site.metadata_file_path,
|
|
50
|
+
)
|
|
53
51
|
datasets_dict["site"] = da_sites
|
|
54
52
|
|
|
55
53
|
return datasets_dict
|
ocf_data_sampler/load/nwp/nwp.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
1
|
import xarray as xr
|
|
3
2
|
|
|
4
3
|
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
5
4
|
from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def open_nwp(zarr_path:
|
|
9
|
-
"""Opens NWP
|
|
7
|
+
def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
|
|
8
|
+
"""Opens NWP zarr
|
|
10
9
|
|
|
11
10
|
Args:
|
|
12
|
-
zarr_path:
|
|
11
|
+
zarr_path: path to the zarr file
|
|
13
12
|
provider: NWP provider
|
|
14
13
|
"""
|
|
15
14
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""ECMWF provider loaders"""
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import xarray as xr
|
|
4
4
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
5
5
|
from ocf_data_sampler.load.utils import (
|
|
@@ -9,7 +9,7 @@ from ocf_data_sampler.load.utils import (
|
|
|
9
9
|
)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def open_ifs(zarr_path:
|
|
12
|
+
def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
13
13
|
"""
|
|
14
14
|
Opens the ECMWF IFS NWP data
|
|
15
15
|
|
|
@@ -19,25 +19,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
|
|
|
19
19
|
Returns:
|
|
20
20
|
Xarray DataArray of the NWP data
|
|
21
21
|
"""
|
|
22
|
-
# Open the data
|
|
23
|
-
ds = open_zarr_paths(zarr_path)
|
|
24
|
-
|
|
25
|
-
# Rename
|
|
26
|
-
ds = ds.rename(
|
|
27
|
-
{
|
|
28
|
-
"init_time": "init_time_utc",
|
|
29
|
-
}
|
|
30
|
-
)
|
|
31
22
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
23
|
+
ds = open_zarr_paths(zarr_path)
|
|
24
|
+
|
|
25
|
+
# LEGACY SUPPORT - rename variable to channel if it exists
|
|
26
|
+
ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
|
|
36
27
|
|
|
37
|
-
# Check the timestamps are unique and increasing
|
|
38
28
|
check_time_unique_increasing(ds.init_time_utc)
|
|
39
29
|
|
|
40
|
-
# Make sure the spatial coords are in increasing order
|
|
41
30
|
ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
|
|
42
31
|
|
|
43
32
|
ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
|
|
@@ -2,8 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
4
|
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
|
|
7
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
8
6
|
from ocf_data_sampler.load.utils import (
|
|
9
7
|
check_time_unique_increasing,
|
|
@@ -12,7 +10,7 @@ from ocf_data_sampler.load.utils import (
|
|
|
12
10
|
)
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
def open_ukv(zarr_path:
|
|
13
|
+
def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
|
|
16
14
|
"""
|
|
17
15
|
Opens the NWP data
|
|
18
16
|
|
|
@@ -22,10 +20,8 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
|
|
|
22
20
|
Returns:
|
|
23
21
|
Xarray DataArray of the NWP data
|
|
24
22
|
"""
|
|
25
|
-
# Open the data
|
|
26
23
|
ds = open_zarr_paths(zarr_path)
|
|
27
24
|
|
|
28
|
-
# Rename
|
|
29
25
|
ds = ds.rename(
|
|
30
26
|
{
|
|
31
27
|
"init_time": "init_time_utc",
|
|
@@ -35,15 +31,11 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
|
|
|
35
31
|
}
|
|
36
32
|
)
|
|
37
33
|
|
|
38
|
-
# Check the timestamps are unique and increasing
|
|
39
34
|
check_time_unique_increasing(ds.init_time_utc)
|
|
40
35
|
|
|
41
|
-
# Make sure the spatial coords are in increasing order
|
|
42
36
|
ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
|
|
43
37
|
|
|
44
38
|
ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
|
|
45
39
|
|
|
46
40
|
# TODO: should we control the dtype of the DataArray?
|
|
47
41
|
return get_xr_data_array_from_xr_dataset(ds)
|
|
48
|
-
|
|
49
|
-
|
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
1
|
import xarray as xr
|
|
3
2
|
|
|
4
3
|
|
|
5
|
-
def open_zarr_paths(
|
|
6
|
-
zarr_path: Path | str | list[Path] | list[str],
|
|
7
|
-
time_dim: str = "init_time"
|
|
8
|
-
) -> xr.Dataset:
|
|
4
|
+
def open_zarr_paths(zarr_path: str | list[str], time_dim: str = "init_time") -> xr.Dataset:
|
|
9
5
|
"""Opens the NWP data
|
|
10
6
|
|
|
11
7
|
Args:
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Satellite loader"""
|
|
2
2
|
|
|
3
3
|
import subprocess
|
|
4
|
-
from pathlib import Path
|
|
5
4
|
|
|
6
5
|
import xarray as xr
|
|
7
6
|
from ocf_data_sampler.load.utils import (
|
|
@@ -11,7 +10,7 @@ from ocf_data_sampler.load.utils import (
|
|
|
11
10
|
)
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
def _get_single_sat_data(zarr_path:
|
|
13
|
+
def _get_single_sat_data(zarr_path: str) -> xr.Dataset:
|
|
15
14
|
"""Helper function to open a Zarr from either a local or GCP path.
|
|
16
15
|
|
|
17
16
|
Args:
|
|
@@ -50,7 +49,7 @@ def _get_single_sat_data(zarr_path: Path | str) -> xr.Dataset:
|
|
|
50
49
|
return ds
|
|
51
50
|
|
|
52
51
|
|
|
53
|
-
def open_sat_data(zarr_path:
|
|
52
|
+
def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
54
53
|
"""Lazily opens the Zarr store.
|
|
55
54
|
|
|
56
55
|
Args:
|
|
@@ -69,7 +68,6 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra
|
|
|
69
68
|
else:
|
|
70
69
|
ds = _get_single_sat_data(zarr_path)
|
|
71
70
|
|
|
72
|
-
# Rename dimensions
|
|
73
71
|
ds = ds.rename(
|
|
74
72
|
{
|
|
75
73
|
"variable": "channel",
|
|
@@ -77,13 +75,11 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra
|
|
|
77
75
|
}
|
|
78
76
|
)
|
|
79
77
|
|
|
80
|
-
# Check timestamps
|
|
81
78
|
check_time_unique_increasing(ds.time_utc)
|
|
82
79
|
|
|
83
|
-
# Ensure spatial coordinates are sorted
|
|
84
80
|
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
|
|
85
|
-
|
|
81
|
+
|
|
86
82
|
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
83
|
+
|
|
87
84
|
# TODO: should we control the dtype of the DataArray?
|
|
88
|
-
|
|
89
85
|
return get_xr_data_array_from_xr_dataset(ds)
|
ocf_data_sampler/load/site.py
CHANGED
|
@@ -1,30 +1,37 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
import pandas as pd
|
|
2
3
|
import xarray as xr
|
|
3
|
-
import numpy as np
|
|
4
4
|
|
|
5
|
-
from ocf_data_sampler.config.model import Site
|
|
6
5
|
|
|
6
|
+
def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
|
|
7
|
+
"""Open a site's generation data and metadata.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
generation_file_path: Path to the site generation netcdf data
|
|
11
|
+
metadata_file_path: Path to the site csv metadata
|
|
7
12
|
|
|
8
|
-
|
|
13
|
+
Returns:
|
|
14
|
+
xr.DataArray: The opened site generation data
|
|
15
|
+
"""
|
|
9
16
|
|
|
10
|
-
|
|
11
|
-
site_generation_ds = xr.open_dataset(sites_config.file_path)
|
|
17
|
+
generation_ds = xr.open_dataset(generation_file_path)
|
|
12
18
|
|
|
13
|
-
|
|
14
|
-
|
|
19
|
+
metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
|
|
20
|
+
|
|
21
|
+
assert metadata_df.index.is_unique
|
|
15
22
|
|
|
16
23
|
# Ensure metadata aligns with the site_id dimension in data_ds
|
|
17
|
-
metadata_df = metadata_df.reindex(
|
|
24
|
+
metadata_df = metadata_df.reindex(generation_ds.site_id.values)
|
|
18
25
|
|
|
19
26
|
# Assign coordinates to the Dataset using the aligned metadata
|
|
20
|
-
|
|
27
|
+
generation_ds = generation_ds.assign_coords(
|
|
21
28
|
latitude=("site_id", metadata_df["latitude"].values),
|
|
22
29
|
longitude=("site_id", metadata_df["longitude"].values),
|
|
23
30
|
capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
|
|
24
31
|
)
|
|
25
32
|
|
|
26
33
|
# Sanity checks
|
|
27
|
-
assert np.isfinite(
|
|
28
|
-
assert (
|
|
29
|
-
|
|
30
|
-
return
|
|
34
|
+
assert np.isfinite(generation_ds.capacity_kwp.values).all()
|
|
35
|
+
assert (generation_ds.capacity_kwp.values > 0).all()
|
|
36
|
+
|
|
37
|
+
return generation_ds.generation_kw
|
|
@@ -45,11 +45,12 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
|
45
45
|
return batch
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def _key_is_constant(key: str):
|
|
48
|
+
def _key_is_constant(key: str) -> bool:
|
|
49
|
+
"""Check if a key is for value which is constant for all samples"""
|
|
49
50
|
return key.endswith("t0_idx") or key.endswith("channel_names")
|
|
50
51
|
|
|
51
52
|
|
|
52
|
-
def stack_data_list(data_list: list, key: str):
|
|
53
|
+
def stack_data_list(data_list: list, key: str) -> np.ndarray:
|
|
53
54
|
"""Stack a sequence of data elements along a new axis
|
|
54
55
|
|
|
55
56
|
Args:
|
|
@@ -57,8 +58,6 @@ def stack_data_list(data_list: list, key: str):
|
|
|
57
58
|
key: string identifying the data type
|
|
58
59
|
"""
|
|
59
60
|
if _key_is_constant(key):
|
|
60
|
-
# These are always the same for all examples.
|
|
61
61
|
return data_list[0]
|
|
62
62
|
else:
|
|
63
63
|
return np.stack(data_list)
|
|
64
|
-
|
|
@@ -2,20 +2,21 @@
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
-
from numpy.typing import NDArray
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def _get_date_time_in_pi(
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
7
|
+
def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
|
|
8
|
+
"""Create positional embeddings for the datetimes in radians
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
dt: DatetimeIndex to create radian embeddings for
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
Tuple of numpy arrays containing radian coordinates for date and time
|
|
13
15
|
"""
|
|
14
16
|
|
|
15
17
|
day_of_year = dt.dayofyear
|
|
16
18
|
minute_of_day = dt.minute + dt.hour * 60
|
|
17
19
|
|
|
18
|
-
# converting into positions on sin-cos circle
|
|
19
20
|
time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 60))
|
|
20
21
|
date_in_pi = (2 * np.pi) * (day_of_year / 365)
|
|
21
22
|
|
|
@@ -23,24 +24,15 @@ def _get_date_time_in_pi(
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
if datetimes.empty:
|
|
29
|
-
raise ValueError("Input datetimes is empty for 'make_datetime_numpy_dict' function")
|
|
30
|
-
|
|
31
|
-
time_numpy_sample = {}
|
|
27
|
+
""" Creates dictionary of cyclical datetime features - encoded """
|
|
32
28
|
|
|
33
29
|
date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
|
|
34
30
|
|
|
35
|
-
|
|
36
|
-
date_sin_batch_key = key_prefix + "_date_sin"
|
|
37
|
-
date_cos_batch_key = key_prefix + "_date_cos"
|
|
38
|
-
time_sin_batch_key = key_prefix + "_time_sin"
|
|
39
|
-
time_cos_batch_key = key_prefix + "_time_cos"
|
|
31
|
+
time_numpy_sample = {}
|
|
40
32
|
|
|
41
|
-
time_numpy_sample[
|
|
42
|
-
time_numpy_sample[
|
|
43
|
-
time_numpy_sample[
|
|
44
|
-
time_numpy_sample[
|
|
33
|
+
time_numpy_sample[key_prefix + "_date_sin"] = np.sin(date_in_pi)
|
|
34
|
+
time_numpy_sample[key_prefix + "_date_cos"] = np.cos(date_in_pi)
|
|
35
|
+
time_numpy_sample[key_prefix + "_time_sin"] = np.sin(time_in_pi)
|
|
36
|
+
time_numpy_sample[key_prefix + "_time_cos"] = np.cos(time_in_pi)
|
|
45
37
|
|
|
46
38
|
return time_numpy_sample
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
+
from ocf_data_sampler.load.utils import check_time_unique_increasing
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
|
|
@@ -28,8 +29,7 @@ def find_contiguous_time_periods(
|
|
|
28
29
|
# Sanity checks.
|
|
29
30
|
assert len(datetimes) > 0
|
|
30
31
|
assert min_seq_length > 1
|
|
31
|
-
|
|
32
|
-
assert datetimes.is_unique
|
|
32
|
+
check_time_unique_increasing(datetimes)
|
|
33
33
|
|
|
34
34
|
# Find indices of gaps larger than max_gap:
|
|
35
35
|
gap_mask = pd.TimedeltaIndex(np.diff(datetimes)) > max_gap_duration
|
|
@@ -187,7 +187,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
187
187
|
|
|
188
188
|
config = load_yaml_configuration(config_filename)
|
|
189
189
|
|
|
190
|
-
datasets_dict = get_dataset_dict(config)
|
|
190
|
+
datasets_dict = get_dataset_dict(config.input_data)
|
|
191
191
|
|
|
192
192
|
# Get t0 times where all input data is available
|
|
193
193
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
@@ -295,7 +295,7 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
295
295
|
|
|
296
296
|
config = load_yaml_configuration(config_filename)
|
|
297
297
|
|
|
298
|
-
datasets_dict = get_dataset_dict(config)
|
|
298
|
+
datasets_dict = get_dataset_dict(config.input_data)
|
|
299
299
|
|
|
300
300
|
# Get t0 times where all input data is available
|
|
301
301
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
@@ -47,7 +47,7 @@ class SitesDataset(Dataset):
|
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
50
|
-
datasets_dict = get_dataset_dict(config)
|
|
50
|
+
datasets_dict = get_dataset_dict(config.input_data)
|
|
51
51
|
|
|
52
52
|
# Assign config and input data to self
|
|
53
53
|
self.datasets_dict = datasets_dict
|
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
2
|
ocf_data_sampler/constants.py,sha256=ClteRIgp7EPlUPqIbkel83BfIaD7_VIDjUeHzUfyhnM,5079
|
|
3
3
|
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
|
-
ocf_data_sampler/config/__init__.py,sha256=
|
|
5
|
-
ocf_data_sampler/config/load.py,sha256=
|
|
6
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
7
|
-
ocf_data_sampler/config/save.py,sha256=
|
|
4
|
+
ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
|
|
5
|
+
ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=IMJhsjL_oGh2c50q8pBnCnArY4qHQcBc_M8jqlEeD0c,7129
|
|
7
|
+
ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
|
-
ocf_data_sampler/load/__init__.py,sha256=
|
|
10
|
-
ocf_data_sampler/load/gsp.py,sha256=
|
|
11
|
-
ocf_data_sampler/load/load_dataset.py,sha256=
|
|
12
|
-
ocf_data_sampler/load/satellite.py,sha256=
|
|
13
|
-
ocf_data_sampler/load/site.py,sha256=
|
|
9
|
+
ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
|
|
10
|
+
ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
|
|
11
|
+
ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
|
|
12
|
+
ocf_data_sampler/load/satellite.py,sha256=4MRJBFDHxx5WXu_6X71wEBznJTIuldEVnu9d6DVoLPI,2436
|
|
13
|
+
ocf_data_sampler/load/site.py,sha256=74M_7RYwEc1bU4idjs3ZmQrx9I8mJXm6H4lwEL-h9n0,1226
|
|
14
14
|
ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
|
|
15
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
16
|
-
ocf_data_sampler/load/nwp/nwp.py,sha256=
|
|
16
|
+
ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
|
|
17
17
|
ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=
|
|
19
|
-
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=
|
|
20
|
-
ocf_data_sampler/load/nwp/providers/utils.py,sha256=
|
|
18
|
+
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=8rYZKdV62AdczVNSOJ2G0BM4-fRFRV0_y5zkHgNYkQs,1004
|
|
19
|
+
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=dM_kvUI0xk9xEdslXqZGjOPP96PEw3qAci5mPUgUvxA,1014
|
|
20
|
+
ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s5ebRx3_6DO-33FQ,780
|
|
21
21
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
22
|
-
ocf_data_sampler/numpy_sample/collate.py,sha256=
|
|
23
|
-
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=
|
|
22
|
+
ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
|
|
23
|
+
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
|
|
24
24
|
ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
|
|
25
25
|
ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
|
|
26
26
|
ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBunbhDN4Pjo0Grxs,910
|
|
@@ -33,7 +33,7 @@ ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcpr
|
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
34
34
|
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
35
35
|
ocf_data_sampler/select/fill_time_periods.py,sha256=h0XD1Ds_wUUoy-7bILxmN8AIbjlQ6YdXRKuCk_Is5jo,460
|
|
36
|
-
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=
|
|
36
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=Nvz4gLCbbKzAe3sQXfxgExL9NtZVk1WNORvHs94DQ_k,11130
|
|
37
37
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
38
38
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
@@ -41,21 +41,22 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
|
|
|
41
41
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
42
42
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
|
|
43
43
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
45
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=xuNJyCXZ4dZ9UldX1lqOoRSRNP39Vcy0DR77Vr7dxlk,11895
|
|
45
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=ZjvJS0mWUyQE7ZcrhS1TdMHaPrEZXVbBAv2vDwBvQwA,16044
|
|
46
46
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
47
47
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
48
48
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
49
49
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
50
|
tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
|
|
51
|
-
tests/config/test_config.py,sha256=
|
|
52
|
-
tests/config/
|
|
51
|
+
tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
|
|
52
|
+
tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
|
|
53
|
+
tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
|
|
53
54
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
54
55
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
55
56
|
tests/load/test_load_satellite.py,sha256=IQ8ISRZKCEoi8IsJoPpXZJTolD0mwjnl2E7762RM_PM,524
|
|
56
|
-
tests/load/test_load_sites.py,sha256=
|
|
57
|
+
tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU,405
|
|
57
58
|
tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
|
|
58
|
-
tests/numpy_sample/test_datetime_features.py,sha256=
|
|
59
|
+
tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
|
|
59
60
|
tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
|
|
60
61
|
tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
|
|
61
62
|
tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
|
|
@@ -71,9 +72,9 @@ tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yB
|
|
|
71
72
|
tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
|
|
72
73
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
73
74
|
tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
|
|
74
|
-
tests/torch_datasets/test_site.py,sha256=
|
|
75
|
-
ocf_data_sampler-0.1.
|
|
76
|
-
ocf_data_sampler-0.1.
|
|
77
|
-
ocf_data_sampler-0.1.
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
75
|
+
tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
|
|
76
|
+
ocf_data_sampler-0.1.6.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
77
|
+
ocf_data_sampler-0.1.6.dist-info/METADATA,sha256=qltSR8dsD54ufCfXXFFYYLY_l_1saBWGaxwzZDIaJoU,12173
|
|
78
|
+
ocf_data_sampler-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
79
|
+
ocf_data_sampler-0.1.6.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
80
|
+
ocf_data_sampler-0.1.6.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -1,59 +1,13 @@
|
|
|
1
|
-
import tempfile
|
|
2
|
-
|
|
3
1
|
import pytest
|
|
4
2
|
from pydantic import ValidationError
|
|
5
|
-
from
|
|
6
|
-
from ocf_data_sampler.config import (
|
|
7
|
-
load_yaml_configuration,
|
|
8
|
-
Configuration,
|
|
9
|
-
save_yaml_configuration
|
|
10
|
-
)
|
|
3
|
+
from ocf_data_sampler.config import load_yaml_configuration, Configuration
|
|
11
4
|
|
|
12
5
|
|
|
13
6
|
def test_default_configuration():
|
|
14
7
|
"""Test default pydantic class"""
|
|
15
|
-
|
|
16
8
|
_ = Configuration()
|
|
17
9
|
|
|
18
10
|
|
|
19
|
-
def test_load_yaml_configuration(test_config_filename):
|
|
20
|
-
"""
|
|
21
|
-
Test that yaml loading works for 'test_config.yaml'
|
|
22
|
-
and fails for an empty .yaml file
|
|
23
|
-
"""
|
|
24
|
-
# Create temporary directory instead of file
|
|
25
|
-
with tempfile.TemporaryDirectory() as temp_dir:
|
|
26
|
-
# Create path for empty file
|
|
27
|
-
empty_file = Path(temp_dir) / "empty.yaml"
|
|
28
|
-
|
|
29
|
-
# Create an empty file
|
|
30
|
-
empty_file.touch()
|
|
31
|
-
|
|
32
|
-
# Test loading empty file
|
|
33
|
-
with pytest.raises(TypeError):
|
|
34
|
-
_ = load_yaml_configuration(str(empty_file))
|
|
35
|
-
|
|
36
|
-
def test_yaml_save(test_config_filename):
|
|
37
|
-
"""
|
|
38
|
-
Check configuration can be saved to a .yaml file
|
|
39
|
-
"""
|
|
40
|
-
test_config = load_yaml_configuration(test_config_filename)
|
|
41
|
-
|
|
42
|
-
with tempfile.TemporaryDirectory() as temp_dir:
|
|
43
|
-
# Create path for config file
|
|
44
|
-
config_path = Path(temp_dir) / "test_config.yaml"
|
|
45
|
-
|
|
46
|
-
# Save configuration
|
|
47
|
-
saved_path = save_yaml_configuration(test_config, config_path)
|
|
48
|
-
|
|
49
|
-
# Verify file exists
|
|
50
|
-
assert saved_path.exists()
|
|
51
|
-
|
|
52
|
-
# Test loading saved configuration
|
|
53
|
-
loaded_config = load_yaml_configuration(str(saved_path))
|
|
54
|
-
assert loaded_config == test_config
|
|
55
|
-
|
|
56
|
-
|
|
57
11
|
def test_extra_field_error():
|
|
58
12
|
"""
|
|
59
13
|
Check an extra parameters in config causes error
|
tests/config/test_save.py
CHANGED
|
@@ -1,37 +1,28 @@
|
|
|
1
1
|
"""Tests for configuration saving functionality."""
|
|
2
|
-
import
|
|
3
|
-
from
|
|
4
|
-
import tempfile
|
|
5
|
-
import yaml
|
|
2
|
+
import os
|
|
3
|
+
from ocf_data_sampler.config import Configuration, save_yaml_configuration, load_yaml_configuration
|
|
6
4
|
|
|
7
|
-
from ocf_data_sampler.config import Configuration, save_yaml_configuration
|
|
8
5
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
"""Create a temporary directory."""
|
|
12
|
-
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
13
|
-
yield Path(tmpdirname)
|
|
14
|
-
|
|
15
|
-
def test_save_yaml_configuration_basic(temp_dir):
|
|
16
|
-
"""Test basic configuration saving functionality."""
|
|
6
|
+
def test_save_yaml_configuration_basic(tmp_path):
|
|
7
|
+
"""Save an empty configuration object"""
|
|
17
8
|
config = Configuration()
|
|
18
|
-
|
|
19
|
-
|
|
9
|
+
|
|
10
|
+
filepath = f"{tmp_path}/config.yaml"
|
|
11
|
+
save_yaml_configuration(config, filepath)
|
|
20
12
|
|
|
21
|
-
assert
|
|
22
|
-
with open(filepath) as f:
|
|
23
|
-
loaded_yaml = yaml.safe_load(f)
|
|
24
|
-
assert isinstance(loaded_yaml, dict)
|
|
13
|
+
assert os.path.exists(filepath)
|
|
25
14
|
|
|
26
|
-
def test_save_yaml_configuration_none_filename():
|
|
27
|
-
"""Test that None filename raises ValueError."""
|
|
28
|
-
config = Configuration()
|
|
29
|
-
with pytest.raises(ValueError, match="filename cannot be None"):
|
|
30
|
-
save_yaml_configuration(config, None)
|
|
31
15
|
|
|
32
|
-
def
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
16
|
+
def test_save_load_yaml_configuration(tmp_path, test_config_filename):
|
|
17
|
+
"""Make sure a saved configuration is the same after loading"""
|
|
18
|
+
|
|
19
|
+
# Start with this config
|
|
20
|
+
initial_config = load_yaml_configuration(test_config_filename)
|
|
21
|
+
|
|
22
|
+
# Save it
|
|
23
|
+
filepath = f"{tmp_path}/config.yaml"
|
|
24
|
+
save_yaml_configuration(initial_config, filepath)
|
|
25
|
+
|
|
26
|
+
# Load it and check it is still the same
|
|
27
|
+
loaded_config = load_yaml_configuration(filepath)
|
|
28
|
+
assert loaded_config == initial_config
|
tests/load/test_load_sites.py
CHANGED
|
@@ -35,13 +35,3 @@ def test_make_datetime_numpy_batch_custom_key_prefix():
|
|
|
35
35
|
# Assert dict contains expected quantity of keys and verify starting with custom prefix
|
|
36
36
|
assert len(datetime_features) == 4
|
|
37
37
|
assert all(key.startswith(key_prefix) for key in datetime_features.keys())
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def test_make_datetime_numpy_batch_empty_input():
|
|
41
|
-
# Verification that function raises error for empty input
|
|
42
|
-
datetimes = pd.DatetimeIndex([])
|
|
43
|
-
|
|
44
|
-
with pytest.raises(
|
|
45
|
-
ValueError, match="Input datetimes is empty for 'make_datetime_numpy_dict' function"
|
|
46
|
-
):
|
|
47
|
-
make_datetime_numpy_dict(datetimes)
|
|
@@ -33,7 +33,7 @@ def sites_dataset(site_config_filename):
|
|
|
33
33
|
return SitesDataset(site_config_filename)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def test_site(site_config_filename):
|
|
36
|
+
def test_site(tmp_path, site_config_filename):
|
|
37
37
|
|
|
38
38
|
# Create dataset object
|
|
39
39
|
dataset = SitesDataset(site_config_filename)
|
|
@@ -71,8 +71,8 @@ def test_site(site_config_filename):
|
|
|
71
71
|
expected_data_vars = {"nwp-ukv", "satellite", "site"}
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
sample.to_netcdf("sample.nc")
|
|
75
|
-
sample = xr.open_dataset("sample.nc")
|
|
74
|
+
sample.to_netcdf(f"{tmp_path}/sample.nc")
|
|
75
|
+
sample = xr.open_dataset(f"{tmp_path}/sample.nc")
|
|
76
76
|
|
|
77
77
|
# Check dimensions
|
|
78
78
|
assert (
|
|
File without changes
|
|
File without changes
|
|
File without changes
|