ocf-data-sampler 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ocf_data_sampler/config/model.py +34 -0
  2. ocf_data_sampler/load/load_dataset.py +55 -0
  3. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  4. ocf_data_sampler/load/site.py +30 -0
  5. ocf_data_sampler/numpy_batch/__init__.py +4 -3
  6. ocf_data_sampler/numpy_batch/gsp.py +12 -12
  7. ocf_data_sampler/numpy_batch/nwp.py +14 -14
  8. ocf_data_sampler/numpy_batch/satellite.py +8 -8
  9. ocf_data_sampler/numpy_batch/site.py +29 -0
  10. ocf_data_sampler/select/__init__.py +8 -1
  11. ocf_data_sampler/select/dropout.py +2 -1
  12. ocf_data_sampler/select/geospatial.py +43 -1
  13. ocf_data_sampler/select/select_spatial_slice.py +8 -2
  14. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  15. ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
  16. ocf_data_sampler/time_functions.py +11 -0
  17. ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
  18. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
  19. ocf_data_sampler/torch_datasets/site.py +196 -0
  20. ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
  21. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/METADATA +1 -1
  22. ocf_data_sampler-0.0.25.dist-info/RECORD +66 -0
  23. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/WHEEL +1 -1
  24. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/top_level.txt +1 -0
  25. scripts/refactor_site.py +50 -0
  26. tests/conftest.py +62 -0
  27. tests/load/test_load_sites.py +14 -0
  28. tests/numpy_batch/test_gsp.py +1 -2
  29. tests/numpy_batch/test_nwp.py +1 -3
  30. tests/numpy_batch/test_satellite.py +1 -3
  31. tests/numpy_batch/test_sun_position.py +7 -7
  32. tests/torch_datasets/test_pvnet_uk_regional.py +4 -6
  33. tests/torch_datasets/test_site.py +85 -0
  34. ocf_data_sampler-0.0.23.dist-info/RECORD +0 -54
  35. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/LICENSE +0 -0
@@ -0,0 +1,153 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xarray as xr
4
+
5
+ from ocf_data_sampler.config import Configuration
6
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
7
+ from ocf_data_sampler.numpy_batch import (
8
+ convert_nwp_to_numpy_batch,
9
+ convert_satellite_to_numpy_batch,
10
+ convert_gsp_to_numpy_batch,
11
+ make_sun_position_numpy_batch,
12
+ convert_site_to_numpy_batch,
13
+ )
14
+ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
15
+ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
16
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
17
+ from ocf_data_sampler.select.location import Location
18
+ from ocf_data_sampler.time_functions import minutes
19
+
20
+
21
+ def process_and_combine_datasets(
22
+ dataset_dict: dict,
23
+ config: Configuration,
24
+ t0: pd.Timestamp,
25
+ location: Location,
26
+ sun_position_key: str = 'gsp'
27
+ ) -> dict:
28
+ """Normalize and convert data to numpy arrays"""
29
+
30
+ numpy_modalities = []
31
+
32
+ if "nwp" in dataset_dict:
33
+
34
+ nwp_numpy_modalities = dict()
35
+
36
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
37
+ # Standardise
38
+ provider = config.input_data.nwp[nwp_key].nwp_provider
39
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
40
+ # Convert to NumpyBatch
41
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
42
+
43
+ # Combine the NWPs into NumpyBatch
44
+ numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
45
+
46
+ if "sat" in dataset_dict:
47
+ # Satellite is already in the range [0-1] so no need to standardise
48
+ da_sat = dataset_dict["sat"]
49
+
50
+ # Convert to NumpyBatch
51
+ numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
52
+
53
+ gsp_config = config.input_data.gsp
54
+
55
+ if "gsp" in dataset_dict:
56
+ da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
57
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
58
+
59
+ numpy_modalities.append(
60
+ convert_gsp_to_numpy_batch(
61
+ da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes
62
+ )
63
+ )
64
+
65
+ # Add coordinate data
66
+ # TODO: Do we need all of these?
67
+ numpy_modalities.append(
68
+ {
69
+ GSPBatchKey.gsp_id: location.id,
70
+ GSPBatchKey.x_osgb: location.x,
71
+ GSPBatchKey.y_osgb: location.y,
72
+ }
73
+ )
74
+
75
+
76
+ if "site" in dataset_dict:
77
+ site_config = config.input_data.site
78
+ da_sites = dataset_dict["site"]
79
+ da_sites = da_sites / da_sites.capacity_kwp
80
+
81
+ numpy_modalities.append(
82
+ convert_site_to_numpy_batch(
83
+ da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes
84
+ )
85
+ )
86
+
87
+ if sun_position_key == 'gsp':
88
+ # Make sun coords NumpyBatch
89
+ datetimes = pd.date_range(
90
+ t0 - minutes(gsp_config.history_minutes),
91
+ t0 + minutes(gsp_config.forecast_minutes),
92
+ freq=minutes(gsp_config.time_resolution_minutes),
93
+ )
94
+
95
+ lon, lat = osgb_to_lon_lat(location.x, location.y)
96
+ key_prefix = "gsp"
97
+
98
+ elif sun_position_key == 'site':
99
+ # Make sun coords NumpyBatch
100
+ datetimes = pd.date_range(
101
+ t0 - minutes(site_config.history_minutes),
102
+ t0 + minutes(site_config.forecast_minutes),
103
+ freq=minutes(site_config.time_resolution_minutes),
104
+ )
105
+
106
+ lon, lat = location.x, location.y
107
+ key_prefix = "site"
108
+
109
+ numpy_modalities.append(
110
+ make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix)
111
+ )
112
+
113
+ # Combine all the modalities and fill NaNs
114
+ combined_sample = merge_dicts(numpy_modalities)
115
+ combined_sample = fill_nans_in_arrays(combined_sample)
116
+
117
+ return combined_sample
118
+
119
+
120
+ def merge_dicts(list_of_dicts: list[dict]) -> dict:
121
+ """Merge a list of dictionaries into a single dictionary"""
122
+ # TODO: This doesn't account for duplicate keys, which will be overwritten
123
+ combined_dict = {}
124
+ for d in list_of_dicts:
125
+ combined_dict.update(d)
126
+ return combined_dict
127
+
128
+
129
+ def fill_nans_in_arrays(batch: dict) -> dict:
130
+ """Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
131
+
132
+ Operation is performed in-place on the batch.
133
+ """
134
+ for k, v in batch.items():
135
+ if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
136
+ if np.isnan(v).any():
137
+ batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
138
+
139
+ # Recursion is included to reach NWP arrays in subdict
140
+ elif isinstance(v, dict):
141
+ fill_nans_in_arrays(v)
142
+
143
+ return batch
144
+
145
+
146
+ def compute(xarray_dict: dict) -> dict:
147
+ """Eagerly load a nested dictionary of xarray DataArrays"""
148
+ for k, v in xarray_dict.items():
149
+ if isinstance(v, dict):
150
+ xarray_dict[k] = compute(v)
151
+ else:
152
+ xarray_dict[k] = v.compute(scheduler="single-threaded")
153
+ return xarray_dict
@@ -2,100 +2,20 @@
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ import pkg_resources
5
6
  import xarray as xr
6
7
  from torch.utils.data import Dataset
7
- import pkg_resources
8
-
9
- from ocf_data_sampler.load.gsp import open_gsp
10
- from ocf_data_sampler.load.nwp import open_nwp
11
- from ocf_data_sampler.load.satellite import open_sat_data
12
-
13
- from ocf_data_sampler.select.find_contiguous_time_periods import (
14
- find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
15
- intersection_of_multiple_dataframes_of_periods,
16
- )
17
- from ocf_data_sampler.select.fill_time_periods import fill_time_periods
18
- from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
19
- from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
20
- from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
21
-
22
- from ocf_data_sampler.numpy_batch import (
23
- convert_gsp_to_numpy_batch,
24
- convert_nwp_to_numpy_batch,
25
- convert_satellite_to_numpy_batch,
26
- make_sun_position_numpy_batch,
27
- )
28
-
29
8
 
30
9
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
31
- from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
32
- from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
33
-
34
- from ocf_data_sampler.select.location import Location
35
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
36
-
37
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
38
-
39
-
40
-
10
+ from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
+ from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
12
+ from ocf_data_sampler.time_functions import minutes
13
+ from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
14
+ from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
41
15
 
42
16
  xr.set_options(keep_attrs=True)
43
17
 
44
18
 
45
-
46
- def minutes(minutes: list[float]):
47
- """Timedelta minutes
48
-
49
- Args:
50
- m: minutes
51
- """
52
- return pd.to_timedelta(minutes, unit="m")
53
-
54
-
55
- def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]:
56
- """Construct dictionary of all of the input data sources
57
-
58
- Args:
59
- config: Configuration file
60
- """
61
-
62
- in_config = config.input_data
63
-
64
- datasets_dict = {}
65
-
66
- # Load GSP data unless the path is None
67
- if in_config.gsp.gsp_zarr_path:
68
- da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
69
-
70
- # Remove national GSP
71
- datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
72
-
73
- # Load NWP data if in config
74
- if in_config.nwp:
75
-
76
- datasets_dict["nwp"] = {}
77
- for nwp_source, nwp_config in in_config.nwp.items():
78
-
79
- da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
80
-
81
- da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
82
-
83
- datasets_dict["nwp"][nwp_source] = da_nwp
84
-
85
- # Load satellite data if in config
86
- if in_config.satellite:
87
- sat_config = config.input_data.satellite
88
-
89
- da_sat = open_sat_data(sat_config.satellite_zarr_path)
90
-
91
- da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
92
-
93
- datasets_dict["sat"] = da_sat
94
-
95
- return datasets_dict
96
-
97
-
98
-
99
19
  def find_valid_t0_times(
100
20
  datasets_dict: dict,
101
21
  config: Configuration,
@@ -103,96 +23,11 @@ def find_valid_t0_times(
103
23
  """Find the t0 times where all of the requested input data is available
104
24
 
105
25
  Args:
106
- datasets_dict: A dictionary of input datasets
26
+ datasets_dict: A dictionary of input datasets
107
27
  config: Configuration file
108
28
  """
109
29
 
110
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
111
-
112
- contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
113
-
114
- if "nwp" in datasets_dict:
115
- for nwp_key, nwp_config in config.input_data.nwp.items():
116
-
117
- da = datasets_dict["nwp"][nwp_key]
118
-
119
- if nwp_config.dropout_timedeltas_minutes is None:
120
- max_dropout = minutes(0)
121
- else:
122
- max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
123
-
124
- if nwp_config.max_staleness_minutes is None:
125
- max_staleness = None
126
- else:
127
- max_staleness = minutes(nwp_config.max_staleness_minutes)
128
-
129
- # The last step of the forecast is lost if we have to diff channels
130
- if len(nwp_config.nwp_accum_channels) > 0:
131
- end_buffer = minutes(nwp_config.time_resolution_minutes)
132
- else:
133
- end_buffer = minutes(0)
134
-
135
- # This is the max staleness we can use considering the max step of the input data
136
- max_possible_staleness = (
137
- pd.Timedelta(da["step"].max().item())
138
- - minutes(nwp_config.forecast_minutes)
139
- - end_buffer
140
- )
141
-
142
- # Default to use max possible staleness unless specified in config
143
- if max_staleness is None:
144
- max_staleness = max_possible_staleness
145
- else:
146
- # Make sure the max acceptable staleness isn't longer than the max possible
147
- assert max_staleness <= max_possible_staleness
148
-
149
- time_periods = find_contiguous_t0_periods_nwp(
150
- datetimes=pd.DatetimeIndex(da["init_time_utc"]),
151
- history_duration=minutes(nwp_config.history_minutes),
152
- max_staleness=max_staleness,
153
- max_dropout=max_dropout,
154
- )
155
-
156
- contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
157
-
158
- if "sat" in datasets_dict:
159
- sat_config = config.input_data.satellite
160
-
161
- time_periods = find_contiguous_t0_periods(
162
- pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
163
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
164
- history_duration=minutes(sat_config.history_minutes),
165
- forecast_duration=minutes(sat_config.forecast_minutes),
166
- )
167
-
168
- contiguous_time_periods['sat'] = time_periods
169
-
170
- if "gsp" in datasets_dict:
171
- gsp_config = config.input_data.gsp
172
-
173
- time_periods = find_contiguous_t0_periods(
174
- pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
175
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
176
- history_duration=minutes(gsp_config.history_minutes),
177
- forecast_duration=minutes(gsp_config.forecast_minutes),
178
- )
179
-
180
- contiguous_time_periods['gsp'] = time_periods
181
-
182
- # just get the values (not the keys)
183
- contiguous_time_periods_values = list(contiguous_time_periods.values())
184
-
185
- # Find joint overlapping contiguous time periods
186
- if len(contiguous_time_periods_values) > 1:
187
- valid_time_periods = intersection_of_multiple_dataframes_of_periods(
188
- contiguous_time_periods_values
189
- )
190
- else:
191
- valid_time_periods = contiguous_time_periods_values[0]
192
-
193
- # check there are some valid time periods
194
- if len(valid_time_periods) == 0:
195
- raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
30
+ valid_time_periods = find_valid_time_periods(datasets_dict, config)
196
31
 
197
32
  # Fill out the contiguous time periods to get the t0 times
198
33
  valid_t0_times = fill_time_periods(
@@ -203,250 +38,6 @@ def find_valid_t0_times(
203
38
  return valid_t0_times
204
39
 
205
40
 
206
- def slice_datasets_by_space(
207
- datasets_dict: dict,
208
- location: Location,
209
- config: Configuration,
210
- ) -> dict:
211
- """Slice a dictionaries of input data sources around a given location
212
-
213
- Args:
214
- datasets_dict: Dictionary of the input data sources
215
- location: The location to sample around
216
- config: Configuration object.
217
- """
218
-
219
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
220
-
221
- sliced_datasets_dict = {}
222
-
223
- if "nwp" in datasets_dict:
224
-
225
- sliced_datasets_dict["nwp"] = {}
226
-
227
- for nwp_key, nwp_config in config.input_data.nwp.items():
228
-
229
- sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
230
- datasets_dict["nwp"][nwp_key],
231
- location,
232
- height_pixels=nwp_config.nwp_image_size_pixels_height,
233
- width_pixels=nwp_config.nwp_image_size_pixels_width,
234
- )
235
-
236
- if "sat" in datasets_dict:
237
- sat_config = config.input_data.satellite
238
-
239
- sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
240
- datasets_dict["sat"],
241
- location,
242
- height_pixels=sat_config.satellite_image_size_pixels_height,
243
- width_pixels=sat_config.satellite_image_size_pixels_width,
244
- )
245
-
246
- if "gsp" in datasets_dict:
247
- sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
248
-
249
- return sliced_datasets_dict
250
-
251
-
252
- def slice_datasets_by_time(
253
- datasets_dict: dict,
254
- t0: pd.Timedelta,
255
- config: Configuration,
256
- ) -> dict:
257
- """Slice a dictionaries of input data sources around a given t0 time
258
-
259
- Args:
260
- datasets_dict: Dictionary of the input data sources
261
- t0: The init-time
262
- config: Configuration object.
263
- """
264
-
265
- sliced_datasets_dict = {}
266
-
267
- if "nwp" in datasets_dict:
268
-
269
- sliced_datasets_dict["nwp"] = {}
270
-
271
- for nwp_key, da_nwp in datasets_dict["nwp"].items():
272
-
273
- nwp_config = config.input_data.nwp[nwp_key]
274
-
275
- sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
276
- da_nwp,
277
- t0,
278
- sample_period_duration=minutes(nwp_config.time_resolution_minutes),
279
- history_duration=minutes(nwp_config.history_minutes),
280
- forecast_duration=minutes(nwp_config.forecast_minutes),
281
- dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
282
- dropout_frac=nwp_config.dropout_fraction,
283
- accum_channels=nwp_config.nwp_accum_channels,
284
- )
285
-
286
- if "sat" in datasets_dict:
287
-
288
- sat_config = config.input_data.satellite
289
-
290
- sliced_datasets_dict["sat"] = select_time_slice(
291
- datasets_dict["sat"],
292
- t0,
293
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
294
- interval_start=minutes(-sat_config.history_minutes),
295
- interval_end=minutes(-sat_config.live_delay_minutes),
296
- max_steps_gap=2,
297
- )
298
-
299
- # Randomly sample dropout
300
- sat_dropout_time = draw_dropout_time(
301
- t0,
302
- dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
303
- dropout_frac=sat_config.dropout_fraction,
304
- )
305
-
306
- # Apply the dropout
307
- sliced_datasets_dict["sat"] = apply_dropout_time(
308
- sliced_datasets_dict["sat"],
309
- sat_dropout_time,
310
- )
311
-
312
- if "gsp" in datasets_dict:
313
- gsp_config = config.input_data.gsp
314
-
315
- sliced_datasets_dict["gsp_future"] = select_time_slice(
316
- datasets_dict["gsp"],
317
- t0,
318
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
319
- interval_start=minutes(30),
320
- interval_end=minutes(gsp_config.forecast_minutes),
321
- )
322
-
323
- sliced_datasets_dict["gsp"] = select_time_slice(
324
- datasets_dict["gsp"],
325
- t0,
326
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
327
- interval_start=-minutes(gsp_config.history_minutes),
328
- interval_end=minutes(0),
329
- )
330
-
331
- # Dropout on the GSP, but not the future GSP
332
- gsp_dropout_time = draw_dropout_time(
333
- t0,
334
- dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
335
- dropout_frac=gsp_config.dropout_fraction,
336
- )
337
-
338
- sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
339
-
340
- return sliced_datasets_dict
341
-
342
-
343
- def fill_nans_in_arrays(batch: dict) -> dict:
344
- """Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
345
-
346
- Operation is performed in-place on the batch.
347
- """
348
- for k, v in batch.items():
349
- if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
350
- if np.isnan(v).any():
351
- batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
352
-
353
- # Recursion is included to reach NWP arrays in subdict
354
- elif isinstance(v, dict):
355
- fill_nans_in_arrays(v)
356
-
357
- return batch
358
-
359
-
360
-
361
- def merge_dicts(list_of_dicts: list[dict]) -> dict:
362
- """Merge a list of dictionaries into a single dictionary"""
363
- # TODO: This doesn't account for duplicate keys, which will be overwritten
364
- combined_dict = {}
365
- for d in list_of_dicts:
366
- combined_dict.update(d)
367
- return combined_dict
368
-
369
-
370
- def process_and_combine_datasets(
371
- dataset_dict: dict,
372
- config: Configuration,
373
- t0: pd.Timedelta,
374
- location: Location,
375
- ) -> dict:
376
- """Normalize and convert data to numpy arrays"""
377
-
378
- numpy_modalities = []
379
-
380
- if "nwp" in dataset_dict:
381
-
382
- nwp_numpy_modalities = dict()
383
-
384
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
385
- # Standardise
386
- provider = config.input_data.nwp[nwp_key].nwp_provider
387
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
388
- # Convert to NumpyBatch
389
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
390
-
391
- # Combine the NWPs into NumpyBatch
392
- numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
393
-
394
- if "sat" in dataset_dict:
395
- # Satellite is already in the range [0-1] so no need to standardise
396
- da_sat = dataset_dict["sat"]
397
-
398
- # Convert to NumpyBatch
399
- numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
400
-
401
- gsp_config = config.input_data.gsp
402
-
403
- if "gsp" in dataset_dict:
404
- da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
405
- da_gsp = da_gsp / da_gsp.effective_capacity_mwp
406
-
407
- numpy_modalities.append(
408
- convert_gsp_to_numpy_batch(
409
- da_gsp,
410
- t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
411
- )
412
- )
413
-
414
- # Make sun coords NumpyBatch
415
- datetimes = pd.date_range(
416
- t0-minutes(gsp_config.history_minutes),
417
- t0+minutes(gsp_config.forecast_minutes),
418
- freq=minutes(gsp_config.time_resolution_minutes),
419
- )
420
-
421
- lon, lat = osgb_to_lon_lat(location.x, location.y)
422
-
423
- numpy_modalities.append(make_sun_position_numpy_batch(datetimes, lon, lat))
424
-
425
- # Add coordinate data
426
- # TODO: Do we need all of these?
427
- numpy_modalities.append({
428
- GSPBatchKey.gsp_id: location.id,
429
- GSPBatchKey.gsp_x_osgb: location.x,
430
- GSPBatchKey.gsp_y_osgb: location.y,
431
- })
432
-
433
- # Combine all the modalities and fill NaNs
434
- combined_sample = merge_dicts(numpy_modalities)
435
- combined_sample = fill_nans_in_arrays(combined_sample)
436
-
437
- return combined_sample
438
-
439
-
440
- def compute(xarray_dict: dict) -> dict:
441
- """Eagerly load a nested dictionary of xarray DataArrays"""
442
- for k, v in xarray_dict.items():
443
- if isinstance(v, dict):
444
- xarray_dict[k] = compute(v)
445
- else:
446
- xarray_dict[k] = v.compute(scheduler="single-threaded")
447
- return xarray_dict
448
-
449
-
450
41
  def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
451
42
  """Get list of locations of all GSPs"""
452
43
 
@@ -473,7 +64,6 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
473
64
  return locations
474
65
 
475
66
 
476
-
477
67
  class PVNetUKRegionalDataset(Dataset):
478
68
  def __init__(
479
69
  self,