ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +73 -61
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
ocf_data_sampler/sample/base.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Base class for handling flat/nested data structures with NWP consideration."""
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from typing import TypeAlias
|
|
7
3
|
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TypeAlias
|
|
8
5
|
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
9
8
|
|
|
10
9
|
NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
11
10
|
NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
@@ -13,39 +12,38 @@ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class SampleBase(ABC):
|
|
16
|
-
"""
|
|
17
|
-
Abstract base class for all sample types
|
|
18
|
-
Provides core data storage functionality
|
|
19
|
-
"""
|
|
15
|
+
"""Abstract base class for all sample types."""
|
|
20
16
|
|
|
21
17
|
@abstractmethod
|
|
22
18
|
def to_numpy(self) -> NumpySample:
|
|
23
|
-
"""Convert sample data to numpy format"""
|
|
19
|
+
"""Convert sample data to numpy format."""
|
|
24
20
|
raise NotImplementedError
|
|
25
21
|
|
|
26
22
|
@abstractmethod
|
|
27
23
|
def plot(self) -> None:
|
|
24
|
+
"""Create a visualisation of the data."""
|
|
28
25
|
raise NotImplementedError
|
|
29
26
|
|
|
30
27
|
@abstractmethod
|
|
31
28
|
def save(self, path: str) -> None:
|
|
29
|
+
"""Saves the sample to disk in the implementations' required format."""
|
|
32
30
|
raise NotImplementedError
|
|
33
31
|
|
|
34
32
|
@classmethod
|
|
35
33
|
@abstractmethod
|
|
36
|
-
def load(cls, path: str) ->
|
|
34
|
+
def load(cls, path: str) -> "SampleBase":
|
|
35
|
+
"""Load a sample from disk from the implementations' format."""
|
|
37
36
|
raise NotImplementedError
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
41
|
-
"""
|
|
42
|
-
|
|
40
|
+
"""Recursively converts numpy arrays in nested dict to torch tensors.
|
|
41
|
+
|
|
43
42
|
Args:
|
|
44
43
|
batch: NumpyBatch with data in numpy arrays
|
|
45
44
|
Returns:
|
|
46
45
|
TensorBatch with data in torch tensors
|
|
47
46
|
"""
|
|
48
|
-
|
|
49
47
|
for k, v in batch.items():
|
|
50
48
|
if isinstance(v, dict):
|
|
51
49
|
batch[k] = batch_to_tensor(v)
|
|
@@ -58,12 +56,12 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
|
58
56
|
|
|
59
57
|
|
|
60
58
|
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
|
|
61
|
-
"""Recursively copies tensors in nested dict to specified device
|
|
59
|
+
"""Recursively copies tensors in nested dict to specified device.
|
|
62
60
|
|
|
63
61
|
Args:
|
|
64
62
|
batch: Nested dict with tensors to move
|
|
65
63
|
device: Device to move tensors to
|
|
66
|
-
|
|
64
|
+
|
|
67
65
|
Returns:
|
|
68
66
|
A dict with tensors moved to the new device
|
|
69
67
|
"""
|
|
@@ -71,7 +69,7 @@ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatc
|
|
|
71
69
|
|
|
72
70
|
for k, v in batch.items():
|
|
73
71
|
if isinstance(v, dict):
|
|
74
|
-
batch_copy[k] = copy_batch_to_device(v, device)
|
|
72
|
+
batch_copy[k] = copy_batch_to_device(v, device)
|
|
75
73
|
elif isinstance(v, torch.Tensor):
|
|
76
74
|
batch_copy[k] = v.to(device)
|
|
77
75
|
else:
|
ocf_data_sampler/sample/site.py
CHANGED
|
@@ -1,44 +1,37 @@
|
|
|
1
|
-
"""PVNet Site sample implementation for netCDF data handling and conversion"""
|
|
1
|
+
"""PVNet Site sample implementation for netCDF data handling and conversion."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
|
-
|
|
5
4
|
from typing_extensions import override
|
|
6
5
|
|
|
7
|
-
from ocf_data_sampler.sample.base import
|
|
6
|
+
from ocf_data_sampler.sample.base import NumpySample, SampleBase
|
|
8
7
|
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
class SiteSample(SampleBase):
|
|
12
|
-
"""Handles PVNet site specific netCDF operations"""
|
|
11
|
+
"""Handles PVNet site specific netCDF operations."""
|
|
13
12
|
|
|
14
|
-
def __init__(self, data: xr.Dataset):
|
|
15
|
-
|
|
13
|
+
def __init__(self, data: xr.Dataset) -> None:
|
|
14
|
+
"""Initializes the SiteSample object with the given xarray Dataset."""
|
|
16
15
|
if not isinstance(data, xr.Dataset):
|
|
17
16
|
raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
|
|
18
|
-
|
|
19
17
|
self._data = data
|
|
20
18
|
|
|
21
19
|
@override
|
|
22
|
-
def to_numpy(self) -> NumpySample:
|
|
20
|
+
def to_numpy(self) -> NumpySample:
|
|
23
21
|
return convert_netcdf_to_numpy_sample(self._data)
|
|
24
22
|
|
|
23
|
+
@override
|
|
25
24
|
def save(self, path: str) -> None:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
path: Path to save the netCDF file
|
|
30
|
-
"""
|
|
25
|
+
# Saves as NetCDF
|
|
31
26
|
self._data.to_netcdf(path, mode="w", engine="h5netcdf")
|
|
32
27
|
|
|
33
28
|
@classmethod
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
path: Path to load the netCDF file from
|
|
39
|
-
"""
|
|
29
|
+
@override
|
|
30
|
+
def load(cls, path: str) -> "SiteSample":
|
|
31
|
+
# Loads from NetCDF
|
|
40
32
|
return cls(xr.open_dataset(path))
|
|
41
33
|
|
|
42
|
-
|
|
34
|
+
@override
|
|
43
35
|
def plot(self) -> None:
|
|
36
|
+
# TODO - placeholder for now
|
|
44
37
|
raise NotImplementedError("Plotting not yet implemented for SiteSample")
|
|
@@ -1,75 +1,69 @@
|
|
|
1
|
-
"""PVNet UK Regional sample implementation for dataset handling and visualisation"""
|
|
2
|
-
|
|
3
|
-
from typing_extensions import override
|
|
1
|
+
"""PVNet UK Regional sample implementation for dataset handling and visualisation."""
|
|
4
2
|
|
|
5
3
|
import torch
|
|
4
|
+
from typing_extensions import override
|
|
6
5
|
|
|
7
|
-
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
8
6
|
from ocf_data_sampler.numpy_sample import (
|
|
9
|
-
NWPSampleKey,
|
|
10
7
|
GSPSampleKey,
|
|
11
|
-
|
|
8
|
+
NWPSampleKey,
|
|
9
|
+
SatelliteSampleKey,
|
|
12
10
|
)
|
|
11
|
+
from ocf_data_sampler.sample.base import NumpySample, SampleBase
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class UKRegionalSample(SampleBase):
|
|
16
|
-
"""Handles UK Regional PVNet data operations"""
|
|
15
|
+
"""Handles UK Regional PVNet data operations."""
|
|
17
16
|
|
|
18
|
-
def __init__(self, data: NumpySample):
|
|
17
|
+
def __init__(self, data: NumpySample) -> None:
|
|
18
|
+
"""Initialises UK Regional sample with data."""
|
|
19
19
|
self._data = data
|
|
20
20
|
|
|
21
21
|
@override
|
|
22
22
|
def to_numpy(self) -> NumpySample:
|
|
23
23
|
return self._data
|
|
24
24
|
|
|
25
|
+
@override
|
|
25
26
|
def save(self, path: str) -> None:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
path: Path to save the sample data to
|
|
30
|
-
"""
|
|
27
|
+
# Saves to pickle format
|
|
31
28
|
torch.save(self._data, path)
|
|
32
29
|
|
|
33
30
|
@classmethod
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
path: Path to load the sample data from
|
|
39
|
-
"""
|
|
31
|
+
@override
|
|
32
|
+
def load(cls, path: str) -> "UKRegionalSample":
|
|
33
|
+
# Loads from .pt format
|
|
40
34
|
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
41
35
|
return cls(torch.load(path, weights_only=False))
|
|
42
36
|
|
|
37
|
+
@override
|
|
43
38
|
def plot(self) -> None:
|
|
44
|
-
"""Creates visualisations for NWP, GSP, solar position, and satellite data"""
|
|
45
39
|
from matplotlib import pyplot as plt
|
|
46
40
|
|
|
47
41
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
48
|
-
|
|
42
|
+
|
|
49
43
|
if NWPSampleKey.nwp in self._data:
|
|
50
|
-
first_nwp =
|
|
51
|
-
if
|
|
52
|
-
axes[0, 1].imshow(first_nwp[
|
|
53
|
-
title =
|
|
44
|
+
first_nwp = next(iter(self._data[NWPSampleKey.nwp].values()))
|
|
45
|
+
if "nwp" in first_nwp:
|
|
46
|
+
axes[0, 1].imshow(first_nwp["nwp"][0])
|
|
47
|
+
title = "NWP (First Channel)"
|
|
54
48
|
if NWPSampleKey.channel_names in first_nwp:
|
|
55
49
|
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
56
50
|
if channel_names:
|
|
57
|
-
title = f
|
|
51
|
+
title = f"NWP: {channel_names[0]}"
|
|
58
52
|
axes[0, 1].set_title(title)
|
|
59
53
|
|
|
60
54
|
if GSPSampleKey.gsp in self._data:
|
|
61
55
|
axes[0, 0].plot(self._data[GSPSampleKey.gsp])
|
|
62
|
-
axes[0, 0].set_title(
|
|
63
|
-
|
|
64
|
-
if
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
56
|
+
axes[0, 0].set_title("GSP Generation")
|
|
57
|
+
|
|
58
|
+
if "solar_azimuth" in self._data and "solar_elevation" in self._data:
|
|
59
|
+
axes[1, 1].plot(self._data["solar_azimuth"], label="Azimuth")
|
|
60
|
+
axes[1, 1].plot(self._data["solar_elevation"], label="Elevation")
|
|
61
|
+
axes[1, 1].set_title("Solar Position")
|
|
62
|
+
axes[1, 1].legend()
|
|
69
63
|
|
|
70
64
|
if SatelliteSampleKey.satellite_actual in self._data:
|
|
71
65
|
axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
|
|
72
|
-
axes[1, 0].set_title(
|
|
73
|
-
|
|
66
|
+
axes[1, 0].set_title("Satellite Data")
|
|
67
|
+
|
|
74
68
|
plt.tight_layout()
|
|
75
69
|
plt.show()
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
"""Functions for simulating dropout in time series data
|
|
1
|
+
"""Functions for simulating dropout in time series data.
|
|
2
2
|
|
|
3
3
|
This is used for the following types of data: GSP, Satellite and Site
|
|
4
4
|
This is not used for NWP
|
|
5
5
|
"""
|
|
6
|
+
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pandas as pd
|
|
8
9
|
import xarray as xr
|
|
@@ -13,22 +14,23 @@ def draw_dropout_time(
|
|
|
13
14
|
dropout_timedeltas: list[pd.Timedelta],
|
|
14
15
|
dropout_frac: float,
|
|
15
16
|
) -> pd.Timestamp:
|
|
16
|
-
"""Randomly pick a dropout time from a list of timedeltas
|
|
17
|
-
|
|
17
|
+
"""Randomly pick a dropout time from a list of timedeltas.
|
|
18
|
+
|
|
18
19
|
Args:
|
|
19
20
|
t0: The forecast init-time
|
|
20
21
|
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
21
|
-
dropout_frac: Probability that dropout will be applied.
|
|
22
|
-
inclusive
|
|
22
|
+
dropout_frac: Probability that dropout will be applied.
|
|
23
|
+
This should be between 0 and 1 inclusive
|
|
23
24
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
|
|
25
|
+
if dropout_frac > 0 and len(dropout_timedeltas) == 0:
|
|
26
|
+
raise ValueError("To apply dropout, dropout_timedeltas must be provided")
|
|
27
27
|
|
|
28
28
|
for t in dropout_timedeltas:
|
|
29
|
-
|
|
29
|
+
if t > pd.Timedelta("0min"):
|
|
30
|
+
raise ValueError("Dropout timedeltas must be negative")
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
if not (0 <= dropout_frac <= 1):
|
|
33
|
+
raise ValueError("dropout_frac must be between 0 and 1 inclusive")
|
|
32
34
|
|
|
33
35
|
if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
|
|
34
36
|
dropout_time = t0
|
|
@@ -41,11 +43,11 @@ def draw_dropout_time(
|
|
|
41
43
|
def apply_dropout_time(
|
|
42
44
|
ds: xr.DataArray,
|
|
43
45
|
dropout_time: pd.Timestamp,
|
|
44
|
-
|
|
45
|
-
"""Apply dropout time to the data
|
|
46
|
-
|
|
46
|
+
) -> xr.DataArray:
|
|
47
|
+
"""Apply dropout time to the data.
|
|
48
|
+
|
|
47
49
|
Args:
|
|
48
|
-
ds: Xarray DataArray with 'time_utc'
|
|
50
|
+
ds: Xarray DataArray with 'time_utc' coordinate
|
|
49
51
|
dropout_time: Time after which data is set to NaN
|
|
50
52
|
"""
|
|
51
53
|
# This replaces the times after the dropout with NaNs
|
|
@@ -1,13 +1,23 @@
|
|
|
1
|
-
"""Fill time periods between start and end dates
|
|
1
|
+
"""Fill time periods between specified start and end dates."""
|
|
2
2
|
|
|
3
|
-
import pandas as pd
|
|
4
3
|
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
|
-
"""
|
|
9
|
-
|
|
8
|
+
"""Create range of timestamps between given start and end times.
|
|
9
|
+
|
|
10
|
+
Each of the continuous periods (i.e. each row of the input DataFrame) is filled with the
|
|
11
|
+
specified frequency.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
time_periods: DataFrame with columns 'start_dt' and 'end_dt'
|
|
15
|
+
freq: Frequency to fill time periods with
|
|
16
|
+
"""
|
|
10
17
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
11
18
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
12
|
-
date_ranges = [
|
|
19
|
+
date_ranges = [
|
|
20
|
+
pd.date_range(start_dt, end_dt, freq=freq)
|
|
21
|
+
for start_dt, end_dt in zip(start_dts, end_dts, strict=False)
|
|
22
|
+
]
|
|
13
23
|
return pd.DatetimeIndex(np.concatenate(date_ranges))
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
"""Get contiguous time periods
|
|
1
|
+
"""Get contiguous time periods."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
+
|
|
5
6
|
from ocf_data_sampler.load.utils import check_time_unique_increasing
|
|
6
7
|
|
|
8
|
+
ZERO_TDELTA = pd.Timedelta(0)
|
|
9
|
+
|
|
7
10
|
|
|
8
11
|
def find_contiguous_time_periods(
|
|
9
12
|
datetimes: pd.DatetimeIndex,
|
|
@@ -14,20 +17,20 @@ def find_contiguous_time_periods(
|
|
|
14
17
|
|
|
15
18
|
Args:
|
|
16
19
|
datetimes: pd.DatetimeIndex. Must be sorted.
|
|
17
|
-
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
|
|
18
|
-
would be set to the `total_seq_length` of each machine learning example.
|
|
20
|
+
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
|
|
19
21
|
max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
|
|
20
22
|
apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
|
|
21
|
-
sequences.
|
|
22
|
-
the timeseries.
|
|
23
|
+
sequences.
|
|
23
24
|
|
|
24
25
|
Returns:
|
|
25
|
-
pd.DataFrame where each row represents a single time period.
|
|
26
|
-
|
|
26
|
+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
27
|
+
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
27
28
|
"""
|
|
28
29
|
# Sanity checks.
|
|
29
|
-
|
|
30
|
-
|
|
30
|
+
if len(datetimes) == 0:
|
|
31
|
+
raise ValueError("No datetimes to use")
|
|
32
|
+
if min_seq_length <= 1:
|
|
33
|
+
raise ValueError(f"{min_seq_length=} must be greater than 1")
|
|
31
34
|
check_time_unique_increasing(datetimes)
|
|
32
35
|
|
|
33
36
|
# Find indices of gaps larger than max_gap:
|
|
@@ -43,77 +46,75 @@ def find_contiguous_time_periods(
|
|
|
43
46
|
# Capture the last segment of dt_index.
|
|
44
47
|
segment_boundaries = np.concatenate((segment_boundaries, [len(datetimes)]))
|
|
45
48
|
|
|
46
|
-
periods: list[
|
|
49
|
+
periods: list[list[pd.Timestamp]] = []
|
|
47
50
|
start_i = 0
|
|
48
51
|
for next_start_i in segment_boundaries:
|
|
49
52
|
n_timesteps = next_start_i - start_i
|
|
50
53
|
if n_timesteps > min_seq_length:
|
|
51
54
|
end_i = next_start_i - 1
|
|
52
|
-
|
|
53
|
-
periods.append(period)
|
|
55
|
+
periods.append([datetimes[start_i], datetimes[end_i]])
|
|
54
56
|
start_i = next_start_i
|
|
55
57
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
if len(periods) == 0:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Did not find any periods from {datetimes}. {min_seq_length=} {max_gap_duration=}",
|
|
61
|
+
)
|
|
59
62
|
|
|
60
|
-
return pd.DataFrame(periods)
|
|
63
|
+
return pd.DataFrame(periods, columns=["start_dt", "end_dt"])
|
|
61
64
|
|
|
62
65
|
|
|
63
66
|
def trim_contiguous_time_periods(
|
|
64
|
-
contiguous_time_periods: pd.DataFrame,
|
|
67
|
+
contiguous_time_periods: pd.DataFrame,
|
|
65
68
|
interval_start: pd.Timedelta,
|
|
66
69
|
interval_end: pd.Timedelta,
|
|
67
70
|
) -> pd.DataFrame:
|
|
68
|
-
"""
|
|
71
|
+
"""Trims contiguous time periods to account for history requirements and forecast horizons.
|
|
69
72
|
|
|
70
73
|
Args:
|
|
71
|
-
contiguous_time_periods: DataFrame where each row represents a single time period.
|
|
72
|
-
DataFrame must have `start_dt` and `end_dt` columns.
|
|
74
|
+
contiguous_time_periods: pd.DataFrame where each row represents a single time period.
|
|
75
|
+
The pd.DataFrame must have `start_dt` and `end_dt` columns.
|
|
73
76
|
interval_start: The start of the interval with respect to t0
|
|
74
77
|
interval_end: The end of the interval with respect to t0
|
|
75
78
|
|
|
76
|
-
|
|
77
79
|
Returns:
|
|
78
|
-
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
|
|
80
|
+
The contiguous_time_periods pd.DataFrame with the `start_dt` and `end_dt` columns updated.
|
|
79
81
|
"""
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
# Make a copy so the data is not edited in place.
|
|
83
|
+
trimmed_time_periods = contiguous_time_periods.copy()
|
|
84
|
+
trimmed_time_periods["start_dt"] -= interval_start
|
|
85
|
+
trimmed_time_periods["end_dt"] -= interval_end
|
|
84
86
|
|
|
85
|
-
valid_mask =
|
|
86
|
-
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
|
|
87
|
-
|
|
88
|
-
return contiguous_time_periods
|
|
87
|
+
valid_mask = trimmed_time_periods["start_dt"] <= trimmed_time_periods["end_dt"]
|
|
89
88
|
|
|
89
|
+
return trimmed_time_periods.loc[valid_mask]
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
def find_contiguous_t0_periods(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
93
|
+
datetimes: pd.DatetimeIndex,
|
|
94
|
+
interval_start: pd.Timedelta,
|
|
95
|
+
interval_end: pd.Timedelta,
|
|
96
|
+
time_resolution: pd.Timedelta,
|
|
97
|
+
) -> pd.DataFrame:
|
|
98
98
|
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
|
|
99
99
|
|
|
100
100
|
Args:
|
|
101
|
-
datetimes: pd.DatetimeIndex
|
|
101
|
+
datetimes: pd.DatetimeIndex
|
|
102
102
|
interval_start: The start of the interval with respect to t0
|
|
103
103
|
interval_end: The end of the interval with respect to t0
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
time_resolution: The sample frequency of the timeseries
|
|
106
105
|
|
|
107
106
|
Returns:
|
|
108
107
|
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
109
108
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
110
109
|
"""
|
|
110
|
+
check_time_unique_increasing(datetimes)
|
|
111
|
+
|
|
111
112
|
total_duration = interval_end - interval_start
|
|
112
|
-
|
|
113
|
+
|
|
113
114
|
contiguous_time_periods = find_contiguous_time_periods(
|
|
114
115
|
datetimes=datetimes,
|
|
115
|
-
min_seq_length=int(total_duration /
|
|
116
|
-
max_gap_duration=
|
|
116
|
+
min_seq_length=int(total_duration / time_resolution) + 1,
|
|
117
|
+
max_gap_duration=time_resolution,
|
|
117
118
|
)
|
|
118
119
|
|
|
119
120
|
contiguous_t0_periods = trim_contiguous_time_periods(
|
|
@@ -122,7 +123,11 @@ def find_contiguous_t0_periods(
|
|
|
122
123
|
interval_end=interval_end,
|
|
123
124
|
)
|
|
124
125
|
|
|
125
|
-
|
|
126
|
+
if len(contiguous_t0_periods) == 0:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"No contiguous time periods found for {datetimes}. "
|
|
129
|
+
f"{interval_start=} {interval_end=} {time_resolution=}",
|
|
130
|
+
)
|
|
126
131
|
|
|
127
132
|
return contiguous_t0_periods
|
|
128
133
|
|
|
@@ -131,54 +136,59 @@ def find_contiguous_t0_periods_nwp(
|
|
|
131
136
|
init_times: pd.DatetimeIndex,
|
|
132
137
|
interval_start: pd.Timedelta,
|
|
133
138
|
max_staleness: pd.Timedelta,
|
|
134
|
-
max_dropout: pd.Timedelta =
|
|
135
|
-
first_forecast_step: pd.Timedelta =
|
|
136
|
-
|
|
139
|
+
max_dropout: pd.Timedelta = ZERO_TDELTA,
|
|
140
|
+
first_forecast_step: pd.Timedelta = ZERO_TDELTA,
|
|
137
141
|
) -> pd.DataFrame:
|
|
138
|
-
"""Get all time periods from the NWP init
|
|
142
|
+
"""Get all time periods from the NWP init-times which are valid as t0 datetimes.
|
|
139
143
|
|
|
140
144
|
Args:
|
|
141
145
|
init_times: The initialisation times of the available forecasts
|
|
142
|
-
interval_start: The start of the
|
|
143
|
-
max_staleness: Up to how long after an init
|
|
144
|
-
init
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
interval_start: The start of the time interval with respect to t0
|
|
147
|
+
max_staleness: Up to how long after an init-time are we willing to use the forecast.
|
|
148
|
+
Each init-time will only be used up to this t0 time regardless of the forecast valid
|
|
149
|
+
time.
|
|
150
|
+
max_dropout: What is the maximum amount of dropout that will be used.
|
|
151
|
+
This must be <= max_staleness.
|
|
152
|
+
first_forecast_step: The timedelta of the first step of the forecast.
|
|
153
|
+
By default we assume the first valid time of the forecast
|
|
154
|
+
is the same as its init-time.
|
|
149
155
|
|
|
150
156
|
Returns:
|
|
151
|
-
pd.DataFrame where each row represents a single time period.
|
|
157
|
+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
152
158
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
153
159
|
"""
|
|
154
160
|
# Sanity checks.
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
161
|
+
if len(init_times) == 0:
|
|
162
|
+
raise ValueError("No init-times to use")
|
|
163
|
+
|
|
164
|
+
check_time_unique_increasing(init_times)
|
|
165
|
+
|
|
166
|
+
if max_staleness < pd.Timedelta(0):
|
|
167
|
+
raise ValueError("The max staleness must be positive")
|
|
168
|
+
if not (pd.Timedelta(0) <= max_dropout <= max_staleness):
|
|
169
|
+
raise ValueError("The max dropout must be between 0 and the max staleness")
|
|
160
170
|
|
|
161
|
-
|
|
171
|
+
history_drop_buffer = max(first_forecast_step - interval_start, max_dropout)
|
|
162
172
|
|
|
163
173
|
# Store contiguous periods
|
|
164
|
-
contiguous_periods = []
|
|
174
|
+
contiguous_periods: list[list[pd.Timestamp]] = []
|
|
165
175
|
|
|
166
|
-
# Begin the first period allowing for the time to the first_forecast_step, the length of the
|
|
176
|
+
# Begin the first period allowing for the time to the first_forecast_step, the length of the
|
|
167
177
|
# interval sampled from before t0, and the dropout
|
|
168
|
-
start_this_period = init_times[0] +
|
|
178
|
+
start_this_period = init_times[0] + history_drop_buffer
|
|
169
179
|
|
|
170
180
|
# The first forecast is valid up to the max staleness
|
|
171
181
|
end_this_period = init_times[0] + max_staleness
|
|
172
182
|
|
|
173
183
|
for dt_init in init_times[1:]:
|
|
174
|
-
# If the previous init
|
|
175
|
-
# considering dropout) then the contiguous period breaks
|
|
176
|
-
# Else if the previous init
|
|
184
|
+
# If the previous init-time becomes stale before the next init-time becomes valid (whilst
|
|
185
|
+
# also considering dropout) then the contiguous period breaks
|
|
186
|
+
# Else if the previous init-time becomes stale before the fist step of the next forecast
|
|
177
187
|
# then this also causes a break in the contiguous period
|
|
178
|
-
if
|
|
188
|
+
if end_this_period < dt_init + max(max_dropout, first_forecast_step):
|
|
179
189
|
contiguous_periods.append([start_this_period, end_this_period])
|
|
180
190
|
# The new period begins with the same conditions as the first period
|
|
181
|
-
start_this_period = dt_init +
|
|
191
|
+
start_this_period = dt_init + history_drop_buffer
|
|
182
192
|
end_this_period = dt_init + max_staleness
|
|
183
193
|
|
|
184
194
|
contiguous_periods.append([start_this_period, end_this_period])
|
|
@@ -189,11 +199,13 @@ def find_contiguous_t0_periods_nwp(
|
|
|
189
199
|
def intersection_of_multiple_dataframes_of_periods(
|
|
190
200
|
time_periods: list[pd.DataFrame],
|
|
191
201
|
) -> pd.DataFrame:
|
|
192
|
-
"""Find the intersection of
|
|
202
|
+
"""Find the intersection of list of time periods.
|
|
193
203
|
|
|
194
|
-
|
|
204
|
+
Consecutively updates intersection of time periods.
|
|
205
|
+
See the docstring of intersection_of_2_dataframes_of_periods() for further details.
|
|
195
206
|
"""
|
|
196
|
-
|
|
207
|
+
if len(time_periods) == 0:
|
|
208
|
+
raise ValueError("No time periods to intersect")
|
|
197
209
|
intersection = time_periods[0]
|
|
198
210
|
for time_period in time_periods[1:]:
|
|
199
211
|
intersection = intersection_of_2_dataframes_of_periods(intersection, time_period)
|
|
@@ -209,7 +221,8 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
209
221
|
A typical use-case is that each pd.DataFrame represents all the time periods where
|
|
210
222
|
a `DataSource` has contiguous, valid data.
|
|
211
223
|
|
|
212
|
-
|
|
224
|
+
Graphical representation of two pd.DataFrames of time periods and their intersection,
|
|
225
|
+
as follows:
|
|
213
226
|
|
|
214
227
|
----------------------> TIME ->---------------------
|
|
215
228
|
a: |-----| |----| |----------| |-----------|
|
|
@@ -217,9 +230,9 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
217
230
|
intersection: |--| |-| |--| |---|
|
|
218
231
|
|
|
219
232
|
Args:
|
|
220
|
-
a: pd.DataFrame where each row represents a time period.
|
|
233
|
+
a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
|
|
221
234
|
two columns: start_dt and end_dt.
|
|
222
|
-
b: pd.DataFrame where each row represents a time period.
|
|
235
|
+
b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
|
|
223
236
|
two columns: start_dt and end_dt.
|
|
224
237
|
|
|
225
238
|
Returns:
|
|
@@ -238,7 +251,7 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
238
251
|
# and `a` must always end after `b` starts:
|
|
239
252
|
|
|
240
253
|
# TODO: <= and >= because we should allow overlap time periods of length 1. e.g.
|
|
241
|
-
# a: |----| or |---|
|
|
254
|
+
# a: |----| or |---|
|
|
242
255
|
# b: |--| |---|
|
|
243
256
|
# These aren't allowed if we use < and >.
|
|
244
257
|
|