ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,17 +1,17 @@
1
- """ECMWF provider loaders"""
1
+ """ECMWF provider loaders."""
2
2
 
3
3
  import xarray as xr
4
+
4
5
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
5
6
  from ocf_data_sampler.load.utils import (
6
7
  check_time_unique_increasing,
8
+ get_xr_data_array_from_xr_dataset,
7
9
  make_spatial_coords_increasing,
8
- get_xr_data_array_from_xr_dataset
9
10
  )
10
11
 
11
12
 
12
13
  def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
13
- """
14
- Opens the ECMWF IFS NWP data
14
+ """Opens the ECMWF IFS NWP data.
15
15
 
16
16
  Args:
17
17
  zarr_path: Path to the zarr to open
@@ -19,9 +19,8 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
19
19
  Returns:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
-
23
22
  ds = open_zarr_paths(zarr_path)
24
-
23
+
25
24
  # LEGACY SUPPORT - rename variable to channel if it exists
26
25
  ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
27
26
 
@@ -30,6 +29,6 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
30
29
  ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
31
30
 
32
31
  ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
33
-
32
+
34
33
  # TODO: should we control the dtype of the DataArray?
35
34
  return get_xr_data_array_from_xr_dataset(ds)
@@ -0,0 +1,36 @@
1
+ """Open GFS Forecast data."""
2
+
3
+ import logging
4
+
5
+ import xarray as xr
6
+
7
+ from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
8
+ from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
9
+
10
+ _log = logging.getLogger(__name__)
11
+
12
+
13
+ def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
14
+ """Opens the GFS data.
15
+
16
+ Args:
17
+ zarr_path: Path to the zarr to open
18
+
19
+ Returns:
20
+ Xarray DataArray of the NWP data
21
+ """
22
+ _log.info("Loading NWP GFS data")
23
+
24
+ # Open data
25
+ gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc")
26
+ nwp: xr.DataArray = gfs.to_array()
27
+
28
+ del gfs
29
+
30
+ nwp = nwp.rename({"variable": "channel","init_time": "init_time_utc"})
31
+ check_time_unique_increasing(nwp.init_time_utc)
32
+ nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
33
+
34
+ nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
35
+
36
+ return nwp
@@ -0,0 +1,46 @@
1
+ """DWD ICON Loading."""
2
+
3
+ import xarray as xr
4
+
5
+ from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
6
+ from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
7
+
8
+
9
+ def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
10
+ """Removes the isobaric levels from the coordinates of the NWP data.
11
+
12
+ Args:
13
+ nwp: NWP data
14
+
15
+ Returns:
16
+ NWP data without isobaric levels in the coordinates
17
+ """
18
+ variables_to_drop = [var for var in nwp.data_vars if "isobaricInhPa" in nwp[var].dims]
19
+ return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
20
+
21
+
22
+ def open_icon_eu(zarr_path: str) -> xr.Dataset:
23
+ """Opens the ICON data.
24
+
25
+ ICON EU Data is on a regular lat/lon grid
26
+ It has data on multiple pressure levels, as well as the surface
27
+ Each of the variables is its own data variable
28
+
29
+ Args:
30
+ zarr_path: Path to the zarr to open
31
+
32
+ Returns:
33
+ Xarray DataArray of the NWP data
34
+ """
35
+ # Open the data
36
+ nwp = open_zarr_paths(zarr_path, time_dim="time")
37
+ nwp = nwp.rename({"time": "init_time_utc"})
38
+ # Sanity checks.
39
+ check_time_unique_increasing(nwp.init_time_utc)
40
+ # 0-78 one hour steps, rest 3 hour steps
41
+ nwp = nwp.isel(step=slice(0, 78))
42
+ nwp = remove_isobaric_lelvels_from_coords(nwp)
43
+ nwp = nwp.to_array().rename({"variable": "channel"})
44
+ nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
45
+ nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
46
+ return nwp
@@ -1,18 +1,17 @@
1
- """UKV provider loaders"""
1
+ """UKV provider loaders."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
6
6
  from ocf_data_sampler.load.utils import (
7
7
  check_time_unique_increasing,
8
+ get_xr_data_array_from_xr_dataset,
8
9
  make_spatial_coords_increasing,
9
- get_xr_data_array_from_xr_dataset
10
10
  )
11
11
 
12
12
 
13
13
  def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
14
- """
15
- Opens the NWP data
14
+ """Opens the NWP data.
16
15
 
17
16
  Args:
18
17
  zarr_path: Path to the zarr to open
@@ -28,7 +27,7 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
28
27
  "variable": "channel",
29
28
  "x": "x_osgb",
30
29
  "y": "y_osgb",
31
- }
30
+ },
32
31
  )
33
32
 
34
33
  check_time_unique_increasing(ds.init_time_utc)
@@ -1,8 +1,10 @@
1
+ """Utility functions for the NWP data processing."""
2
+
1
3
  import xarray as xr
2
4
 
3
5
 
4
6
  def open_zarr_paths(zarr_path: str | list[str], time_dim: str = "init_time") -> xr.Dataset:
5
- """Opens the NWP data
7
+ """Opens the NWP data.
6
8
 
7
9
  Args:
8
10
  zarr_path: Path to the zarr(s) to open
@@ -1,85 +1,76 @@
1
- """Satellite loader"""
2
-
3
- import subprocess
1
+ """Satellite loader."""
4
2
 
5
3
  import xarray as xr
4
+
6
5
  from ocf_data_sampler.load.utils import (
7
6
  check_time_unique_increasing,
7
+ get_xr_data_array_from_xr_dataset,
8
8
  make_spatial_coords_increasing,
9
- get_xr_data_array_from_xr_dataset
10
9
  )
11
10
 
12
11
 
13
- def _get_single_sat_data(zarr_path: str) -> xr.Dataset:
14
- """Helper function to open a Zarr from either a local or GCP path.
12
+ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
13
+ """Helper function to open a zarr from either a local or GCP path.
15
14
 
16
15
  Args:
17
- zarr_path: Path to a Zarr file. Wildcards (*) are supported **only** for local paths.
18
- GCS paths (gs://) **do not support** wildcards.
16
+ zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
17
+ GCS paths (gs://) do not support wildcards
19
18
 
20
19
  Returns:
21
- An xarray Dataset containing satellite data.
20
+ An xarray Dataset containing satellite data
22
21
 
23
22
  Raises:
24
- ValueError: If a wildcard (*) is used in a GCS (gs://) path.
23
+ ValueError: If a wildcard (*) is used in a GCS (gs://) path
25
24
  """
26
-
27
- # These kwargs are used if the path contains "*"
28
- openmf_kwargs = dict(
29
- engine="zarr",
30
- concat_dim="time",
31
- combine="nested",
32
- chunks="auto",
33
- join="override",
34
- )
35
-
36
25
  # Raise an error if a wildcard is used in a GCP path
37
26
  if "gs://" in str(zarr_path) and "*" in str(zarr_path):
38
- raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs.")
27
+ raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
39
28
 
40
29
  # Handle multi-file dataset for local paths
41
30
  if "*" in str(zarr_path):
42
- ds = xr.open_mfdataset(zarr_path, **openmf_kwargs)
31
+ ds = xr.open_mfdataset(
32
+ zarr_path,
33
+ engine="zarr",
34
+ concat_dim="time",
35
+ combine="nested",
36
+ chunks="auto",
37
+ join="override",
38
+ )
39
+ check_time_unique_increasing(ds.time)
43
40
  else:
44
41
  ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
45
42
 
46
- # Ensure time is unique and sorted
47
- ds = ds.drop_duplicates("time").sortby("time")
48
-
49
43
  return ds
50
44
 
51
45
 
52
46
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
53
- """Lazily opens the Zarr store.
47
+ """Lazily opens the zarr store.
54
48
 
55
49
  Args:
56
- zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL, it must start with
57
- 'gs://'.
50
+ zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
51
+ it must start with 'gs://'
58
52
  """
59
-
60
53
  # Open the data
61
- if isinstance(zarr_path, (list, tuple)):
54
+ if isinstance(zarr_path, list | tuple):
62
55
  ds = xr.combine_nested(
63
- [_get_single_sat_data(path) for path in zarr_path],
56
+ [get_single_sat_data(path) for path in zarr_path],
64
57
  concat_dim="time",
65
58
  combine_attrs="override",
66
59
  join="override",
67
60
  )
68
61
  else:
69
- ds = _get_single_sat_data(zarr_path)
62
+ ds = get_single_sat_data(zarr_path)
70
63
 
71
64
  ds = ds.rename(
72
65
  {
73
66
  "variable": "channel",
74
67
  "time": "time_utc",
75
- }
68
+ },
76
69
  )
77
70
 
78
71
  check_time_unique_increasing(ds.time_utc)
79
-
80
72
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
81
-
82
73
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
83
-
74
+
84
75
  # TODO: should we control the dtype of the DataArray?
85
76
  return get_xr_data_array_from_xr_dataset(ds)
@@ -1,3 +1,5 @@
1
+ """Funcitons for loading site data."""
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import xarray as xr
@@ -5,7 +7,7 @@ import xarray as xr
5
7
 
6
8
  def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
7
9
  """Open a site's generation data and metadata.
8
-
10
+
9
11
  Args:
10
12
  generation_file_path: Path to the site generation netcdf data
11
13
  metadata_file_path: Path to the site csv metadata
@@ -13,14 +15,14 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
13
15
  Returns:
14
16
  xr.DataArray: The opened site generation data
15
17
  """
16
-
17
18
  generation_ds = xr.open_dataset(generation_file_path)
18
19
 
19
20
  metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
20
21
 
21
- assert metadata_df.index.is_unique
22
+ if not metadata_df.index.is_unique:
23
+ raise ValueError("site_id is not unique in metadata")
22
24
 
23
- # Ensure metadata aligns with the site_id dimension in data_ds
25
+ # Ensure metadata aligns with the site_id dimension in generation_ds
24
26
  metadata_df = metadata_df.reindex(generation_ds.site_id.values)
25
27
 
26
28
  # Assign coordinates to the Dataset using the aligned metadata
@@ -31,7 +33,9 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
31
33
  )
32
34
 
33
35
  # Sanity checks
34
- assert np.isfinite(generation_ds.capacity_kwp.values).all()
35
- assert (generation_ds.capacity_kwp.values > 0).all()
36
-
36
+ if not np.isfinite(generation_ds.generation_kw.values).all():
37
+ raise ValueError("generation_kw contains non-finite values")
38
+ if not (generation_ds.capacity_kwp.values > 0).all():
39
+ raise ValueError("capacity_kwp contains non-positive values")
40
+
37
41
  return generation_ds.generation_kw
@@ -1,43 +1,48 @@
1
- import xarray as xr
1
+ """Utility functions for working with xarray objects."""
2
+
2
3
  import pandas as pd
4
+ import xarray as xr
5
+
3
6
 
4
- def check_time_unique_increasing(datetimes) -> None:
5
- """Check that the time dimension is unique and increasing"""
7
+ def check_time_unique_increasing(datetimes: xr.DataArray) -> None:
8
+ """Check that the time dimension is unique and increasing."""
6
9
  time = pd.DatetimeIndex(datetimes)
7
- assert time.is_unique
8
- assert time.is_monotonic_increasing
10
+ if not (time.is_unique and time.is_monotonic_increasing):
11
+ raise ValueError("Time dimension must be unique and monotonically increasing")
9
12
 
10
13
 
11
14
  def make_spatial_coords_increasing(ds: xr.Dataset, x_coord: str, y_coord: str) -> xr.Dataset:
12
- """Make sure the spatial coordinates are in increasing order
13
-
15
+ """Make sure the spatial coordinates are in increasing order.
16
+
14
17
  Args:
15
18
  ds: Xarray Dataset
16
19
  x_coord: Name of the x coordinate
17
20
  y_coord: Name of the y coordinate
18
21
  """
19
-
20
22
  # Make sure the coords are in increasing order
21
23
  if ds[x_coord][0] > ds[x_coord][-1]:
22
- ds = ds.isel({x_coord:slice(None, None, -1)})
24
+ ds = ds.isel({x_coord: slice(None, None, -1)})
23
25
  if ds[y_coord][0] > ds[y_coord][-1]:
24
- ds = ds.isel({y_coord:slice(None, None, -1)})
26
+ ds = ds.isel({y_coord: slice(None, None, -1)})
25
27
 
26
28
  # Check the coords are all increasing now
27
- assert (ds[x_coord].diff(dim=x_coord) > 0).all()
28
- assert (ds[y_coord].diff(dim=y_coord) > 0).all()
29
+ if not (ds[x_coord].diff(dim=x_coord) > 0).all():
30
+ raise ValueError(f"'{x_coord}' coordinate must be increasing")
31
+ if not (ds[y_coord].diff(dim=y_coord) > 0).all():
32
+ raise ValueError(f"'{y_coord}' coordinate must be increasing")
29
33
 
30
34
  return ds
31
35
 
32
36
 
33
37
  def get_xr_data_array_from_xr_dataset(ds: xr.Dataset) -> xr.DataArray:
34
- """Return underlying xr.DataArray from passed xr.Dataset.
38
+ """Return underlying xr.DataArray from passed xr.Dataset.
39
+
35
40
  Checks only one variable is present and returns it as an xr.DataArray.
36
41
 
37
42
  Args:
38
43
  ds: xr.Dataset to extract xr.DataArray from
39
44
  """
40
-
41
45
  datavars = list(ds.var())
42
- assert len(datavars) == 1, "Cannot open as xr.DataArray: dataset contains multiple variables"
43
- return ds[datavars[0]]
46
+ if len(datavars) != 1:
47
+ raise ValueError("Cannot open as xr.DataArray: dataset contains multiple variables")
48
+ return ds[datavars[0]]
@@ -1,8 +1,10 @@
1
+ """Functions for collating samples into batches."""
2
+
1
3
  import numpy as np
2
4
 
3
5
 
4
6
  def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
5
- """Stacks list of dict samples into a dict where all samples are joined along a new axis
7
+ """Stacks list of dict samples into a dict where all samples are joined along a new axis.
6
8
 
7
9
  Args:
8
10
  dict_list: A list of dict-like samples to stack
@@ -10,7 +12,6 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
10
12
  Returns:
11
13
  Dict of the samples stacked with new batch dimension on axis 0
12
14
  """
13
-
14
15
  batch = {}
15
16
 
16
17
  keys = list(dict_list[0].keys())
@@ -26,10 +27,10 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
26
27
  for nwp_provider in nwp_providers:
27
28
  # Keys can be different for different NWPs
28
29
  nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
29
-
30
+
30
31
  # Create dict to store NWP batch for this provider
31
32
  nwp_provider_batch = {}
32
-
33
+
33
34
  for nwp_key in nwp_keys:
34
35
  # Stack values under each NWP key for this provider
35
36
  nwp_provider_batch[nwp_key] = stack_data_list(
@@ -46,16 +47,16 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
46
47
 
47
48
 
48
49
  def _key_is_constant(key: str) -> bool:
49
- """Check if a key is for value which is constant for all samples"""
50
+ """Check if a key is for value which is constant for all samples."""
50
51
  return key.endswith("t0_idx") or key.endswith("channel_names")
51
52
 
52
53
 
53
54
  def stack_data_list(data_list: list, key: str) -> np.ndarray:
54
- """Stack a sequence of data elements along a new axis
55
+ """Stack a sequence of data elements along a new axis.
55
56
 
56
- Args:
57
- data_list: List of data elements to combine
58
- key: string identifying the data type
57
+ Args:
58
+ data_list: List of data elements to combine
59
+ key: string identifying the data type
59
60
  """
60
61
  if _key_is_constant(key):
61
62
  return data_list[0]
@@ -1,11 +1,11 @@
1
- """Functions to create trigonometric date and time inputs"""
1
+ """Functions to create trigonometric date and time inputs."""
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
5
 
6
6
 
7
7
  def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
8
- """Create positional embeddings for the datetimes in radians
8
+ """Create positional embeddings for the datetimes in radians.
9
9
 
10
10
  Args:
11
11
  dt: DatetimeIndex to create radian embeddings for
@@ -13,7 +13,6 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
13
13
  Returns:
14
14
  Tuple of numpy arrays containing radian coordinates for date and time
15
15
  """
16
-
17
16
  day_of_year = dt.dayofyear
18
17
  minute_of_day = dt.minute + dt.hour * 60
19
18
 
@@ -24,8 +23,7 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
24
23
 
25
24
 
26
25
  def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
27
- """ Creates dictionary of cyclical datetime features - encoded """
28
-
26
+ """Creates dictionary of cyclical datetime features - encoded."""
29
27
  date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
30
28
 
31
29
  time_numpy_sample = {}
@@ -1,26 +1,28 @@
1
- """Convert GSP to Numpy Sample"""
1
+ """Convert GSP to Numpy Sample."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  class GSPSampleKey:
7
+ """Keys for the GSP sample dictionary."""
7
8
 
8
- gsp = 'gsp'
9
- nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
10
- effective_capacity_mwp = 'gsp_effective_capacity_mwp'
11
- time_utc = 'gsp_time_utc'
12
- t0_idx = 'gsp_t0_idx'
13
- solar_azimuth = 'gsp_solar_azimuth'
14
- solar_elevation = 'gsp_solar_elevation'
15
- gsp_id = 'gsp_id'
16
- x_osgb = 'gsp_x_osgb'
17
- y_osgb = 'gsp_y_osgb'
9
+ gsp = "gsp"
10
+ nominal_capacity_mwp = "gsp_nominal_capacity_mwp"
11
+ effective_capacity_mwp = "gsp_effective_capacity_mwp"
12
+ time_utc = "gsp_time_utc"
13
+ t0_idx = "gsp_t0_idx"
14
+ gsp_id = "gsp_id"
15
+ x_osgb = "gsp_x_osgb"
16
+ y_osgb = "gsp_y_osgb"
18
17
 
19
18
 
20
19
  def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
20
+ """Convert from Xarray to NumpySample.
22
21
 
23
- # Extract values from the DataArray
22
+ Args:
23
+ da: Xarray DataArray containing GSP data
24
+ t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
25
+ """
24
26
  sample = {
25
27
  GSPSampleKey.gsp: da.values,
26
28
  GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
@@ -1,42 +1,36 @@
1
- """Convert NWP to NumpySample"""
1
+ """Convert NWP to NumpySample."""
2
2
 
3
3
  import pandas as pd
4
4
  import xarray as xr
5
5
 
6
6
 
7
7
  class NWPSampleKey:
8
+ """Keys for NWP NumpySample."""
8
9
 
9
- nwp = 'nwp'
10
- channel_names = 'nwp_channel_names'
11
- init_time_utc = 'nwp_init_time_utc'
12
- step = 'nwp_step'
13
- target_time_utc = 'nwp_target_time_utc'
14
- t0_idx = 'nwp_t0_idx'
15
- y_osgb = 'nwp_y_osgb'
16
- x_osgb = 'nwp_x_osgb'
17
-
10
+ nwp = "nwp"
11
+ channel_names = "nwp_channel_names"
12
+ init_time_utc = "nwp_init_time_utc"
13
+ step = "nwp_step"
14
+ target_time_utc = "nwp_target_time_utc"
15
+ t0_idx = "nwp_t0_idx"
18
16
 
19
17
 
20
18
  def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NWP NumpySample"""
22
-
23
- # Create example and add t if available
19
+ """Convert from Xarray to NWP NumpySample.
20
+
21
+ Args:
22
+ da: Xarray DataArray containing NWP data
23
+ t0_idx: Index of the t0 timestamp in the time dimension of the NWP
24
+ """
24
25
  sample = {
25
26
  NWPSampleKey.nwp: da.values,
26
27
  NWPSampleKey.channel_names: da.channel.values,
27
28
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
28
29
  NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
30
+ NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
29
31
  }
30
32
 
31
- if "target_time_utc" in da.coords:
32
- sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
33
-
34
- # TODO: Do we need this at all? Especially since it is only present in UKV data
35
- for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
36
- if dataset_key in da.coords:
37
- sample[sample_key] = da[dataset_key].values
38
-
39
33
  if t0_idx is not None:
40
34
  sample[NWPSampleKey.t0_idx] = t0_idx
41
-
42
- return sample
35
+
36
+ return sample
@@ -1,30 +1,33 @@
1
- """Convert Satellite to NumpySample"""
1
+ """Convert Satellite to NumpySample."""
2
+
2
3
  import xarray as xr
3
4
 
4
5
 
5
6
  class SatelliteSampleKey:
7
+ """Keys for the SatelliteSample dictionary."""
6
8
 
7
- satellite_actual = 'satellite_actual'
8
- time_utc = 'satellite_time_utc'
9
- x_geostationary = 'satellite_x_geostationary'
10
- y_geostationary = 'satellite_y_geostationary'
11
- t0_idx = 'satellite_t0_idx'
9
+ satellite_actual = "satellite_actual"
10
+ time_utc = "satellite_time_utc"
11
+ x_geostationary = "satellite_x_geostationary"
12
+ y_geostationary = "satellite_y_geostationary"
13
+ t0_idx = "satellite_t0_idx"
12
14
 
13
15
 
14
16
  def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
- """Convert from Xarray to NumpySample"""
17
+ """Convert from Xarray to NumpySample.
18
+
19
+ Args:
20
+ da: xarray DataArray containing satellite data
21
+ t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
22
+ """
16
23
  sample = {
17
24
  SatelliteSampleKey.satellite_actual: da.values,
18
25
  SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
26
+ SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
27
+ SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
19
28
  }
20
29
 
21
- for sample_key, dataset_key in (
22
- (SatelliteSampleKey.x_geostationary, "x_geostationary"),
23
- (SatelliteSampleKey.y_geostationary, "y_geostationary"),
24
- ):
25
- sample[sample_key] = da[dataset_key].values
26
-
27
30
  if t0_idx is not None:
28
31
  sample[SatelliteSampleKey.t0_idx] = t0_idx
29
32
 
30
- return sample
33
+ return sample