ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +75 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.43.dist-info/METADATA +154 -0
  32. ocf_data_sampler-0.0.43.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/LICENSE +0 -0
@@ -0,0 +1,62 @@
1
+ """location"""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel, Field, model_validator
7
+
8
+
9
+ allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
10
+
11
+ class Location(BaseModel):
12
+ """Represent a spatial location."""
13
+
14
+ coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
15
+ x: float
16
+ y: float
17
+ id: Optional[int] = Field(None)
18
+
19
+ @model_validator(mode='after')
20
+ def validate_coordinate_system(self):
21
+ """Validate 'coordinate_system'"""
22
+ if self.coordinate_system not in allowed_coordinate_systems:
23
+ raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
24
+ return self
25
+
26
+ @model_validator(mode='after')
27
+ def validate_x(self):
28
+ """Validate 'x'"""
29
+ min_x: float
30
+ max_x: float
31
+
32
+ co = self.coordinate_system
33
+ if co == "osgb":
34
+ min_x, max_x = -103976.3, 652897.98
35
+ if co == "lon_lat":
36
+ min_x, max_x = -180, 180
37
+ if co == "geostationary":
38
+ min_x, max_x = -5568748.275756836, 5567248.074173927
39
+ if co == "idx":
40
+ min_x, max_x = 0, np.inf
41
+ if self.x < min_x or self.x > max_x:
42
+ raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
43
+ return self
44
+
45
+ @model_validator(mode='after')
46
+ def validate_y(self):
47
+ """Validate 'y'"""
48
+ min_y: float
49
+ max_y: float
50
+
51
+ co = self.coordinate_system
52
+ if co == "osgb":
53
+ min_y, max_y = -16703.87, 1199851.44
54
+ if co == "lon_lat":
55
+ min_y, max_y = -90, 90
56
+ if co == "geostationary":
57
+ min_y, max_y = 1393687.2151494026, 5570748.323202133
58
+ if co == "idx":
59
+ min_y, max_y = 0, np.inf
60
+ if self.y < min_y or self.y > max_y:
61
+ raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
62
+ return self
@@ -5,15 +5,15 @@ import logging
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from ocf_datapipes.utils import Location
9
- from ocf_datapipes.utils.geospatial import (
10
- lon_lat_to_geostationary_area_coords,
8
+ from ocf_data_sampler.select.location import Location
9
+ from ocf_data_sampler.select.geospatial import (
11
10
  lon_lat_to_osgb,
11
+ lon_lat_to_geostationary_area_coords,
12
12
  osgb_to_geostationary_area_coords,
13
13
  osgb_to_lon_lat,
14
14
  spatial_coord_type,
15
15
  )
16
- from ocf_datapipes.utils.utils import searchsorted
16
+
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -45,9 +45,6 @@ def convert_coords_to_match_xarray(
45
45
  if from_coords == "osgb":
46
46
  x, y = osgb_to_geostationary_area_coords(x, y, da)
47
47
 
48
- elif from_coords == "lon_lat":
49
- x, y = lon_lat_to_geostationary_area_coords(x, y, da)
50
-
51
48
  elif target_coords == "lon_lat":
52
49
  if from_coords == "osgb":
53
50
  x, y = osgb_to_lon_lat(x, y)
@@ -105,7 +102,7 @@ def _get_idx_of_pixel_closest_to_poi(
105
102
 
106
103
  def _get_idx_of_pixel_closest_to_poi_geostationary(
107
104
  da: xr.DataArray,
108
- center_osgb: Location,
105
+ center: Location,
109
106
  ) -> Location:
110
107
  """
111
108
  Return x and y index location of pixel at center of region of interest.
@@ -120,7 +117,12 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
120
117
 
121
118
  _, x_dim, y_dim = spatial_coord_type(da)
122
119
 
123
- x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da)
120
+ if center.coordinate_system == 'osgb':
121
+ x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
122
+ elif center.coordinate_system == 'lon_lat':
123
+ x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
124
+ else:
125
+ x,y = center.x, center.y
124
126
  center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
125
127
 
126
128
  # Check that the requested point lies within the data
@@ -130,13 +132,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
130
132
  f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
131
133
 
132
134
  # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
133
- x_index_at_center = searchsorted(
134
- da[x_dim].values, center_geostationary.x, assume_ascending=True
135
- )
136
-
137
- y_index_at_center = searchsorted(
138
- da[y_dim].values, center_geostationary.y, assume_ascending=True
139
- )
135
+ x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
136
+ y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
140
137
 
141
138
  return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
142
139
 
@@ -39,23 +39,14 @@ def _sel_fillinterp(
39
39
  def select_time_slice(
40
40
  ds: xr.DataArray,
41
41
  t0: pd.Timestamp,
42
+ interval_start: pd.Timedelta,
43
+ interval_end: pd.Timedelta,
42
44
  sample_period_duration: pd.Timedelta,
43
- history_duration: pd.Timedelta | None = None,
44
- forecast_duration: pd.Timedelta | None = None,
45
- interval_start: pd.Timedelta | None = None,
46
- interval_end: pd.Timedelta | None = None,
47
45
  fill_selection: bool = False,
48
46
  max_steps_gap: int = 0,
49
47
  ):
50
48
  """Select a time slice from a Dataset or DataArray."""
51
- used_duration = history_duration is not None and forecast_duration is not None
52
- used_intervals = interval_start is not None and interval_end is not None
53
- assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
54
49
  assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
55
-
56
- if used_duration:
57
- interval_start = - history_duration
58
- interval_end = forecast_duration
59
50
 
60
51
  if fill_selection and max_steps_gap == 0:
61
52
  _sel = _sel_fillnan
@@ -75,11 +66,11 @@ def select_time_slice(
75
66
 
76
67
 
77
68
  def select_time_slice_nwp(
78
- ds: xr.DataArray,
69
+ da: xr.DataArray,
79
70
  t0: pd.Timestamp,
71
+ interval_start: pd.Timedelta,
72
+ interval_end: pd.Timedelta,
80
73
  sample_period_duration: pd.Timedelta,
81
- history_duration: pd.Timedelta,
82
- forecast_duration: pd.Timedelta,
83
74
  dropout_timedeltas: list[pd.Timedelta] | None = None,
84
75
  dropout_frac: float | None = 0,
85
76
  accum_channels: list[str] = [],
@@ -92,31 +83,31 @@ def select_time_slice_nwp(
92
83
  ), "dropout timedeltas must be negative"
93
84
  assert len(dropout_timedeltas) >= 1
94
85
  assert 0 <= dropout_frac <= 1
95
- _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
86
+ consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
96
87
 
97
88
 
98
89
  # The accumatation and non-accumulation channels
99
90
  accum_channels = np.intersect1d(
100
- ds[channel_dim_name].values, accum_channels
91
+ da[channel_dim_name].values, accum_channels
101
92
  )
102
93
  non_accum_channels = np.setdiff1d(
103
- ds[channel_dim_name].values, accum_channels
94
+ da[channel_dim_name].values, accum_channels
104
95
  )
105
96
 
106
- start_dt = (t0 - history_duration).ceil(sample_period_duration)
107
- end_dt = (t0 + forecast_duration).ceil(sample_period_duration)
97
+ start_dt = (t0 + interval_start).ceil(sample_period_duration)
98
+ end_dt = (t0 + interval_end).ceil(sample_period_duration)
108
99
 
109
100
  target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
110
101
 
111
102
  # Maybe apply NWP dropout
112
- if _consider_dropout and (np.random.uniform() < dropout_frac):
103
+ if consider_dropout and (np.random.uniform() < dropout_frac):
113
104
  dt = np.random.choice(dropout_timedeltas)
114
105
  t0_available = t0 + dt
115
106
  else:
116
107
  t0_available = t0
117
108
 
118
109
  # Forecasts made up to and including t0
119
- available_init_times = ds.init_time_utc.sel(
110
+ available_init_times = da.init_time_utc.sel(
120
111
  init_time_utc=slice(None, t0_available)
121
112
  )
122
113
 
@@ -139,7 +130,7 @@ def select_time_slice_nwp(
139
130
  step_indexer = xr.DataArray(steps, coords=coords)
140
131
 
141
132
  if len(accum_channels) == 0:
142
- xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer)
133
+ da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
143
134
 
144
135
  else:
145
136
  # First minimise the size of the dataset we are diffing
@@ -149,7 +140,7 @@ def select_time_slice_nwp(
149
140
  min_step = min(steps)
150
141
  max_step = max(steps) + sample_period_duration
151
142
 
152
- xr_min = ds.sel(
143
+ da_min = da.sel(
153
144
  {
154
145
  "init_time_utc": unique_init_times,
155
146
  "step": slice(min_step, max_step),
@@ -157,28 +148,28 @@ def select_time_slice_nwp(
157
148
  )
158
149
 
159
150
  # Slice out the data which does not need to be diffed
160
- xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels})
161
- xr_sel_non_accum = xr_non_accum.sel(
151
+ da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
152
+ da_sel_non_accum = da_non_accum.sel(
162
153
  step=step_indexer, init_time_utc=init_time_indexer
163
154
  )
164
155
 
165
156
  # Slice out the channels which need to be diffed
166
- xr_accum = xr_min.sel({channel_dim_name: accum_channels})
157
+ da_accum = da_min.sel({channel_dim_name: accum_channels})
167
158
 
168
159
  # Take the diff and slice requested data
169
- xr_accum = xr_accum.diff(dim="step", label="lower")
170
- xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
160
+ da_accum = da_accum.diff(dim="step", label="lower")
161
+ da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
171
162
 
172
163
  # Join diffed and non-diffed variables
173
- xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name)
164
+ da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
174
165
 
175
166
  # Reorder the variable back to the original order
176
- xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values})
167
+ da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
177
168
 
178
169
  # Rename the diffed channels
179
- xr_sel[channel_dim_name] = [
170
+ da_sel[channel_dim_name] = [
180
171
  f"diff_{v}" if v in accum_channels else v
181
- for v in xr_sel[channel_dim_name].values
172
+ for v in da_sel[channel_dim_name].values
182
173
  ]
183
174
 
184
- return xr_sel
175
+ return da_sel
@@ -0,0 +1,53 @@
1
+ """ Functions for selecting data around a given location """
2
+ from ocf_data_sampler.config import Configuration
3
+ from ocf_data_sampler.select.location import Location
4
+ from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
5
+
6
+
7
+ def slice_datasets_by_space(
8
+ datasets_dict: dict,
9
+ location: Location,
10
+ config: Configuration,
11
+ ) -> dict:
12
+ """Slice the dictionary of input data sources around a given location
13
+
14
+ Args:
15
+ datasets_dict: Dictionary of the input data sources
16
+ location: The location to sample around
17
+ config: Configuration object.
18
+ """
19
+
20
+ assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})
21
+
22
+ sliced_datasets_dict = {}
23
+
24
+ if "nwp" in datasets_dict:
25
+
26
+ sliced_datasets_dict["nwp"] = {}
27
+
28
+ for nwp_key, nwp_config in config.input_data.nwp.items():
29
+
30
+ sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
31
+ datasets_dict["nwp"][nwp_key],
32
+ location,
33
+ height_pixels=nwp_config.image_size_pixels_height,
34
+ width_pixels=nwp_config.image_size_pixels_width,
35
+ )
36
+
37
+ if "sat" in datasets_dict:
38
+ sat_config = config.input_data.satellite
39
+
40
+ sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
41
+ datasets_dict["sat"],
42
+ location,
43
+ height_pixels=sat_config.image_size_pixels_height,
44
+ width_pixels=sat_config.image_size_pixels_width,
45
+ )
46
+
47
+ if "gsp" in datasets_dict:
48
+ sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
49
+
50
+ if "site" in datasets_dict:
51
+ sliced_datasets_dict["site"] = datasets_dict["site"].sel(site_id=location.id)
52
+
53
+ return sliced_datasets_dict
@@ -0,0 +1,125 @@
1
+ """ Slice datasets by time"""
2
+ import pandas as pd
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
6
+ from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
7
+ from ocf_data_sampler.utils import minutes
8
+
9
+
10
+ def slice_datasets_by_time(
11
+ datasets_dict: dict,
12
+ t0: pd.Timestamp,
13
+ config: Configuration,
14
+ ) -> dict:
15
+ """Slice the dictionary of input data sources around a given t0 time
16
+
17
+ Args:
18
+ datasets_dict: Dictionary of the input data sources
19
+ t0: The init-time
20
+ config: Configuration object.
21
+ """
22
+
23
+ sliced_datasets_dict = {}
24
+
25
+ if "nwp" in datasets_dict:
26
+
27
+ sliced_datasets_dict["nwp"] = {}
28
+
29
+ for nwp_key, da_nwp in datasets_dict["nwp"].items():
30
+
31
+ nwp_config = config.input_data.nwp[nwp_key]
32
+
33
+ sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
34
+ da_nwp,
35
+ t0,
36
+ sample_period_duration=minutes(nwp_config.time_resolution_minutes),
37
+ interval_start=minutes(nwp_config.interval_start_minutes),
38
+ interval_end=minutes(nwp_config.interval_end_minutes),
39
+ dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
40
+ dropout_frac=nwp_config.dropout_fraction,
41
+ accum_channels=nwp_config.accum_channels,
42
+ )
43
+
44
+ if "sat" in datasets_dict:
45
+
46
+ sat_config = config.input_data.satellite
47
+
48
+ sliced_datasets_dict["sat"] = select_time_slice(
49
+ datasets_dict["sat"],
50
+ t0,
51
+ sample_period_duration=minutes(sat_config.time_resolution_minutes),
52
+ interval_start=minutes(sat_config.interval_start_minutes),
53
+ interval_end=minutes(sat_config.interval_end_minutes),
54
+ max_steps_gap=2,
55
+ )
56
+
57
+ # Randomly sample dropout
58
+ sat_dropout_time = draw_dropout_time(
59
+ t0,
60
+ dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
61
+ dropout_frac=sat_config.dropout_fraction,
62
+ )
63
+
64
+ # Apply the dropout
65
+ sliced_datasets_dict["sat"] = apply_dropout_time(
66
+ sliced_datasets_dict["sat"],
67
+ sat_dropout_time,
68
+ )
69
+
70
+ if "gsp" in datasets_dict:
71
+ gsp_config = config.input_data.gsp
72
+
73
+ sliced_datasets_dict["gsp_future"] = select_time_slice(
74
+ datasets_dict["gsp"],
75
+ t0,
76
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
77
+ interval_start=minutes(gsp_config.time_resolution_minutes),
78
+ interval_end=minutes(gsp_config.interval_end_minutes),
79
+ )
80
+
81
+ sliced_datasets_dict["gsp"] = select_time_slice(
82
+ datasets_dict["gsp"],
83
+ t0,
84
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
85
+ interval_start=minutes(gsp_config.interval_start_minutes),
86
+ interval_end=minutes(0),
87
+ )
88
+
89
+ # Dropout on the GSP, but not the future GSP
90
+ gsp_dropout_time = draw_dropout_time(
91
+ t0,
92
+ dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
93
+ dropout_frac=gsp_config.dropout_fraction,
94
+ )
95
+
96
+ sliced_datasets_dict["gsp"] = apply_dropout_time(
97
+ sliced_datasets_dict["gsp"],
98
+ gsp_dropout_time
99
+ )
100
+
101
+ if "site" in datasets_dict:
102
+ site_config = config.input_data.site
103
+
104
+ sliced_datasets_dict["site"] = select_time_slice(
105
+ datasets_dict["site"],
106
+ t0,
107
+ sample_period_duration=minutes(site_config.time_resolution_minutes),
108
+ interval_start=minutes(site_config.interval_start_minutes),
109
+ interval_end=minutes(site_config.interval_end_minutes),
110
+ )
111
+
112
+ # Randomly sample dropout
113
+ site_dropout_time = draw_dropout_time(
114
+ t0,
115
+ dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
116
+ dropout_frac=site_config.dropout_fraction,
117
+ )
118
+
119
+ # Apply the dropout
120
+ sliced_datasets_dict["site"] = apply_dropout_time(
121
+ sliced_datasets_dict["site"],
122
+ site_dropout_time,
123
+ )
124
+
125
+ return sliced_datasets_dict
@@ -1 +1,2 @@
1
-
1
+ from .pvnet_uk_regional import PVNetUKRegionalDataset
2
+ from .site import SitesDataset
@@ -0,0 +1,131 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xarray as xr
4
+ from typing import Optional
5
+
6
+ from ocf_data_sampler.config import Configuration
7
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS,RSS_MEAN,RSS_STD
8
+ from ocf_data_sampler.numpy_sample import (
9
+ convert_nwp_to_numpy_sample,
10
+ convert_satellite_to_numpy_sample,
11
+ convert_gsp_to_numpy_sample,
12
+ make_sun_position_numpy_sample,
13
+ )
14
+ from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
15
+ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
16
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
17
+ from ocf_data_sampler.select.location import Location
18
+ from ocf_data_sampler.utils import minutes
19
+
20
+
21
+ def process_and_combine_datasets(
22
+ dataset_dict: dict,
23
+ config: Configuration,
24
+ t0: Optional[pd.Timestamp] = None,
25
+ location: Optional[Location] = None,
26
+ target_key: str = 'gsp'
27
+ ) -> dict:
28
+
29
+ """Normalise and convert data to numpy arrays"""
30
+ numpy_modalities = []
31
+
32
+ if "nwp" in dataset_dict:
33
+
34
+ nwp_numpy_modalities = dict()
35
+
36
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
37
+ # Standardise
38
+ provider = config.input_data.nwp[nwp_key].provider
39
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
40
+ # Convert to NumpySample
41
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
42
+
43
+ # Combine the NWPs into NumpySample
44
+ numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
45
+
46
+
47
+ if "sat" in dataset_dict:
48
+ # Standardise
49
+ da_sat = dataset_dict["sat"]
50
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
51
+
52
+ # Convert to NumpySample
53
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
54
+
55
+
56
+ gsp_config = config.input_data.gsp
57
+
58
+ if "gsp" in dataset_dict:
59
+ da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
60
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
61
+
62
+ numpy_modalities.append(
63
+ convert_gsp_to_numpy_sample(
64
+ da_gsp,
65
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
66
+ )
67
+ )
68
+
69
+ # Add coordinate data
70
+ # TODO: Do we need all of these?
71
+ numpy_modalities.append(
72
+ {
73
+ GSPSampleKey.gsp_id: location.id,
74
+ GSPSampleKey.x_osgb: location.x,
75
+ GSPSampleKey.y_osgb: location.y,
76
+ }
77
+ )
78
+
79
+ if target_key == 'gsp':
80
+ # Make sun coords NumpySample
81
+ datetimes = pd.date_range(
82
+ t0+minutes(gsp_config.interval_start_minutes),
83
+ t0+minutes(gsp_config.interval_end_minutes),
84
+ freq=minutes(gsp_config.time_resolution_minutes),
85
+ )
86
+
87
+ lon, lat = osgb_to_lon_lat(location.x, location.y)
88
+
89
+ numpy_modalities.append(
90
+ make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
91
+ )
92
+
93
+ # Combine all the modalities and fill NaNs
94
+ combined_sample = merge_dicts(numpy_modalities)
95
+ combined_sample = fill_nans_in_arrays(combined_sample)
96
+
97
+ return combined_sample
98
+
99
+ def merge_dicts(list_of_dicts: list[dict]) -> dict:
100
+ """Merge a list of dictionaries into a single dictionary"""
101
+ # TODO: This doesn't account for duplicate keys, which will be overwritten
102
+ combined_dict = {}
103
+ for d in list_of_dicts:
104
+ combined_dict.update(d)
105
+ return combined_dict
106
+
107
+ def fill_nans_in_arrays(sample: dict) -> dict:
108
+ """Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
109
+
110
+ Operation is performed in-place on the sample.
111
+ """
112
+ for k, v in sample.items():
113
+ if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
114
+ if np.isnan(v).any():
115
+ sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
116
+
117
+ # Recursion is included to reach NWP arrays in subdict
118
+ elif isinstance(v, dict):
119
+ fill_nans_in_arrays(v)
120
+
121
+ return sample
122
+
123
+
124
+ def compute(xarray_dict: dict) -> dict:
125
+ """Eagerly load a nested dictionary of xarray DataArrays"""
126
+ for k, v in xarray_dict.items():
127
+ if isinstance(v, dict):
128
+ xarray_dict[k] = compute(v)
129
+ else:
130
+ xarray_dict[k] = v.compute(scheduler="single-threaded")
131
+ return xarray_dict