ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,33 +1,34 @@
1
- """Convert site to Numpy Sample"""
1
+ """Convert site to Numpy Sample."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  class SiteSampleKey:
7
+ """Keys for the site sample dictionary."""
7
8
 
8
9
  generation = "site"
9
10
  capacity_kwp = "site_capacity_kwp"
10
11
  time_utc = "site_time_utc"
11
12
  t0_idx = "site_t0_idx"
12
13
  id = "site_id"
13
- solar_azimuth = "site_solar_azimuth"
14
- solar_elevation = "site_solar_elevation"
15
14
  date_sin = "site_date_sin"
16
15
  date_cos = "site_date_cos"
17
16
  time_sin = "site_time_sin"
18
17
  time_cos = "site_time_cos"
19
18
 
19
+
20
20
  def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
21
+ """Convert from Xarray to NumpySample.
22
22
 
23
- # Extract values from the DataArray
23
+ Args:
24
+ da: xarray DataArray containing site data
25
+ t0_idx: Index of the t0 timestamp in the time dimension of the site data
26
+ """
24
27
  sample = {
25
28
  SiteSampleKey.generation: da.values,
26
29
  SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
27
30
  SiteSampleKey.time_utc: da["time_utc"].values.astype(float),
28
31
  SiteSampleKey.id: da["site_id"].values,
29
- SiteSampleKey.solar_azimuth: da["solar_azimuth"].values,
30
- SiteSampleKey.solar_elevation: da["solar_elevation"].values,
31
32
  SiteSampleKey.date_sin: da["date_sin"].values,
32
33
  SiteSampleKey.date_cos: da["date_cos"].values,
33
34
  SiteSampleKey.time_sin: da["time_sin"].values,
@@ -1,16 +1,17 @@
1
+ """Module for calculating solar position."""
1
2
 
2
- import pvlib
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ import pvlib
5
6
 
6
7
 
7
8
  def calculate_azimuth_and_elevation(
8
- datetimes: pd.DatetimeIndex,
9
- lon: float,
10
- lat: float
9
+ datetimes: pd.DatetimeIndex,
10
+ lon: float,
11
+ lat: float,
11
12
  ) -> tuple[np.ndarray, np.ndarray]:
12
- """Calculate the solar coordinates for multiple datetimes at a single location
13
-
13
+ """Calculate the solar coordinates for multiple datetimes at a single location.
14
+
14
15
  Args:
15
16
  datetimes: The datetimes to calculate for
16
17
  lon: The longitude
@@ -20,46 +21,39 @@ def calculate_azimuth_and_elevation(
20
21
  np.ndarray: The azimuth of the datetimes in degrees
21
22
  np.ndarray: The elevation of the datetimes in degrees
22
23
  """
23
-
24
24
  solpos = pvlib.solarposition.get_solarposition(
25
25
  time=datetimes,
26
26
  longitude=lon,
27
27
  latitude=lat,
28
- method='nrel_numpy'
28
+ method="nrel_numpy",
29
29
  )
30
- azimuth = solpos["azimuth"].values
31
- elevation = solpos["elevation"].values
32
- return azimuth, elevation
30
+
31
+ return solpos["azimuth"].values, solpos["elevation"].values
33
32
 
34
33
 
35
34
  def make_sun_position_numpy_sample(
36
- datetimes: pd.DatetimeIndex,
37
- lon: float,
38
- lat: float,
39
- key_prefix: str = "gsp"
35
+ datetimes: pd.DatetimeIndex,
36
+ lon: float,
37
+ lat: float,
40
38
  ) -> dict:
41
- """Creates NumpySample with standardized solar coordinates
39
+ """Creates NumpySample with standardized solar coordinates.
42
40
 
43
41
  Args:
44
42
  datetimes: The datetimes to calculate solar angles for
45
43
  lon: The longitude
46
44
  lat: The latitude
47
45
  """
48
-
49
46
  azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
50
47
 
51
48
  # Normalise
52
-
53
49
  # Azimuth is in range [0, 360] degrees
54
50
  azimuth = azimuth / 360
55
51
 
56
- # Elevation is in range [-90, 90] degrees
52
+ # Elevation is in range [-90, 90] degrees
57
53
  elevation = elevation / 180 + 0.5
58
-
54
+
59
55
  # Make NumpySample
60
- sun_numpy_sample = {
61
- key_prefix + "_solar_azimuth": azimuth,
62
- key_prefix + "_solar_elevation": elevation,
56
+ return {
57
+ "solar_azimuth": azimuth,
58
+ "solar_elevation": elevation,
63
59
  }
64
-
65
- return sun_numpy_sample
@@ -1,10 +1,3 @@
1
1
  from ocf_data_sampler.sample.base import SampleBase
2
2
  from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
3
  from ocf_data_sampler.sample.site import SiteSample
4
-
5
-
6
- __all__ = [
7
- 'SampleBase',
8
- 'UKRegionalSample',
9
- 'SiteSample'
10
- ]
@@ -1,69 +1,49 @@
1
- """
2
- Base class definition - abstract
3
- Handling of both flat and nested structures - consideration for NWP
4
- """
1
+ """Base class for handling flat/nested data structures with NWP consideration."""
5
2
 
6
- import logging
7
- import numpy as np
8
- import torch
9
- import xarray as xr
10
-
11
- from pathlib import Path
12
- from typing import Any, Dict, Optional, Union, TypeAlias
13
3
  from abc import ABC, abstractmethod
4
+ from typing import TypeAlias
14
5
 
6
+ import numpy as np
7
+ import torch
15
8
 
16
- logger = logging.getLogger(__name__)
17
-
18
- NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
19
- NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
20
- TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
9
+ NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
10
+ NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
11
+ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
21
12
 
22
13
 
23
14
  class SampleBase(ABC):
24
- """
25
- Abstract base class for all sample types
26
- Provides core data storage functionality
27
- """
28
-
29
- def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
30
- """ Initialise data container """
31
- logger.debug("Initialising SampleBase instance")
32
- self._data = data
15
+ """Abstract base class for all sample types."""
33
16
 
34
17
  @abstractmethod
35
18
  def to_numpy(self) -> NumpySample:
36
- """ Convert data to a numpy array representation """
19
+ """Convert sample data to numpy format."""
37
20
  raise NotImplementedError
38
21
 
39
22
  @abstractmethod
40
- def plot(self, **kwargs) -> None:
41
- """ Abstract method for plotting """
23
+ def plot(self) -> None:
24
+ """Create a visualisation of the data."""
42
25
  raise NotImplementedError
43
26
 
44
27
  @abstractmethod
45
- def save(self, path: Union[str, Path]) -> None:
46
- """ Abstract method for saving sample data """
28
+ def save(self, path: str) -> None:
29
+ """Saves the sample to disk in the implementations' required format."""
47
30
  raise NotImplementedError
48
31
 
49
32
  @classmethod
50
33
  @abstractmethod
51
- def load(cls, path: Union[str, Path]) -> 'SampleBase':
52
- """ Abstract class method for loading sample data """
34
+ def load(cls, path: str) -> "SampleBase":
35
+ """Load a sample from disk from the implementations' format."""
53
36
  raise NotImplementedError
54
37
 
55
38
 
56
39
  def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
57
- """
58
- Moves ndarrays in a nested dict to torch tensors
40
+ """Recursively converts numpy arrays in nested dict to torch tensors.
41
+
59
42
  Args:
60
43
  batch: NumpyBatch with data in numpy arrays
61
44
  Returns:
62
45
  TensorBatch with data in torch tensors
63
46
  """
64
- if not batch:
65
- raise ValueError("Cannot convert empty batch to tensors")
66
-
67
47
  for k, v in batch.items():
68
48
  if isinstance(v, dict):
69
49
  batch[k] = batch_to_tensor(v)
@@ -75,22 +55,21 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
75
55
  return batch
76
56
 
77
57
 
78
- def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
- """
80
- Moves tensor leaves in a nested dict to a new device.
58
+ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
59
+ """Recursively copies tensors in nested dict to specified device.
81
60
 
82
61
  Args:
83
- batch: Nested dict with tensors to move.
84
- device: Device to move tensors to.
62
+ batch: Nested dict with tensors to move
63
+ device: Device to move tensors to
85
64
 
86
65
  Returns:
87
- A dict with tensors moved to the new device.
66
+ A dict with tensors moved to the new device
88
67
  """
89
68
  batch_copy = {}
90
69
 
91
70
  for k, v in batch.items():
92
71
  if isinstance(v, dict):
93
- batch_copy[k] = copy_batch_to_device(v, device)
72
+ batch_copy[k] = copy_batch_to_device(v, device)
94
73
  elif isinstance(v, torch.Tensor):
95
74
  batch_copy[k] = v.to(device)
96
75
  else:
@@ -1,81 +1,37 @@
1
- """
2
- PVNet - Site sample / dataset implementation
3
- """
1
+ """PVNet Site sample implementation for netCDF data handling and conversion."""
4
2
 
5
- import logging
6
3
  import xarray as xr
7
- import numpy as np
4
+ from typing_extensions import override
8
5
 
9
- from pathlib import Path
10
- from typing import Dict, Any, Union
11
-
12
- from ocf_data_sampler.sample.base import SampleBase
6
+ from ocf_data_sampler.sample.base import NumpySample, SampleBase
13
7
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
14
8
 
15
9
 
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
10
  class SiteSample(SampleBase):
20
- """ Sample class specific to Site PVNet """
11
+ """Handles PVNet site specific netCDF operations."""
21
12
 
22
- def __init__(self):
23
- logger.debug("Initialise SiteSample instance")
24
- super().__init__()
25
- self._data = {}
13
+ def __init__(self, data: xr.Dataset) -> None:
14
+ """Initializes the SiteSample object with the given xarray Dataset."""
15
+ if not isinstance(data, xr.Dataset):
16
+ raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
17
+ self._data = data
26
18
 
27
- def to_numpy(self) -> Dict[str, Any]:
28
- """ Convert sample numpy arrays - netCDF conversion """
29
- logger.debug("Converting site sample to numpy format")
30
-
31
- try:
32
- if not isinstance(self._data, xr.Dataset):
33
- raise TypeError("Data must be xarray Dataset")
34
-
35
- numpy_data = convert_netcdf_to_numpy_sample(self._data)
19
+ @override
20
+ def to_numpy(self) -> NumpySample:
21
+ return convert_netcdf_to_numpy_sample(self._data)
36
22
 
37
- logger.debug("Successfully converted to numpy format")
38
- return numpy_data
39
-
40
- except Exception as e:
41
- logger.error(f"Error converting to numpy: {str(e)}")
42
- raise
43
-
44
- def save(self, path: Union[str, Path]) -> None:
45
- """ Save site sample as netCDF - h5netcdf engine """
46
- logger.debug(f"Saving SiteSample to {path}")
47
- path = Path(path)
48
-
49
- if path.suffix != '.nc':
50
- logger.error(f"Invalid file format - {path.suffix}")
51
- raise ValueError("Only .nc format is supported")
52
-
53
- if not isinstance(self._data, xr.Dataset):
54
- raise TypeError("Data must be xarray Dataset for saving")
55
-
56
- self._data.to_netcdf(
57
- path,
58
- mode="w",
59
- engine="h5netcdf"
60
- )
61
- logger.debug(f"Successfully saved SiteSample - {path}")
23
+ @override
24
+ def save(self, path: str) -> None:
25
+ # Saves as NetCDF
26
+ self._data.to_netcdf(path, mode="w", engine="h5netcdf")
62
27
 
63
28
  @classmethod
64
- def load(cls, path: str) -> None:
65
- """ Load site sample from netCDF """
66
- logger.debug(f"Loading SiteSample from {path}")
67
- path = Path(path)
68
-
69
- if path.suffix != '.nc':
70
- logger.error(f"Invalid file format - {path.suffix}")
71
- raise ValueError("Only .nc format is supported")
72
-
73
- instance = cls()
74
- instance._data = xr.open_dataset(path)
75
- logger.debug(f"Loaded SiteSample from {path}")
76
- return instance
77
-
78
- # TO DO - placeholder for now
79
- def plot(self, **kwargs) -> None:
80
- """ Plot sample data - placeholder """
81
- pass
29
+ @override
30
+ def load(cls, path: str) -> "SiteSample":
31
+ # Loads from NetCDF
32
+ return cls(xr.open_dataset(path))
33
+
34
+ @override
35
+ def plot(self) -> None:
36
+ # TODO - placeholder for now
37
+ raise NotImplementedError("Plotting not yet implemented for SiteSample")
@@ -1,120 +1,69 @@
1
- """
2
- PVNet - UK Regional sample / dataset implementation
3
- """
1
+ """PVNet UK Regional sample implementation for dataset handling and visualisation."""
4
2
 
5
- import numpy as np
6
- import pandas as pd
7
3
  import torch
8
- import logging
9
-
10
- from typing import Dict, Any, Union, List, Optional
11
- from pathlib import Path
4
+ from typing_extensions import override
12
5
 
13
6
  from ocf_data_sampler.numpy_sample import (
14
- NWPSampleKey,
15
7
  GSPSampleKey,
16
- SatelliteSampleKey
8
+ NWPSampleKey,
9
+ SatelliteSampleKey,
17
10
  )
18
-
19
- from ocf_data_sampler.sample.base import SampleBase
20
-
21
- try:
22
- import matplotlib.pyplot as plt
23
- MATPLOTLIB_AVAILABLE = True
24
- except ImportError:
25
- MATPLOTLIB_AVAILABLE = False
26
- plt = None
27
-
28
-
29
- logger = logging.getLogger(__name__)
11
+ from ocf_data_sampler.sample.base import NumpySample, SampleBase
30
12
 
31
13
 
32
14
  class UKRegionalSample(SampleBase):
33
- """ Sample class specific to UK Regional PVNet """
15
+ """Handles UK Regional PVNet data operations."""
34
16
 
35
- def __init__(self):
36
- logger.debug("Initialise UKRegionalSample instance")
37
- super().__init__()
38
- self._data = {}
17
+ def __init__(self, data: NumpySample) -> None:
18
+ """Initialises UK Regional sample with data."""
19
+ self._data = data
39
20
 
40
- def to_numpy(self) -> Dict[str, Any]:
41
- """ Convert sample data to numpy format """
42
- logger.debug("Converting sample data to numpy format")
21
+ @override
22
+ def to_numpy(self) -> NumpySample:
43
23
  return self._data
44
24
 
45
- def save(self, path: Union[str, Path]) -> None:
46
- """ Save PVNet sample as .pt """
47
- logger.debug(f"Saving UKRegionalSample to {path}")
48
- path = Path(path)
49
-
50
- if path.suffix != '.pt':
51
- logger.error(f"Invalid file format: {path.suffix}")
52
- raise ValueError(f"Only .pt format is supported: {path.suffix}")
53
-
25
+ @override
26
+ def save(self, path: str) -> None:
27
+ # Saves to pickle format
54
28
  torch.save(self._data, path)
55
- logger.debug(f"Successfully saved UKRegionalSample to {path}")
56
29
 
57
30
  @classmethod
58
- def load(cls, path: Union[str, Path]) -> 'UKRegionalSample':
59
- """ Load PVNet sample data from .pt """
60
- logger.debug(f"Attempting to load UKRegionalSample from {path}")
61
- path = Path(path)
62
-
63
- if path.suffix != '.pt':
64
- logger.error(f"Invalid file format: {path.suffix}")
65
- raise ValueError(f"Only .pt format is supported: {path.suffix}")
66
-
67
- instance = cls()
31
+ @override
32
+ def load(cls, path: str) -> "UKRegionalSample":
33
+ # Loads from .pt format
68
34
  # TODO: We should move away from using torch.load(..., weights_only=False)
69
- # This is not recommended
70
- instance._data = torch.load(path, weights_only=False)
71
- logger.debug(f"Successfully loaded UKRegionalSample from {path}")
72
- return instance
73
-
74
- def plot(self, **kwargs) -> None:
75
- """ Sample visualisation definition """
76
- logger.debug("Creating UKRegionalSample visualisation")
77
-
78
- if not MATPLOTLIB_AVAILABLE:
79
- raise ImportError(
80
- "Matplotlib required for plotting"
81
- "Install via 'ocf_data_sampler[plot]'"
82
- )
83
-
84
- try:
85
- fig, axes = plt.subplots(2, 2, figsize=(12, 8))
86
-
87
- if NWPSampleKey.nwp in self._data:
88
- logger.debug("Plotting NWP data")
89
- first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
90
- if 'nwp' in first_nwp:
91
- axes[0, 1].imshow(first_nwp['nwp'][0])
92
- axes[0, 1].set_title('NWP (First Channel)')
93
- if NWPSampleKey.channel_names in first_nwp:
94
- channel_names = first_nwp[NWPSampleKey.channel_names]
95
- if len(channel_names) > 0:
96
- axes[0, 1].set_title(f'NWP: {channel_names[0]}')
97
-
98
- if GSPSampleKey.gsp in self._data:
99
- logger.debug("Plotting GSP generation data")
100
- axes[0, 0].plot(self._data[GSPSampleKey.gsp])
101
- axes[0, 0].set_title('GSP Generation')
102
-
103
- if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
104
- logger.debug("Plotting solar position data")
105
- axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
106
- axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
107
- axes[1, 1].set_title('Solar Position')
108
- axes[1, 1].legend()
109
-
110
- if SatelliteSampleKey.satellite_actual in self._data:
111
- logger.debug("Plotting satellite data")
112
- axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
113
- axes[1, 0].set_title('Satellite Data')
114
-
115
- plt.tight_layout()
116
- plt.show()
117
- logger.debug("Successfully created visualisation")
118
- except Exception as e:
119
- logger.error(f"Error creating visualisation: {str(e)}")
120
- raise
35
+ return cls(torch.load(path, weights_only=False))
36
+
37
+ @override
38
+ def plot(self) -> None:
39
+ from matplotlib import pyplot as plt
40
+
41
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
42
+
43
+ if NWPSampleKey.nwp in self._data:
44
+ first_nwp = next(iter(self._data[NWPSampleKey.nwp].values()))
45
+ if "nwp" in first_nwp:
46
+ axes[0, 1].imshow(first_nwp["nwp"][0])
47
+ title = "NWP (First Channel)"
48
+ if NWPSampleKey.channel_names in first_nwp:
49
+ channel_names = first_nwp[NWPSampleKey.channel_names]
50
+ if channel_names:
51
+ title = f"NWP: {channel_names[0]}"
52
+ axes[0, 1].set_title(title)
53
+
54
+ if GSPSampleKey.gsp in self._data:
55
+ axes[0, 0].plot(self._data[GSPSampleKey.gsp])
56
+ axes[0, 0].set_title("GSP Generation")
57
+
58
+ if "solar_azimuth" in self._data and "solar_elevation" in self._data:
59
+ axes[1, 1].plot(self._data["solar_azimuth"], label="Azimuth")
60
+ axes[1, 1].plot(self._data["solar_elevation"], label="Elevation")
61
+ axes[1, 1].set_title("Solar Position")
62
+ axes[1, 1].legend()
63
+
64
+ if SatelliteSampleKey.satellite_actual in self._data:
65
+ axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
66
+ axes[1, 0].set_title("Satellite Data")
67
+
68
+ plt.tight_layout()
69
+ plt.show()
@@ -1,39 +1,54 @@
1
- """ Functions for simulating dropout in time series data """
1
+ """Functions for simulating dropout in time series data.
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
6
+
2
7
  import numpy as np
3
8
  import pandas as pd
4
9
  import xarray as xr
5
10
 
6
11
 
7
12
  def draw_dropout_time(
8
- t0: pd.Timestamp,
9
- dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None,
10
- dropout_frac: float = 0,
11
- ):
12
-
13
- if dropout_timedeltas is not None:
14
- assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas"
15
- assert all(
16
- [t <= pd.Timedelta("0min") for t in dropout_timedeltas]
17
- ), "dropout timedeltas must be negative"
18
- assert 0 <= dropout_frac <= 1
19
-
20
- if (dropout_timedeltas is None) or (np.random.uniform() >= dropout_frac):
21
- dropout_time = None
13
+ t0: pd.Timestamp,
14
+ dropout_timedeltas: list[pd.Timedelta],
15
+ dropout_frac: float,
16
+ ) -> pd.Timestamp:
17
+ """Randomly pick a dropout time from a list of timedeltas.
18
+
19
+ Args:
20
+ t0: The forecast init-time
21
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
22
+ dropout_frac: Probability that dropout will be applied.
23
+ This should be between 0 and 1 inclusive
24
+ """
25
+ if dropout_frac > 0 and len(dropout_timedeltas) == 0:
26
+ raise ValueError("To apply dropout, dropout_timedeltas must be provided")
27
+
28
+ for t in dropout_timedeltas:
29
+ if t > pd.Timedelta("0min"):
30
+ raise ValueError("Dropout timedeltas must be negative")
31
+
32
+ if not (0 <= dropout_frac <= 1):
33
+ raise ValueError("dropout_frac must be between 0 and 1 inclusive")
34
+
35
+ if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
36
+ dropout_time = t0
22
37
  else:
23
- t0_datetime_utc = pd.Timestamp(t0)
24
- dt = np.random.choice(dropout_timedeltas)
25
- dropout_time = t0_datetime_utc + dt
38
+ dropout_time = t0 + np.random.choice(dropout_timedeltas)
26
39
 
27
40
  return dropout_time
28
41
 
29
42
 
30
43
  def apply_dropout_time(
31
- ds: xr.DataArray,
32
- dropout_time: pd.Timestamp | None,
33
- ):
34
-
35
- if dropout_time is None:
36
- return ds
37
- else:
38
- # This replaces the times after the dropout with NaNs
39
- return ds.where(ds.time_utc <= dropout_time)
44
+ ds: xr.DataArray,
45
+ dropout_time: pd.Timestamp,
46
+ ) -> xr.DataArray:
47
+ """Apply dropout time to the data.
48
+
49
+ Args:
50
+ ds: Xarray DataArray with 'time_utc' coordinate
51
+ dropout_time: Time after which data is set to NaN
52
+ """
53
+ # This replaces the times after the dropout with NaNs
54
+ return ds.where(ds.time_utc <= dropout_time)
@@ -1,11 +1,23 @@
1
- """fill time periods"""
1
+ """Fill time periods between specified start and end dates."""
2
2
 
3
- import pandas as pd
4
3
  import numpy as np
4
+ import pandas as pd
5
5
 
6
6
 
7
7
  def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
+ """Create range of timestamps between given start and end times.
9
+
10
+ Each of the continuous periods (i.e. each row of the input DataFrame) is filled with the
11
+ specified frequency.
12
+
13
+ Args:
14
+ time_periods: DataFrame with columns 'start_dt' and 'end_dt'
15
+ freq: Frequency to fill time periods with
16
+ """
8
17
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
9
18
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
10
- date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
19
+ date_ranges = [
20
+ pd.date_range(start_dt, end_dt, freq=freq)
21
+ for start_dt, end_dt in zip(start_dts, end_dts, strict=False)
22
+ ]
11
23
  return pd.DatetimeIndex(np.concatenate(date_ranges))