ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
"""ECMWF provider loaders"""
|
|
1
|
+
"""ECMWF provider loaders."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
|
+
|
|
4
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
5
6
|
from ocf_data_sampler.load.utils import (
|
|
6
7
|
check_time_unique_increasing,
|
|
8
|
+
get_xr_data_array_from_xr_dataset,
|
|
7
9
|
make_spatial_coords_increasing,
|
|
8
|
-
get_xr_data_array_from_xr_dataset
|
|
9
10
|
)
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
13
|
-
"""
|
|
14
|
-
Opens the ECMWF IFS NWP data
|
|
14
|
+
"""Opens the ECMWF IFS NWP data.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
zarr_path: Path to the zarr to open
|
|
@@ -19,9 +19,8 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
19
19
|
Returns:
|
|
20
20
|
Xarray DataArray of the NWP data
|
|
21
21
|
"""
|
|
22
|
-
|
|
23
22
|
ds = open_zarr_paths(zarr_path)
|
|
24
|
-
|
|
23
|
+
|
|
25
24
|
# LEGACY SUPPORT - rename variable to channel if it exists
|
|
26
25
|
ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
|
|
27
26
|
|
|
@@ -30,6 +29,6 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
30
29
|
ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
|
|
31
30
|
|
|
32
31
|
ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
|
|
33
|
-
|
|
32
|
+
|
|
34
33
|
# TODO: should we control the dtype of the DataArray?
|
|
35
34
|
return get_xr_data_array_from_xr_dataset(ds)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Open GFS Forecast data."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
8
|
+
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
|
|
9
|
+
|
|
10
|
+
_log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
14
|
+
"""Opens the GFS data.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
zarr_path: Path to the zarr to open
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Xarray DataArray of the NWP data
|
|
21
|
+
"""
|
|
22
|
+
_log.info("Loading NWP GFS data")
|
|
23
|
+
|
|
24
|
+
# Open data
|
|
25
|
+
gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc")
|
|
26
|
+
nwp: xr.DataArray = gfs.to_array()
|
|
27
|
+
|
|
28
|
+
del gfs
|
|
29
|
+
|
|
30
|
+
nwp = nwp.rename({"variable": "channel","init_time": "init_time_utc"})
|
|
31
|
+
check_time_unique_increasing(nwp.init_time_utc)
|
|
32
|
+
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
|
|
33
|
+
|
|
34
|
+
nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
|
|
35
|
+
|
|
36
|
+
return nwp
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""DWD ICON Loading."""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
6
|
+
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
|
|
10
|
+
"""Removes the isobaric levels from the coordinates of the NWP data.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
nwp: NWP data
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
NWP data without isobaric levels in the coordinates
|
|
17
|
+
"""
|
|
18
|
+
variables_to_drop = [var for var in nwp.data_vars if "isobaricInhPa" in nwp[var].dims]
|
|
19
|
+
return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def open_icon_eu(zarr_path: str) -> xr.Dataset:
|
|
23
|
+
"""Opens the ICON data.
|
|
24
|
+
|
|
25
|
+
ICON EU Data is on a regular lat/lon grid
|
|
26
|
+
It has data on multiple pressure levels, as well as the surface
|
|
27
|
+
Each of the variables is its own data variable
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
zarr_path: Path to the zarr to open
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Xarray DataArray of the NWP data
|
|
34
|
+
"""
|
|
35
|
+
# Open the data
|
|
36
|
+
nwp = open_zarr_paths(zarr_path, time_dim="time")
|
|
37
|
+
nwp = nwp.rename({"time": "init_time_utc"})
|
|
38
|
+
# Sanity checks.
|
|
39
|
+
check_time_unique_increasing(nwp.init_time_utc)
|
|
40
|
+
# 0-78 one hour steps, rest 3 hour steps
|
|
41
|
+
nwp = nwp.isel(step=slice(0, 78))
|
|
42
|
+
nwp = remove_isobaric_lelvels_from_coords(nwp)
|
|
43
|
+
nwp = nwp.to_array().rename({"variable": "channel"})
|
|
44
|
+
nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
|
|
45
|
+
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
|
|
46
|
+
return nwp
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
"""UKV provider loaders"""
|
|
1
|
+
"""UKV provider loaders."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
4
|
|
|
5
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
6
6
|
from ocf_data_sampler.load.utils import (
|
|
7
7
|
check_time_unique_increasing,
|
|
8
|
+
get_xr_data_array_from_xr_dataset,
|
|
8
9
|
make_spatial_coords_increasing,
|
|
9
|
-
get_xr_data_array_from_xr_dataset
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
|
|
14
|
-
"""
|
|
15
|
-
Opens the NWP data
|
|
14
|
+
"""Opens the NWP data.
|
|
16
15
|
|
|
17
16
|
Args:
|
|
18
17
|
zarr_path: Path to the zarr to open
|
|
@@ -28,7 +27,7 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
28
27
|
"variable": "channel",
|
|
29
28
|
"x": "x_osgb",
|
|
30
29
|
"y": "y_osgb",
|
|
31
|
-
}
|
|
30
|
+
},
|
|
32
31
|
)
|
|
33
32
|
|
|
34
33
|
check_time_unique_increasing(ds.init_time_utc)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
"""Utility functions for the NWP data processing."""
|
|
2
|
+
|
|
1
3
|
import xarray as xr
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
def open_zarr_paths(zarr_path: str | list[str], time_dim: str = "init_time") -> xr.Dataset:
|
|
5
|
-
"""Opens the NWP data
|
|
7
|
+
"""Opens the NWP data.
|
|
6
8
|
|
|
7
9
|
Args:
|
|
8
10
|
zarr_path: Path to the zarr(s) to open
|
|
@@ -1,85 +1,76 @@
|
|
|
1
|
-
"""Satellite loader"""
|
|
2
|
-
|
|
3
|
-
import subprocess
|
|
1
|
+
"""Satellite loader."""
|
|
4
2
|
|
|
5
3
|
import xarray as xr
|
|
4
|
+
|
|
6
5
|
from ocf_data_sampler.load.utils import (
|
|
7
6
|
check_time_unique_increasing,
|
|
7
|
+
get_xr_data_array_from_xr_dataset,
|
|
8
8
|
make_spatial_coords_increasing,
|
|
9
|
-
get_xr_data_array_from_xr_dataset
|
|
10
9
|
)
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
def
|
|
14
|
-
"""Helper function to open a
|
|
12
|
+
def get_single_sat_data(zarr_path: str) -> xr.Dataset:
|
|
13
|
+
"""Helper function to open a zarr from either a local or GCP path.
|
|
15
14
|
|
|
16
15
|
Args:
|
|
17
|
-
zarr_path:
|
|
18
|
-
GCS paths (gs://)
|
|
16
|
+
zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
|
|
17
|
+
GCS paths (gs://) do not support wildcards
|
|
19
18
|
|
|
20
19
|
Returns:
|
|
21
|
-
An xarray Dataset containing satellite data
|
|
20
|
+
An xarray Dataset containing satellite data
|
|
22
21
|
|
|
23
22
|
Raises:
|
|
24
|
-
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
23
|
+
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
25
24
|
"""
|
|
26
|
-
|
|
27
|
-
# These kwargs are used if the path contains "*"
|
|
28
|
-
openmf_kwargs = dict(
|
|
29
|
-
engine="zarr",
|
|
30
|
-
concat_dim="time",
|
|
31
|
-
combine="nested",
|
|
32
|
-
chunks="auto",
|
|
33
|
-
join="override",
|
|
34
|
-
)
|
|
35
|
-
|
|
36
25
|
# Raise an error if a wildcard is used in a GCP path
|
|
37
26
|
if "gs://" in str(zarr_path) and "*" in str(zarr_path):
|
|
38
|
-
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs
|
|
27
|
+
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
|
|
39
28
|
|
|
40
29
|
# Handle multi-file dataset for local paths
|
|
41
30
|
if "*" in str(zarr_path):
|
|
42
|
-
ds = xr.open_mfdataset(
|
|
31
|
+
ds = xr.open_mfdataset(
|
|
32
|
+
zarr_path,
|
|
33
|
+
engine="zarr",
|
|
34
|
+
concat_dim="time",
|
|
35
|
+
combine="nested",
|
|
36
|
+
chunks="auto",
|
|
37
|
+
join="override",
|
|
38
|
+
)
|
|
39
|
+
check_time_unique_increasing(ds.time)
|
|
43
40
|
else:
|
|
44
41
|
ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
|
|
45
42
|
|
|
46
|
-
# Ensure time is unique and sorted
|
|
47
|
-
ds = ds.drop_duplicates("time").sortby("time")
|
|
48
|
-
|
|
49
43
|
return ds
|
|
50
44
|
|
|
51
45
|
|
|
52
46
|
def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
53
|
-
"""Lazily opens the
|
|
47
|
+
"""Lazily opens the zarr store.
|
|
54
48
|
|
|
55
49
|
Args:
|
|
56
|
-
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
57
|
-
|
|
50
|
+
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
51
|
+
it must start with 'gs://'
|
|
58
52
|
"""
|
|
59
|
-
|
|
60
53
|
# Open the data
|
|
61
|
-
if isinstance(zarr_path,
|
|
54
|
+
if isinstance(zarr_path, list | tuple):
|
|
62
55
|
ds = xr.combine_nested(
|
|
63
|
-
[
|
|
56
|
+
[get_single_sat_data(path) for path in zarr_path],
|
|
64
57
|
concat_dim="time",
|
|
65
58
|
combine_attrs="override",
|
|
66
59
|
join="override",
|
|
67
60
|
)
|
|
68
61
|
else:
|
|
69
|
-
ds =
|
|
62
|
+
ds = get_single_sat_data(zarr_path)
|
|
70
63
|
|
|
71
64
|
ds = ds.rename(
|
|
72
65
|
{
|
|
73
66
|
"variable": "channel",
|
|
74
67
|
"time": "time_utc",
|
|
75
|
-
}
|
|
68
|
+
},
|
|
76
69
|
)
|
|
77
70
|
|
|
78
71
|
check_time_unique_increasing(ds.time_utc)
|
|
79
|
-
|
|
80
72
|
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
|
|
81
|
-
|
|
82
73
|
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
83
|
-
|
|
74
|
+
|
|
84
75
|
# TODO: should we control the dtype of the DataArray?
|
|
85
76
|
return get_xr_data_array_from_xr_dataset(ds)
|
ocf_data_sampler/load/site.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Funcitons for loading site data."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
5
|
import xarray as xr
|
|
@@ -5,7 +7,7 @@ import xarray as xr
|
|
|
5
7
|
|
|
6
8
|
def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
|
|
7
9
|
"""Open a site's generation data and metadata.
|
|
8
|
-
|
|
10
|
+
|
|
9
11
|
Args:
|
|
10
12
|
generation_file_path: Path to the site generation netcdf data
|
|
11
13
|
metadata_file_path: Path to the site csv metadata
|
|
@@ -13,14 +15,14 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
|
|
|
13
15
|
Returns:
|
|
14
16
|
xr.DataArray: The opened site generation data
|
|
15
17
|
"""
|
|
16
|
-
|
|
17
18
|
generation_ds = xr.open_dataset(generation_file_path)
|
|
18
19
|
|
|
19
20
|
metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
if not metadata_df.index.is_unique:
|
|
23
|
+
raise ValueError("site_id is not unique in metadata")
|
|
22
24
|
|
|
23
|
-
# Ensure metadata aligns with the site_id dimension in
|
|
25
|
+
# Ensure metadata aligns with the site_id dimension in generation_ds
|
|
24
26
|
metadata_df = metadata_df.reindex(generation_ds.site_id.values)
|
|
25
27
|
|
|
26
28
|
# Assign coordinates to the Dataset using the aligned metadata
|
|
@@ -31,7 +33,9 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
|
|
|
31
33
|
)
|
|
32
34
|
|
|
33
35
|
# Sanity checks
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
36
|
+
if not np.isfinite(generation_ds.generation_kw.values).all():
|
|
37
|
+
raise ValueError("generation_kw contains non-finite values")
|
|
38
|
+
if not (generation_ds.capacity_kwp.values > 0).all():
|
|
39
|
+
raise ValueError("capacity_kwp contains non-positive values")
|
|
40
|
+
|
|
37
41
|
return generation_ds.generation_kw
|
ocf_data_sampler/load/utils.py
CHANGED
|
@@ -1,43 +1,48 @@
|
|
|
1
|
-
|
|
1
|
+
"""Utility functions for working with xarray objects."""
|
|
2
|
+
|
|
2
3
|
import pandas as pd
|
|
4
|
+
import xarray as xr
|
|
5
|
+
|
|
3
6
|
|
|
4
|
-
def check_time_unique_increasing(datetimes) -> None:
|
|
5
|
-
"""Check that the time dimension is unique and increasing"""
|
|
7
|
+
def check_time_unique_increasing(datetimes: xr.DataArray) -> None:
|
|
8
|
+
"""Check that the time dimension is unique and increasing."""
|
|
6
9
|
time = pd.DatetimeIndex(datetimes)
|
|
7
|
-
|
|
8
|
-
|
|
10
|
+
if not (time.is_unique and time.is_monotonic_increasing):
|
|
11
|
+
raise ValueError("Time dimension must be unique and monotonically increasing")
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
def make_spatial_coords_increasing(ds: xr.Dataset, x_coord: str, y_coord: str) -> xr.Dataset:
|
|
12
|
-
"""Make sure the spatial coordinates are in increasing order
|
|
13
|
-
|
|
15
|
+
"""Make sure the spatial coordinates are in increasing order.
|
|
16
|
+
|
|
14
17
|
Args:
|
|
15
18
|
ds: Xarray Dataset
|
|
16
19
|
x_coord: Name of the x coordinate
|
|
17
20
|
y_coord: Name of the y coordinate
|
|
18
21
|
"""
|
|
19
|
-
|
|
20
22
|
# Make sure the coords are in increasing order
|
|
21
23
|
if ds[x_coord][0] > ds[x_coord][-1]:
|
|
22
|
-
ds = ds.isel({x_coord:slice(None, None, -1)})
|
|
24
|
+
ds = ds.isel({x_coord: slice(None, None, -1)})
|
|
23
25
|
if ds[y_coord][0] > ds[y_coord][-1]:
|
|
24
|
-
|
|
26
|
+
ds = ds.isel({y_coord: slice(None, None, -1)})
|
|
25
27
|
|
|
26
28
|
# Check the coords are all increasing now
|
|
27
|
-
|
|
28
|
-
|
|
29
|
+
if not (ds[x_coord].diff(dim=x_coord) > 0).all():
|
|
30
|
+
raise ValueError(f"'{x_coord}' coordinate must be increasing")
|
|
31
|
+
if not (ds[y_coord].diff(dim=y_coord) > 0).all():
|
|
32
|
+
raise ValueError(f"'{y_coord}' coordinate must be increasing")
|
|
29
33
|
|
|
30
34
|
return ds
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
def get_xr_data_array_from_xr_dataset(ds: xr.Dataset) -> xr.DataArray:
|
|
34
|
-
"""Return underlying xr.DataArray from passed xr.Dataset.
|
|
38
|
+
"""Return underlying xr.DataArray from passed xr.Dataset.
|
|
39
|
+
|
|
35
40
|
Checks only one variable is present and returns it as an xr.DataArray.
|
|
36
41
|
|
|
37
42
|
Args:
|
|
38
43
|
ds: xr.Dataset to extract xr.DataArray from
|
|
39
44
|
"""
|
|
40
|
-
|
|
41
45
|
datavars = list(ds.var())
|
|
42
|
-
|
|
43
|
-
|
|
46
|
+
if len(datavars) != 1:
|
|
47
|
+
raise ValueError("Cannot open as xr.DataArray: dataset contains multiple variables")
|
|
48
|
+
return ds[datavars[0]]
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
"""Functions for collating samples into batches."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
5
|
-
"""Stacks list of dict samples into a dict where all samples are joined along a new axis
|
|
7
|
+
"""Stacks list of dict samples into a dict where all samples are joined along a new axis.
|
|
6
8
|
|
|
7
9
|
Args:
|
|
8
10
|
dict_list: A list of dict-like samples to stack
|
|
@@ -10,7 +12,6 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
|
10
12
|
Returns:
|
|
11
13
|
Dict of the samples stacked with new batch dimension on axis 0
|
|
12
14
|
"""
|
|
13
|
-
|
|
14
15
|
batch = {}
|
|
15
16
|
|
|
16
17
|
keys = list(dict_list[0].keys())
|
|
@@ -26,10 +27,10 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
|
26
27
|
for nwp_provider in nwp_providers:
|
|
27
28
|
# Keys can be different for different NWPs
|
|
28
29
|
nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
|
|
29
|
-
|
|
30
|
+
|
|
30
31
|
# Create dict to store NWP batch for this provider
|
|
31
32
|
nwp_provider_batch = {}
|
|
32
|
-
|
|
33
|
+
|
|
33
34
|
for nwp_key in nwp_keys:
|
|
34
35
|
# Stack values under each NWP key for this provider
|
|
35
36
|
nwp_provider_batch[nwp_key] = stack_data_list(
|
|
@@ -46,16 +47,16 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
def _key_is_constant(key: str) -> bool:
|
|
49
|
-
"""Check if a key is for value which is constant for all samples"""
|
|
50
|
+
"""Check if a key is for value which is constant for all samples."""
|
|
50
51
|
return key.endswith("t0_idx") or key.endswith("channel_names")
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
def stack_data_list(data_list: list, key: str) -> np.ndarray:
|
|
54
|
-
"""Stack a sequence of data elements along a new axis
|
|
55
|
+
"""Stack a sequence of data elements along a new axis.
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
Args:
|
|
58
|
+
data_list: List of data elements to combine
|
|
59
|
+
key: string identifying the data type
|
|
59
60
|
"""
|
|
60
61
|
if _key_is_constant(key):
|
|
61
62
|
return data_list[0]
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
"""Functions to create trigonometric date and time inputs"""
|
|
1
|
+
"""Functions to create trigonometric date and time inputs."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
|
|
8
|
-
"""Create positional embeddings for the datetimes in radians
|
|
8
|
+
"""Create positional embeddings for the datetimes in radians.
|
|
9
9
|
|
|
10
10
|
Args:
|
|
11
11
|
dt: DatetimeIndex to create radian embeddings for
|
|
@@ -13,7 +13,6 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
|
|
|
13
13
|
Returns:
|
|
14
14
|
Tuple of numpy arrays containing radian coordinates for date and time
|
|
15
15
|
"""
|
|
16
|
-
|
|
17
16
|
day_of_year = dt.dayofyear
|
|
18
17
|
minute_of_day = dt.minute + dt.hour * 60
|
|
19
18
|
|
|
@@ -24,8 +23,7 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
|
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
|
|
27
|
-
"""
|
|
28
|
-
|
|
26
|
+
"""Creates dictionary of cyclical datetime features - encoded."""
|
|
29
27
|
date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
|
|
30
28
|
|
|
31
29
|
time_numpy_sample = {}
|
|
@@ -1,26 +1,28 @@
|
|
|
1
|
-
"""Convert GSP to Numpy Sample"""
|
|
1
|
+
"""Convert GSP to Numpy Sample."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class GSPSampleKey:
|
|
7
|
+
"""Keys for the GSP sample dictionary."""
|
|
7
8
|
|
|
8
|
-
gsp =
|
|
9
|
-
nominal_capacity_mwp =
|
|
10
|
-
effective_capacity_mwp =
|
|
11
|
-
time_utc =
|
|
12
|
-
t0_idx =
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
x_osgb = 'gsp_x_osgb'
|
|
17
|
-
y_osgb = 'gsp_y_osgb'
|
|
9
|
+
gsp = "gsp"
|
|
10
|
+
nominal_capacity_mwp = "gsp_nominal_capacity_mwp"
|
|
11
|
+
effective_capacity_mwp = "gsp_effective_capacity_mwp"
|
|
12
|
+
time_utc = "gsp_time_utc"
|
|
13
|
+
t0_idx = "gsp_t0_idx"
|
|
14
|
+
gsp_id = "gsp_id"
|
|
15
|
+
x_osgb = "gsp_x_osgb"
|
|
16
|
+
y_osgb = "gsp_y_osgb"
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NumpySample
|
|
20
|
+
"""Convert from Xarray to NumpySample.
|
|
22
21
|
|
|
23
|
-
|
|
22
|
+
Args:
|
|
23
|
+
da: Xarray DataArray containing GSP data
|
|
24
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
|
|
25
|
+
"""
|
|
24
26
|
sample = {
|
|
25
27
|
GSPSampleKey.gsp: da.values,
|
|
26
28
|
GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
|
|
@@ -1,42 +1,36 @@
|
|
|
1
|
-
"""Convert NWP to NumpySample"""
|
|
1
|
+
"""Convert NWP to NumpySample."""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
import xarray as xr
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class NWPSampleKey:
|
|
8
|
+
"""Keys for NWP NumpySample."""
|
|
8
9
|
|
|
9
|
-
nwp =
|
|
10
|
-
channel_names =
|
|
11
|
-
init_time_utc =
|
|
12
|
-
step =
|
|
13
|
-
target_time_utc =
|
|
14
|
-
t0_idx =
|
|
15
|
-
y_osgb = 'nwp_y_osgb'
|
|
16
|
-
x_osgb = 'nwp_x_osgb'
|
|
17
|
-
|
|
10
|
+
nwp = "nwp"
|
|
11
|
+
channel_names = "nwp_channel_names"
|
|
12
|
+
init_time_utc = "nwp_init_time_utc"
|
|
13
|
+
step = "nwp_step"
|
|
14
|
+
target_time_utc = "nwp_target_time_utc"
|
|
15
|
+
t0_idx = "nwp_t0_idx"
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NWP NumpySample
|
|
22
|
-
|
|
23
|
-
|
|
19
|
+
"""Convert from Xarray to NWP NumpySample.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
da: Xarray DataArray containing NWP data
|
|
23
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the NWP
|
|
24
|
+
"""
|
|
24
25
|
sample = {
|
|
25
26
|
NWPSampleKey.nwp: da.values,
|
|
26
27
|
NWPSampleKey.channel_names: da.channel.values,
|
|
27
28
|
NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
|
|
28
29
|
NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
|
|
30
|
+
NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
|
|
29
31
|
}
|
|
30
32
|
|
|
31
|
-
if "target_time_utc" in da.coords:
|
|
32
|
-
sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
|
|
33
|
-
|
|
34
|
-
# TODO: Do we need this at all? Especially since it is only present in UKV data
|
|
35
|
-
for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
|
|
36
|
-
if dataset_key in da.coords:
|
|
37
|
-
sample[sample_key] = da[dataset_key].values
|
|
38
|
-
|
|
39
33
|
if t0_idx is not None:
|
|
40
34
|
sample[NWPSampleKey.t0_idx] = t0_idx
|
|
41
|
-
|
|
42
|
-
return sample
|
|
35
|
+
|
|
36
|
+
return sample
|
|
@@ -1,30 +1,33 @@
|
|
|
1
|
-
"""Convert Satellite to NumpySample"""
|
|
1
|
+
"""Convert Satellite to NumpySample."""
|
|
2
|
+
|
|
2
3
|
import xarray as xr
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class SatelliteSampleKey:
|
|
7
|
+
"""Keys for the SatelliteSample dictionary."""
|
|
6
8
|
|
|
7
|
-
satellite_actual =
|
|
8
|
-
time_utc =
|
|
9
|
-
x_geostationary =
|
|
10
|
-
y_geostationary =
|
|
11
|
-
t0_idx =
|
|
9
|
+
satellite_actual = "satellite_actual"
|
|
10
|
+
time_utc = "satellite_time_utc"
|
|
11
|
+
x_geostationary = "satellite_x_geostationary"
|
|
12
|
+
y_geostationary = "satellite_y_geostationary"
|
|
13
|
+
t0_idx = "satellite_t0_idx"
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
15
|
-
"""Convert from Xarray to NumpySample
|
|
17
|
+
"""Convert from Xarray to NumpySample.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
da: xarray DataArray containing satellite data
|
|
21
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
|
|
22
|
+
"""
|
|
16
23
|
sample = {
|
|
17
24
|
SatelliteSampleKey.satellite_actual: da.values,
|
|
18
25
|
SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
|
|
26
|
+
SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
|
|
27
|
+
SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
|
|
19
28
|
}
|
|
20
29
|
|
|
21
|
-
for sample_key, dataset_key in (
|
|
22
|
-
(SatelliteSampleKey.x_geostationary, "x_geostationary"),
|
|
23
|
-
(SatelliteSampleKey.y_geostationary, "y_geostationary"),
|
|
24
|
-
):
|
|
25
|
-
sample[sample_key] = da[dataset_key].values
|
|
26
|
-
|
|
27
30
|
if t0_idx is not None:
|
|
28
31
|
sample[SatelliteSampleKey.t0_idx] = t0_idx
|
|
29
32
|
|
|
30
|
-
return sample
|
|
33
|
+
return sample
|