ocf-data-sampler 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -1,75 +1,64 @@
1
- from ocf_data_sampler.numpy_sample import NWPSampleKey
2
-
3
1
  import numpy as np
4
- import logging
5
- from typing import Union
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
-
10
-
11
- def stack_np_samples_into_batch(dict_list):
12
- # """
13
- # Stacks Numpy samples into a batch
14
2
 
15
- # Args:
16
- # dict_list: A list of dict-like Numpy samples to stack
17
3
 
18
- # Returns:
19
- # The stacked NumpySample object, aka a batch
20
- # """
4
+ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
5
+ """Stacks list of dict samples into a dict where all samples are joined along a new axis
21
6
 
22
- if not dict_list:
23
- raise ValueError("Input is empty")
7
+ Args:
8
+ dict_list: A list of dict-like samples to stack
24
9
 
25
- # Extract keys from first dict - structure
26
- sample = {}
27
- sample_keys = list(dict_list[0].keys())
10
+ Returns:
11
+ Dict of the samples stacked with new batch dimension on axis 0
12
+ """
13
+
14
+ batch = {}
15
+
16
+ keys = list(dict_list[0].keys())
17
+
18
+ for key in keys:
19
+ # NWP is nested so treat separately
20
+ if key == "nwp":
21
+ batch["nwp"] = {}
22
+
23
+ # Unpack NWP provider keys
24
+ nwp_providers = list(dict_list[0]["nwp"].keys())
25
+
26
+ for nwp_provider in nwp_providers:
27
+ # Keys can be different for different NWPs
28
+ nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
29
+
30
+ # Create dict to store NWP batch for this provider
31
+ nwp_provider_batch = {}
32
+
33
+ for nwp_key in nwp_keys:
34
+ # Stack values under each NWP key for this provider
35
+ nwp_provider_batch[nwp_key] = stack_data_list(
36
+ [d["nwp"][nwp_provider][nwp_key] for d in dict_list],
37
+ nwp_key,
38
+ )
39
+
40
+ batch["nwp"][nwp_provider] = nwp_provider_batch
28
41
 
29
- # Process - handle NWP separately due to nested structure
30
- for sample_key in sample_keys:
31
- if sample_key == "nwp":
32
- sample["nwp"] = process_nwp_data(dict_list)
33
42
  else:
34
- # Stack arrays for the given key across all dicts
35
- sample[sample_key] = stack_data_list([d[sample_key] for d in dict_list], sample_key)
36
- return sample
43
+ batch[key] = stack_data_list([d[key] for d in dict_list], key)
37
44
 
45
+ return batch
38
46
 
39
- def process_nwp_data(dict_list):
40
- """Stacks data for NWP, handling nested structure"""
41
-
42
- nwp_sample = {}
43
- nwp_sources = dict_list[0]["nwp"].keys()
44
47
 
45
- # Stack data for each NWP source independently
46
- for nwp_source in nwp_sources:
47
- nested_keys = dict_list[0]["nwp"][nwp_source].keys()
48
- nwp_sample[nwp_source] = {
49
- key: stack_data_list([d["nwp"][nwp_source][key] for d in dict_list], key)
50
- for key in nested_keys
51
- }
52
- return nwp_sample
48
+ def _key_is_constant(key: str):
49
+ return key.endswith("t0_idx") or key.endswith("channel_names")
53
50
 
54
- def _key_is_constant(sample_key):
55
- return sample_key.endswith("t0_idx") or sample_key == NWPSampleKey.channel_names
56
51
 
57
-
58
- def stack_data_list(data_list: list,sample_key: Union[str, NWPSampleKey],):
59
- """How to combine data entries for each key
52
+ def stack_data_list(data_list: list, key: str):
53
+ """Stack a sequence of data elements along a new axis
60
54
 
61
55
  Args:
62
- data_list: List of data entries to combine
63
- sample_key: Key identifying the data type
56
+ data_list: List of data elements to combine
57
+ key: string identifying the data type
64
58
  """
65
- if _key_is_constant(sample_key):
59
+ if _key_is_constant(key):
66
60
  # These are always the same for all examples.
67
61
  return data_list[0]
68
- try:
62
+ else:
69
63
  return np.stack(data_list)
70
- except Exception as e:
71
- logger.debug(f"Could not stack the following shapes together, ({sample_key})")
72
- shapes = [example.shape for example in data_list]
73
- logger.debug(shapes)
74
- logger.error(e)
75
- raise e
64
+
@@ -65,7 +65,9 @@ class UKRegionalSample(SampleBase):
65
65
  raise ValueError(f"Only .pt format is supported: {path.suffix}")
66
66
 
67
67
  instance = cls()
68
- instance._data = torch.load(path)
68
+ # TODO: We should move away from using torch.load(..., weights_only=False)
69
+ # This is not recommended
70
+ instance._data = torch.load(path, weights_only=False)
69
71
  logger.debug(f"Successfully loaded UKRegionalSample from {path}")
70
72
  return instance
71
73
 
@@ -4,7 +4,7 @@ import pandas as pd
4
4
  import numpy as np
5
5
 
6
6
 
7
- def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta):
7
+ def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
8
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
9
9
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
10
10
  date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
@@ -1,5 +1,6 @@
1
1
  """ Slice datasets by time"""
2
2
  import pandas as pd
3
+ import xarray as xr
3
4
 
4
5
  from ocf_data_sampler.config import Configuration
5
6
  from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
@@ -64,16 +65,8 @@ def slice_datasets_by_time(
64
65
 
65
66
  if "gsp" in datasets_dict:
66
67
  gsp_config = config.input_data.gsp
67
-
68
- sliced_datasets_dict["gsp_future"] = select_time_slice(
69
- datasets_dict["gsp"],
70
- t0,
71
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
72
- interval_start=minutes(gsp_config.time_resolution_minutes),
73
- interval_end=minutes(gsp_config.interval_end_minutes),
74
- )
75
-
76
- sliced_datasets_dict["gsp"] = select_time_slice(
68
+
69
+ da_gsp_past = select_time_slice(
77
70
  datasets_dict["gsp"],
78
71
  t0,
79
72
  sample_period_duration=minutes(gsp_config.time_resolution_minutes),
@@ -81,17 +74,27 @@ def slice_datasets_by_time(
81
74
  interval_end=minutes(0),
82
75
  )
83
76
 
84
- # Dropout on the GSP, but not the future GSP
77
+ # Dropout on the past GSP, but not the future GSP
85
78
  gsp_dropout_time = draw_dropout_time(
86
79
  t0,
87
80
  dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
88
81
  dropout_frac=gsp_config.dropout_fraction,
89
82
  )
90
83
 
91
- sliced_datasets_dict["gsp"] = apply_dropout_time(
92
- sliced_datasets_dict["gsp"],
84
+ da_gsp_past = apply_dropout_time(
85
+ da_gsp_past,
93
86
  gsp_dropout_time
94
87
  )
88
+
89
+ da_gsp_future = select_time_slice(
90
+ datasets_dict["gsp"],
91
+ t0,
92
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
93
+ interval_start=minutes(gsp_config.time_resolution_minutes),
94
+ interval_end=minutes(gsp_config.interval_end_minutes),
95
+ )
96
+
97
+ sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
95
98
 
96
99
  if "site" in datasets_dict:
97
100
  site_config = config.input_data.site
@@ -1,11 +1,6 @@
1
- from .pvnet_uk_regional import PVNetUKRegionalDataset
1
+ from .pvnet_uk import PVNetUKRegionalDataset, PVNetUKConcurrentDataset
2
2
 
3
3
  from .site import (
4
4
  convert_netcdf_to_numpy_sample,
5
5
  SitesDataset
6
- )
7
-
8
- __all__ = [
9
- 'convert_netcdf_to_numpy_sample',
10
- 'SitesDataset'
11
- ]
6
+ )
@@ -1,15 +1,20 @@
1
- """Torch dataset for PVNet"""
1
+ """Torch dataset for UK PVNet"""
2
+
3
+ import pkg_resources
2
4
 
3
5
  import numpy as np
4
6
  import pandas as pd
5
- import pkg_resources
6
7
  import xarray as xr
7
8
  from torch.utils.data import Dataset
8
9
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
9
10
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
10
- from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
11
+ from ocf_data_sampler.select import (
12
+ fill_time_periods,
13
+ Location,
14
+ slice_datasets_by_space,
15
+ slice_datasets_by_time,
16
+ )
11
17
  from ocf_data_sampler.utils import minutes
12
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
13
18
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
14
19
  from ocf_data_sampler.numpy_sample import (
15
20
  convert_nwp_to_numpy_sample,
@@ -17,13 +22,16 @@ from ocf_data_sampler.numpy_sample import (
17
22
  convert_gsp_to_numpy_sample,
18
23
  make_sun_position_numpy_sample,
19
24
  )
25
+ from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
26
+ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
27
+ from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
28
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
20
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
21
31
  merge_dicts,
22
32
  fill_nans_in_arrays,
23
33
  )
24
- from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
25
- from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
26
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
34
+
27
35
 
28
36
  xr.set_options(keep_attrs=True)
29
37
 
@@ -65,9 +73,10 @@ def process_and_combine_datasets(
65
73
  gsp_config = config.input_data.gsp
66
74
 
67
75
  if "gsp" in dataset_dict:
68
- da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
76
+ da_gsp = dataset_dict["gsp"]
69
77
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
70
-
78
+
79
+ # Convert to NumpyBatch
71
80
  numpy_modalities.append(
72
81
  convert_gsp_to_numpy_sample(
73
82
  da_gsp,
@@ -105,6 +114,7 @@ def process_and_combine_datasets(
105
114
 
106
115
  return combined_sample
107
116
 
117
+
108
118
  def compute(xarray_dict: dict) -> dict:
109
119
  """Eagerly load a nested dictionary of xarray DataArrays"""
110
120
  for k, v in xarray_dict.items():
@@ -114,10 +124,8 @@ def compute(xarray_dict: dict) -> dict:
114
124
  xarray_dict[k] = v.compute(scheduler="single-threaded")
115
125
  return xarray_dict
116
126
 
117
- def find_valid_t0_times(
118
- datasets_dict: dict,
119
- config: Configuration,
120
- ):
127
+
128
+ def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
121
129
  """Find the t0 times where all of the requested input data is available
122
130
 
123
131
  Args:
@@ -167,7 +175,7 @@ class PVNetUKRegionalDataset(Dataset):
167
175
  self,
168
176
  config_filename: str,
169
177
  start_time: str | None = None,
170
- end_time: str| None = None,
178
+ end_time: str | None = None,
171
179
  gsp_ids: list[int] | None = None,
172
180
  ):
173
181
  """A torch Dataset for creating PVNet UK GSP samples
@@ -253,7 +261,7 @@ class PVNetUKRegionalDataset(Dataset):
253
261
  def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
254
262
  """Generate a sample for the given coordinates.
255
263
 
256
- Useful for users to generate samples by GSP ID.
264
+ Useful for users to generate specific samples.
257
265
 
258
266
  Args:
259
267
  t0: init-time for sample
@@ -265,4 +273,94 @@ class PVNetUKRegionalDataset(Dataset):
265
273
 
266
274
  location = self.location_lookup[gsp_id]
267
275
 
268
- return self._get_sample(t0, location)
276
+ return self._get_sample(t0, location)
277
+
278
+
279
+ class PVNetUKConcurrentDataset(Dataset):
280
+ def __init__(
281
+ self,
282
+ config_filename: str,
283
+ start_time: str | None = None,
284
+ end_time: str | None = None,
285
+ gsp_ids: list[int] | None = None,
286
+ ):
287
+ """A torch Dataset for creating concurrent samples of PVNet UK regional data
288
+
289
+ Each concurrent sample includes the data from all GSPs for a single t0 time
290
+
291
+ Args:
292
+ config_filename: Path to the configuration file
293
+ start_time: Limit the init-times to be after this
294
+ end_time: Limit the init-times to be before this
295
+ gsp_ids: List of all GSP IDs included in each sample. Defaults to all
296
+ """
297
+
298
+ config = load_yaml_configuration(config_filename)
299
+
300
+ datasets_dict = get_dataset_dict(config)
301
+
302
+ # Get t0 times where all input data is available
303
+ valid_t0_times = find_valid_t0_times(datasets_dict, config)
304
+
305
+ # Filter t0 times to given range
306
+ if start_time is not None:
307
+ valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
308
+
309
+ if end_time is not None:
310
+ valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
311
+
312
+ # Construct list of locations to sample from
313
+ locations = get_gsp_locations(gsp_ids)
314
+
315
+ # Assign coords and indices to self
316
+ self.valid_t0_times = valid_t0_times
317
+ self.locations = locations
318
+
319
+ # Assign config and input data to self
320
+ self.datasets_dict = datasets_dict
321
+ self.config = config
322
+
323
+
324
+ def __len__(self):
325
+ return len(self.valid_t0_times)
326
+
327
+
328
+ def _get_sample(self, t0: pd.Timestamp) -> dict:
329
+ """Generate a concurrent PVNet sample for given init-time
330
+
331
+ Args:
332
+ t0: init-time for sample
333
+ """
334
+ # Slice by time then load to avoid loading the data multiple times from disk
335
+ sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
336
+ sample_dict = compute(sample_dict)
337
+
338
+ gsp_samples = []
339
+
340
+ # Prepare sample for each GSP
341
+ for location in self.locations:
342
+ gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
343
+ gsp_numpy_sample = process_and_combine_datasets(
344
+ gsp_sample_dict, self.config, t0, location
345
+ )
346
+ gsp_samples.append(gsp_numpy_sample)
347
+
348
+ # Stack GSP samples
349
+ return stack_np_samples_into_batch(gsp_samples)
350
+
351
+
352
+ def __getitem__(self, idx):
353
+ return self._get_sample(self.valid_t0_times[idx])
354
+
355
+
356
+ def get_sample(self, t0: pd.Timestamp) -> dict:
357
+ """Generate a sample for the given init-time.
358
+
359
+ Useful for users to generate specific samples.
360
+
361
+ Args:
362
+ t0: init-time for sample
363
+ """
364
+ # Check data is availablle for init-time t0
365
+ assert t0 in self.valid_t0_times
366
+ return self._get_sample(t0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -19,7 +19,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68
19
19
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
20
20
  ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
21
21
  ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
22
- ocf_data_sampler/numpy_sample/collate.py,sha256=y8QFUhskaAfOMP22aVkexwyGAwLHbNE-q1pOZ6txWKA,2227
22
+ ocf_data_sampler/numpy_sample/collate.py,sha256=Onl_aKhsZ4pbFJsh70orjsHk523GHxrpRirH2vJq_GA,1911
23
23
  ocf_data_sampler/numpy_sample/datetime_features.py,sha256=U-9uRplfZ7VYFA4qBduI8OkG2x_65RYIP8wrLG4i-Nw,1441
24
24
  ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
25
25
  ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
@@ -29,32 +29,32 @@ ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrd
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
30
  ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
- ocf_data_sampler/sample/uk_regional.py,sha256=FPaFi6qaTsi1ag42pfVKDZhopt3cDjQsF4rVI8k2qWo,4244
32
+ ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
34
34
  ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
35
- ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
35
+ ocf_data_sampler/select/fill_time_periods.py,sha256=h0XD1Ds_wUUoy-7bILxmN8AIbjlQ6YdXRKuCk_Is5jo,460
36
36
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=q7IaNfX95A3z9XHqbhgtkZ4Js1gn5K9Qyp6DVLbsL-Q,11093
37
37
  ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
38
38
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
39
39
  ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
40
40
  ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56sNWk3BnCnkCgcPI,4725
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
- ocf_data_sampler/select/time_slice_for_dataset.py,sha256=P7cAARfDzjttGDvpKt2zuA4WkLoTmSXy_lBpI8RiA6k,4249
43
- ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=JMfMQ6DxCWQiwm-Xwdy_b0gnGqZnORSR_SGrLM1QEe4,201
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py,sha256=xxeX4Js9LQpydehi3BS7k9psqkYGzgJuM17uTYux40M,8742
42
+ ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
+ ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=rodvVSR4sh8qZ2hLdI8qAc3lyxq5U7cVGfS4rRKCzbs,11944
45
45
  ocf_data_sampler/torch_datasets/datasets/site.py,sha256=5T8nkTMUHHFidZRuFOunYeKAqNuyZ8V7sikBoBOBwwA,16033
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
48
  scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
49
49
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- tests/conftest.py,sha256=DfrH0Pm552Tnl35eZn2UHCfOn2lHRiEQCcUCJIhycSU,8021
50
+ tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
51
51
  tests/config/test_config.py,sha256=Vq_kTL5tJcwEP-hXD_Nah5O6cgafo99iX6Fw1AN5NDY,5288
52
52
  tests/config/test_save.py,sha256=rA_XVxP1pOxB--5Ebujz4T5o-VbcrCbg2VSlSq2iI0o,1318
53
53
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
54
54
  tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
55
- tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
55
+ tests/load/test_load_satellite.py,sha256=IQ8ISRZKCEoi8IsJoPpXZJTolD0mwjnl2E7762RM_PM,524
56
56
  tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
57
- tests/numpy_sample/test_collate.py,sha256=ngbJ8vIewnAvkXx-PpfuSMVNM82_SYaZPLhJkZZw7s0,867
57
+ tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
58
58
  tests/numpy_sample/test_datetime_features.py,sha256=o4t3KeKFdGrOBQ77rNFcDuDMQSD23ileCS5T5AP3wG4,1769
59
59
  tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
60
60
  tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
@@ -69,12 +69,11 @@ tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4
69
69
  tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
70
70
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
71
71
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
72
- tests/torch_datasets/conftest.py,sha256=eRCzHE7cxS4AoskExkCGFDBeqItktAYNAdkfpMoFCeE,629
73
- tests/torch_datasets/test_merge_and_fill_utils.py,sha256=ueA0A7gZaWEgNdsU8p3CnKuvSnlleTUjEhSw2HUUROM,1229
74
- tests/torch_datasets/test_pvnet_uk_regional.py,sha256=FCiFueeFqrsXe7gWguSjBz5ZeUrvyhGbGw81gaVvkHM,5087
75
- tests/torch_datasets/test_site.py,sha256=J1ZDE5V5MRlq7EuZ1zUu-aFRGTDJIiO-ZZzkOXvDdWA,6757
76
- ocf_data_sampler-0.1.0.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
- ocf_data_sampler-0.1.0.dist-info/METADATA,sha256=9hgn-WJMx51JmLEiZmd5oiEwjZtcutfZObTtnwxUT2k,12173
78
- ocf_data_sampler-0.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
- ocf_data_sampler-0.1.0.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
- ocf_data_sampler-0.1.0.dist-info/RECORD,,
72
+ tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
73
+ tests/torch_datasets/test_pvnet_uk.py,sha256=OzT9ArdnWPa3iJKggxc2-7npkDqWmZyS5pzM4M08NZU,5566
74
+ tests/torch_datasets/test_site.py,sha256=5MH5zkHFJXekwpnV6nHuSxt_sRNu9_mxiUjfWqmEhr0,6966
75
+ ocf_data_sampler-0.1.2.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
76
+ ocf_data_sampler-0.1.2.dist-info/METADATA,sha256=tWyfIpvmOufUWc7LquOXZ6g5Le_WhwihhZQBSZ0WKhA,12173
77
+ ocf_data_sampler-0.1.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
78
+ ocf_data_sampler-0.1.2.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
79
+ ocf_data_sampler-0.1.2.dist-info/RECORD,,
tests/conftest.py CHANGED
@@ -1,14 +1,15 @@
1
+ import pytest
2
+
1
3
  import os
2
4
  import numpy as np
3
5
  import pandas as pd
4
- import pytest
5
6
  import xarray as xr
6
- import tempfile
7
- from typing import Generator
7
+ import dask.array
8
8
 
9
9
  from ocf_data_sampler.config.model import Site
10
10
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
11
11
 
12
+
12
13
  _top_test_directory = os.path.dirname(os.path.realpath(__file__))
13
14
 
14
15
  @pytest.fixture()
@@ -18,40 +19,27 @@ def test_config_filename():
18
19
 
19
20
  @pytest.fixture(scope="session")
20
21
  def config_filename():
21
- return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/configs/pvnet_test_config.yaml"
22
+ return f"{_top_test_directory}/test_data/configs/pvnet_test_config.yaml"
22
23
 
23
24
 
24
25
  @pytest.fixture(scope="session")
25
- def sat_zarr_path():
26
-
27
- # Load dataset which only contains coordinates, but no data
28
- ds = xr.open_zarr(
29
- f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip"
30
- ).compute()
31
-
32
- # Add time coord
33
- ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min"))
34
-
35
- # Add data to dataset
36
- ds["data"] = xr.DataArray(
37
- np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
38
- coords=ds.coords,
39
- )
40
-
41
- # Transpose to variables, time, y, x (just in case)
42
- ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
26
+ def session_tmp_path(tmp_path_factory):
27
+ return tmp_path_factory.mktemp("data")
43
28
 
44
- # add 100,000 to x_geostationary, this to make sure the fix index is within the satellite image
45
- ds["x_geostationary"] = ds["x_geostationary"] - 200_000
46
29
 
47
- # Add some NaNs
48
- ds["data"].values[:, :, 0, 0] = np.nan
49
-
50
- # make sure channel values are strings
51
- ds["variable"] = ds["variable"].astype(str)
52
-
53
- # add data attrs area
54
- ds["data"].attrs["area"] = (
30
+ @pytest.fixture(scope="session")
31
+ def sat_zarr_path(session_tmp_path):
32
+
33
+ # Define coords for satellite-like dataset
34
+ variables = [
35
+ 'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
36
+ 'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073',
37
+ ]
38
+ x = np.linspace(start=15002, stop=-1824245, num=100)
39
+ y = np.linspace(start=4191563, stop=5304712, num=100)
40
+ times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min")
41
+
42
+ area_string = (
55
43
  """msg_seviri_rss_3km:
56
44
  description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution
57
45
  projection:
@@ -73,16 +61,31 @@ def sat_zarr_path():
73
61
  units: m
74
62
  """
75
63
  )
76
-
77
- # Specifiy chunking
78
- ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1})
64
+
65
+ # Create satellite-like data with some NaNs
66
+ data = dask.array.zeros(
67
+ shape=(len(variables), len(times), len(y), len(x)),
68
+ chunks=(-1, 10, -1, -1),
69
+ dtype=np.float32
70
+ )
71
+ data [:, 10, :, :] = np.nan
72
+
73
+ ds = xr.DataArray(
74
+ data=data,
75
+ coords=dict(
76
+ variable=variables,
77
+ time=times,
78
+ y_geostationary=y,
79
+ x_geostationary=x,
80
+ ),
81
+ attrs=dict(area=area_string),
82
+ ).to_dataset(name="data")
79
83
 
80
84
  # Save temporarily as a zarr
81
- with tempfile.TemporaryDirectory() as tmpdir:
82
- zarr_path = f"{tmpdir}/test_sat.zarr"
83
- ds.to_zarr(zarr_path)
85
+ zarr_path = session_tmp_path / "test_sat.zarr"
86
+ ds.to_zarr(zarr_path)
84
87
 
85
- yield zarr_path
88
+ yield zarr_path
86
89
 
87
90
 
88
91
  @pytest.fixture(scope="session")
@@ -112,7 +115,7 @@ def ds_nwp_ukv():
112
115
 
113
116
 
114
117
  @pytest.fixture(scope="session")
115
- def nwp_ukv_zarr_path(ds_nwp_ukv):
118
+ def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
116
119
  ds = ds_nwp_ukv.chunk(
117
120
  {
118
121
  "init_time": 1,
@@ -122,10 +125,9 @@ def nwp_ukv_zarr_path(ds_nwp_ukv):
122
125
  "y": 50,
123
126
  }
124
127
  )
125
- with tempfile.TemporaryDirectory() as tmpdir:
126
- filename = tmpdir + "/ukv_nwp.zarr"
127
- ds.to_zarr(filename)
128
- yield filename
128
+ zarr_path = session_tmp_path / "ukv_nwp.zarr"
129
+ ds.to_zarr(zarr_path)
130
+ yield zarr_path
129
131
 
130
132
 
131
133
  @pytest.fixture(scope="session")
@@ -155,7 +157,7 @@ def ds_nwp_ecmwf():
155
157
 
156
158
 
157
159
  @pytest.fixture(scope="session")
158
- def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
160
+ def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
159
161
  ds = ds_nwp_ecmwf.chunk(
160
162
  {
161
163
  "init_time": 1,
@@ -165,10 +167,10 @@ def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
165
167
  "latitude": 50,
166
168
  }
167
169
  )
168
- with tempfile.TemporaryDirectory() as tmpdir:
169
- filename = tmpdir + "/ukv_ecmwf.zarr"
170
- ds.to_zarr(filename)
171
- yield filename
170
+
171
+ zarr_path = session_tmp_path / "ukv_ecmwf.zarr"
172
+ ds.to_zarr(zarr_path)
173
+ yield zarr_path
172
174
 
173
175
 
174
176
  @pytest.fixture(scope="session")
@@ -201,7 +203,7 @@ def ds_uk_gsp():
201
203
 
202
204
 
203
205
  @pytest.fixture(scope="session")
204
- def data_sites() -> Generator[Site, None, None]:
206
+ def data_sites(session_tmp_path) -> Site:
205
207
  """
206
208
  Make fake data for sites
207
209
  Returns: filename for netcdf file, and csv metadata
@@ -245,30 +247,27 @@ def data_sites() -> Generator[Site, None, None]:
245
247
  "generation_kw": da_gen,
246
248
  })
247
249
 
248
- with tempfile.TemporaryDirectory() as tmpdir:
249
- filename = tmpdir + "/sites.netcdf"
250
- filename_csv = tmpdir + "/sites_metadata.csv"
251
- generation.to_netcdf(filename)
252
- meta_df.to_csv(filename_csv)
253
-
254
- site = Site(
255
- file_path=filename,
256
- metadata_file_path=filename_csv,
257
- interval_start_minutes=-30,
258
- interval_end_minutes=60,
259
- time_resolution_minutes=30,
260
- )
250
+ filename = f"{session_tmp_path}/sites.netcdf"
251
+ filename_csv = f"{session_tmp_path}/sites_metadata.csv"
252
+ generation.to_netcdf(filename)
253
+ meta_df.to_csv(filename_csv)
254
+
255
+ site = Site(
256
+ file_path=filename,
257
+ metadata_file_path=filename_csv,
258
+ interval_start_minutes=-30,
259
+ interval_end_minutes=60,
260
+ time_resolution_minutes=30,
261
+ )
261
262
 
262
- yield site
263
+ yield site
263
264
 
264
265
 
265
266
  @pytest.fixture(scope="session")
266
- def uk_gsp_zarr_path(ds_uk_gsp):
267
-
268
- with tempfile.TemporaryDirectory() as tmpdir:
269
- filename = tmpdir + "/uk_gsp.zarr"
270
- ds_uk_gsp.to_zarr(filename)
271
- yield filename
267
+ def uk_gsp_zarr_path(session_tmp_path, ds_uk_gsp):
268
+ zarr_path = session_tmp_path / "uk_gsp.zarr"
269
+ ds_uk_gsp.to_zarr(zarr_path)
270
+ yield zarr_path
272
271
 
273
272
 
274
273
  @pytest.fixture()
@@ -8,10 +8,10 @@ def test_open_satellite(sat_zarr_path):
8
8
 
9
9
  assert isinstance(da, xr.DataArray)
10
10
  assert da.dims == ("time_utc", "channel", "x_geostationary", "y_geostationary")
11
- # 576 is 2 days of data at 5 minutes intervals, 12 * 24 * 2
11
+ # 288 is 1 days of data at 5 minutes intervals, 12 * 24
12
12
  # There are 11 channels
13
- # There are 49 x 20 pixels
14
- assert da.shape == (576, 11, 49, 20)
13
+ # There are 100 x 100 pixels
14
+ assert da.shape == (288, 11, 100, 100)
15
15
  assert np.issubdtype(da.dtype, np.number)
16
16
 
17
17
 
@@ -1,17 +1,12 @@
1
- from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
2
1
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
3
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
2
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
4
3
 
5
4
 
6
- def test_pvnet(pvnet_config_filename):
5
+ def test_stack_np_samples_into_batch(pvnet_config_filename):
7
6
 
8
7
  # Create dataset object
9
8
  dataset = PVNetUKRegionalDataset(pvnet_config_filename)
10
9
 
11
- assert len(dataset.locations) == 317
12
- assert len(dataset.valid_t0_times) == 39
13
- assert len(dataset) == 317 * 39
14
-
15
10
  # Generate 2 samples
16
11
  sample1 = dataset[0]
17
12
  sample2 = dataset[1]
@@ -22,5 +17,5 @@ def test_pvnet(pvnet_config_filename):
22
17
  assert "nwp" in batch
23
18
  assert isinstance(batch["nwp"], dict)
24
19
  assert "ukv" in batch["nwp"]
25
- assert GSPSampleKey.gsp in batch
26
- assert SatelliteSampleKey.satellite_actual in batch
20
+ assert "gsp" in batch
21
+ assert "satellite_actual" in batch
@@ -33,9 +33,7 @@ def test_fill_nans_in_arrays():
33
33
 
34
34
  result = fill_nans_in_arrays(nested_dict)
35
35
 
36
- assert not np.isnan(result["array1"]).any()
37
36
  assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
38
- assert not np.isnan(result["nested"]["array2"]).any()
39
37
  assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
40
38
  assert result["string_key"] == "not_an_array"
41
39
 
@@ -0,0 +1,166 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xarray as xr
4
+ import dask.array
5
+
6
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
7
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
8
+ PVNetUKRegionalDataset,
9
+ PVNetUKConcurrentDataset,
10
+ process_and_combine_datasets,
11
+ compute,
12
+ )
13
+ from ocf_data_sampler.select.location import Location
14
+
15
+ def test_process_and_combine_datasets(pvnet_config_filename):
16
+
17
+ # Load in config for function and define location
18
+ config = load_yaml_configuration(pvnet_config_filename)
19
+ t0 = pd.Timestamp("2024-01-01 00:00")
20
+ location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
21
+
22
+ nwp_data = xr.DataArray(
23
+ np.random.rand(4, 2, 2, 2),
24
+ dims=["time_utc", "channel", "y", "x"],
25
+ coords={
26
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
+ "channel": ["t2m", "dswrf"],
28
+ "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
+ "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
+ }
31
+ )
32
+
33
+ sat_data = xr.DataArray(
34
+ np.random.rand(7, 1, 2, 2),
35
+ dims=["time_utc", "channel", "y", "x"],
36
+ coords={
37
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
38
+ "channel": ["HRV"],
39
+ "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
40
+ "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
41
+ }
42
+ )
43
+
44
+ # Combine as dict
45
+ dataset_dict = {
46
+ "nwp": {"ukv": nwp_data},
47
+ "sat": sat_data
48
+ }
49
+
50
+ # Call relevant function
51
+ sample = process_and_combine_datasets(dataset_dict, config, t0, location)
52
+
53
+ # Assert result is dict - check and validate
54
+ assert isinstance(sample, dict)
55
+ assert "nwp" in sample
56
+ assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
+ assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
58
+
59
+
60
+ def test_compute():
61
+ """Test compute function with dask array"""
62
+ da_dask = xr.DataArray(dask.array.random.random((5, 5)))
63
+
64
+ # Create a nested dictionary with dask array
65
+ lazy_data_dict = {
66
+ "array1": da_dask,
67
+ "nested": {
68
+ "array2": da_dask
69
+ }
70
+ }
71
+
72
+ computed_data_dict = compute(lazy_data_dict)
73
+
74
+ # Assert that the result is no longer lazy
75
+ assert isinstance(computed_data_dict["array1"].data, np.ndarray)
76
+ assert isinstance(computed_data_dict["nested"]["array2"].data, np.ndarray)
77
+
78
+
79
+ def test_pvnet_uk_regional_dataset(pvnet_config_filename):
80
+
81
+ # Create dataset object
82
+ dataset = PVNetUKRegionalDataset(pvnet_config_filename)
83
+
84
+ assert len(dataset.locations) == 317 # Number of regional GSPs
85
+ # NB. I have not checked the value (39 below) is in fact correct
86
+ assert len(dataset.valid_t0_times) == 39
87
+ assert len(dataset) == 317*39
88
+
89
+ # Generate a sample
90
+ sample = dataset[0]
91
+
92
+ assert isinstance(sample, dict)
93
+
94
+ for key in [
95
+ "nwp", "satellite_actual", "gsp",
96
+ "gsp_solar_azimuth", "gsp_solar_elevation",
97
+ ]:
98
+ assert key in sample
99
+
100
+ for nwp_source in ["ukv"]:
101
+ assert nwp_source in sample["nwp"]
102
+
103
+ # Check the shape of the data is correct
104
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
105
+ assert sample["satellite_actual"].shape == (7, 1, 2, 2)
106
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
107
+ assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
108
+ # 3 hours of 30 minute data (inclusive)
109
+ assert sample["gsp"].shape == (7,)
110
+ # Solar angles have same shape as GSP data
111
+ assert sample["gsp_solar_azimuth"].shape == (7,)
112
+ assert sample["gsp_solar_elevation"].shape == (7,)
113
+
114
+
115
+ def test_pvnet_no_gsp(tmp_path, pvnet_config_filename):
116
+
117
+ # Create new config without GSP inputs
118
+ config = load_yaml_configuration(pvnet_config_filename)
119
+ config.input_data.gsp.zarr_path = ''
120
+ new_config_path = tmp_path / "pvnet_config_no_gsp.yaml"
121
+ save_yaml_configuration(config, new_config_path)
122
+
123
+ # Create dataset object
124
+ dataset = PVNetUKRegionalDataset(new_config_path)
125
+
126
+ # Generate a sample
127
+ _ = dataset[0]
128
+
129
+
130
+ def test_pvnet_uk_concurrent_dataset(pvnet_config_filename):
131
+
132
+ # Create dataset object using a limited set of GSPs for test
133
+ gsp_ids = [1,2,3]
134
+ num_gsps = len(gsp_ids)
135
+
136
+ dataset = PVNetUKConcurrentDataset(pvnet_config_filename, gsp_ids=gsp_ids)
137
+
138
+ assert len(dataset.locations) == num_gsps # Number of regional GSPs
139
+ # NB. I have not checked the value (39 below) is in fact correct
140
+ assert len(dataset.valid_t0_times) == 39
141
+ assert len(dataset) == 39
142
+
143
+ # Generate a sample
144
+ sample = dataset[0]
145
+
146
+ assert isinstance(sample, dict)
147
+
148
+ for key in [
149
+ "nwp", "satellite_actual", "gsp",
150
+ "gsp_solar_azimuth", "gsp_solar_elevation",
151
+ ]:
152
+ assert key in sample
153
+
154
+ for nwp_source in ["ukv"]:
155
+ assert nwp_source in sample["nwp"]
156
+
157
+ # Check the shape of the data is correct
158
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
159
+ assert sample["satellite_actual"].shape == (num_gsps, 7, 1, 2, 2)
160
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
161
+ assert sample["nwp"]["ukv"]["nwp"].shape == (num_gsps, 4, 1, 2, 2)
162
+ # 3 hours of 30 minute data (inclusive)
163
+ assert sample["gsp"].shape == (num_gsps, 7,)
164
+ # Solar angles have same shape as GSP data
165
+ assert sample["gsp_solar_azimuth"].shape == (num_gsps, 7,)
166
+ assert sample["gsp_solar_elevation"].shape == (num_gsps, 7,)
@@ -1,11 +1,37 @@
1
- import pandas as pd
1
+ import pytest
2
+
2
3
  import numpy as np
3
- from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
4
- from xarray import Dataset, DataArray
4
+ import pandas as pd
5
5
  import xarray as xr
6
6
 
7
7
  from torch.utils.data import DataLoader
8
8
 
9
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
10
+ from ocf_data_sampler.torch_datasets.datasets.site import (
11
+ SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
12
+ )
13
+
14
+
15
+
16
+ @pytest.fixture()
17
+ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
18
+
19
+ # adjust config to point to the zarr file
20
+ config = load_yaml_configuration(config_filename)
21
+ config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
22
+ config.input_data.satellite.zarr_path = sat_zarr_path
23
+ config.input_data.site = data_sites
24
+ config.input_data.gsp = None
25
+
26
+ filename = f"{tmp_path}/configuration.yaml"
27
+ save_yaml_configuration(config, filename)
28
+ yield filename
29
+
30
+
31
+ @pytest.fixture()
32
+ def sites_dataset(site_config_filename):
33
+ return SitesDataset(site_config_filename)
34
+
9
35
 
10
36
  def test_site(site_config_filename):
11
37
 
@@ -18,7 +44,7 @@ def test_site(site_config_filename):
18
44
  # Generate a sample
19
45
  sample = dataset[0]
20
46
 
21
- assert isinstance(sample, Dataset)
47
+ assert isinstance(sample, xr.Dataset)
22
48
 
23
49
  # Expected dimensions and data variables
24
50
  expected_dims = {
@@ -85,21 +111,14 @@ def test_site_time_filter_end(site_config_filename):
85
111
  assert len(dataset) == 0
86
112
 
87
113
 
88
- def test_site_get_sample(site_config_filename):
89
-
90
- # Create dataset object
91
- dataset = SitesDataset(site_config_filename)
92
-
93
- assert len(dataset) == 410
94
- sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
114
+ def test_site_get_sample(sites_dataset):
115
+ sample = sites_dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
95
116
 
96
117
 
97
- def test_convert_from_dataset_to_dict_datasets(site_config_filename):
98
- # Create dataset object
99
- dataset = SitesDataset(site_config_filename)
118
+ def test_convert_from_dataset_to_dict_datasets(sites_dataset):
100
119
 
101
- # Generate two samples
102
- sample_xr = dataset[0]
120
+ # Generate sample
121
+ sample_xr = sites_dataset[0]
103
122
 
104
123
  sample = convert_from_dataset_to_dict_datasets(sample_xr)
105
124
 
@@ -109,9 +128,7 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename):
109
128
  assert key in sample
110
129
 
111
130
 
112
- def test_site_dataset_with_dataloader(site_config_filename):
113
- # Create dataset object
114
- dataset = SitesDataset(site_config_filename)
131
+ def test_site_dataset_with_dataloader(sites_dataset):
115
132
 
116
133
  expected_coods = {
117
134
  "site__solar_azimuth",
@@ -122,10 +139,6 @@ def test_site_dataset_with_dataloader(site_config_filename):
122
139
  "site__date_sin",
123
140
  }
124
141
 
125
- sample = dataset[0]
126
- for key in expected_coods:
127
- assert key in sample
128
-
129
142
  dataloader_kwargs = dict(
130
143
  shuffle=False,
131
144
  batch_size=None,
@@ -141,25 +154,23 @@ def test_site_dataset_with_dataloader(site_config_filename):
141
154
  persistent_workers=False, # Not needed since we only enter the dataloader loop once
142
155
  )
143
156
 
144
- dataloader = DataLoader(dataset, collate_fn=None, batch_size=None)
157
+ dataloader = DataLoader(sites_dataset, collate_fn=None, batch_size=None)
145
158
 
146
- for i, sample in zip(range(1), dataloader):
159
+ sample = next(iter(dataloader))
160
+
161
+ # check that expected_dims is in the sample
162
+ for key in expected_coods:
163
+ assert key in sample
147
164
 
148
- # check that expected_dims is in the sample
149
- for key in expected_coods:
150
- assert key in sample
151
165
 
166
+ def test_process_and_combine_site_sample_dict(sites_dataset):
152
167
 
153
- def test_process_and_combine_site_sample_dict(site_config_filename):
154
- # Load config
155
- # config = load_yaml_configuration(pvnet_config_filename)
156
- site_ds = SitesDataset(site_config_filename)
157
168
  # Specify minimal structure for testing
158
169
  raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
159
170
  fake_site_values = np.random.rand(197)
160
171
  site_dict = {
161
172
  "nwp": {
162
- "ukv": DataArray(
173
+ "ukv": xr.DataArray(
163
174
  raw_nwp_values,
164
175
  dims=["time_utc", "channel", "y", "x"],
165
176
  coords={
@@ -168,7 +179,7 @@ def test_process_and_combine_site_sample_dict(site_config_filename):
168
179
  },
169
180
  )
170
181
  },
171
- "site": DataArray(
182
+ "site": xr.DataArray(
172
183
  fake_site_values,
173
184
  dims=["time_utc"],
174
185
  coords={
@@ -183,10 +194,10 @@ def test_process_and_combine_site_sample_dict(site_config_filename):
183
194
  print(f"Input site_dict: {site_dict}")
184
195
 
185
196
  # Call function
186
- result = site_ds.process_and_combine_site_sample_dict(site_dict)
197
+ result = sites_dataset.process_and_combine_site_sample_dict(site_dict)
187
198
 
188
199
  # Assert to validate output structure
189
- assert isinstance(result, Dataset), "Result should be an xarray.Dataset"
200
+ assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
190
201
  assert len(result.data_vars) > 0, "Dataset should contain data variables"
191
202
 
192
203
  # Validate variable via assertion and shape of such
@@ -1,18 +0,0 @@
1
- import pytest
2
-
3
- from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
4
-
5
-
6
- @pytest.fixture()
7
- def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
8
-
9
- # adjust config to point to the zarr file
10
- config = load_yaml_configuration(config_filename)
11
- config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
12
- config.input_data.satellite.zarr_path = sat_zarr_path
13
- config.input_data.site = data_sites
14
- config.input_data.gsp = None
15
-
16
- filename = f"{tmp_path}/configuration.yaml"
17
- save_yaml_configuration(config, filename)
18
- return filename
@@ -1,136 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import xarray as xr
4
- import dask.array as da
5
- import tempfile
6
-
7
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
8
- from ocf_data_sampler.config.save import save_yaml_configuration
9
- from ocf_data_sampler.config.load import load_yaml_configuration
10
- from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
11
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import process_and_combine_datasets, compute
12
- from ocf_data_sampler.select.location import Location
13
-
14
- def test_process_and_combine_datasets(pvnet_config_filename):
15
-
16
- # Load in config for function and define location
17
- config = load_yaml_configuration(pvnet_config_filename)
18
- t0 = pd.Timestamp("2024-01-01 00:00")
19
- location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
20
-
21
- nwp_data = xr.DataArray(
22
- np.random.rand(4, 2, 2, 2),
23
- dims=["time_utc", "channel", "y", "x"],
24
- coords={
25
- "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
26
- "channel": ["t2m", "dswrf"],
27
- "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
28
- "init_time_utc": pd.Timestamp("2024-01-01 00:00")
29
- }
30
- )
31
-
32
- sat_data = xr.DataArray(
33
- np.random.rand(7, 1, 2, 2),
34
- dims=["time_utc", "channel", "y", "x"],
35
- coords={
36
- "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
37
- "channel": ["HRV"],
38
- "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
39
- "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
40
- }
41
- )
42
-
43
- # Combine as dict
44
- dataset_dict = {
45
- "nwp": {"ukv": nwp_data},
46
- "sat": sat_data
47
- }
48
-
49
- # Call relevant function
50
- result = process_and_combine_datasets(dataset_dict, config, t0, location)
51
-
52
- # Assert result is dict - check and validate
53
- assert isinstance(result, dict)
54
- assert NWPSampleKey.nwp in result
55
- assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
56
- assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
57
-
58
- def test_compute():
59
- """Test compute function with dask array"""
60
- da_dask = xr.DataArray(da.random.random((5, 5)))
61
-
62
- # Create a nested dictionary with dask array
63
- nested_dict = {
64
- "array1": da_dask,
65
- "nested": {
66
- "array2": da_dask
67
- }
68
- }
69
-
70
- # Ensure initial data is lazy - i.e. not yet computed
71
- assert not isinstance(nested_dict["array1"].data, np.ndarray)
72
- assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
73
-
74
- # Call the compute function
75
- result = compute(nested_dict)
76
-
77
- # Assert that the result is an xarray DataArray and no longer lazy
78
- assert isinstance(result["array1"], xr.DataArray)
79
- assert isinstance(result["nested"]["array2"], xr.DataArray)
80
- assert isinstance(result["array1"].data, np.ndarray)
81
- assert isinstance(result["nested"]["array2"].data, np.ndarray)
82
-
83
- # Ensure there no NaN values in computed data
84
- assert not np.isnan(result["array1"].data).any()
85
- assert not np.isnan(result["nested"]["array2"].data).any()
86
-
87
- def test_pvnet(pvnet_config_filename):
88
-
89
- # Create dataset object
90
- dataset = PVNetUKRegionalDataset(pvnet_config_filename)
91
-
92
- assert len(dataset.locations) == 317 # no of GSPs not including the National level
93
- # NB. I have not checked this value is in fact correct, but it does seem to stay constant
94
- assert len(dataset.valid_t0_times) == 39
95
- assert len(dataset) == 317*39
96
-
97
- # Generate a sample
98
- sample = dataset[0]
99
-
100
- assert isinstance(sample, dict)
101
-
102
- for key in [
103
- NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
104
- GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
105
- ]:
106
- assert key in sample
107
-
108
- for nwp_source in ["ukv"]:
109
- assert nwp_source in sample[NWPSampleKey.nwp]
110
-
111
- # check the shape of the data is correct
112
- # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
113
- assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
114
- # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
115
- assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
116
- # 3 hours of 30 minute data (inclusive)
117
- assert sample[GSPSampleKey.gsp].shape == (7,)
118
- # Solar angles have same shape as GSP data
119
- assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
120
- assert sample[GSPSampleKey.solar_elevation].shape == (7,)
121
-
122
- def test_pvnet_no_gsp(pvnet_config_filename):
123
-
124
- # load config
125
- config = load_yaml_configuration(pvnet_config_filename)
126
- # remove gsp
127
- config.input_data.gsp.zarr_path = ''
128
-
129
- # save temp config file
130
- with tempfile.NamedTemporaryFile() as temp_config_file:
131
- save_yaml_configuration(config, temp_config_file.name)
132
- # Create dataset object
133
- dataset = PVNetUKRegionalDataset(temp_config_file.name)
134
-
135
- # Generate a sample
136
- _ = dataset[0]