ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (76) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +73 -61
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +9 -10
  14. ocf_data_sampler/load/site.py +10 -6
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  19. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  20. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  21. ocf_data_sampler/numpy_sample/site.py +5 -8
  22. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  23. ocf_data_sampler/sample/base.py +15 -17
  24. ocf_data_sampler/sample/site.py +13 -20
  25. ocf_data_sampler/sample/uk_regional.py +29 -35
  26. ocf_data_sampler/select/dropout.py +16 -14
  27. ocf_data_sampler/select/fill_time_periods.py +15 -5
  28. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  29. ocf_data_sampler/select/geospatial.py +63 -54
  30. ocf_data_sampler/select/location.py +16 -51
  31. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  32. ocf_data_sampler/select/select_time_slice.py +71 -58
  33. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  34. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  35. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  36. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  37. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  38. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  39. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +62 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  48. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  49. tests/__init__.py +0 -0
  50. tests/config/test_config.py +0 -113
  51. tests/config/test_load.py +0 -7
  52. tests/config/test_save.py +0 -28
  53. tests/conftest.py +0 -319
  54. tests/load/test_load_gsp.py +0 -15
  55. tests/load/test_load_nwp.py +0 -21
  56. tests/load/test_load_satellite.py +0 -17
  57. tests/load/test_load_sites.py +0 -14
  58. tests/numpy_sample/test_collate.py +0 -21
  59. tests/numpy_sample/test_datetime_features.py +0 -37
  60. tests/numpy_sample/test_gsp.py +0 -38
  61. tests/numpy_sample/test_nwp.py +0 -13
  62. tests/numpy_sample/test_satellite.py +0 -40
  63. tests/numpy_sample/test_sun_position.py +0 -81
  64. tests/select/test_dropout.py +0 -69
  65. tests/select/test_fill_time_periods.py +0 -28
  66. tests/select/test_find_contiguous_time_periods.py +0 -202
  67. tests/select/test_location.py +0 -67
  68. tests/select/test_select_spatial_slice.py +0 -154
  69. tests/select/test_select_time_slice.py +0 -275
  70. tests/test_sample/test_base.py +0 -164
  71. tests/test_sample/test_site_sample.py +0 -165
  72. tests/test_sample/test_uk_regional_sample.py +0 -136
  73. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  74. tests/torch_datasets/test_pvnet_uk.py +0 -154
  75. tests/torch_datasets/test_site.py +0 -226
  76. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,11 +1,10 @@
1
- """ Base class for handling flat/nested data structures with NWP consideration """
1
+ """Base class for handling flat/nested data structures with NWP consideration."""
2
2
 
3
- import numpy as np
4
- import torch
5
-
6
- from typing import TypeAlias
7
3
  from abc import ABC, abstractmethod
4
+ from typing import TypeAlias
8
5
 
6
+ import numpy as np
7
+ import torch
9
8
 
10
9
  NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
11
10
  NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
@@ -13,39 +12,38 @@ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
13
12
 
14
13
 
15
14
  class SampleBase(ABC):
16
- """
17
- Abstract base class for all sample types
18
- Provides core data storage functionality
19
- """
15
+ """Abstract base class for all sample types."""
20
16
 
21
17
  @abstractmethod
22
18
  def to_numpy(self) -> NumpySample:
23
- """Convert sample data to numpy format"""
19
+ """Convert sample data to numpy format."""
24
20
  raise NotImplementedError
25
21
 
26
22
  @abstractmethod
27
23
  def plot(self) -> None:
24
+ """Create a visualisation of the data."""
28
25
  raise NotImplementedError
29
26
 
30
27
  @abstractmethod
31
28
  def save(self, path: str) -> None:
29
+ """Saves the sample to disk in the implementations' required format."""
32
30
  raise NotImplementedError
33
31
 
34
32
  @classmethod
35
33
  @abstractmethod
36
- def load(cls, path: str) -> 'SampleBase':
34
+ def load(cls, path: str) -> "SampleBase":
35
+ """Load a sample from disk from the implementations' format."""
37
36
  raise NotImplementedError
38
37
 
39
38
 
40
39
  def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
41
- """
42
- Recursively converts numpy arrays in nested dict to torch tensors
40
+ """Recursively converts numpy arrays in nested dict to torch tensors.
41
+
43
42
  Args:
44
43
  batch: NumpyBatch with data in numpy arrays
45
44
  Returns:
46
45
  TensorBatch with data in torch tensors
47
46
  """
48
-
49
47
  for k, v in batch.items():
50
48
  if isinstance(v, dict):
51
49
  batch[k] = batch_to_tensor(v)
@@ -58,12 +56,12 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
58
56
 
59
57
 
60
58
  def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
61
- """Recursively copies tensors in nested dict to specified device
59
+ """Recursively copies tensors in nested dict to specified device.
62
60
 
63
61
  Args:
64
62
  batch: Nested dict with tensors to move
65
63
  device: Device to move tensors to
66
-
64
+
67
65
  Returns:
68
66
  A dict with tensors moved to the new device
69
67
  """
@@ -71,7 +69,7 @@ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatc
71
69
 
72
70
  for k, v in batch.items():
73
71
  if isinstance(v, dict):
74
- batch_copy[k] = copy_batch_to_device(v, device)
72
+ batch_copy[k] = copy_batch_to_device(v, device)
75
73
  elif isinstance(v, torch.Tensor):
76
74
  batch_copy[k] = v.to(device)
77
75
  else:
@@ -1,44 +1,37 @@
1
- """PVNet Site sample implementation for netCDF data handling and conversion"""
1
+ """PVNet Site sample implementation for netCDF data handling and conversion."""
2
2
 
3
3
  import xarray as xr
4
-
5
4
  from typing_extensions import override
6
5
 
7
- from ocf_data_sampler.sample.base import SampleBase, NumpySample
6
+ from ocf_data_sampler.sample.base import NumpySample, SampleBase
8
7
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
9
8
 
10
9
 
11
10
  class SiteSample(SampleBase):
12
- """Handles PVNet site specific netCDF operations"""
11
+ """Handles PVNet site specific netCDF operations."""
13
12
 
14
- def __init__(self, data: xr.Dataset):
15
-
13
+ def __init__(self, data: xr.Dataset) -> None:
14
+ """Initializes the SiteSample object with the given xarray Dataset."""
16
15
  if not isinstance(data, xr.Dataset):
17
16
  raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
18
-
19
17
  self._data = data
20
18
 
21
19
  @override
22
- def to_numpy(self) -> NumpySample:
20
+ def to_numpy(self) -> NumpySample:
23
21
  return convert_netcdf_to_numpy_sample(self._data)
24
22
 
23
+ @override
25
24
  def save(self, path: str) -> None:
26
- """Save site sample data as netCDF
27
-
28
- Args:
29
- path: Path to save the netCDF file
30
- """
25
+ # Saves as NetCDF
31
26
  self._data.to_netcdf(path, mode="w", engine="h5netcdf")
32
27
 
33
28
  @classmethod
34
- def load(cls, path: str) -> 'SiteSample':
35
- """Load site sample data from netCDF
36
-
37
- Args:
38
- path: Path to load the netCDF file from
39
- """
29
+ @override
30
+ def load(cls, path: str) -> "SiteSample":
31
+ # Loads from NetCDF
40
32
  return cls(xr.open_dataset(path))
41
33
 
42
- # TODO - placeholder for now
34
+ @override
43
35
  def plot(self) -> None:
36
+ # TODO - placeholder for now
44
37
  raise NotImplementedError("Plotting not yet implemented for SiteSample")
@@ -1,75 +1,69 @@
1
- """PVNet UK Regional sample implementation for dataset handling and visualisation"""
2
-
3
- from typing_extensions import override
1
+ """PVNet UK Regional sample implementation for dataset handling and visualisation."""
4
2
 
5
3
  import torch
4
+ from typing_extensions import override
6
5
 
7
- from ocf_data_sampler.sample.base import SampleBase, NumpySample
8
6
  from ocf_data_sampler.numpy_sample import (
9
- NWPSampleKey,
10
7
  GSPSampleKey,
11
- SatelliteSampleKey
8
+ NWPSampleKey,
9
+ SatelliteSampleKey,
12
10
  )
11
+ from ocf_data_sampler.sample.base import NumpySample, SampleBase
13
12
 
14
13
 
15
14
  class UKRegionalSample(SampleBase):
16
- """Handles UK Regional PVNet data operations"""
15
+ """Handles UK Regional PVNet data operations."""
17
16
 
18
- def __init__(self, data: NumpySample):
17
+ def __init__(self, data: NumpySample) -> None:
18
+ """Initialises UK Regional sample with data."""
19
19
  self._data = data
20
20
 
21
21
  @override
22
22
  def to_numpy(self) -> NumpySample:
23
23
  return self._data
24
24
 
25
+ @override
25
26
  def save(self, path: str) -> None:
26
- """Save PVNet sample as pickle format using torch.save
27
-
28
- Args:
29
- path: Path to save the sample data to
30
- """
27
+ # Saves to pickle format
31
28
  torch.save(self._data, path)
32
29
 
33
30
  @classmethod
34
- def load(cls, path: str) -> 'UKRegionalSample':
35
- """Load PVNet sample data from .pt format
36
-
37
- Args:
38
- path: Path to load the sample data from
39
- """
31
+ @override
32
+ def load(cls, path: str) -> "UKRegionalSample":
33
+ # Loads from .pt format
40
34
  # TODO: We should move away from using torch.load(..., weights_only=False)
41
35
  return cls(torch.load(path, weights_only=False))
42
36
 
37
+ @override
43
38
  def plot(self) -> None:
44
- """Creates visualisations for NWP, GSP, solar position, and satellite data"""
45
39
  from matplotlib import pyplot as plt
46
40
 
47
41
  fig, axes = plt.subplots(2, 2, figsize=(12, 8))
48
-
42
+
49
43
  if NWPSampleKey.nwp in self._data:
50
- first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
51
- if 'nwp' in first_nwp:
52
- axes[0, 1].imshow(first_nwp['nwp'][0])
53
- title = 'NWP (First Channel)'
44
+ first_nwp = next(iter(self._data[NWPSampleKey.nwp].values()))
45
+ if "nwp" in first_nwp:
46
+ axes[0, 1].imshow(first_nwp["nwp"][0])
47
+ title = "NWP (First Channel)"
54
48
  if NWPSampleKey.channel_names in first_nwp:
55
49
  channel_names = first_nwp[NWPSampleKey.channel_names]
56
50
  if channel_names:
57
- title = f'NWP: {channel_names[0]}'
51
+ title = f"NWP: {channel_names[0]}"
58
52
  axes[0, 1].set_title(title)
59
53
 
60
54
  if GSPSampleKey.gsp in self._data:
61
55
  axes[0, 0].plot(self._data[GSPSampleKey.gsp])
62
- axes[0, 0].set_title('GSP Generation')
63
-
64
- if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
65
- axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
66
- axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
67
- axes[1, 1].set_title('Solar Position')
68
- axes[1, 1].legend()
56
+ axes[0, 0].set_title("GSP Generation")
57
+
58
+ if "solar_azimuth" in self._data and "solar_elevation" in self._data:
59
+ axes[1, 1].plot(self._data["solar_azimuth"], label="Azimuth")
60
+ axes[1, 1].plot(self._data["solar_elevation"], label="Elevation")
61
+ axes[1, 1].set_title("Solar Position")
62
+ axes[1, 1].legend()
69
63
 
70
64
  if SatelliteSampleKey.satellite_actual in self._data:
71
65
  axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
72
- axes[1, 0].set_title('Satellite Data')
73
-
66
+ axes[1, 0].set_title("Satellite Data")
67
+
74
68
  plt.tight_layout()
75
69
  plt.show()
@@ -1,8 +1,9 @@
1
- """Functions for simulating dropout in time series data
1
+ """Functions for simulating dropout in time series data.
2
2
 
3
3
  This is used for the following types of data: GSP, Satellite and Site
4
4
  This is not used for NWP
5
5
  """
6
+
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import xarray as xr
@@ -13,22 +14,23 @@ def draw_dropout_time(
13
14
  dropout_timedeltas: list[pd.Timedelta],
14
15
  dropout_frac: float,
15
16
  ) -> pd.Timestamp:
16
- """Randomly pick a dropout time from a list of timedeltas
17
-
17
+ """Randomly pick a dropout time from a list of timedeltas.
18
+
18
19
  Args:
19
20
  t0: The forecast init-time
20
21
  dropout_timedeltas: List of timedeltas relative to t0 to pick from
21
- dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
22
- inclusive
22
+ dropout_frac: Probability that dropout will be applied.
23
+ This should be between 0 and 1 inclusive
23
24
  """
24
-
25
- if dropout_frac>0:
26
- assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
25
+ if dropout_frac > 0 and len(dropout_timedeltas) == 0:
26
+ raise ValueError("To apply dropout, dropout_timedeltas must be provided")
27
27
 
28
28
  for t in dropout_timedeltas:
29
- assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
29
+ if t > pd.Timedelta("0min"):
30
+ raise ValueError("Dropout timedeltas must be negative")
30
31
 
31
- assert 0 <= dropout_frac <= 1
32
+ if not (0 <= dropout_frac <= 1):
33
+ raise ValueError("dropout_frac must be between 0 and 1 inclusive")
32
34
 
33
35
  if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
34
36
  dropout_time = t0
@@ -41,11 +43,11 @@ def draw_dropout_time(
41
43
  def apply_dropout_time(
42
44
  ds: xr.DataArray,
43
45
  dropout_time: pd.Timestamp,
44
- ) -> xr.DataArray:
45
- """Apply dropout time to the data
46
-
46
+ ) -> xr.DataArray:
47
+ """Apply dropout time to the data.
48
+
47
49
  Args:
48
- ds: Xarray DataArray with 'time_utc' coordiante
50
+ ds: Xarray DataArray with 'time_utc' coordinate
49
51
  dropout_time: Time after which data is set to NaN
50
52
  """
51
53
  # This replaces the times after the dropout with NaNs
@@ -1,13 +1,23 @@
1
- """Fill time periods between start and end dates at specified frequency"""
1
+ """Fill time periods between specified start and end dates."""
2
2
 
3
- import pandas as pd
4
3
  import numpy as np
4
+ import pandas as pd
5
5
 
6
6
 
7
7
  def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
- """Generate DatetimeIndex for all timestamps between start and end dates"""
9
-
8
+ """Create range of timestamps between given start and end times.
9
+
10
+ Each of the continuous periods (i.e. each row of the input DataFrame) is filled with the
11
+ specified frequency.
12
+
13
+ Args:
14
+ time_periods: DataFrame with columns 'start_dt' and 'end_dt'
15
+ freq: Frequency to fill time periods with
16
+ """
10
17
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
11
18
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
12
- date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
19
+ date_ranges = [
20
+ pd.date_range(start_dt, end_dt, freq=freq)
21
+ for start_dt, end_dt in zip(start_dts, end_dts, strict=False)
22
+ ]
13
23
  return pd.DatetimeIndex(np.concatenate(date_ranges))
@@ -1,9 +1,12 @@
1
- """Get contiguous time periods for training"""
1
+ """Get contiguous time periods."""
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+
5
6
  from ocf_data_sampler.load.utils import check_time_unique_increasing
6
7
 
8
+ ZERO_TDELTA = pd.Timedelta(0)
9
+
7
10
 
8
11
  def find_contiguous_time_periods(
9
12
  datetimes: pd.DatetimeIndex,
@@ -14,20 +17,20 @@ def find_contiguous_time_periods(
14
17
 
15
18
  Args:
16
19
  datetimes: pd.DatetimeIndex. Must be sorted.
17
- min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this
18
- would be set to the `total_seq_length` of each machine learning example.
20
+ min_seq_length: Sequences of min_seq_length or shorter will be discarded.
19
21
  max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
20
22
  apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
21
- sequences. Typically, `max_gap_duration` would be set to the sample period of
22
- the timeseries.
23
+ sequences.
23
24
 
24
25
  Returns:
25
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
26
- has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
26
+ pd.DataFrame where each row represents a single time period. The pd.DataFrame
27
+ has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
27
28
  """
28
29
  # Sanity checks.
29
- assert len(datetimes) > 0
30
- assert min_seq_length > 1
30
+ if len(datetimes) == 0:
31
+ raise ValueError("No datetimes to use")
32
+ if min_seq_length <= 1:
33
+ raise ValueError(f"{min_seq_length=} must be greater than 1")
31
34
  check_time_unique_increasing(datetimes)
32
35
 
33
36
  # Find indices of gaps larger than max_gap:
@@ -43,77 +46,75 @@ def find_contiguous_time_periods(
43
46
  # Capture the last segment of dt_index.
44
47
  segment_boundaries = np.concatenate((segment_boundaries, [len(datetimes)]))
45
48
 
46
- periods: list[dict[str, pd.Timestamp]] = []
49
+ periods: list[list[pd.Timestamp]] = []
47
50
  start_i = 0
48
51
  for next_start_i in segment_boundaries:
49
52
  n_timesteps = next_start_i - start_i
50
53
  if n_timesteps > min_seq_length:
51
54
  end_i = next_start_i - 1
52
- period = {"start_dt": datetimes[start_i], "end_dt": datetimes[end_i]}
53
- periods.append(period)
55
+ periods.append([datetimes[start_i], datetimes[end_i]])
54
56
  start_i = next_start_i
55
57
 
56
- assert len(periods) > 0, (
57
- f"Did not find an periods from {datetimes}. " f"{min_seq_length=} {max_gap_duration=}"
58
- )
58
+ if len(periods) == 0:
59
+ raise ValueError(
60
+ f"Did not find any periods from {datetimes}. {min_seq_length=} {max_gap_duration=}",
61
+ )
59
62
 
60
- return pd.DataFrame(periods)
63
+ return pd.DataFrame(periods, columns=["start_dt", "end_dt"])
61
64
 
62
65
 
63
66
  def trim_contiguous_time_periods(
64
- contiguous_time_periods: pd.DataFrame,
67
+ contiguous_time_periods: pd.DataFrame,
65
68
  interval_start: pd.Timedelta,
66
69
  interval_end: pd.Timedelta,
67
70
  ) -> pd.DataFrame:
68
- """Trim the contiguous time periods to allow for history and forecast durations.
71
+ """Trims contiguous time periods to account for history requirements and forecast horizons.
69
72
 
70
73
  Args:
71
- contiguous_time_periods: DataFrame where each row represents a single time period. The
72
- DataFrame must have `start_dt` and `end_dt` columns.
74
+ contiguous_time_periods: pd.DataFrame where each row represents a single time period.
75
+ The pd.DataFrame must have `start_dt` and `end_dt` columns.
73
76
  interval_start: The start of the interval with respect to t0
74
77
  interval_end: The end of the interval with respect to t0
75
78
 
76
-
77
79
  Returns:
78
- The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
80
+ The contiguous_time_periods pd.DataFrame with the `start_dt` and `end_dt` columns updated.
79
81
  """
80
- contiguous_time_periods = contiguous_time_periods.copy()
81
-
82
- contiguous_time_periods["start_dt"] -= interval_start
83
- contiguous_time_periods["end_dt"] -= interval_end
82
+ # Make a copy so the data is not edited in place.
83
+ trimmed_time_periods = contiguous_time_periods.copy()
84
+ trimmed_time_periods["start_dt"] -= interval_start
85
+ trimmed_time_periods["end_dt"] -= interval_end
84
86
 
85
- valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
86
- contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
87
-
88
- return contiguous_time_periods
87
+ valid_mask = trimmed_time_periods["start_dt"] <= trimmed_time_periods["end_dt"]
89
88
 
89
+ return trimmed_time_periods.loc[valid_mask]
90
90
 
91
91
 
92
92
  def find_contiguous_t0_periods(
93
- datetimes: pd.DatetimeIndex,
94
- interval_start: pd.Timedelta,
95
- interval_end: pd.Timedelta,
96
- sample_period_duration: pd.Timedelta,
97
- ) -> pd.DataFrame:
93
+ datetimes: pd.DatetimeIndex,
94
+ interval_start: pd.Timedelta,
95
+ interval_end: pd.Timedelta,
96
+ time_resolution: pd.Timedelta,
97
+ ) -> pd.DataFrame:
98
98
  """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
99
99
 
100
100
  Args:
101
- datetimes: pd.DatetimeIndex. Must be sorted.
101
+ datetimes: pd.DatetimeIndex
102
102
  interval_start: The start of the interval with respect to t0
103
103
  interval_end: The end of the interval with respect to t0
104
- sample_period_duration: The sample frequency of the timeseries
105
-
104
+ time_resolution: The sample frequency of the timeseries
106
105
 
107
106
  Returns:
108
107
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
109
108
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
110
109
  """
110
+ check_time_unique_increasing(datetimes)
111
+
111
112
  total_duration = interval_end - interval_start
112
-
113
+
113
114
  contiguous_time_periods = find_contiguous_time_periods(
114
115
  datetimes=datetimes,
115
- min_seq_length=int(total_duration / sample_period_duration) + 1,
116
- max_gap_duration=sample_period_duration,
116
+ min_seq_length=int(total_duration / time_resolution) + 1,
117
+ max_gap_duration=time_resolution,
117
118
  )
118
119
 
119
120
  contiguous_t0_periods = trim_contiguous_time_periods(
@@ -122,7 +123,11 @@ def find_contiguous_t0_periods(
122
123
  interval_end=interval_end,
123
124
  )
124
125
 
125
- assert len(contiguous_t0_periods) > 0
126
+ if len(contiguous_t0_periods) == 0:
127
+ raise ValueError(
128
+ f"No contiguous time periods found for {datetimes}. "
129
+ f"{interval_start=} {interval_end=} {time_resolution=}",
130
+ )
126
131
 
127
132
  return contiguous_t0_periods
128
133
 
@@ -131,54 +136,59 @@ def find_contiguous_t0_periods_nwp(
131
136
  init_times: pd.DatetimeIndex,
132
137
  interval_start: pd.Timedelta,
133
138
  max_staleness: pd.Timedelta,
134
- max_dropout: pd.Timedelta = pd.Timedelta(0),
135
- first_forecast_step: pd.Timedelta = pd.Timedelta(0),
136
-
139
+ max_dropout: pd.Timedelta = ZERO_TDELTA,
140
+ first_forecast_step: pd.Timedelta = ZERO_TDELTA,
137
141
  ) -> pd.DataFrame:
138
- """Get all time periods from the NWP init times which are valid as t0 datetimes.
142
+ """Get all time periods from the NWP init-times which are valid as t0 datetimes.
139
143
 
140
144
  Args:
141
145
  init_times: The initialisation times of the available forecasts
142
- interval_start: The start of the desired data interval with respect to t0
143
- max_staleness: Up to how long after an init time are we willing to use the forecast. Each
144
- init time will only be used up to this t0 time regardless of the forecast valid time.
145
- max_dropout: What is the maximum amount of dropout that will be used. This must be <=
146
- max_staleness.
147
- first_forecast_step: The timedelta of the first step of the forecast. By default we assume
148
- the first valid time of the forecast is the same as its init time.
146
+ interval_start: The start of the time interval with respect to t0
147
+ max_staleness: Up to how long after an init-time are we willing to use the forecast.
148
+ Each init-time will only be used up to this t0 time regardless of the forecast valid
149
+ time.
150
+ max_dropout: What is the maximum amount of dropout that will be used.
151
+ This must be <= max_staleness.
152
+ first_forecast_step: The timedelta of the first step of the forecast.
153
+ By default we assume the first valid time of the forecast
154
+ is the same as its init-time.
149
155
 
150
156
  Returns:
151
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
157
+ pd.DataFrame where each row represents a single time period. The pd.DataFrame
152
158
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
153
159
  """
154
160
  # Sanity checks.
155
- assert len(init_times) > 0
156
- assert init_times.is_monotonic_increasing
157
- assert init_times.is_unique
158
- assert max_staleness >= pd.Timedelta(0)
159
- assert pd.Timedelta(0) <= max_dropout <= max_staleness
161
+ if len(init_times) == 0:
162
+ raise ValueError("No init-times to use")
163
+
164
+ check_time_unique_increasing(init_times)
165
+
166
+ if max_staleness < pd.Timedelta(0):
167
+ raise ValueError("The max staleness must be positive")
168
+ if not (pd.Timedelta(0) <= max_dropout <= max_staleness):
169
+ raise ValueError("The max dropout must be between 0 and the max staleness")
160
170
 
161
- hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
171
+ history_drop_buffer = max(first_forecast_step - interval_start, max_dropout)
162
172
 
163
173
  # Store contiguous periods
164
- contiguous_periods = []
174
+ contiguous_periods: list[list[pd.Timestamp]] = []
165
175
 
166
- # Begin the first period allowing for the time to the first_forecast_step, the length of the
176
+ # Begin the first period allowing for the time to the first_forecast_step, the length of the
167
177
  # interval sampled from before t0, and the dropout
168
- start_this_period = init_times[0] + hist_drop_buffer
178
+ start_this_period = init_times[0] + history_drop_buffer
169
179
 
170
180
  # The first forecast is valid up to the max staleness
171
181
  end_this_period = init_times[0] + max_staleness
172
182
 
173
183
  for dt_init in init_times[1:]:
174
- # If the previous init time becomes stale before the next init becomes valid (whilst also
175
- # considering dropout) then the contiguous period breaks
176
- # Else if the previous init time becomes stale before the fist step of the next forecast
184
+ # If the previous init-time becomes stale before the next init-time becomes valid (whilst
185
+ # also considering dropout) then the contiguous period breaks
186
+ # Else if the previous init-time becomes stale before the fist step of the next forecast
177
187
  # then this also causes a break in the contiguous period
178
- if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
188
+ if end_this_period < dt_init + max(max_dropout, first_forecast_step):
179
189
  contiguous_periods.append([start_this_period, end_this_period])
180
190
  # The new period begins with the same conditions as the first period
181
- start_this_period = dt_init + hist_drop_buffer
191
+ start_this_period = dt_init + history_drop_buffer
182
192
  end_this_period = dt_init + max_staleness
183
193
 
184
194
  contiguous_periods.append([start_this_period, end_this_period])
@@ -189,11 +199,13 @@ def find_contiguous_t0_periods_nwp(
189
199
  def intersection_of_multiple_dataframes_of_periods(
190
200
  time_periods: list[pd.DataFrame],
191
201
  ) -> pd.DataFrame:
192
- """Find the intersection of a list of time periods.
202
+ """Find the intersection of list of time periods.
193
203
 
194
- See the docstring of intersection_of_2_dataframes_of_periods() for more details.
204
+ Consecutively updates intersection of time periods.
205
+ See the docstring of intersection_of_2_dataframes_of_periods() for further details.
195
206
  """
196
- assert len(time_periods) > 0
207
+ if len(time_periods) == 0:
208
+ raise ValueError("No time periods to intersect")
197
209
  intersection = time_periods[0]
198
210
  for time_period in time_periods[1:]:
199
211
  intersection = intersection_of_2_dataframes_of_periods(intersection, time_period)
@@ -209,7 +221,8 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
209
221
  A typical use-case is that each pd.DataFrame represents all the time periods where
210
222
  a `DataSource` has contiguous, valid data.
211
223
 
212
- Here's a graphical example of two pd.DataFrames of time periods and their intersection:
224
+ Graphical representation of two pd.DataFrames of time periods and their intersection,
225
+ as follows:
213
226
 
214
227
  ----------------------> TIME ->---------------------
215
228
  a: |-----| |----| |----------| |-----------|
@@ -217,9 +230,9 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
217
230
  intersection: |--| |-| |--| |---|
218
231
 
219
232
  Args:
220
- a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
233
+ a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
221
234
  two columns: start_dt and end_dt.
222
- b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
235
+ b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
223
236
  two columns: start_dt and end_dt.
224
237
 
225
238
  Returns:
@@ -238,7 +251,7 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
238
251
  # and `a` must always end after `b` starts:
239
252
 
240
253
  # TODO: <= and >= because we should allow overlap time periods of length 1. e.g.
241
- # a: |----| or |---|
254
+ # a: |----| or |---|
242
255
  # b: |--| |---|
243
256
  # These aren't allowed if we use < and >.
244
257