ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (78) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +146 -64
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/load/gsp.py +6 -5
  5. ocf_data_sampler/load/load_dataset.py +5 -6
  6. ocf_data_sampler/load/nwp/nwp.py +17 -5
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  8. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  9. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  10. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  11. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  12. ocf_data_sampler/load/satellite.py +9 -10
  13. ocf_data_sampler/load/site.py +10 -6
  14. ocf_data_sampler/load/utils.py +21 -16
  15. ocf_data_sampler/numpy_sample/collate.py +10 -9
  16. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  17. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  18. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  19. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  20. ocf_data_sampler/numpy_sample/site.py +5 -8
  21. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  22. ocf_data_sampler/sample/base.py +15 -17
  23. ocf_data_sampler/sample/site.py +13 -20
  24. ocf_data_sampler/sample/uk_regional.py +29 -35
  25. ocf_data_sampler/select/dropout.py +16 -14
  26. ocf_data_sampler/select/fill_time_periods.py +15 -5
  27. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  28. ocf_data_sampler/select/geospatial.py +63 -54
  29. ocf_data_sampler/select/location.py +16 -51
  30. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  31. ocf_data_sampler/select/select_time_slice.py +71 -58
  32. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  33. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  34. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
  35. ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
  36. ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  37. ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +63 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler/constants.py +0 -222
  48. ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
  49. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  50. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  51. tests/__init__.py +0 -0
  52. tests/config/test_config.py +0 -113
  53. tests/config/test_load.py +0 -7
  54. tests/config/test_save.py +0 -28
  55. tests/conftest.py +0 -319
  56. tests/load/test_load_gsp.py +0 -15
  57. tests/load/test_load_nwp.py +0 -21
  58. tests/load/test_load_satellite.py +0 -17
  59. tests/load/test_load_sites.py +0 -14
  60. tests/numpy_sample/test_collate.py +0 -21
  61. tests/numpy_sample/test_datetime_features.py +0 -37
  62. tests/numpy_sample/test_gsp.py +0 -38
  63. tests/numpy_sample/test_nwp.py +0 -13
  64. tests/numpy_sample/test_satellite.py +0 -40
  65. tests/numpy_sample/test_sun_position.py +0 -81
  66. tests/select/test_dropout.py +0 -69
  67. tests/select/test_fill_time_periods.py +0 -28
  68. tests/select/test_find_contiguous_time_periods.py +0 -202
  69. tests/select/test_location.py +0 -67
  70. tests/select/test_select_spatial_slice.py +0 -154
  71. tests/select/test_select_time_slice.py +0 -275
  72. tests/test_sample/test_base.py +0 -164
  73. tests/test_sample/test_site_sample.py +0 -165
  74. tests/test_sample/test_uk_regional_sample.py +0 -136
  75. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  76. tests/torch_datasets/test_pvnet_uk.py +0 -154
  77. tests/torch_datasets/test_site.py +0 -226
  78. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,63 +1,58 @@
1
- """Torch dataset for sites"""
1
+ """Torch dataset for sites."""
2
2
 
3
3
  import logging
4
+
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import xarray as xr
7
- from typing import Tuple
8
-
9
8
  from torch.utils.data import Dataset
9
+ from typing_extensions import override
10
10
 
11
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
12
12
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
13
+ from ocf_data_sampler.numpy_sample import (
14
+ NWPSampleKey,
15
+ convert_nwp_to_numpy_sample,
16
+ convert_satellite_to_numpy_sample,
17
+ convert_site_to_numpy_sample,
18
+ make_datetime_numpy_dict,
19
+ make_sun_position_numpy_sample,
20
+ )
13
21
  from ocf_data_sampler.select import (
14
22
  Location,
15
23
  fill_time_periods,
16
24
  find_contiguous_t0_periods,
17
25
  intersection_of_multiple_dataframes_of_periods,
18
- slice_datasets_by_time, slice_datasets_by_space
19
- )
20
- from ocf_data_sampler.utils import minutes
21
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
22
- from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
-
24
- from ocf_data_sampler.numpy_sample import (
25
- convert_site_to_numpy_sample,
26
- convert_satellite_to_numpy_sample,
27
- convert_nwp_to_numpy_sample,
28
- make_datetime_numpy_dict,
29
- make_sun_position_numpy_sample,
26
+ slice_datasets_by_space,
27
+ slice_datasets_by_time,
30
28
  )
31
- from ocf_data_sampler.numpy_sample import NWPSampleKey
32
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
33
-
34
- from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
- validate_nwp_channels,
36
- validate_satellite_channels,
29
+ from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
+ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
+ fill_nans_in_arrays,
32
+ merge_dicts,
37
33
  )
34
+ from ocf_data_sampler.utils import minutes
38
35
 
39
36
  xr.set_options(keep_attrs=True)
40
37
 
41
38
 
42
39
  class SitesDataset(Dataset):
40
+ """A torch Dataset for creating PVNet Site samples."""
41
+
43
42
  def __init__(
44
43
  self,
45
44
  config_filename: str,
46
45
  start_time: str | None = None,
47
46
  end_time: str | None = None,
48
- ):
49
- """A torch Dataset for creating PVNet Site samples
47
+ ) -> None:
48
+ """A torch Dataset for creating PVNet Site samples.
50
49
 
51
50
  Args:
52
51
  config_filename: Path to the configuration file
53
52
  start_time: Limit the init-times to be after this
54
53
  end_time: Limit the init-times to be before this
55
54
  """
56
-
57
55
  config: Configuration = load_yaml_configuration(config_filename)
58
- validate_nwp_channels(config)
59
- validate_satellite_channels(config)
60
-
61
56
  datasets_dict = get_dataset_dict(config.input_data)
62
57
 
63
58
  # Assign config and input data to self
@@ -65,28 +60,31 @@ class SitesDataset(Dataset):
65
60
  self.config = config
66
61
 
67
62
  # get all locations
68
- self.locations = self.get_locations(datasets_dict['site'])
63
+ self.locations = self.get_locations(datasets_dict["site"])
69
64
 
70
65
  # Get t0 times where all input data is available
71
66
  valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
72
67
 
73
68
  # Filter t0 times to given range
74
69
  if start_time is not None:
75
- valid_t0_and_site_ids \
76
- = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] >= pd.Timestamp(start_time)]
70
+ valid_t0_and_site_ids = valid_t0_and_site_ids[
71
+ valid_t0_and_site_ids["t0"] >= pd.Timestamp(start_time)
72
+ ]
77
73
 
78
74
  if end_time is not None:
79
- valid_t0_and_site_ids \
80
- = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)]
75
+ valid_t0_and_site_ids = valid_t0_and_site_ids[
76
+ valid_t0_and_site_ids["t0"] <= pd.Timestamp(end_time)
77
+ ]
81
78
 
82
79
  # Assign coords and indices to self
83
80
  self.valid_t0_and_site_ids = valid_t0_and_site_ids
84
81
 
85
- def __len__(self):
82
+ @override
83
+ def __len__(self) -> int:
86
84
  return len(self.valid_t0_and_site_ids)
87
-
88
- def __getitem__(self, idx):
89
85
 
86
+ @override
87
+ def __getitem__(self, idx: int) -> dict:
90
88
  # Get the coordinates of the sample
91
89
  t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
92
90
 
@@ -97,7 +95,7 @@ class SitesDataset(Dataset):
97
95
  return self._get_sample(t0, location)
98
96
 
99
97
  def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
100
- """Generate the PVNet sample for given coordinates
98
+ """Generate the PVNet sample for given coordinates.
101
99
 
102
100
  Args:
103
101
  t0: init-time for sample
@@ -106,7 +104,7 @@ class SitesDataset(Dataset):
106
104
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
107
105
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
108
106
 
109
- sample = self.process_and_combine_site_sample_dict(sample_dict)
107
+ sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
110
108
  sample = sample.compute()
111
109
  return sample
112
110
 
@@ -119,20 +117,20 @@ class SitesDataset(Dataset):
119
117
  t0: init-time for sample
120
118
  site_id: site id as int
121
119
  """
122
-
123
120
  location = self.get_location_from_site_id(site_id)
124
121
 
125
122
  return self._get_sample(t0, location)
126
-
127
- def get_location_from_site_id(self, site_id):
128
- """Get location from system id"""
129
123
 
124
+ def get_location_from_site_id(self, site_id: int) -> Location:
125
+ """Get location from system id."""
130
126
  locations = [loc for loc in self.locations if loc.id == site_id]
131
127
  if len(locations) == 0:
132
128
  raise ValueError(f"Location not found for site_id {site_id}")
133
129
 
134
130
  if len(locations) > 1:
135
- logging.warning(f"Multiple locations found for site_id {site_id}, but will take the first")
131
+ logging.warning(
132
+ f"Multiple locations found for site_id {site_id}, but will take the first",
133
+ )
136
134
 
137
135
  return locations[0]
138
136
 
@@ -140,7 +138,7 @@ class SitesDataset(Dataset):
140
138
  self,
141
139
  datasets_dict: dict,
142
140
  ) -> pd.DataFrame:
143
- """Find the t0 times where all of the requested input data is available
141
+ """Find the t0 times where all of the requested input data is available.
144
142
 
145
143
  The idea is to
146
144
  1. Get valid time period for nwp and satellite
@@ -150,9 +148,8 @@ class SitesDataset(Dataset):
150
148
  datasets_dict: A dictionary of input datasets
151
149
  config: Configuration file
152
150
  """
153
-
154
151
  # 1. Get valid time period for nwp and satellite
155
- datasets_without_site = {k:v for k, v in datasets_dict.items() if k!="site"}
152
+ datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
156
153
  valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
157
154
 
158
155
  # 2. Now lets loop over each location in system id and find the valid periods
@@ -166,39 +163,37 @@ class SitesDataset(Dataset):
166
163
 
167
164
  # drop any nan values
168
165
  # not sure this is right?
169
- site = site.dropna(dim='time_utc')
166
+ site = site.dropna(dim="time_utc")
170
167
 
171
168
  # Get the valid time periods for this location
172
169
  time_periods = find_contiguous_t0_periods(
173
170
  pd.DatetimeIndex(site["time_utc"]),
174
- sample_period_duration=minutes(site_config.time_resolution_minutes),
171
+ time_resolution=minutes(site_config.time_resolution_minutes),
175
172
  interval_start=minutes(site_config.interval_start_minutes),
176
173
  interval_end=minutes(site_config.interval_end_minutes),
177
174
  )
178
175
  valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
179
- [valid_time_periods, time_periods]
176
+ [valid_time_periods, time_periods],
180
177
  )
181
178
 
182
179
  # Fill out the contiguous time periods to get the t0 times
183
180
  valid_t0_times_per_site = fill_time_periods(
184
181
  valid_time_periods_per_site,
185
- freq=minutes(site_config.time_resolution_minutes)
182
+ freq=minutes(site_config.time_resolution_minutes),
186
183
  )
187
184
 
188
185
  valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
189
- valid_t0_per_site['site_id'] = site_id
186
+ valid_t0_per_site["site_id"] = site_id
190
187
  valid_t0_and_site_ids.append(valid_t0_per_site)
191
188
 
192
189
  valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
193
- valid_t0_and_site_ids.index.name = 't0'
190
+ valid_t0_and_site_ids.index.name = "t0"
194
191
  valid_t0_and_site_ids.reset_index(inplace=True)
195
192
 
196
193
  return valid_t0_and_site_ids
197
194
 
198
-
199
- def get_locations(self, site_xr: xr.Dataset):
200
- """Get list of locations of all sites"""
201
-
195
+ def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
196
+ """Get list of locations of all sites."""
202
197
  locations = []
203
198
  for site_id in site_xr.site_id.values:
204
199
  site = site_xr.sel(site_id=site_id)
@@ -206,7 +201,7 @@ class SitesDataset(Dataset):
206
201
  id=site_id,
207
202
  x=site.longitude.values,
208
203
  y=site.latitude.values,
209
- coordinate_system="lon_lat"
204
+ coordinate_system="lon_lat",
210
205
  )
211
206
  locations.append(location)
212
207
 
@@ -215,34 +210,48 @@ class SitesDataset(Dataset):
215
210
  def process_and_combine_site_sample_dict(
216
211
  self,
217
212
  dataset_dict: dict,
213
+ t0: pd.Timestamp,
218
214
  ) -> xr.Dataset:
219
- """
220
- Normalize and combine data into a single xr Dataset
215
+ """Normalize and combine data into a single xr Dataset.
221
216
 
222
217
  Args:
223
218
  dataset_dict: dict containing sliced xr DataArrays
224
- config: Configuration for the model
219
+ t0: The initial timestamp of the sample
225
220
 
226
221
  Returns:
227
222
  xr.Dataset: A merged Dataset with nans filled in.
228
-
229
- """
230
223
 
224
+ """
231
225
  data_arrays = []
232
226
 
233
227
  if "nwp" in dataset_dict:
234
228
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
235
229
  provider = self.config.input_data.nwp[nwp_key].provider
236
-
230
+
237
231
  # Standardise
238
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
232
+ da_channel_means = channel_dict_to_dataarray(
233
+ self.config.input_data.nwp[nwp_key].channel_means,
234
+ )
235
+ da_channel_stds = channel_dict_to_dataarray(
236
+ self.config.input_data.nwp[nwp_key].channel_stds,
237
+ )
238
+
239
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
240
+
239
241
  data_arrays.append((f"nwp-{provider}", da_nwp))
240
-
242
+
241
243
  if "sat" in dataset_dict:
242
244
  da_sat = dataset_dict["sat"]
243
245
 
244
- # Standardise
245
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
246
+ da_channel_means = channel_dict_to_dataarray(
247
+ self.config.input_data.satellite.channel_means,
248
+ )
249
+ da_channel_stds = channel_dict_to_dataarray(
250
+ self.config.input_data.satellite.channel_stds,
251
+ )
252
+
253
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
254
+
246
255
  data_arrays.append(("satellite", da_sat))
247
256
 
248
257
  if "site" in dataset_dict:
@@ -257,33 +266,57 @@ class SitesDataset(Dataset):
257
266
  datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
258
267
  datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
259
268
  combined_sample_dataset = combined_sample_dataset.assign_coords(
260
- {k: ("site__time_utc", v) for k, v in datetime_features.items()}
269
+ {k: ("site__time_utc", v) for k, v in datetime_features.items()},
261
270
  )
262
271
 
263
- # add sun features
264
- sun_position_features = make_sun_position_numpy_sample(
265
- datetimes=datetimes,
266
- lon=combined_sample_dataset.site__longitude.values,
267
- lat=combined_sample_dataset.site__latitude.values,
268
- key_prefix="site_",
269
- )
270
- combined_sample_dataset = combined_sample_dataset.assign_coords(
271
- {k: ("site__time_utc", v) for k, v in sun_position_features.items()}
272
+ # Only add solar position if explicitly configured
273
+ has_solar_config = (
274
+ hasattr(self.config.input_data, "solar_position") and
275
+ self.config.input_data.solar_position is not None
272
276
  )
273
277
 
274
- # TODO include t0_index in xr dataset?
278
+ if has_solar_config:
279
+ solar_config = self.config.input_data.solar_position
280
+
281
+ # Datetime range - solar config params
282
+ solar_datetimes = pd.date_range(
283
+ t0 + minutes(solar_config.interval_start_minutes),
284
+ t0 + minutes(solar_config.interval_end_minutes),
285
+ freq=minutes(solar_config.time_resolution_minutes),
286
+ )
287
+
288
+ # Calculate sun position features
289
+ sun_position_features = make_sun_position_numpy_sample(
290
+ datetimes=solar_datetimes,
291
+ lon=combined_sample_dataset.site__longitude.values,
292
+ lat=combined_sample_dataset.site__latitude.values,
293
+ )
294
+
295
+ # Dimension state for solar position data
296
+ solar_dim_name = "solar_time_utc"
297
+ combined_sample_dataset = combined_sample_dataset.assign_coords(
298
+ {solar_dim_name: solar_datetimes},
299
+ )
300
+
301
+ # Assign solar position values
302
+ for key, values in sun_position_features.items():
303
+ combined_sample_dataset = combined_sample_dataset.assign_coords(
304
+ {key: (solar_dim_name, values)},
305
+ )
306
+
307
+ # TODO include t0_index in xr dataset?
275
308
 
276
309
  # Fill any nan values
277
310
  return combined_sample_dataset.fillna(0.0)
278
311
 
279
312
  def merge_data_arrays(
280
- self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]
313
+ self,
314
+ normalised_data_arrays: list[tuple[str, xr.DataArray]],
281
315
  ) -> xr.Dataset:
282
- """
283
- Combine a list of DataArrays into a single Dataset with unique naming conventions.
316
+ """Combine a list of DataArrays into a single Dataset with unique naming conventions.
284
317
 
285
318
  Args:
286
- list_of_arrays: List of tuples where each tuple contains:
319
+ normalised_data_arrays: List of tuples where each tuple contains:
287
320
  - A string (key name).
288
321
  - An xarray.DataArray.
289
322
 
@@ -295,7 +328,7 @@ class SitesDataset(Dataset):
295
328
  for key, data_array in normalised_data_arrays:
296
329
  # Ensure all attributes are strings for consistency
297
330
  data_array = data_array.assign_attrs(
298
- {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
331
+ {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
299
332
  )
300
333
 
301
334
  # Convert DataArray to Dataset with the variable name as the key
@@ -303,15 +336,16 @@ class SitesDataset(Dataset):
303
336
 
304
337
  # Prepend key name to all dimension and coordinate names for uniqueness
305
338
  dataset = dataset.rename(
306
- {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
339
+ {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
307
340
  )
308
341
  dataset = dataset.rename(
309
- {coord: f"{key}__{coord}" for coord in dataset.coords}
342
+ {coord: f"{key}__{coord}" for coord in dataset.coords},
310
343
  )
311
344
 
312
345
  # Handle concatenation dimension if applicable
313
346
  concat_dim = (
314
- f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
347
+ f"{key}__target_time_utc"
348
+ if f"{key}__target_time_utc" in dataset.coords
315
349
  else f"{key}__time_utc"
316
350
  )
317
351
 
@@ -325,20 +359,22 @@ class SitesDataset(Dataset):
325
359
 
326
360
  # Ensure all datasets are valid xarray.Dataset objects
327
361
  for ds in datasets:
328
- assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"
362
+ if not isinstance(ds, xr.Dataset):
363
+ raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
329
364
 
330
365
  # Merge all prepared datasets
331
366
  combined_dataset = xr.merge(datasets)
332
367
 
333
368
  return combined_dataset
334
369
 
370
+
335
371
  # ----- functions to load presaved samples ------
336
372
 
337
- def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
338
- """Convert a netcdf dataset to a numpy sample"""
339
373
 
374
+ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
375
+ """Convert a netcdf dataset to a numpy sample."""
340
376
  # convert the single dataset to a dict of arrays
341
- sample_dict = convert_from_dataset_to_dict_datasets(ds)
377
+ sample_dict = convert_from_dataset_to_dict_datasets(ds)
342
378
 
343
379
  if "satellite" in sample_dict:
344
380
  # rename satellite to satellite actual # TODO this could be improves
@@ -349,14 +385,21 @@ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
349
385
  dataset_dict=sample_dict,
350
386
  )
351
387
 
352
- # TODO think about normalization, maybe its done not in sample creation, maybe its done afterwards,
353
- # to allow it to be flexible
388
+ # Extraction of solar position coords
389
+ solar_keys = ["solar_azimuth", "solar_elevation"]
390
+ for key in solar_keys:
391
+ if key in ds.coords:
392
+ sample[key] = ds.coords[key].values
393
+
394
+ # TODO think about normalization:
395
+ # * maybe its done not in sample creation, maybe its done afterwards,
396
+ # to allow it to be flexible
354
397
 
355
398
  return sample
356
399
 
400
+
357
401
  def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
358
- """
359
- Convert a combined sample dataset to a dict of datasets for each input
402
+ """Convert a combined sample dataset to a dict of datasets for each input.
360
403
 
361
404
  Args:
362
405
  combined_dataset: The combined NetCDF dataset
@@ -374,10 +417,10 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
374
417
  if f"{key}__" not in dim:
375
418
  dataset: xr.Dataset = dataset.drop(dim)
376
419
  dataset = dataset.rename(
377
- {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords}
420
+ {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
378
421
  )
379
422
  dataset: xr.Dataset = dataset.rename(
380
- {coord: coord.split(f"{key}__")[1] for coord in dataset.coords}
423
+ {coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
381
424
  )
382
425
  # Split the dataset by the prefix
383
426
  datasets[key] = dataset
@@ -391,22 +434,21 @@ def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
391
434
  """Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
392
435
  nwp_prefix = f"nwp{sep}"
393
436
  new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
394
- nwp_keys = [k for k in d.keys() if k.startswith(nwp_prefix)]
437
+ nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
395
438
  if len(nwp_keys) > 0:
396
439
  nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
397
440
  new_dict["nwp"] = nwp_subdict
398
441
  return new_dict
399
442
 
443
+
400
444
  def convert_to_numpy_and_combine(
401
445
  dataset_dict: dict,
402
446
  ) -> dict:
403
- """Convert input data in a dict to numpy arrays"""
404
-
447
+ """Convert input data in a dict to numpy arrays."""
405
448
  numpy_modalities = []
406
449
 
407
450
  if "nwp" in dataset_dict:
408
-
409
- nwp_numpy_modalities = dict()
451
+ nwp_numpy_modalities = {}
410
452
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
411
453
  # Convert to NumpySample
412
454
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
@@ -427,7 +469,7 @@ def convert_to_numpy_and_combine(
427
469
  numpy_modalities.append(
428
470
  convert_site_to_numpy_sample(
429
471
  da_sites,
430
- )
472
+ ),
431
473
  )
432
474
 
433
475
  # Combine all the modalities and fill NaNs
@@ -437,25 +479,23 @@ def convert_to_numpy_and_combine(
437
479
  return combined_sample
438
480
 
439
481
 
440
- def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float=0.1):
441
- """
442
- Coarsen the data to a specified resolution in degrees.
443
-
482
+ def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
483
+ """Coarsen the data to a specified resolution in degrees.
484
+
444
485
  Args:
445
486
  xr_data: xarray dataset to coarsen
446
487
  coarsen_to_deg: resolution to coarsen to in degrees
447
488
  """
448
-
449
489
  if "latitude" in xr_data.coords and "longitude" in xr_data.coords:
450
- step = np.abs(xr_data.latitude.values[1]-xr_data.latitude.values[0])
451
- step = np.round(step,4)
452
- coarsen_factor = int(coarsen_to_deg/step)
490
+ step = np.abs(xr_data.latitude.values[1] - xr_data.latitude.values[0])
491
+ step = np.round(step, 4)
492
+ coarsen_factor = int(coarsen_to_deg / step)
453
493
  if coarsen_factor > 1:
454
494
  xr_data = xr_data.coarsen(
455
495
  latitude=coarsen_factor,
456
496
  longitude=coarsen_factor,
457
497
  boundary="pad",
458
- coord_func="min"
498
+ coord_func="min",
459
499
  ).mean()
460
-
461
- return xr_data
500
+
501
+ return xr_data
@@ -0,0 +1,3 @@
1
+ from .channel_dict_to_dataarray import channel_dict_to_dataarray
2
+ from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
+ from .valid_time_periods import find_valid_time_periods
@@ -0,0 +1,11 @@
1
+ """Converts a dictionary of channel values to a DataArray."""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
+ """Converts a dictionary of channel values to a DataArray."""
8
+ return xr.DataArray(
9
+ list(channel_dict.values()),
10
+ coords={"channel": list(channel_dict.keys())},
11
+ )
@@ -1,13 +1,17 @@
1
+ """Utility functions for merging dictionaries and filling NaNs in arrays."""
2
+
1
3
  import numpy as np
2
4
 
5
+
3
6
  def merge_dicts(list_of_dicts: list[dict]) -> dict:
4
- """Merge a list of dictionaries into a single dictionary"""
7
+ """Merge a list of dictionaries into a single dictionary."""
5
8
  # TODO: This doesn't account for duplicate keys, which will be overwritten
6
9
  combined_dict = {}
7
10
  for d in list_of_dicts:
8
11
  combined_dict.update(d)
9
12
  return combined_dict
10
13
 
14
+
11
15
  def fill_nans_in_arrays(sample: dict) -> dict:
12
16
  """Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
13
17
 
@@ -22,4 +26,4 @@ def fill_nans_in_arrays(sample: dict) -> dict:
22
26
  elif isinstance(v, dict):
23
27
  fill_nans_in_arrays(v)
24
28
 
25
- return sample
29
+ return sample