ocf-data-sampler 0.1.9__tar.gz → 0.1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (93) hide show
  1. {ocf_data_sampler-0.1.9/ocf_data_sampler.egg-info → ocf_data_sampler-0.1.11}/PKG-INFO +1 -1
  2. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/model.py +25 -23
  3. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/satellite.py +21 -29
  4. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/site.py +1 -1
  5. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/gsp.py +6 -2
  6. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/nwp.py +7 -13
  7. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/satellite.py +11 -8
  8. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/site.py +6 -2
  9. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/sun_position.py +9 -10
  10. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/sample/__init__.py +0 -7
  11. ocf_data_sampler-0.1.11/ocf_data_sampler/sample/base.py +79 -0
  12. ocf_data_sampler-0.1.11/ocf_data_sampler/sample/site.py +44 -0
  13. ocf_data_sampler-0.1.11/ocf_data_sampler/sample/uk_regional.py +75 -0
  14. ocf_data_sampler-0.1.11/ocf_data_sampler/select/dropout.py +52 -0
  15. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/fill_time_periods.py +3 -1
  16. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
  17. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
  18. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
  19. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
  20. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/pyproject.toml +1 -1
  21. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_config.py +3 -3
  22. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/conftest.py +33 -0
  23. ocf_data_sampler-0.1.11/tests/numpy_sample/test_nwp.py +13 -0
  24. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_dropout.py +7 -13
  25. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_site_sample.py +5 -35
  26. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_uk_regional_sample.py +8 -35
  27. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_pvnet_uk.py +6 -19
  28. ocf_data_sampler-0.1.9/ocf_data_sampler/sample/base.py +0 -98
  29. ocf_data_sampler-0.1.9/ocf_data_sampler/sample/site.py +0 -81
  30. ocf_data_sampler-0.1.9/ocf_data_sampler/sample/uk_regional.py +0 -120
  31. ocf_data_sampler-0.1.9/ocf_data_sampler/select/dropout.py +0 -39
  32. ocf_data_sampler-0.1.9/tests/numpy_sample/test_nwp.py +0 -52
  33. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/LICENSE +0 -0
  34. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/MANIFEST.in +0 -0
  35. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/README.md +0 -0
  36. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/__init__.py +0 -0
  37. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/__init__.py +0 -0
  38. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/load.py +0 -0
  39. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/save.py +0 -0
  40. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/constants.py +0 -0
  41. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  42. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/__init__.py +0 -0
  43. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/gsp.py +0 -0
  44. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/load_dataset.py +0 -0
  45. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  46. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  47. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  48. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  49. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  50. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  51. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/utils.py +0 -0
  52. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  53. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  54. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  55. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/__init__.py +0 -0
  56. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/geospatial.py +0 -0
  57. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/location.py +0 -0
  58. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  59. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/select_time_slice.py +0 -0
  60. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  61. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
  62. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  63. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  64. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  65. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -0
  66. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/utils.py +0 -0
  67. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
  68. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  69. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/requires.txt +0 -0
  70. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  71. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/scripts/refactor_site.py +0 -0
  72. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/setup.cfg +0 -0
  73. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/__init__.py +0 -0
  74. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_load.py +0 -0
  75. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_save.py +0 -0
  76. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_gsp.py +0 -0
  77. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_nwp.py +0 -0
  78. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_satellite.py +0 -0
  79. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_sites.py +0 -0
  80. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_collate.py +0 -0
  81. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_datetime_features.py +0 -0
  82. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_gsp.py +0 -0
  83. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_satellite.py +0 -0
  84. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_sun_position.py +0 -0
  85. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_fill_time_periods.py +0 -0
  86. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_find_contiguous_time_periods.py +0 -0
  87. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_location.py +0 -0
  88. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_select_spatial_slice.py +0 -0
  89. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_select_time_slice.py +0 -0
  90. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_base.py +0 -0
  91. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -0
  92. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_site.py +0 -0
  93. {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_validate_channels_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -49,31 +49,34 @@ class TimeWindowMixin(Base):
49
49
  ...,
50
50
  description="Data interval ends at `t0 + interval_end_minutes`",
51
51
  )
52
-
52
+
53
53
  @model_validator(mode='after')
54
- def check_interval_range(cls, values):
55
- if values.interval_start_minutes > values.interval_end_minutes:
56
- raise ValueError('interval_start_minutes must be <= interval_end_minutes')
54
+ def validate_intervals(cls, values):
55
+ start = values.interval_start_minutes
56
+ end = values.interval_end_minutes
57
+ resolution = values.time_resolution_minutes
58
+ if start > end:
59
+ raise ValueError(
60
+ f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
61
+ )
62
+ if (start % resolution != 0):
63
+ raise ValueError(
64
+ f"interval_start_minutes ({start}) must be divisible "
65
+ f"by time_resolution_minutes ({resolution})"
66
+ )
67
+ if (end % resolution != 0):
68
+ raise ValueError(
69
+ f"interval_end_minutes ({end}) must be divisible "
70
+ f"by time_resolution_minutes ({resolution})"
71
+ )
57
72
  return values
58
73
 
59
- @field_validator("interval_start_minutes")
60
- def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
61
- if v % info.data["time_resolution_minutes"] != 0:
62
- raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
63
- return v
64
-
65
- @field_validator("interval_end_minutes")
66
- def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
67
- if v % info.data["time_resolution_minutes"] != 0:
68
- raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
69
- return v
70
-
71
74
 
72
75
  class DropoutMixin(Base):
73
76
  """Mixin class, to add dropout minutes"""
74
77
 
75
- dropout_timedeltas_minutes: Optional[List[int]] = Field(
76
- default=None,
78
+ dropout_timedeltas_minutes: List[int] = Field(
79
+ default=[],
77
80
  description="List of possible minutes before t0 where data availability may start. Must be "
78
81
  "negative or zero.",
79
82
  )
@@ -88,18 +91,17 @@ class DropoutMixin(Base):
88
91
  @field_validator("dropout_timedeltas_minutes")
89
92
  def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
90
93
  """Validate 'dropout_timedeltas_minutes'"""
91
- if v is not None:
92
- for m in v:
93
- assert m <= 0, "Dropout timedeltas must be negative"
94
+ for m in v:
95
+ assert m <= 0, "Dropout timedeltas must be negative"
94
96
  return v
95
97
 
96
98
  @model_validator(mode="after")
97
99
  def dropout_instructions_consistent(self) -> Self:
98
100
  if self.dropout_fraction == 0:
99
- if self.dropout_timedeltas_minutes is not None:
101
+ if self.dropout_timedeltas_minutes != []:
100
102
  raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
101
103
  else:
102
- if self.dropout_timedeltas_minutes is None:
104
+ if self.dropout_timedeltas_minutes == []:
103
105
  raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
104
106
  return self
105
107
 
@@ -1,7 +1,5 @@
1
1
  """Satellite loader"""
2
2
 
3
- import subprocess
4
-
5
3
  import xarray as xr
6
4
  from ocf_data_sampler.load.utils import (
7
5
  check_time_unique_increasing,
@@ -10,63 +8,59 @@ from ocf_data_sampler.load.utils import (
10
8
  )
11
9
 
12
10
 
13
- def _get_single_sat_data(zarr_path: str) -> xr.Dataset:
14
- """Helper function to open a Zarr from either a local or GCP path.
11
+ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
12
+ """Helper function to open a zarr from either a local or GCP path
15
13
 
16
14
  Args:
17
- zarr_path: Path to a Zarr file. Wildcards (*) are supported **only** for local paths.
18
- GCS paths (gs://) **do not support** wildcards.
15
+ zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
16
+ GCS paths (gs://) do not support wildcards
19
17
 
20
18
  Returns:
21
- An xarray Dataset containing satellite data.
19
+ An xarray Dataset containing satellite data
22
20
 
23
21
  Raises:
24
- ValueError: If a wildcard (*) is used in a GCS (gs://) path.
22
+ ValueError: If a wildcard (*) is used in a GCS (gs://) path
25
23
  """
26
24
 
27
- # These kwargs are used if the path contains "*"
28
- openmf_kwargs = dict(
29
- engine="zarr",
30
- concat_dim="time",
31
- combine="nested",
32
- chunks="auto",
33
- join="override",
34
- )
35
-
36
25
  # Raise an error if a wildcard is used in a GCP path
37
26
  if "gs://" in str(zarr_path) and "*" in str(zarr_path):
38
- raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs.")
27
+ raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
39
28
 
40
29
  # Handle multi-file dataset for local paths
41
30
  if "*" in str(zarr_path):
42
- ds = xr.open_mfdataset(zarr_path, **openmf_kwargs)
31
+ ds = xr.open_mfdataset(
32
+ zarr_path,
33
+ engine="zarr",
34
+ concat_dim="time",
35
+ combine="nested",
36
+ chunks="auto",
37
+ join="override",
38
+ )
39
+ check_time_unique_increasing(ds.time)
43
40
  else:
44
41
  ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
45
42
 
46
- # Ensure time is unique and sorted
47
- ds = ds.drop_duplicates("time").sortby("time")
48
-
49
43
  return ds
50
44
 
51
45
 
52
46
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
53
- """Lazily opens the Zarr store.
47
+ """Lazily opens the zarr store
54
48
 
55
49
  Args:
56
- zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL, it must start with
57
- 'gs://'.
50
+ zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
51
+ it must start with 'gs://'
58
52
  """
59
53
 
60
54
  # Open the data
61
55
  if isinstance(zarr_path, (list, tuple)):
62
56
  ds = xr.combine_nested(
63
- [_get_single_sat_data(path) for path in zarr_path],
57
+ [get_single_sat_data(path) for path in zarr_path],
64
58
  concat_dim="time",
65
59
  combine_attrs="override",
66
60
  join="override",
67
61
  )
68
62
  else:
69
- ds = _get_single_sat_data(zarr_path)
63
+ ds = get_single_sat_data(zarr_path)
70
64
 
71
65
  ds = ds.rename(
72
66
  {
@@ -76,9 +70,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
76
70
  )
77
71
 
78
72
  check_time_unique_increasing(ds.time_utc)
79
-
80
73
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
81
-
82
74
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
83
75
 
84
76
  # TODO: should we control the dtype of the DataArray?
@@ -20,7 +20,7 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
20
20
 
21
21
  assert metadata_df.index.is_unique
22
22
 
23
- # Ensure metadata aligns with the site_id dimension in data_ds
23
+ # Ensure metadata aligns with the site_id dimension in generation_ds
24
24
  metadata_df = metadata_df.reindex(generation_ds.site_id.values)
25
25
 
26
26
  # Assign coordinates to the Dataset using the aligned metadata
@@ -18,9 +18,13 @@ class GSPSampleKey:
18
18
 
19
19
 
20
20
  def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
21
+ """Convert from Xarray to NumpySample
22
+
23
+ Args:
24
+ da: Xarray DataArray containing GSP data
25
+ t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
26
+ """
22
27
 
23
- # Extract values from the DataArray
24
28
  sample = {
25
29
  GSPSampleKey.gsp: da.values,
26
30
  GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
@@ -12,30 +12,24 @@ class NWPSampleKey:
12
12
  step = 'nwp_step'
13
13
  target_time_utc = 'nwp_target_time_utc'
14
14
  t0_idx = 'nwp_t0_idx'
15
- y_osgb = 'nwp_y_osgb'
16
- x_osgb = 'nwp_x_osgb'
17
-
18
15
 
19
16
 
20
17
  def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NWP NumpySample"""
18
+ """Convert from Xarray to NWP NumpySample
19
+
20
+ Args:
21
+ da: Xarray DataArray containing NWP data
22
+ t0_idx: Index of the t0 timestamp in the time dimension of the NWP
23
+ """
22
24
 
23
- # Create example and add t if available
24
25
  sample = {
25
26
  NWPSampleKey.nwp: da.values,
26
27
  NWPSampleKey.channel_names: da.channel.values,
27
28
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
28
29
  NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
30
+ NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
29
31
  }
30
32
 
31
- if "target_time_utc" in da.coords:
32
- sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
33
-
34
- # TODO: Do we need this at all? Especially since it is only present in UKV data
35
- for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
36
- if dataset_key in da.coords:
37
- sample[sample_key] = da[dataset_key].values
38
-
39
33
  if t0_idx is not None:
40
34
  sample[NWPSampleKey.t0_idx] = t0_idx
41
35
 
@@ -1,4 +1,5 @@
1
1
  """Convert Satellite to NumpySample"""
2
+
2
3
  import xarray as xr
3
4
 
4
5
 
@@ -12,19 +13,21 @@ class SatelliteSampleKey:
12
13
 
13
14
 
14
15
  def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
- """Convert from Xarray to NumpySample"""
16
+ """Convert from Xarray to NumpySample
17
+
18
+ Args:
19
+ da: xarray DataArray containing satellite data
20
+ t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
21
+ """
16
22
  sample = {
17
23
  SatelliteSampleKey.satellite_actual: da.values,
18
24
  SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
25
+ SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
26
+ SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
19
27
  }
20
28
 
21
- for sample_key, dataset_key in (
22
- (SatelliteSampleKey.x_geostationary, "x_geostationary"),
23
- (SatelliteSampleKey.y_geostationary, "y_geostationary"),
24
- ):
25
- sample[sample_key] = da[dataset_key].values
26
-
27
29
  if t0_idx is not None:
28
30
  sample[SatelliteSampleKey.t0_idx] = t0_idx
29
31
 
30
- return sample
32
+ return sample
33
+
@@ -18,9 +18,13 @@ class SiteSampleKey:
18
18
  time_cos = "site_time_cos"
19
19
 
20
20
  def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
21
+ """Convert from Xarray to NumpySample
22
+
23
+ Args:
24
+ da: xarray DataArray containing site data
25
+ t0_idx: Index of the t0 timestamp in the time dimension of the site data
26
+ """
22
27
 
23
- # Extract values from the DataArray
24
28
  sample = {
25
29
  SiteSampleKey.generation: da.values,
26
30
  SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
@@ -27,16 +27,15 @@ def calculate_azimuth_and_elevation(
27
27
  latitude=lat,
28
28
  method='nrel_numpy'
29
29
  )
30
- azimuth = solpos["azimuth"].values
31
- elevation = solpos["elevation"].values
32
- return azimuth, elevation
30
+
31
+ return solpos["azimuth"].values, solpos["elevation"].values
33
32
 
34
33
 
35
34
  def make_sun_position_numpy_sample(
36
- datetimes: pd.DatetimeIndex,
37
- lon: float,
38
- lat: float,
39
- key_prefix: str = "gsp"
35
+ datetimes: pd.DatetimeIndex,
36
+ lon: float,
37
+ lat: float,
38
+ key_prefix: str = "gsp"
40
39
  ) -> dict:
41
40
  """Creates NumpySample with standardized solar coordinates
42
41
 
@@ -44,22 +43,22 @@ def make_sun_position_numpy_sample(
44
43
  datetimes: The datetimes to calculate solar angles for
45
44
  lon: The longitude
46
45
  lat: The latitude
46
+ key_prefix: The prefix to add to the keys in the NumpySample
47
47
  """
48
48
 
49
49
  azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
50
50
 
51
51
  # Normalise
52
-
53
52
  # Azimuth is in range [0, 360] degrees
54
53
  azimuth = azimuth / 360
55
54
 
56
55
  # Elevation is in range [-90, 90] degrees
57
56
  elevation = elevation / 180 + 0.5
58
-
57
+
59
58
  # Make NumpySample
60
59
  sun_numpy_sample = {
61
60
  key_prefix + "_solar_azimuth": azimuth,
62
61
  key_prefix + "_solar_elevation": elevation,
63
62
  }
64
63
 
65
- return sun_numpy_sample
64
+ return sun_numpy_sample
@@ -1,10 +1,3 @@
1
1
  from ocf_data_sampler.sample.base import SampleBase
2
2
  from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
3
  from ocf_data_sampler.sample.site import SiteSample
4
-
5
-
6
- __all__ = [
7
- 'SampleBase',
8
- 'UKRegionalSample',
9
- 'SiteSample'
10
- ]
@@ -0,0 +1,79 @@
1
+ """ Base class for handling flat/nested data structures with NWP consideration """
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from typing import TypeAlias
7
+ from abc import ABC, abstractmethod
8
+
9
+
10
+ NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
11
+ NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
12
+ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
13
+
14
+
15
+ class SampleBase(ABC):
16
+ """
17
+ Abstract base class for all sample types
18
+ Provides core data storage functionality
19
+ """
20
+
21
+ @abstractmethod
22
+ def to_numpy(self) -> NumpySample:
23
+ """Convert sample data to numpy format"""
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def plot(self) -> None:
28
+ raise NotImplementedError
29
+
30
+ @abstractmethod
31
+ def save(self, path: str) -> None:
32
+ raise NotImplementedError
33
+
34
+ @classmethod
35
+ @abstractmethod
36
+ def load(cls, path: str) -> 'SampleBase':
37
+ raise NotImplementedError
38
+
39
+
40
+ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
41
+ """
42
+ Recursively converts numpy arrays in nested dict to torch tensors
43
+ Args:
44
+ batch: NumpyBatch with data in numpy arrays
45
+ Returns:
46
+ TensorBatch with data in torch tensors
47
+ """
48
+
49
+ for k, v in batch.items():
50
+ if isinstance(v, dict):
51
+ batch[k] = batch_to_tensor(v)
52
+ elif isinstance(v, np.ndarray):
53
+ if v.dtype == np.bool_:
54
+ batch[k] = torch.tensor(v, dtype=torch.bool)
55
+ elif np.issubdtype(v.dtype, np.number):
56
+ batch[k] = torch.as_tensor(v)
57
+ return batch
58
+
59
+
60
+ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
61
+ """Recursively copies tensors in nested dict to specified device
62
+
63
+ Args:
64
+ batch: Nested dict with tensors to move
65
+ device: Device to move tensors to
66
+
67
+ Returns:
68
+ A dict with tensors moved to the new device
69
+ """
70
+ batch_copy = {}
71
+
72
+ for k, v in batch.items():
73
+ if isinstance(v, dict):
74
+ batch_copy[k] = copy_batch_to_device(v, device)
75
+ elif isinstance(v, torch.Tensor):
76
+ batch_copy[k] = v.to(device)
77
+ else:
78
+ batch_copy[k] = v
79
+ return batch_copy
@@ -0,0 +1,44 @@
1
+ """PVNet Site sample implementation for netCDF data handling and conversion"""
2
+
3
+ import xarray as xr
4
+
5
+ from typing_extensions import override
6
+
7
+ from ocf_data_sampler.sample.base import SampleBase, NumpySample
8
+ from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
9
+
10
+
11
+ class SiteSample(SampleBase):
12
+ """Handles PVNet site specific netCDF operations"""
13
+
14
+ def __init__(self, data: xr.Dataset):
15
+
16
+ if not isinstance(data, xr.Dataset):
17
+ raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
18
+
19
+ self._data = data
20
+
21
+ @override
22
+ def to_numpy(self) -> NumpySample:
23
+ return convert_netcdf_to_numpy_sample(self._data)
24
+
25
+ def save(self, path: str) -> None:
26
+ """Save site sample data as netCDF
27
+
28
+ Args:
29
+ path: Path to save the netCDF file
30
+ """
31
+ self._data.to_netcdf(path, mode="w", engine="h5netcdf")
32
+
33
+ @classmethod
34
+ def load(cls, path: str) -> 'SiteSample':
35
+ """Load site sample data from netCDF
36
+
37
+ Args:
38
+ path: Path to load the netCDF file from
39
+ """
40
+ return cls(xr.open_dataset(path))
41
+
42
+ # TODO - placeholder for now
43
+ def plot(self) -> None:
44
+ raise NotImplementedError("Plotting not yet implemented for SiteSample")
@@ -0,0 +1,75 @@
1
+ """PVNet UK Regional sample implementation for dataset handling and visualisation"""
2
+
3
+ from typing_extensions import override
4
+
5
+ import torch
6
+
7
+ from ocf_data_sampler.sample.base import SampleBase, NumpySample
8
+ from ocf_data_sampler.numpy_sample import (
9
+ NWPSampleKey,
10
+ GSPSampleKey,
11
+ SatelliteSampleKey
12
+ )
13
+
14
+
15
+ class UKRegionalSample(SampleBase):
16
+ """Handles UK Regional PVNet data operations"""
17
+
18
+ def __init__(self, data: NumpySample):
19
+ self._data = data
20
+
21
+ @override
22
+ def to_numpy(self) -> NumpySample:
23
+ return self._data
24
+
25
+ def save(self, path: str) -> None:
26
+ """Save PVNet sample as pickle format using torch.save
27
+
28
+ Args:
29
+ path: Path to save the sample data to
30
+ """
31
+ torch.save(self._data, path)
32
+
33
+ @classmethod
34
+ def load(cls, path: str) -> 'UKRegionalSample':
35
+ """Load PVNet sample data from .pt format
36
+
37
+ Args:
38
+ path: Path to load the sample data from
39
+ """
40
+ # TODO: We should move away from using torch.load(..., weights_only=False)
41
+ return cls(torch.load(path, weights_only=False))
42
+
43
+ def plot(self) -> None:
44
+ """Creates visualisations for NWP, GSP, solar position, and satellite data"""
45
+ from matplotlib import pyplot as plt
46
+
47
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
48
+
49
+ if NWPSampleKey.nwp in self._data:
50
+ first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
51
+ if 'nwp' in first_nwp:
52
+ axes[0, 1].imshow(first_nwp['nwp'][0])
53
+ title = 'NWP (First Channel)'
54
+ if NWPSampleKey.channel_names in first_nwp:
55
+ channel_names = first_nwp[NWPSampleKey.channel_names]
56
+ if channel_names:
57
+ title = f'NWP: {channel_names[0]}'
58
+ axes[0, 1].set_title(title)
59
+
60
+ if GSPSampleKey.gsp in self._data:
61
+ axes[0, 0].plot(self._data[GSPSampleKey.gsp])
62
+ axes[0, 0].set_title('GSP Generation')
63
+
64
+ if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
65
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
66
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
67
+ axes[1, 1].set_title('Solar Position')
68
+ axes[1, 1].legend()
69
+
70
+ if SatelliteSampleKey.satellite_actual in self._data:
71
+ axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
72
+ axes[1, 0].set_title('Satellite Data')
73
+
74
+ plt.tight_layout()
75
+ plt.show()
@@ -0,0 +1,52 @@
1
+ """Functions for simulating dropout in time series data
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
6
+ import numpy as np
7
+ import pandas as pd
8
+ import xarray as xr
9
+
10
+
11
+ def draw_dropout_time(
12
+ t0: pd.Timestamp,
13
+ dropout_timedeltas: list[pd.Timedelta],
14
+ dropout_frac: float,
15
+ ) -> pd.Timestamp:
16
+ """Randomly pick a dropout time from a list of timedeltas
17
+
18
+ Args:
19
+ t0: The forecast init-time
20
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
21
+ dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
22
+ inclusive
23
+ """
24
+
25
+ if dropout_frac>0:
26
+ assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
27
+
28
+ for t in dropout_timedeltas:
29
+ assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
30
+
31
+ assert 0 <= dropout_frac <= 1
32
+
33
+ if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
34
+ dropout_time = t0
35
+ else:
36
+ dropout_time = t0 + np.random.choice(dropout_timedeltas)
37
+
38
+ return dropout_time
39
+
40
+
41
+ def apply_dropout_time(
42
+ ds: xr.DataArray,
43
+ dropout_time: pd.Timestamp,
44
+ ) -> xr.DataArray:
45
+ """Apply dropout time to the data
46
+
47
+ Args:
48
+ ds: Xarray DataArray with 'time_utc' coordiante
49
+ dropout_time: Time after which data is set to NaN
50
+ """
51
+ # This replaces the times after the dropout with NaNs
52
+ return ds.where(ds.time_utc <= dropout_time)
@@ -1,10 +1,12 @@
1
- """fill time periods"""
1
+ """Fill time periods between start and end dates at specified frequency"""
2
2
 
3
3
  import pandas as pd
4
4
  import numpy as np
5
5
 
6
6
 
7
7
  def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
+ """Generate DatetimeIndex for all timestamps between start and end dates"""
9
+
8
10
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
9
11
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
10
12
  date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
@@ -5,7 +5,6 @@ import pandas as pd
5
5
  from ocf_data_sampler.load.utils import check_time_unique_increasing
6
6
 
7
7
 
8
-
9
8
  def find_contiguous_time_periods(
10
9
  datetimes: pd.DatetimeIndex,
11
10
  min_seq_length: int,
@@ -186,9 +186,8 @@ class PVNetUKRegionalDataset(Dataset):
186
186
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
187
187
  """
188
188
 
189
- config = load_yaml_configuration(config_filename)
190
-
191
- # Validate channels for NWP and satellite data
189
+ # config = load_yaml_configuration(config_filename)
190
+ config: Configuration = load_yaml_configuration(config_filename)
192
191
  validate_nwp_channels(config)
193
192
  validate_satellite_channels(config)
194
193