ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,59 +1,62 @@
1
- """Torch dataset for sites"""
1
+ """Torch dataset for sites."""
2
2
 
3
3
  import logging
4
+
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import xarray as xr
7
- from typing import Tuple
8
-
9
8
  from torch.utils.data import Dataset
9
+ from typing_extensions import override
10
10
 
11
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
12
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
12
13
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
14
+ from ocf_data_sampler.numpy_sample import (
15
+ NWPSampleKey,
16
+ convert_nwp_to_numpy_sample,
17
+ convert_satellite_to_numpy_sample,
18
+ convert_site_to_numpy_sample,
19
+ make_datetime_numpy_dict,
20
+ make_sun_position_numpy_sample,
21
+ )
13
22
  from ocf_data_sampler.select import (
14
23
  Location,
15
24
  fill_time_periods,
16
25
  find_contiguous_t0_periods,
17
26
  intersection_of_multiple_dataframes_of_periods,
18
- slice_datasets_by_time, slice_datasets_by_space
27
+ slice_datasets_by_space,
28
+ slice_datasets_by_time,
19
29
  )
20
- from ocf_data_sampler.utils import minutes
21
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
22
- from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
-
24
- from ocf_data_sampler.numpy_sample import (
25
- convert_site_to_numpy_sample,
26
- convert_satellite_to_numpy_sample,
27
- convert_nwp_to_numpy_sample,
28
- make_datetime_numpy_dict,
29
- make_sun_position_numpy_sample,
30
+ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
+ fill_nans_in_arrays,
32
+ merge_dicts,
30
33
  )
31
- from ocf_data_sampler.numpy_sample import NWPSampleKey
32
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
33
-
34
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
34
35
  from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
36
  validate_nwp_channels,
36
37
  validate_satellite_channels,
37
38
  )
39
+ from ocf_data_sampler.utils import minutes
38
40
 
39
41
  xr.set_options(keep_attrs=True)
40
42
 
41
43
 
42
44
  class SitesDataset(Dataset):
45
+ """A torch Dataset for creating PVNet Site samples."""
46
+
43
47
  def __init__(
44
48
  self,
45
49
  config_filename: str,
46
50
  start_time: str | None = None,
47
51
  end_time: str | None = None,
48
- ):
49
- """A torch Dataset for creating PVNet Site samples
52
+ ) -> None:
53
+ """A torch Dataset for creating PVNet Site samples.
50
54
 
51
55
  Args:
52
56
  config_filename: Path to the configuration file
53
57
  start_time: Limit the init-times to be after this
54
58
  end_time: Limit the init-times to be before this
55
59
  """
56
-
57
60
  config: Configuration = load_yaml_configuration(config_filename)
58
61
  validate_nwp_channels(config)
59
62
  validate_satellite_channels(config)
@@ -65,28 +68,31 @@ class SitesDataset(Dataset):
65
68
  self.config = config
66
69
 
67
70
  # get all locations
68
- self.locations = self.get_locations(datasets_dict['site'])
71
+ self.locations = self.get_locations(datasets_dict["site"])
69
72
 
70
73
  # Get t0 times where all input data is available
71
74
  valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
72
75
 
73
76
  # Filter t0 times to given range
74
77
  if start_time is not None:
75
- valid_t0_and_site_ids \
76
- = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] >= pd.Timestamp(start_time)]
78
+ valid_t0_and_site_ids = valid_t0_and_site_ids[
79
+ valid_t0_and_site_ids["t0"] >= pd.Timestamp(start_time)
80
+ ]
77
81
 
78
82
  if end_time is not None:
79
- valid_t0_and_site_ids \
80
- = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)]
83
+ valid_t0_and_site_ids = valid_t0_and_site_ids[
84
+ valid_t0_and_site_ids["t0"] <= pd.Timestamp(end_time)
85
+ ]
81
86
 
82
87
  # Assign coords and indices to self
83
88
  self.valid_t0_and_site_ids = valid_t0_and_site_ids
84
89
 
85
- def __len__(self):
90
+ @override
91
+ def __len__(self) -> int:
86
92
  return len(self.valid_t0_and_site_ids)
87
-
88
- def __getitem__(self, idx):
89
93
 
94
+ @override
95
+ def __getitem__(self, idx: int) -> dict:
90
96
  # Get the coordinates of the sample
91
97
  t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
92
98
 
@@ -97,7 +103,7 @@ class SitesDataset(Dataset):
97
103
  return self._get_sample(t0, location)
98
104
 
99
105
  def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
100
- """Generate the PVNet sample for given coordinates
106
+ """Generate the PVNet sample for given coordinates.
101
107
 
102
108
  Args:
103
109
  t0: init-time for sample
@@ -106,7 +112,7 @@ class SitesDataset(Dataset):
106
112
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
107
113
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
108
114
 
109
- sample = self.process_and_combine_site_sample_dict(sample_dict)
115
+ sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
110
116
  sample = sample.compute()
111
117
  return sample
112
118
 
@@ -119,20 +125,20 @@ class SitesDataset(Dataset):
119
125
  t0: init-time for sample
120
126
  site_id: site id as int
121
127
  """
122
-
123
128
  location = self.get_location_from_site_id(site_id)
124
129
 
125
130
  return self._get_sample(t0, location)
126
-
127
- def get_location_from_site_id(self, site_id):
128
- """Get location from system id"""
129
131
 
132
+ def get_location_from_site_id(self, site_id: int) -> Location:
133
+ """Get location from system id."""
130
134
  locations = [loc for loc in self.locations if loc.id == site_id]
131
135
  if len(locations) == 0:
132
136
  raise ValueError(f"Location not found for site_id {site_id}")
133
137
 
134
138
  if len(locations) > 1:
135
- logging.warning(f"Multiple locations found for site_id {site_id}, but will take the first")
139
+ logging.warning(
140
+ f"Multiple locations found for site_id {site_id}, but will take the first",
141
+ )
136
142
 
137
143
  return locations[0]
138
144
 
@@ -140,7 +146,7 @@ class SitesDataset(Dataset):
140
146
  self,
141
147
  datasets_dict: dict,
142
148
  ) -> pd.DataFrame:
143
- """Find the t0 times where all of the requested input data is available
149
+ """Find the t0 times where all of the requested input data is available.
144
150
 
145
151
  The idea is to
146
152
  1. Get valid time period for nwp and satellite
@@ -150,9 +156,8 @@ class SitesDataset(Dataset):
150
156
  datasets_dict: A dictionary of input datasets
151
157
  config: Configuration file
152
158
  """
153
-
154
159
  # 1. Get valid time period for nwp and satellite
155
- datasets_without_site = {k:v for k, v in datasets_dict.items() if k!="site"}
160
+ datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
156
161
  valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
157
162
 
158
163
  # 2. Now lets loop over each location in system id and find the valid periods
@@ -166,39 +171,37 @@ class SitesDataset(Dataset):
166
171
 
167
172
  # drop any nan values
168
173
  # not sure this is right?
169
- site = site.dropna(dim='time_utc')
174
+ site = site.dropna(dim="time_utc")
170
175
 
171
176
  # Get the valid time periods for this location
172
177
  time_periods = find_contiguous_t0_periods(
173
178
  pd.DatetimeIndex(site["time_utc"]),
174
- sample_period_duration=minutes(site_config.time_resolution_minutes),
179
+ time_resolution=minutes(site_config.time_resolution_minutes),
175
180
  interval_start=minutes(site_config.interval_start_minutes),
176
181
  interval_end=minutes(site_config.interval_end_minutes),
177
182
  )
178
183
  valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
179
- [valid_time_periods, time_periods]
184
+ [valid_time_periods, time_periods],
180
185
  )
181
186
 
182
187
  # Fill out the contiguous time periods to get the t0 times
183
188
  valid_t0_times_per_site = fill_time_periods(
184
189
  valid_time_periods_per_site,
185
- freq=minutes(site_config.time_resolution_minutes)
190
+ freq=minutes(site_config.time_resolution_minutes),
186
191
  )
187
192
 
188
193
  valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
189
- valid_t0_per_site['site_id'] = site_id
194
+ valid_t0_per_site["site_id"] = site_id
190
195
  valid_t0_and_site_ids.append(valid_t0_per_site)
191
196
 
192
197
  valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
193
- valid_t0_and_site_ids.index.name = 't0'
198
+ valid_t0_and_site_ids.index.name = "t0"
194
199
  valid_t0_and_site_ids.reset_index(inplace=True)
195
200
 
196
201
  return valid_t0_and_site_ids
197
202
 
198
-
199
- def get_locations(self, site_xr: xr.Dataset):
200
- """Get list of locations of all sites"""
201
-
203
+ def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
204
+ """Get list of locations of all sites."""
202
205
  locations = []
203
206
  for site_id in site_xr.site_id.values:
204
207
  site = site_xr.sel(site_id=site_id)
@@ -206,7 +209,7 @@ class SitesDataset(Dataset):
206
209
  id=site_id,
207
210
  x=site.longitude.values,
208
211
  y=site.latitude.values,
209
- coordinate_system="lon_lat"
212
+ coordinate_system="lon_lat",
210
213
  )
211
214
  locations.append(location)
212
215
 
@@ -215,29 +218,29 @@ class SitesDataset(Dataset):
215
218
  def process_and_combine_site_sample_dict(
216
219
  self,
217
220
  dataset_dict: dict,
221
+ t0: pd.Timestamp,
218
222
  ) -> xr.Dataset:
219
- """
220
- Normalize and combine data into a single xr Dataset
223
+ """Normalize and combine data into a single xr Dataset.
221
224
 
222
225
  Args:
223
226
  dataset_dict: dict containing sliced xr DataArrays
224
227
  config: Configuration for the model
228
+ t0: The initial timestamp of the sample
225
229
 
226
230
  Returns:
227
231
  xr.Dataset: A merged Dataset with nans filled in.
228
-
229
- """
230
232
 
233
+ """
231
234
  data_arrays = []
232
235
 
233
236
  if "nwp" in dataset_dict:
234
237
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
235
238
  provider = self.config.input_data.nwp[nwp_key].provider
236
-
239
+
237
240
  # Standardise
238
241
  da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
239
242
  data_arrays.append((f"nwp-{provider}", da_nwp))
240
-
243
+
241
244
  if "sat" in dataset_dict:
242
245
  da_sat = dataset_dict["sat"]
243
246
 
@@ -257,33 +260,57 @@ class SitesDataset(Dataset):
257
260
  datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
258
261
  datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
259
262
  combined_sample_dataset = combined_sample_dataset.assign_coords(
260
- {k: ("site__time_utc", v) for k, v in datetime_features.items()}
263
+ {k: ("site__time_utc", v) for k, v in datetime_features.items()},
261
264
  )
262
265
 
263
- # add sun features
264
- sun_position_features = make_sun_position_numpy_sample(
265
- datetimes=datetimes,
266
- lon=combined_sample_dataset.site__longitude.values,
267
- lat=combined_sample_dataset.site__latitude.values,
268
- key_prefix="site_",
269
- )
270
- combined_sample_dataset = combined_sample_dataset.assign_coords(
271
- {k: ("site__time_utc", v) for k, v in sun_position_features.items()}
266
+ # Only add solar position if explicitly configured
267
+ has_solar_config = (
268
+ hasattr(self.config.input_data, "solar_position") and
269
+ self.config.input_data.solar_position is not None
272
270
  )
273
271
 
274
- # TODO include t0_index in xr dataset?
272
+ if has_solar_config:
273
+ solar_config = self.config.input_data.solar_position
274
+
275
+ # Datetime range - solar config params
276
+ solar_datetimes = pd.date_range(
277
+ t0 + minutes(solar_config.interval_start_minutes),
278
+ t0 + minutes(solar_config.interval_end_minutes),
279
+ freq=minutes(solar_config.time_resolution_minutes),
280
+ )
281
+
282
+ # Calculate sun position features
283
+ sun_position_features = make_sun_position_numpy_sample(
284
+ datetimes=solar_datetimes,
285
+ lon=combined_sample_dataset.site__longitude.values,
286
+ lat=combined_sample_dataset.site__latitude.values,
287
+ )
288
+
289
+ # Dimension state for solar position data
290
+ solar_dim_name = "solar_time_utc"
291
+ combined_sample_dataset = combined_sample_dataset.assign_coords(
292
+ {solar_dim_name: solar_datetimes},
293
+ )
294
+
295
+ # Assign solar position values
296
+ for key, values in sun_position_features.items():
297
+ combined_sample_dataset = combined_sample_dataset.assign_coords(
298
+ {key: (solar_dim_name, values)},
299
+ )
300
+
301
+ # TODO include t0_index in xr dataset?
275
302
 
276
303
  # Fill any nan values
277
304
  return combined_sample_dataset.fillna(0.0)
278
305
 
279
306
  def merge_data_arrays(
280
- self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]
307
+ self,
308
+ normalised_data_arrays: list[tuple[str, xr.DataArray]],
281
309
  ) -> xr.Dataset:
282
- """
283
- Combine a list of DataArrays into a single Dataset with unique naming conventions.
310
+ """Combine a list of DataArrays into a single Dataset with unique naming conventions.
284
311
 
285
312
  Args:
286
- list_of_arrays: List of tuples where each tuple contains:
313
+ normalised_data_arrays: List of tuples where each tuple contains:
287
314
  - A string (key name).
288
315
  - An xarray.DataArray.
289
316
 
@@ -295,7 +322,7 @@ class SitesDataset(Dataset):
295
322
  for key, data_array in normalised_data_arrays:
296
323
  # Ensure all attributes are strings for consistency
297
324
  data_array = data_array.assign_attrs(
298
- {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
325
+ {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
299
326
  )
300
327
 
301
328
  # Convert DataArray to Dataset with the variable name as the key
@@ -303,15 +330,16 @@ class SitesDataset(Dataset):
303
330
 
304
331
  # Prepend key name to all dimension and coordinate names for uniqueness
305
332
  dataset = dataset.rename(
306
- {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
333
+ {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
307
334
  )
308
335
  dataset = dataset.rename(
309
- {coord: f"{key}__{coord}" for coord in dataset.coords}
336
+ {coord: f"{key}__{coord}" for coord in dataset.coords},
310
337
  )
311
338
 
312
339
  # Handle concatenation dimension if applicable
313
340
  concat_dim = (
314
- f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
341
+ f"{key}__target_time_utc"
342
+ if f"{key}__target_time_utc" in dataset.coords
315
343
  else f"{key}__time_utc"
316
344
  )
317
345
 
@@ -325,20 +353,22 @@ class SitesDataset(Dataset):
325
353
 
326
354
  # Ensure all datasets are valid xarray.Dataset objects
327
355
  for ds in datasets:
328
- assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"
356
+ if not isinstance(ds, xr.Dataset):
357
+ raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
329
358
 
330
359
  # Merge all prepared datasets
331
360
  combined_dataset = xr.merge(datasets)
332
361
 
333
362
  return combined_dataset
334
363
 
364
+
335
365
  # ----- functions to load presaved samples ------
336
366
 
337
- def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
338
- """Convert a netcdf dataset to a numpy sample"""
339
367
 
368
+ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
369
+ """Convert a netcdf dataset to a numpy sample."""
340
370
  # convert the single dataset to a dict of arrays
341
- sample_dict = convert_from_dataset_to_dict_datasets(ds)
371
+ sample_dict = convert_from_dataset_to_dict_datasets(ds)
342
372
 
343
373
  if "satellite" in sample_dict:
344
374
  # rename satellite to satellite actual # TODO this could be improves
@@ -349,14 +379,21 @@ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
349
379
  dataset_dict=sample_dict,
350
380
  )
351
381
 
352
- # TODO think about normalization, maybe its done not in sample creation, maybe its done afterwards,
353
- # to allow it to be flexible
382
+ # Extraction of solar position coords
383
+ solar_keys = ["solar_azimuth", "solar_elevation"]
384
+ for key in solar_keys:
385
+ if key in ds.coords:
386
+ sample[key] = ds.coords[key].values
387
+
388
+ # TODO think about normalization:
389
+ # * maybe its done not in sample creation, maybe its done afterwards,
390
+ # to allow it to be flexible
354
391
 
355
392
  return sample
356
393
 
394
+
357
395
  def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
358
- """
359
- Convert a combined sample dataset to a dict of datasets for each input
396
+ """Convert a combined sample dataset to a dict of datasets for each input.
360
397
 
361
398
  Args:
362
399
  combined_dataset: The combined NetCDF dataset
@@ -374,10 +411,10 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
374
411
  if f"{key}__" not in dim:
375
412
  dataset: xr.Dataset = dataset.drop(dim)
376
413
  dataset = dataset.rename(
377
- {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords}
414
+ {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
378
415
  )
379
416
  dataset: xr.Dataset = dataset.rename(
380
- {coord: coord.split(f"{key}__")[1] for coord in dataset.coords}
417
+ {coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
381
418
  )
382
419
  # Split the dataset by the prefix
383
420
  datasets[key] = dataset
@@ -391,22 +428,21 @@ def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
391
428
  """Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
392
429
  nwp_prefix = f"nwp{sep}"
393
430
  new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
394
- nwp_keys = [k for k in d.keys() if k.startswith(nwp_prefix)]
431
+ nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
395
432
  if len(nwp_keys) > 0:
396
433
  nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
397
434
  new_dict["nwp"] = nwp_subdict
398
435
  return new_dict
399
436
 
437
+
400
438
  def convert_to_numpy_and_combine(
401
439
  dataset_dict: dict,
402
440
  ) -> dict:
403
- """Convert input data in a dict to numpy arrays"""
404
-
441
+ """Convert input data in a dict to numpy arrays."""
405
442
  numpy_modalities = []
406
443
 
407
444
  if "nwp" in dataset_dict:
408
-
409
- nwp_numpy_modalities = dict()
445
+ nwp_numpy_modalities = {}
410
446
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
411
447
  # Convert to NumpySample
412
448
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
@@ -427,7 +463,7 @@ def convert_to_numpy_and_combine(
427
463
  numpy_modalities.append(
428
464
  convert_site_to_numpy_sample(
429
465
  da_sites,
430
- )
466
+ ),
431
467
  )
432
468
 
433
469
  # Combine all the modalities and fill NaNs
@@ -437,25 +473,23 @@ def convert_to_numpy_and_combine(
437
473
  return combined_sample
438
474
 
439
475
 
440
- def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float=0.1):
441
- """
442
- Coarsen the data to a specified resolution in degrees.
443
-
476
+ def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
477
+ """Coarsen the data to a specified resolution in degrees.
478
+
444
479
  Args:
445
480
  xr_data: xarray dataset to coarsen
446
481
  coarsen_to_deg: resolution to coarsen to in degrees
447
482
  """
448
-
449
483
  if "latitude" in xr_data.coords and "longitude" in xr_data.coords:
450
- step = np.abs(xr_data.latitude.values[1]-xr_data.latitude.values[0])
451
- step = np.round(step,4)
452
- coarsen_factor = int(coarsen_to_deg/step)
484
+ step = np.abs(xr_data.latitude.values[1] - xr_data.latitude.values[0])
485
+ step = np.round(step, 4)
486
+ coarsen_factor = int(coarsen_to_deg / step)
453
487
  if coarsen_factor > 1:
454
488
  xr_data = xr_data.coarsen(
455
489
  latitude=coarsen_factor,
456
490
  longitude=coarsen_factor,
457
491
  boundary="pad",
458
- coord_func="min"
492
+ coord_func="min",
459
493
  ).mean()
460
-
461
- return xr_data
494
+
495
+ return xr_data
@@ -1,13 +1,17 @@
1
+ """Utility functions for merging dictionaries and filling NaNs in arrays."""
2
+
1
3
  import numpy as np
2
4
 
5
+
3
6
  def merge_dicts(list_of_dicts: list[dict]) -> dict:
4
- """Merge a list of dictionaries into a single dictionary"""
7
+ """Merge a list of dictionaries into a single dictionary."""
5
8
  # TODO: This doesn't account for duplicate keys, which will be overwritten
6
9
  combined_dict = {}
7
10
  for d in list_of_dicts:
8
11
  combined_dict.update(d)
9
12
  return combined_dict
10
13
 
14
+
11
15
  def fill_nans_in_arrays(sample: dict) -> dict:
12
16
  """Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
13
17
 
@@ -22,4 +26,4 @@ def fill_nans_in_arrays(sample: dict) -> dict:
22
26
  elif isinstance(v, dict):
23
27
  fill_nans_in_arrays(v)
24
28
 
25
- return sample
29
+ return sample
@@ -1,34 +1,31 @@
1
+ """Functions pertaining to finding valid time periods for the input data."""
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
 
4
6
  from ocf_data_sampler.config import Configuration
5
7
  from ocf_data_sampler.select.find_contiguous_time_periods import (
8
+ find_contiguous_t0_periods,
6
9
  find_contiguous_t0_periods_nwp,
7
- find_contiguous_t0_periods,
8
10
  intersection_of_multiple_dataframes_of_periods,
9
11
  )
10
12
  from ocf_data_sampler.utils import minutes
11
13
 
12
14
 
13
-
14
- def find_valid_time_periods(
15
- datasets_dict: dict,
16
- config: Configuration,
17
- ):
18
- """Find the t0 times where all of the requested input data is available
15
+ def find_valid_time_periods(datasets_dict: dict, config: Configuration) -> pd.DataFrame:
16
+ """Find the t0 times where all of the requested input data is available.
19
17
 
20
18
  Args:
21
19
  datasets_dict: A dictionary of input datasets
22
20
  config: Configuration file
23
21
  """
22
+ if not set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"}):
23
+ raise ValueError(f"Invalid keys in datasets_dict: {datasets_dict.keys()}")
24
24
 
25
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
26
-
27
- contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
28
-
25
+ # Used to store contiguous time periods from each data source
26
+ contiguous_time_periods: dict[str : pd.DataFrame] = {}
29
27
  if "nwp" in datasets_dict:
30
28
  for nwp_key, nwp_config in config.input_data.nwp.items():
31
-
32
29
  da = datasets_dict["nwp"][nwp_key]
33
30
 
34
31
  if nwp_config.dropout_timedeltas_minutes is None:
@@ -59,8 +56,12 @@ def find_valid_time_periods(
59
56
  max_staleness = max_possible_staleness
60
57
  else:
61
58
  # Make sure the max acceptable staleness isn't longer than the max possible
62
- assert max_staleness <= max_possible_staleness
63
-
59
+ if max_staleness > max_possible_staleness:
60
+ raise ValueError(
61
+ f"max_staleness_minutes is too long for the input data, "
62
+ f"{max_staleness=}, {max_possible_staleness=}",
63
+ )
64
+
64
65
  # Find the first forecast step
65
66
  first_forecast_step = pd.Timedelta(da["step"].min().item())
66
67
 
@@ -69,34 +70,34 @@ def find_valid_time_periods(
69
70
  interval_start=minutes(nwp_config.interval_start_minutes),
70
71
  max_staleness=max_staleness,
71
72
  max_dropout=max_dropout,
72
- first_forecast_step = first_forecast_step,
73
+ first_forecast_step=first_forecast_step,
73
74
  )
74
75
 
75
- contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
76
+ contiguous_time_periods[f"nwp_{nwp_key}"] = time_periods
76
77
 
77
78
  if "sat" in datasets_dict:
78
79
  sat_config = config.input_data.satellite
79
80
 
80
81
  time_periods = find_contiguous_t0_periods(
81
82
  pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
82
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
83
+ time_resolution=minutes(sat_config.time_resolution_minutes),
83
84
  interval_start=minutes(sat_config.interval_start_minutes),
84
85
  interval_end=minutes(sat_config.interval_end_minutes),
85
86
  )
86
87
 
87
- contiguous_time_periods['sat'] = time_periods
88
+ contiguous_time_periods["sat"] = time_periods
88
89
 
89
90
  if "gsp" in datasets_dict:
90
91
  gsp_config = config.input_data.gsp
91
92
 
92
93
  time_periods = find_contiguous_t0_periods(
93
94
  pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
94
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
95
+ time_resolution=minutes(gsp_config.time_resolution_minutes),
95
96
  interval_start=minutes(gsp_config.interval_start_minutes),
96
97
  interval_end=minutes(gsp_config.interval_end_minutes),
97
98
  )
98
99
 
99
- contiguous_time_periods['gsp'] = time_periods
100
+ contiguous_time_periods["gsp"] = time_periods
100
101
 
101
102
  # just get the values (not the keys)
102
103
  contiguous_time_periods_values = list(contiguous_time_periods.values())
@@ -104,7 +105,7 @@ def find_valid_time_periods(
104
105
  # Find joint overlapping contiguous time periods
105
106
  if len(contiguous_time_periods_values) > 1:
106
107
  valid_time_periods = intersection_of_multiple_dataframes_of_periods(
107
- contiguous_time_periods_values
108
+ contiguous_time_periods_values,
108
109
  )
109
110
  else:
110
111
  valid_time_periods = contiguous_time_periods_values[0]
@@ -113,4 +114,4 @@ def find_valid_time_periods(
113
114
  if len(valid_time_periods) == 0:
114
115
  raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
115
116
 
116
- return valid_time_periods
117
+ return valid_time_periods