ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (78) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +146 -64
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/load/gsp.py +6 -5
  5. ocf_data_sampler/load/load_dataset.py +5 -6
  6. ocf_data_sampler/load/nwp/nwp.py +17 -5
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  8. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  9. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  10. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  11. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  12. ocf_data_sampler/load/satellite.py +9 -10
  13. ocf_data_sampler/load/site.py +10 -6
  14. ocf_data_sampler/load/utils.py +21 -16
  15. ocf_data_sampler/numpy_sample/collate.py +10 -9
  16. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  17. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  18. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  19. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  20. ocf_data_sampler/numpy_sample/site.py +5 -8
  21. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  22. ocf_data_sampler/sample/base.py +15 -17
  23. ocf_data_sampler/sample/site.py +13 -20
  24. ocf_data_sampler/sample/uk_regional.py +29 -35
  25. ocf_data_sampler/select/dropout.py +16 -14
  26. ocf_data_sampler/select/fill_time_periods.py +15 -5
  27. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  28. ocf_data_sampler/select/geospatial.py +63 -54
  29. ocf_data_sampler/select/location.py +16 -51
  30. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  31. ocf_data_sampler/select/select_time_slice.py +71 -58
  32. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  33. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  34. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
  35. ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
  36. ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  37. ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +63 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler/constants.py +0 -222
  48. ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
  49. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  50. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  51. tests/__init__.py +0 -0
  52. tests/config/test_config.py +0 -113
  53. tests/config/test_load.py +0 -7
  54. tests/config/test_save.py +0 -28
  55. tests/conftest.py +0 -319
  56. tests/load/test_load_gsp.py +0 -15
  57. tests/load/test_load_nwp.py +0 -21
  58. tests/load/test_load_satellite.py +0 -17
  59. tests/load/test_load_sites.py +0 -14
  60. tests/numpy_sample/test_collate.py +0 -21
  61. tests/numpy_sample/test_datetime_features.py +0 -37
  62. tests/numpy_sample/test_gsp.py +0 -38
  63. tests/numpy_sample/test_nwp.py +0 -13
  64. tests/numpy_sample/test_satellite.py +0 -40
  65. tests/numpy_sample/test_sun_position.py +0 -81
  66. tests/select/test_dropout.py +0 -69
  67. tests/select/test_fill_time_periods.py +0 -28
  68. tests/select/test_find_contiguous_time_periods.py +0 -202
  69. tests/select/test_location.py +0 -67
  70. tests/select/test_select_spatial_slice.py +0 -154
  71. tests/select/test_select_time_slice.py +0 -275
  72. tests/test_sample/test_base.py +0 -164
  73. tests/test_sample/test_site_sample.py +0 -165
  74. tests/test_sample/test_uk_regional_sample.py +0 -136
  75. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  76. tests/torch_datasets/test_pvnet_uk.py +0 -154
  77. tests/torch_datasets/test_site.py +0 -226
  78. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,8 +1,10 @@
1
+ """Utility functions for the NWP data processing."""
2
+
1
3
  import xarray as xr
2
4
 
3
5
 
4
6
  def open_zarr_paths(zarr_path: str | list[str], time_dim: str = "init_time") -> xr.Dataset:
5
- """Opens the NWP data
7
+ """Opens the NWP data.
6
8
 
7
9
  Args:
8
10
  zarr_path: Path to the zarr(s) to open
@@ -1,15 +1,16 @@
1
- """Satellite loader"""
1
+ """Satellite loader."""
2
2
 
3
3
  import xarray as xr
4
+
4
5
  from ocf_data_sampler.load.utils import (
5
6
  check_time_unique_increasing,
7
+ get_xr_data_array_from_xr_dataset,
6
8
  make_spatial_coords_increasing,
7
- get_xr_data_array_from_xr_dataset
8
9
  )
9
10
 
10
11
 
11
12
  def get_single_sat_data(zarr_path: str) -> xr.Dataset:
12
- """Helper function to open a zarr from either a local or GCP path
13
+ """Helper function to open a zarr from either a local or GCP path.
13
14
 
14
15
  Args:
15
16
  zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
@@ -21,7 +22,6 @@ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
21
22
  Raises:
22
23
  ValueError: If a wildcard (*) is used in a GCS (gs://) path
23
24
  """
24
-
25
25
  # Raise an error if a wildcard is used in a GCP path
26
26
  if "gs://" in str(zarr_path) and "*" in str(zarr_path):
27
27
  raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
@@ -44,15 +44,14 @@ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
44
44
 
45
45
 
46
46
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
47
- """Lazily opens the zarr store
47
+ """Lazily opens the zarr store.
48
48
 
49
49
  Args:
50
- zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
50
+ zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
51
51
  it must start with 'gs://'
52
52
  """
53
-
54
53
  # Open the data
55
- if isinstance(zarr_path, (list, tuple)):
54
+ if isinstance(zarr_path, list | tuple):
56
55
  ds = xr.combine_nested(
57
56
  [get_single_sat_data(path) for path in zarr_path],
58
57
  concat_dim="time",
@@ -66,12 +65,12 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
66
65
  {
67
66
  "variable": "channel",
68
67
  "time": "time_utc",
69
- }
68
+ },
70
69
  )
71
70
 
72
71
  check_time_unique_increasing(ds.time_utc)
73
72
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
74
73
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
75
-
74
+
76
75
  # TODO: should we control the dtype of the DataArray?
77
76
  return get_xr_data_array_from_xr_dataset(ds)
@@ -1,3 +1,5 @@
1
+ """Funcitons for loading site data."""
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import xarray as xr
@@ -5,7 +7,7 @@ import xarray as xr
5
7
 
6
8
  def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
7
9
  """Open a site's generation data and metadata.
8
-
10
+
9
11
  Args:
10
12
  generation_file_path: Path to the site generation netcdf data
11
13
  metadata_file_path: Path to the site csv metadata
@@ -13,12 +15,12 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
13
15
  Returns:
14
16
  xr.DataArray: The opened site generation data
15
17
  """
16
-
17
18
  generation_ds = xr.open_dataset(generation_file_path)
18
19
 
19
20
  metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
20
21
 
21
- assert metadata_df.index.is_unique
22
+ if not metadata_df.index.is_unique:
23
+ raise ValueError("site_id is not unique in metadata")
22
24
 
23
25
  # Ensure metadata aligns with the site_id dimension in generation_ds
24
26
  metadata_df = metadata_df.reindex(generation_ds.site_id.values)
@@ -31,7 +33,9 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
31
33
  )
32
34
 
33
35
  # Sanity checks
34
- assert np.isfinite(generation_ds.capacity_kwp.values).all()
35
- assert (generation_ds.capacity_kwp.values > 0).all()
36
-
36
+ if not np.isfinite(generation_ds.generation_kw.values).all():
37
+ raise ValueError("generation_kw contains non-finite values")
38
+ if not (generation_ds.capacity_kwp.values > 0).all():
39
+ raise ValueError("capacity_kwp contains non-positive values")
40
+
37
41
  return generation_ds.generation_kw
@@ -1,43 +1,48 @@
1
- import xarray as xr
1
+ """Utility functions for working with xarray objects."""
2
+
2
3
  import pandas as pd
4
+ import xarray as xr
5
+
3
6
 
4
- def check_time_unique_increasing(datetimes) -> None:
5
- """Check that the time dimension is unique and increasing"""
7
+ def check_time_unique_increasing(datetimes: xr.DataArray) -> None:
8
+ """Check that the time dimension is unique and increasing."""
6
9
  time = pd.DatetimeIndex(datetimes)
7
- assert time.is_unique
8
- assert time.is_monotonic_increasing
10
+ if not (time.is_unique and time.is_monotonic_increasing):
11
+ raise ValueError("Time dimension must be unique and monotonically increasing")
9
12
 
10
13
 
11
14
  def make_spatial_coords_increasing(ds: xr.Dataset, x_coord: str, y_coord: str) -> xr.Dataset:
12
- """Make sure the spatial coordinates are in increasing order
13
-
15
+ """Make sure the spatial coordinates are in increasing order.
16
+
14
17
  Args:
15
18
  ds: Xarray Dataset
16
19
  x_coord: Name of the x coordinate
17
20
  y_coord: Name of the y coordinate
18
21
  """
19
-
20
22
  # Make sure the coords are in increasing order
21
23
  if ds[x_coord][0] > ds[x_coord][-1]:
22
- ds = ds.isel({x_coord:slice(None, None, -1)})
24
+ ds = ds.isel({x_coord: slice(None, None, -1)})
23
25
  if ds[y_coord][0] > ds[y_coord][-1]:
24
- ds = ds.isel({y_coord:slice(None, None, -1)})
26
+ ds = ds.isel({y_coord: slice(None, None, -1)})
25
27
 
26
28
  # Check the coords are all increasing now
27
- assert (ds[x_coord].diff(dim=x_coord) > 0).all()
28
- assert (ds[y_coord].diff(dim=y_coord) > 0).all()
29
+ if not (ds[x_coord].diff(dim=x_coord) > 0).all():
30
+ raise ValueError(f"'{x_coord}' coordinate must be increasing")
31
+ if not (ds[y_coord].diff(dim=y_coord) > 0).all():
32
+ raise ValueError(f"'{y_coord}' coordinate must be increasing")
29
33
 
30
34
  return ds
31
35
 
32
36
 
33
37
  def get_xr_data_array_from_xr_dataset(ds: xr.Dataset) -> xr.DataArray:
34
- """Return underlying xr.DataArray from passed xr.Dataset.
38
+ """Return underlying xr.DataArray from passed xr.Dataset.
39
+
35
40
  Checks only one variable is present and returns it as an xr.DataArray.
36
41
 
37
42
  Args:
38
43
  ds: xr.Dataset to extract xr.DataArray from
39
44
  """
40
-
41
45
  datavars = list(ds.var())
42
- assert len(datavars) == 1, "Cannot open as xr.DataArray: dataset contains multiple variables"
43
- return ds[datavars[0]]
46
+ if len(datavars) != 1:
47
+ raise ValueError("Cannot open as xr.DataArray: dataset contains multiple variables")
48
+ return ds[datavars[0]]
@@ -1,8 +1,10 @@
1
+ """Functions for collating samples into batches."""
2
+
1
3
  import numpy as np
2
4
 
3
5
 
4
6
  def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
5
- """Stacks list of dict samples into a dict where all samples are joined along a new axis
7
+ """Stacks list of dict samples into a dict where all samples are joined along a new axis.
6
8
 
7
9
  Args:
8
10
  dict_list: A list of dict-like samples to stack
@@ -10,7 +12,6 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
10
12
  Returns:
11
13
  Dict of the samples stacked with new batch dimension on axis 0
12
14
  """
13
-
14
15
  batch = {}
15
16
 
16
17
  keys = list(dict_list[0].keys())
@@ -26,10 +27,10 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
26
27
  for nwp_provider in nwp_providers:
27
28
  # Keys can be different for different NWPs
28
29
  nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
29
-
30
+
30
31
  # Create dict to store NWP batch for this provider
31
32
  nwp_provider_batch = {}
32
-
33
+
33
34
  for nwp_key in nwp_keys:
34
35
  # Stack values under each NWP key for this provider
35
36
  nwp_provider_batch[nwp_key] = stack_data_list(
@@ -46,16 +47,16 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
46
47
 
47
48
 
48
49
  def _key_is_constant(key: str) -> bool:
49
- """Check if a key is for value which is constant for all samples"""
50
+ """Check if a key is for value which is constant for all samples."""
50
51
  return key.endswith("t0_idx") or key.endswith("channel_names")
51
52
 
52
53
 
53
54
  def stack_data_list(data_list: list, key: str) -> np.ndarray:
54
- """Stack a sequence of data elements along a new axis
55
+ """Stack a sequence of data elements along a new axis.
55
56
 
56
- Args:
57
- data_list: List of data elements to combine
58
- key: string identifying the data type
57
+ Args:
58
+ data_list: List of data elements to combine
59
+ key: string identifying the data type
59
60
  """
60
61
  if _key_is_constant(key):
61
62
  return data_list[0]
@@ -1,11 +1,11 @@
1
- """Functions to create trigonometric date and time inputs"""
1
+ """Functions to create trigonometric date and time inputs."""
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
5
 
6
6
 
7
7
  def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
8
- """Create positional embeddings for the datetimes in radians
8
+ """Create positional embeddings for the datetimes in radians.
9
9
 
10
10
  Args:
11
11
  dt: DatetimeIndex to create radian embeddings for
@@ -13,7 +13,6 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
13
13
  Returns:
14
14
  Tuple of numpy arrays containing radian coordinates for date and time
15
15
  """
16
-
17
16
  day_of_year = dt.dayofyear
18
17
  minute_of_day = dt.minute + dt.hour * 60
19
18
 
@@ -24,8 +23,7 @@ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
24
23
 
25
24
 
26
25
  def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
27
- """ Creates dictionary of cyclical datetime features - encoded """
28
-
26
+ """Creates dictionary of cyclical datetime features - encoded."""
29
27
  date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
30
28
 
31
29
  time_numpy_sample = {}
@@ -1,30 +1,28 @@
1
- """Convert GSP to Numpy Sample"""
1
+ """Convert GSP to Numpy Sample."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  class GSPSampleKey:
7
+ """Keys for the GSP sample dictionary."""
7
8
 
8
- gsp = 'gsp'
9
- nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
10
- effective_capacity_mwp = 'gsp_effective_capacity_mwp'
11
- time_utc = 'gsp_time_utc'
12
- t0_idx = 'gsp_t0_idx'
13
- solar_azimuth = 'gsp_solar_azimuth'
14
- solar_elevation = 'gsp_solar_elevation'
15
- gsp_id = 'gsp_id'
16
- x_osgb = 'gsp_x_osgb'
17
- y_osgb = 'gsp_y_osgb'
9
+ gsp = "gsp"
10
+ nominal_capacity_mwp = "gsp_nominal_capacity_mwp"
11
+ effective_capacity_mwp = "gsp_effective_capacity_mwp"
12
+ time_utc = "gsp_time_utc"
13
+ t0_idx = "gsp_t0_idx"
14
+ gsp_id = "gsp_id"
15
+ x_osgb = "gsp_x_osgb"
16
+ y_osgb = "gsp_y_osgb"
18
17
 
19
18
 
20
19
  def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample
22
-
20
+ """Convert from Xarray to NumpySample.
21
+
23
22
  Args:
24
23
  da: Xarray DataArray containing GSP data
25
24
  t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
26
25
  """
27
-
28
26
  sample = {
29
27
  GSPSampleKey.gsp: da.values,
30
28
  GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
@@ -1,27 +1,27 @@
1
- """Convert NWP to NumpySample"""
1
+ """Convert NWP to NumpySample."""
2
2
 
3
3
  import pandas as pd
4
4
  import xarray as xr
5
5
 
6
6
 
7
7
  class NWPSampleKey:
8
+ """Keys for NWP NumpySample."""
8
9
 
9
- nwp = 'nwp'
10
- channel_names = 'nwp_channel_names'
11
- init_time_utc = 'nwp_init_time_utc'
12
- step = 'nwp_step'
13
- target_time_utc = 'nwp_target_time_utc'
14
- t0_idx = 'nwp_t0_idx'
10
+ nwp = "nwp"
11
+ channel_names = "nwp_channel_names"
12
+ init_time_utc = "nwp_init_time_utc"
13
+ step = "nwp_step"
14
+ target_time_utc = "nwp_target_time_utc"
15
+ t0_idx = "nwp_t0_idx"
15
16
 
16
17
 
17
18
  def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
18
- """Convert from Xarray to NWP NumpySample
19
-
19
+ """Convert from Xarray to NWP NumpySample.
20
+
20
21
  Args:
21
22
  da: Xarray DataArray containing NWP data
22
23
  t0_idx: Index of the t0 timestamp in the time dimension of the NWP
23
24
  """
24
-
25
25
  sample = {
26
26
  NWPSampleKey.nwp: da.values,
27
27
  NWPSampleKey.channel_names: da.channel.values,
@@ -32,5 +32,5 @@ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
32
32
 
33
33
  if t0_idx is not None:
34
34
  sample[NWPSampleKey.t0_idx] = t0_idx
35
-
36
- return sample
35
+
36
+ return sample
@@ -1,20 +1,21 @@
1
- """Convert Satellite to NumpySample"""
1
+ """Convert Satellite to NumpySample."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  class SatelliteSampleKey:
7
+ """Keys for the SatelliteSample dictionary."""
7
8
 
8
- satellite_actual = 'satellite_actual'
9
- time_utc = 'satellite_time_utc'
10
- x_geostationary = 'satellite_x_geostationary'
11
- y_geostationary = 'satellite_y_geostationary'
12
- t0_idx = 'satellite_t0_idx'
9
+ satellite_actual = "satellite_actual"
10
+ time_utc = "satellite_time_utc"
11
+ x_geostationary = "satellite_x_geostationary"
12
+ y_geostationary = "satellite_y_geostationary"
13
+ t0_idx = "satellite_t0_idx"
13
14
 
14
15
 
15
16
  def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
16
- """Convert from Xarray to NumpySample
17
-
17
+ """Convert from Xarray to NumpySample.
18
+
18
19
  Args:
19
20
  da: xarray DataArray containing satellite data
20
21
  t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
@@ -30,4 +31,3 @@ def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = Non
30
31
  sample[SatelliteSampleKey.t0_idx] = t0_idx
31
32
 
32
33
  return sample
33
-
@@ -1,37 +1,34 @@
1
- """Convert site to Numpy Sample"""
1
+ """Convert site to Numpy Sample."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  class SiteSampleKey:
7
+ """Keys for the site sample dictionary."""
7
8
 
8
9
  generation = "site"
9
10
  capacity_kwp = "site_capacity_kwp"
10
11
  time_utc = "site_time_utc"
11
12
  t0_idx = "site_t0_idx"
12
13
  id = "site_id"
13
- solar_azimuth = "site_solar_azimuth"
14
- solar_elevation = "site_solar_elevation"
15
14
  date_sin = "site_date_sin"
16
15
  date_cos = "site_date_cos"
17
16
  time_sin = "site_time_sin"
18
17
  time_cos = "site_time_cos"
19
18
 
19
+
20
20
  def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample
22
-
21
+ """Convert from Xarray to NumpySample.
22
+
23
23
  Args:
24
24
  da: xarray DataArray containing site data
25
25
  t0_idx: Index of the t0 timestamp in the time dimension of the site data
26
26
  """
27
-
28
27
  sample = {
29
28
  SiteSampleKey.generation: da.values,
30
29
  SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
31
30
  SiteSampleKey.time_utc: da["time_utc"].values.astype(float),
32
31
  SiteSampleKey.id: da["site_id"].values,
33
- SiteSampleKey.solar_azimuth: da["solar_azimuth"].values,
34
- SiteSampleKey.solar_elevation: da["solar_elevation"].values,
35
32
  SiteSampleKey.date_sin: da["date_sin"].values,
36
33
  SiteSampleKey.date_cos: da["date_cos"].values,
37
34
  SiteSampleKey.time_sin: da["time_sin"].values,
@@ -1,16 +1,17 @@
1
+ """Module for calculating solar position."""
1
2
 
2
- import pvlib
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ import pvlib
5
6
 
6
7
 
7
8
  def calculate_azimuth_and_elevation(
8
- datetimes: pd.DatetimeIndex,
9
- lon: float,
10
- lat: float
9
+ datetimes: pd.DatetimeIndex,
10
+ lon: float,
11
+ lat: float,
11
12
  ) -> tuple[np.ndarray, np.ndarray]:
12
- """Calculate the solar coordinates for multiple datetimes at a single location
13
-
13
+ """Calculate the solar coordinates for multiple datetimes at a single location.
14
+
14
15
  Args:
15
16
  datetimes: The datetimes to calculate for
16
17
  lon: The longitude
@@ -20,45 +21,39 @@ def calculate_azimuth_and_elevation(
20
21
  np.ndarray: The azimuth of the datetimes in degrees
21
22
  np.ndarray: The elevation of the datetimes in degrees
22
23
  """
23
-
24
24
  solpos = pvlib.solarposition.get_solarposition(
25
25
  time=datetimes,
26
26
  longitude=lon,
27
27
  latitude=lat,
28
- method='nrel_numpy'
28
+ method="nrel_numpy",
29
29
  )
30
30
 
31
31
  return solpos["azimuth"].values, solpos["elevation"].values
32
32
 
33
33
 
34
34
  def make_sun_position_numpy_sample(
35
- datetimes: pd.DatetimeIndex,
36
- lon: float,
37
- lat: float,
38
- key_prefix: str = "gsp"
35
+ datetimes: pd.DatetimeIndex,
36
+ lon: float,
37
+ lat: float,
39
38
  ) -> dict:
40
- """Creates NumpySample with standardized solar coordinates
39
+ """Creates NumpySample with standardized solar coordinates.
41
40
 
42
41
  Args:
43
42
  datetimes: The datetimes to calculate solar angles for
44
43
  lon: The longitude
45
44
  lat: The latitude
46
- key_prefix: The prefix to add to the keys in the NumpySample
47
45
  """
48
-
49
46
  azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
50
47
 
51
48
  # Normalise
52
49
  # Azimuth is in range [0, 360] degrees
53
50
  azimuth = azimuth / 360
54
51
 
55
- # Elevation is in range [-90, 90] degrees
52
+ # Elevation is in range [-90, 90] degrees
56
53
  elevation = elevation / 180 + 0.5
57
54
 
58
55
  # Make NumpySample
59
- sun_numpy_sample = {
60
- key_prefix + "_solar_azimuth": azimuth,
61
- key_prefix + "_solar_elevation": elevation,
56
+ return {
57
+ "solar_azimuth": azimuth,
58
+ "solar_elevation": elevation,
62
59
  }
63
-
64
- return sun_numpy_sample
@@ -1,11 +1,10 @@
1
- """ Base class for handling flat/nested data structures with NWP consideration """
1
+ """Base class for handling flat/nested data structures with NWP consideration."""
2
2
 
3
- import numpy as np
4
- import torch
5
-
6
- from typing import TypeAlias
7
3
  from abc import ABC, abstractmethod
4
+ from typing import TypeAlias
8
5
 
6
+ import numpy as np
7
+ import torch
9
8
 
10
9
  NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
11
10
  NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
@@ -13,39 +12,38 @@ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
13
12
 
14
13
 
15
14
  class SampleBase(ABC):
16
- """
17
- Abstract base class for all sample types
18
- Provides core data storage functionality
19
- """
15
+ """Abstract base class for all sample types."""
20
16
 
21
17
  @abstractmethod
22
18
  def to_numpy(self) -> NumpySample:
23
- """Convert sample data to numpy format"""
19
+ """Convert sample data to numpy format."""
24
20
  raise NotImplementedError
25
21
 
26
22
  @abstractmethod
27
23
  def plot(self) -> None:
24
+ """Create a visualisation of the data."""
28
25
  raise NotImplementedError
29
26
 
30
27
  @abstractmethod
31
28
  def save(self, path: str) -> None:
29
+ """Saves the sample to disk in the implementations' required format."""
32
30
  raise NotImplementedError
33
31
 
34
32
  @classmethod
35
33
  @abstractmethod
36
- def load(cls, path: str) -> 'SampleBase':
34
+ def load(cls, path: str) -> "SampleBase":
35
+ """Load a sample from disk from the implementations' format."""
37
36
  raise NotImplementedError
38
37
 
39
38
 
40
39
  def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
41
- """
42
- Recursively converts numpy arrays in nested dict to torch tensors
40
+ """Recursively converts numpy arrays in nested dict to torch tensors.
41
+
43
42
  Args:
44
43
  batch: NumpyBatch with data in numpy arrays
45
44
  Returns:
46
45
  TensorBatch with data in torch tensors
47
46
  """
48
-
49
47
  for k, v in batch.items():
50
48
  if isinstance(v, dict):
51
49
  batch[k] = batch_to_tensor(v)
@@ -58,12 +56,12 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
58
56
 
59
57
 
60
58
  def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
61
- """Recursively copies tensors in nested dict to specified device
59
+ """Recursively copies tensors in nested dict to specified device.
62
60
 
63
61
  Args:
64
62
  batch: Nested dict with tensors to move
65
63
  device: Device to move tensors to
66
-
64
+
67
65
  Returns:
68
66
  A dict with tensors moved to the new device
69
67
  """
@@ -71,7 +69,7 @@ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatc
71
69
 
72
70
  for k, v in batch.items():
73
71
  if isinstance(v, dict):
74
- batch_copy[k] = copy_batch_to_device(v, device)
72
+ batch_copy[k] = copy_batch_to_device(v, device)
75
73
  elif isinstance(v, torch.Tensor):
76
74
  batch_copy[k] = v.to(device)
77
75
  else:
@@ -1,44 +1,37 @@
1
- """PVNet Site sample implementation for netCDF data handling and conversion"""
1
+ """PVNet Site sample implementation for netCDF data handling and conversion."""
2
2
 
3
3
  import xarray as xr
4
-
5
4
  from typing_extensions import override
6
5
 
7
- from ocf_data_sampler.sample.base import SampleBase, NumpySample
6
+ from ocf_data_sampler.sample.base import NumpySample, SampleBase
8
7
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
9
8
 
10
9
 
11
10
  class SiteSample(SampleBase):
12
- """Handles PVNet site specific netCDF operations"""
11
+ """Handles PVNet site specific netCDF operations."""
13
12
 
14
- def __init__(self, data: xr.Dataset):
15
-
13
+ def __init__(self, data: xr.Dataset) -> None:
14
+ """Initializes the SiteSample object with the given xarray Dataset."""
16
15
  if not isinstance(data, xr.Dataset):
17
16
  raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
18
-
19
17
  self._data = data
20
18
 
21
19
  @override
22
- def to_numpy(self) -> NumpySample:
20
+ def to_numpy(self) -> NumpySample:
23
21
  return convert_netcdf_to_numpy_sample(self._data)
24
22
 
23
+ @override
25
24
  def save(self, path: str) -> None:
26
- """Save site sample data as netCDF
27
-
28
- Args:
29
- path: Path to save the netCDF file
30
- """
25
+ # Saves as NetCDF
31
26
  self._data.to_netcdf(path, mode="w", engine="h5netcdf")
32
27
 
33
28
  @classmethod
34
- def load(cls, path: str) -> 'SiteSample':
35
- """Load site sample data from netCDF
36
-
37
- Args:
38
- path: Path to load the netCDF file from
39
- """
29
+ @override
30
+ def load(cls, path: str) -> "SiteSample":
31
+ # Loads from NetCDF
40
32
  return cls(xr.open_dataset(path))
41
33
 
42
- # TODO - placeholder for now
34
+ @override
43
35
  def plot(self) -> None:
36
+ # TODO - placeholder for now
44
37
  raise NotImplementedError("Plotting not yet implemented for SiteSample")