ocf-data-sampler 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (29) hide show
  1. ocf_data_sampler/config/model.py +25 -23
  2. ocf_data_sampler/load/satellite.py +21 -29
  3. ocf_data_sampler/load/site.py +1 -1
  4. ocf_data_sampler/numpy_sample/gsp.py +6 -2
  5. ocf_data_sampler/numpy_sample/nwp.py +7 -13
  6. ocf_data_sampler/numpy_sample/satellite.py +11 -8
  7. ocf_data_sampler/numpy_sample/site.py +6 -2
  8. ocf_data_sampler/numpy_sample/sun_position.py +9 -10
  9. ocf_data_sampler/sample/__init__.py +0 -7
  10. ocf_data_sampler/sample/base.py +16 -35
  11. ocf_data_sampler/sample/site.py +28 -65
  12. ocf_data_sampler/sample/uk_regional.py +52 -97
  13. ocf_data_sampler/select/dropout.py +38 -25
  14. ocf_data_sampler/select/fill_time_periods.py +3 -1
  15. ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
  16. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
  17. ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
  18. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
  19. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +29 -29
  20. tests/config/test_config.py +3 -3
  21. tests/conftest.py +33 -0
  22. tests/numpy_sample/test_nwp.py +3 -42
  23. tests/select/test_dropout.py +7 -13
  24. tests/test_sample/test_site_sample.py +5 -35
  25. tests/test_sample/test_uk_regional_sample.py +8 -35
  26. tests/torch_datasets/test_pvnet_uk.py +6 -19
  27. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
  28. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
  29. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
@@ -1,120 +1,75 @@
1
- """
2
- PVNet - UK Regional sample / dataset implementation
3
- """
1
+ """PVNet UK Regional sample implementation for dataset handling and visualisation"""
4
2
 
5
- import numpy as np
6
- import pandas as pd
7
- import torch
8
- import logging
3
+ from typing_extensions import override
9
4
 
10
- from typing import Dict, Any, Union, List, Optional
11
- from pathlib import Path
5
+ import torch
12
6
 
7
+ from ocf_data_sampler.sample.base import SampleBase, NumpySample
13
8
  from ocf_data_sampler.numpy_sample import (
14
9
  NWPSampleKey,
15
10
  GSPSampleKey,
16
11
  SatelliteSampleKey
17
12
  )
18
13
 
19
- from ocf_data_sampler.sample.base import SampleBase
20
-
21
- try:
22
- import matplotlib.pyplot as plt
23
- MATPLOTLIB_AVAILABLE = True
24
- except ImportError:
25
- MATPLOTLIB_AVAILABLE = False
26
- plt = None
27
-
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
14
 
32
15
  class UKRegionalSample(SampleBase):
33
- """ Sample class specific to UK Regional PVNet """
16
+ """Handles UK Regional PVNet data operations"""
34
17
 
35
- def __init__(self):
36
- logger.debug("Initialise UKRegionalSample instance")
37
- super().__init__()
38
- self._data = {}
18
+ def __init__(self, data: NumpySample):
19
+ self._data = data
39
20
 
40
- def to_numpy(self) -> Dict[str, Any]:
41
- """ Convert sample data to numpy format """
42
- logger.debug("Converting sample data to numpy format")
21
+ @override
22
+ def to_numpy(self) -> NumpySample:
43
23
  return self._data
44
24
 
45
- def save(self, path: Union[str, Path]) -> None:
46
- """ Save PVNet sample as .pt """
47
- logger.debug(f"Saving UKRegionalSample to {path}")
48
- path = Path(path)
49
-
50
- if path.suffix != '.pt':
51
- logger.error(f"Invalid file format: {path.suffix}")
52
- raise ValueError(f"Only .pt format is supported: {path.suffix}")
25
+ def save(self, path: str) -> None:
26
+ """Save PVNet sample as pickle format using torch.save
53
27
 
28
+ Args:
29
+ path: Path to save the sample data to
30
+ """
54
31
  torch.save(self._data, path)
55
- logger.debug(f"Successfully saved UKRegionalSample to {path}")
56
32
 
57
33
  @classmethod
58
- def load(cls, path: Union[str, Path]) -> 'UKRegionalSample':
59
- """ Load PVNet sample data from .pt """
60
- logger.debug(f"Attempting to load UKRegionalSample from {path}")
61
- path = Path(path)
34
+ def load(cls, path: str) -> 'UKRegionalSample':
35
+ """Load PVNet sample data from .pt format
62
36
 
63
- if path.suffix != '.pt':
64
- logger.error(f"Invalid file format: {path.suffix}")
65
- raise ValueError(f"Only .pt format is supported: {path.suffix}")
66
-
67
- instance = cls()
37
+ Args:
38
+ path: Path to load the sample data from
39
+ """
68
40
  # TODO: We should move away from using torch.load(..., weights_only=False)
69
- # This is not recommended
70
- instance._data = torch.load(path, weights_only=False)
71
- logger.debug(f"Successfully loaded UKRegionalSample from {path}")
72
- return instance
73
-
74
- def plot(self, **kwargs) -> None:
75
- """ Sample visualisation definition """
76
- logger.debug("Creating UKRegionalSample visualisation")
77
-
78
- if not MATPLOTLIB_AVAILABLE:
79
- raise ImportError(
80
- "Matplotlib required for plotting"
81
- "Install via 'ocf_data_sampler[plot]'"
82
- )
83
-
84
- try:
85
- fig, axes = plt.subplots(2, 2, figsize=(12, 8))
86
-
87
- if NWPSampleKey.nwp in self._data:
88
- logger.debug("Plotting NWP data")
89
- first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
90
- if 'nwp' in first_nwp:
91
- axes[0, 1].imshow(first_nwp['nwp'][0])
92
- axes[0, 1].set_title('NWP (First Channel)')
93
- if NWPSampleKey.channel_names in first_nwp:
94
- channel_names = first_nwp[NWPSampleKey.channel_names]
95
- if len(channel_names) > 0:
96
- axes[0, 1].set_title(f'NWP: {channel_names[0]}')
41
+ return cls(torch.load(path, weights_only=False))
97
42
 
98
- if GSPSampleKey.gsp in self._data:
99
- logger.debug("Plotting GSP generation data")
100
- axes[0, 0].plot(self._data[GSPSampleKey.gsp])
101
- axes[0, 0].set_title('GSP Generation')
102
-
103
- if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
104
- logger.debug("Plotting solar position data")
105
- axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
106
- axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
107
- axes[1, 1].set_title('Solar Position')
108
- axes[1, 1].legend()
43
+ def plot(self) -> None:
44
+ """Creates visualisations for NWP, GSP, solar position, and satellite data"""
45
+ from matplotlib import pyplot as plt
109
46
 
110
- if SatelliteSampleKey.satellite_actual in self._data:
111
- logger.debug("Plotting satellite data")
112
- axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
113
- axes[1, 0].set_title('Satellite Data')
114
-
115
- plt.tight_layout()
116
- plt.show()
117
- logger.debug("Successfully created visualisation")
118
- except Exception as e:
119
- logger.error(f"Error creating visualisation: {str(e)}")
120
- raise
47
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
48
+
49
+ if NWPSampleKey.nwp in self._data:
50
+ first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
51
+ if 'nwp' in first_nwp:
52
+ axes[0, 1].imshow(first_nwp['nwp'][0])
53
+ title = 'NWP (First Channel)'
54
+ if NWPSampleKey.channel_names in first_nwp:
55
+ channel_names = first_nwp[NWPSampleKey.channel_names]
56
+ if channel_names:
57
+ title = f'NWP: {channel_names[0]}'
58
+ axes[0, 1].set_title(title)
59
+
60
+ if GSPSampleKey.gsp in self._data:
61
+ axes[0, 0].plot(self._data[GSPSampleKey.gsp])
62
+ axes[0, 0].set_title('GSP Generation')
63
+
64
+ if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
65
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
66
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
67
+ axes[1, 1].set_title('Solar Position')
68
+ axes[1, 1].legend()
69
+
70
+ if SatelliteSampleKey.satellite_actual in self._data:
71
+ axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
72
+ axes[1, 0].set_title('Satellite Data')
73
+
74
+ plt.tight_layout()
75
+ plt.show()
@@ -1,39 +1,52 @@
1
- """ Functions for simulating dropout in time series data """
1
+ """Functions for simulating dropout in time series data
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
2
6
  import numpy as np
3
7
  import pandas as pd
4
8
  import xarray as xr
5
9
 
6
10
 
7
11
  def draw_dropout_time(
8
- t0: pd.Timestamp,
9
- dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None,
10
- dropout_frac: float = 0,
11
- ):
12
-
13
- if dropout_timedeltas is not None:
14
- assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas"
15
- assert all(
16
- [t <= pd.Timedelta("0min") for t in dropout_timedeltas]
17
- ), "dropout timedeltas must be negative"
12
+ t0: pd.Timestamp,
13
+ dropout_timedeltas: list[pd.Timedelta],
14
+ dropout_frac: float,
15
+ ) -> pd.Timestamp:
16
+ """Randomly pick a dropout time from a list of timedeltas
17
+
18
+ Args:
19
+ t0: The forecast init-time
20
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
21
+ dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
22
+ inclusive
23
+ """
24
+
25
+ if dropout_frac>0:
26
+ assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
27
+
28
+ for t in dropout_timedeltas:
29
+ assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
30
+
18
31
  assert 0 <= dropout_frac <= 1
19
32
 
20
- if (dropout_timedeltas is None) or (np.random.uniform() >= dropout_frac):
21
- dropout_time = None
33
+ if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
34
+ dropout_time = t0
22
35
  else:
23
- t0_datetime_utc = pd.Timestamp(t0)
24
- dt = np.random.choice(dropout_timedeltas)
25
- dropout_time = t0_datetime_utc + dt
36
+ dropout_time = t0 + np.random.choice(dropout_timedeltas)
26
37
 
27
38
  return dropout_time
28
39
 
29
40
 
30
41
  def apply_dropout_time(
31
- ds: xr.DataArray,
32
- dropout_time: pd.Timestamp | None,
33
- ):
34
-
35
- if dropout_time is None:
36
- return ds
37
- else:
38
- # This replaces the times after the dropout with NaNs
39
- return ds.where(ds.time_utc <= dropout_time)
42
+ ds: xr.DataArray,
43
+ dropout_time: pd.Timestamp,
44
+ ) -> xr.DataArray:
45
+ """Apply dropout time to the data
46
+
47
+ Args:
48
+ ds: Xarray DataArray with 'time_utc' coordiante
49
+ dropout_time: Time after which data is set to NaN
50
+ """
51
+ # This replaces the times after the dropout with NaNs
52
+ return ds.where(ds.time_utc <= dropout_time)
@@ -1,10 +1,12 @@
1
- """fill time periods"""
1
+ """Fill time periods between start and end dates at specified frequency"""
2
2
 
3
3
  import pandas as pd
4
4
  import numpy as np
5
5
 
6
6
 
7
7
  def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
+ """Generate DatetimeIndex for all timestamps between start and end dates"""
9
+
8
10
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
9
11
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
10
12
  date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
@@ -5,7 +5,6 @@ import pandas as pd
5
5
  from ocf_data_sampler.load.utils import check_time_unique_increasing
6
6
 
7
7
 
8
-
9
8
  def find_contiguous_time_periods(
10
9
  datetimes: pd.DatetimeIndex,
11
10
  min_seq_length: int,
@@ -186,9 +186,8 @@ class PVNetUKRegionalDataset(Dataset):
186
186
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
187
187
  """
188
188
 
189
- config = load_yaml_configuration(config_filename)
190
-
191
- # Validate channels for NWP and satellite data
189
+ # config = load_yaml_configuration(config_filename)
190
+ config: Configuration = load_yaml_configuration(config_filename)
192
191
  validate_nwp_channels(config)
193
192
  validate_satellite_channels(config)
194
193
 
@@ -20,7 +20,6 @@ from ocf_data_sampler.select import (
20
20
  from ocf_data_sampler.utils import minutes
21
21
  from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
22
22
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
- from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
24
23
 
25
24
  from ocf_data_sampler.numpy_sample import (
26
25
  convert_site_to_numpy_sample,
@@ -30,8 +29,12 @@ from ocf_data_sampler.numpy_sample import (
30
29
  make_sun_position_numpy_sample,
31
30
  )
32
31
  from ocf_data_sampler.numpy_sample import NWPSampleKey
33
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
32
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
34
33
 
34
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
+ validate_nwp_channels,
36
+ validate_satellite_channels,
37
+ )
35
38
 
36
39
  xr.set_options(keep_attrs=True)
37
40
 
@@ -52,9 +55,8 @@ class SitesDataset(Dataset):
52
55
  """
53
56
 
54
57
  config: Configuration = load_yaml_configuration(config_filename)
55
-
56
- # Validate NWP channels
57
58
  validate_nwp_channels(config)
59
+ validate_satellite_channels(config)
58
60
 
59
61
  datasets_dict = get_dataset_dict(config.input_data)
60
62
 
@@ -237,8 +239,10 @@ class SitesDataset(Dataset):
237
239
  data_arrays.append((f"nwp-{provider}", da_nwp))
238
240
 
239
241
  if "sat" in dataset_dict:
240
- # TODO add some satellite normalisation
241
242
  da_sat = dataset_dict["sat"]
243
+
244
+ # Standardise
245
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
242
246
  data_arrays.append(("satellite", da_sat))
243
247
 
244
248
  if "site" in dataset_dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -3,14 +3,14 @@ ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs
3
3
  ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
4
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
5
5
  ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
6
- ocf_data_sampler/config/model.py,sha256=IMJhsjL_oGh2c50q8pBnCnArY4qHQcBc_M8jqlEeD0c,7129
6
+ ocf_data_sampler/config/model.py,sha256=8PO-23uVy_JjWOJKgaZWdNMehQsAI-Jn8t0lcmBycwg,6992
7
7
  ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
8
8
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
9
9
  ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
10
10
  ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
12
- ocf_data_sampler/load/satellite.py,sha256=4MRJBFDHxx5WXu_6X71wEBznJTIuldEVnu9d6DVoLPI,2436
13
- ocf_data_sampler/load/site.py,sha256=74M_7RYwEc1bU4idjs3ZmQrx9I8mJXm6H4lwEL-h9n0,1226
12
+ ocf_data_sampler/load/satellite.py,sha256=SEQZ9oPe-asEeZeEMDkB1xWK5hErhWMagxohFcBl6KI,2294
13
+ ocf_data_sampler/load/site.py,sha256=hMdoF6sn2PcSBfF2soj7nuQoK9SItaxDXco5nk2n-44,1232
14
14
  ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
15
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
16
  ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
@@ -21,19 +21,19 @@ ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s
21
21
  ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
22
22
  ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
23
23
  ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
24
- ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
25
- ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
26
- ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBunbhDN4Pjo0Grxs,910
27
- ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
- ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
- ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=q3wpqoW4JXRmzfar6ed7UMn1nxBxSJXNvMLJmHXy1dw,2856
31
- ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
- ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
24
+ ocf_data_sampler/numpy_sample/gsp.py,sha256=uBquCFCoWuhJKY8sXpgsTCUDWUuLuv1XeixtFnFw6KU,1115
25
+ ocf_data_sampler/numpy_sample/nwp.py,sha256=Tiba-es23XeyMoEPgZUpLT6EnJCGU9A_1MdY6qkE7bM,1015
26
+ ocf_data_sampler/numpy_sample/satellite.py,sha256=RdXMdGGXysUx-AdL9T33yFOlxprtIdPNBKKX99-mhpY,991
27
+ ocf_data_sampler/numpy_sample/site.py,sha256=TvoEU85fmjYW8pD9UZOyUUACjimdQYxEzulQXunRO6Q,1425
28
+ ocf_data_sampler/numpy_sample/sun_position.py,sha256=ithM--eztAhiIQ1g52tlxgj-tMKbsJzx8mk6CgV2tzk,1613
29
+ ocf_data_sampler/sample/__init__.py,sha256=zdS73NTnxFX_j8uh9tT-IXiURB6635wbneM1koWYV1o,169
30
+ ocf_data_sampler/sample/base.py,sha256=IH3HbfqEUwjHmq-h2eJYLd8Jk-0ZcOylnehMyCPMV38,2223
31
+ ocf_data_sampler/sample/site.py,sha256=ONf2Yz5zi8Ombd_znA4T7NXbO01F76kQsBZv6rfnC74,1343
32
+ ocf_data_sampler/sample/uk_regional.py,sha256=KhJ5Ik1pZRp7PgIJjGIrE4i7SQnIdVjUbBHnfn-7ghg,2649
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
34
- ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
35
- ocf_data_sampler/select/fill_time_periods.py,sha256=h0XD1Ds_wUUoy-7bILxmN8AIbjlQ6YdXRKuCk_Is5jo,460
36
- ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=Nvz4gLCbbKzAe3sQXfxgExL9NtZVk1WNORvHs94DQ_k,11130
34
+ ocf_data_sampler/select/dropout.py,sha256=Pgov9P7rQMkSdqluG_hwm8loGyYNFOg-3PJUBLN_kjU,1526
35
+ ocf_data_sampler/select/fill_time_periods.py,sha256=EIcXG-77aQVOAYNwbDBEv6SGf6DO2p1WMEf96iW4MEM,596
36
+ ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=IwPQwvgu4cOiAZ5Gbjflv3fnQCcs0EVK0g4V6yqqSgw,11129
37
37
  ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
38
38
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
39
39
  ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
@@ -41,15 +41,15 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
42
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=N85duDyEm6LIYgYIpLhrpxHddMIcvFosuZg8rzIztwE,12267
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSmamEShKRT28j9_f-enJY,16228
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=ZgfvVCcEU3dj3RoY0zdBdKGppC7Wm81qecqB17gYTmE,12286
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_uHmqg-VJu-MHgXc5JFDX1noPfH6E8nY4XhQmsrOav4,16325
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
48
  ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
49
49
  scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
50
50
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
- tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
52
- tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
51
+ tests/conftest.py,sha256=k7nM3u2YJmkMupN4SIbJP3BRoxNR1dpIoo2fPFf0abg,8588
52
+ tests/config/test_config.py,sha256=CzYVhAUpgT4lvQdIddtVxtJeMqYL_TJolfeIwaaohq4,3969
53
53
  tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
54
54
  tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
55
55
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
@@ -59,24 +59,24 @@ tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU
59
59
  tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
60
60
  tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
61
61
  tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
62
- tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
62
+ tests/numpy_sample/test_nwp.py,sha256=Lnd-PMa6gI-fSIJkSZ554QiHFfnwxeXZxLg-rpuBv1U,442
63
63
  tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
64
64
  tests/numpy_sample/test_sun_position.py,sha256=_ENYzsNBVPdNXf--FI-UUFqw2u5w7_zqw6LcENU2uZM,2504
65
- tests/select/test_dropout.py,sha256=kiycl7RxAQYMCZJlokmx6Da5h_oBpSs8Is8pmSW4gOU,2413
65
+ tests/select/test_dropout.py,sha256=aQuSSqZF9RxBjN9-ogkQ8O-_zktAM30CrT1Lz7j1hMg,2222
66
66
  tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
67
67
  tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
68
68
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
69
69
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
70
70
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
71
71
  tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
72
- tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
73
- tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
72
+ tests/test_sample/test_site_sample.py,sha256=8HNenhIWYouCQu4y389PDQGokSPI5jQ4lS4CG-eA1Y8,5382
73
+ tests/test_sample/test_uk_regional_sample.py,sha256=MFibX9-M8mFK7vwMPu58gAG2VoY6y7w7chW5BlZclwk,3962
74
74
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
75
- tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
75
+ tests/torch_datasets/test_pvnet_uk.py,sha256=hgD_IDa4D8cgc4cgK1UqKYkT6sFlrTMAvgVn_iwD5_4,5086
76
76
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
77
77
  tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
- ocf_data_sampler-0.1.9.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
- ocf_data_sampler-0.1.9.dist-info/METADATA,sha256=Lfu8Yrj4CSlqPzGhk0iDy5r5zCLd5REnGAlVcFuKuow,12173
80
- ocf_data_sampler-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
- ocf_data_sampler-0.1.9.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
- ocf_data_sampler-0.1.9.dist-info/RECORD,,
78
+ ocf_data_sampler-0.1.11.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
+ ocf_data_sampler-0.1.11.dist-info/METADATA,sha256=d8wctSlRyDbP1_yYHFvIGQgEC8DmOkM8h-ITI4XFuPw,12174
80
+ ocf_data_sampler-0.1.11.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
+ ocf_data_sampler-0.1.11.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
+ ocf_data_sampler-0.1.11.dist-info/RECORD,,
@@ -30,7 +30,7 @@ def test_incorrect_interval_start_minutes(test_config_filename):
30
30
  configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
31
31
  with pytest.raises(
32
32
  ValueError,
33
- match="interval_start_minutes must be divisible by time_resolution_minutes"
33
+ match="interval_start_minutes.*must be divisible.*time_resolution_minutes.*"
34
34
  ):
35
35
  _ = Configuration(**configuration.model_dump())
36
36
 
@@ -45,7 +45,7 @@ def test_incorrect_interval_end_minutes(test_config_filename):
45
45
  configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
46
46
  with pytest.raises(
47
47
  ValueError,
48
- match="interval_end_minutes must be divisible by time_resolution_minutes"
48
+ match="interval_end_minutes.*must be divisible.*time_resolution_minutes.*"
49
49
  ):
50
50
  _ = Configuration(**configuration.model_dump())
51
51
 
@@ -103,7 +103,7 @@ def test_inconsistent_dropout_use(test_config_filename):
103
103
 
104
104
  configuration = load_yaml_configuration(test_config_filename)
105
105
  configuration.input_data.satellite.dropout_fraction= 1.0
106
- configuration.input_data.satellite.dropout_timedeltas_minutes = None
106
+ configuration.input_data.satellite.dropout_timedeltas_minutes = []
107
107
 
108
108
  with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"):
109
109
  _ = Configuration(**configuration.model_dump())
tests/conftest.py CHANGED
@@ -130,6 +130,39 @@ def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
130
130
  yield zarr_path
131
131
 
132
132
 
133
+ @pytest.fixture()
134
+ def ds_nwp_ukv_time_sliced():
135
+
136
+ t0 = pd.to_datetime("2024-01-02 00:00")
137
+
138
+ x = np.arange(-100, 100, 10)
139
+ y = np.arange(-100, 100, 10)
140
+ steps = pd.timedelta_range("0h", "8h", freq="1h")
141
+ target_times = t0 + steps
142
+
143
+ channels = ["t", "dswrf"]
144
+ init_times = pd.to_datetime([t0]*len(steps))
145
+
146
+ # Create dummy time-sliced NWP data
147
+ da_nwp = xr.DataArray(
148
+ np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
149
+ coords=dict(
150
+ target_time_utc=(["target_time_utc"], target_times),
151
+ channel=(["channel"], channels),
152
+ x_osgb=(["x_osgb"], x),
153
+ y_osgb=(["y_osgb"], y),
154
+ )
155
+ )
156
+
157
+ # Add extra non-coordinate dimensions
158
+ da_nwp = da_nwp.assign_coords(
159
+ init_time_utc=("target_time_utc", init_times),
160
+ step=("target_time_utc", steps),
161
+ )
162
+
163
+ return da_nwp
164
+
165
+
133
166
  @pytest.fixture(scope="session")
134
167
  def ds_nwp_ecmwf():
135
168
  init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7)
@@ -1,52 +1,13 @@
1
- import numpy as np
2
- import pandas as pd
3
- import xarray as xr
4
-
5
- import pytest
6
-
7
1
  from ocf_data_sampler.numpy_sample import convert_nwp_to_numpy_sample, NWPSampleKey
8
2
 
9
- @pytest.fixture(scope="module")
10
- def da_nwp_like():
11
- """Create dummy data which looks like time-sliced NWP data"""
12
-
13
- t0 = pd.to_datetime("2024-01-02 00:00")
14
-
15
- x = np.arange(-100, 100, 10)
16
- y = np.arange(-100, 100, 10)
17
- steps = pd.timedelta_range("0h", "8h", freq="1h")
18
- target_times = t0 + steps
19
-
20
- channels = ["t", "dswrf"]
21
- init_times = pd.to_datetime([t0]*len(steps))
22
-
23
- # Create dummy time-sliced NWP data
24
- da_nwp = xr.DataArray(
25
- np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
26
- coords=dict(
27
- target_times_utc=(["target_times_utc"], target_times),
28
- channel=(["channel"], channels),
29
- x_osgb=(["x_osgb"], x),
30
- y_osgb=(["y_osgb"], y),
31
- )
32
- )
33
-
34
- # Add extra non-coordinate dimensions
35
- da_nwp = da_nwp.assign_coords(
36
- init_time_utc=("target_times_utc", init_times),
37
- step=("target_times_utc", steps),
38
- )
39
-
40
- return da_nwp
41
-
42
3
 
43
- def test_convert_nwp_to_numpy_sample(da_nwp_like):
4
+ def test_convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced):
44
5
 
45
6
  # Call the function
46
- numpy_sample = convert_nwp_to_numpy_sample(da_nwp_like)
7
+ numpy_sample = convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced)
47
8
 
48
9
  # Assert the output type
49
10
  assert isinstance(numpy_sample, dict)
50
11
 
51
12
  # Assert the shape of the numpy sample
52
- assert (numpy_sample[NWPSampleKey.nwp] == da_nwp_like.values).all()
13
+ assert (numpy_sample[NWPSampleKey.nwp] == ds_nwp_ukv_time_sliced.values).all()
@@ -14,10 +14,8 @@ def da_sample():
14
14
  datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 13:00", freq="5min")
15
15
 
16
16
  da_sat = xr.DataArray(
17
- np.random.normal(size=(len(datetimes),)),
18
- coords=dict(
19
- time_utc=(["time_utc"], datetimes),
20
- )
17
+ np.random.normal(size=(len(datetimes))),
18
+ coords=dict(time_utc=datetimes)
21
19
  )
22
20
  return da_sat
23
21
 
@@ -29,7 +27,7 @@ def test_draw_dropout_time():
29
27
  dropout_time = draw_dropout_time(t0, dropout_timedeltas, dropout_frac=1)
30
28
 
31
29
  assert isinstance(dropout_time, pd.Timestamp)
32
- assert dropout_time-t0 in dropout_timedeltas
30
+ assert (dropout_time-t0) in dropout_timedeltas
33
31
 
34
32
 
35
33
  def test_draw_dropout_time_partial():
@@ -48,21 +46,17 @@ def test_draw_dropout_time_partial():
48
46
  dropouts == {None} | set(t0 + dt for dt in dropout_timedeltas)
49
47
 
50
48
 
51
- def test_draw_dropout_time_none():
49
+ def test_draw_dropout_time_null():
52
50
  t0 = pd.Timestamp("2021-01-01 04:00:00")
53
51
 
54
- # No dropout timedeltas
55
- dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=1)
56
- assert dropout_time is None
57
-
58
52
  # Dropout fraction is 0
59
53
  dropout_timedeltas = [pd.Timedelta(-30, "min")]
60
54
  dropout_time = draw_dropout_time(t0, dropout_timedeltas=dropout_timedeltas, dropout_frac=0)
61
- assert dropout_time is None
55
+ assert dropout_time==t0
62
56
 
63
57
  # No dropout timedeltas and dropout fraction is 0
64
- dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=0)
65
- assert dropout_time is None
58
+ dropout_time = draw_dropout_time(t0, dropout_timedeltas=[], dropout_frac=0)
59
+ assert dropout_time==t0
66
60
 
67
61
 
68
62
  @pytest.mark.parametrize("t0_str", ["12:00", "12:30", "13:00"])