ocf-data-sampler 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +46 -46
- ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
- ocf_data_sampler/select/select_time_slice.py +24 -33
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/process_and_combine.py +12 -13
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +1 -1
- ocf_data_sampler/torch_datasets/site.py +10 -10
- ocf_data_sampler/torch_datasets/valid_time_periods.py +19 -11
- ocf_data_sampler/{time_functions.py → utils.py} +1 -2
- {ocf_data_sampler-0.0.26.dist-info → ocf_data_sampler-0.0.28.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.26.dist-info → ocf_data_sampler-0.0.28.dist-info}/RECORD +18 -18
- {ocf_data_sampler-0.0.26.dist-info → ocf_data_sampler-0.0.28.dist-info}/WHEEL +1 -1
- tests/config/test_config.py +14 -8
- tests/conftest.py +7 -5
- tests/select/test_find_contiguous_time_periods.py +8 -8
- tests/select/test_select_time_slice.py +31 -43
- {ocf_data_sampler-0.0.26.dist-info → ocf_data_sampler-0.0.28.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.26.dist-info → ocf_data_sampler-0.0.28.dist-info}/top_level.txt +0 -0
ocf_data_sampler/config/model.py
CHANGED
|
@@ -14,7 +14,8 @@ import logging
|
|
|
14
14
|
from typing import Dict, List, Optional
|
|
15
15
|
from typing_extensions import Self
|
|
16
16
|
|
|
17
|
-
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
|
17
|
+
from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
|
|
18
|
+
|
|
18
19
|
from ocf_data_sampler.constants import NWP_PROVIDERS
|
|
19
20
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
@@ -40,6 +41,45 @@ class General(Base):
|
|
|
40
41
|
)
|
|
41
42
|
|
|
42
43
|
|
|
44
|
+
class TimeWindowMixin(Base):
|
|
45
|
+
"""Mixin class, to add interval start, end and resolution minutes"""
|
|
46
|
+
|
|
47
|
+
time_resolution_minutes: int = Field(
|
|
48
|
+
...,
|
|
49
|
+
gt=0,
|
|
50
|
+
description="The temporal resolution of the data in minutes",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
interval_start_minutes: int = Field(
|
|
54
|
+
...,
|
|
55
|
+
description="Data interval starts at `t0 + interval_start_minutes`",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
interval_end_minutes: int = Field(
|
|
59
|
+
...,
|
|
60
|
+
description="Data interval ends at `t0 + interval_end_minutes`",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@model_validator(mode='after')
|
|
64
|
+
def check_interval_range(cls, values):
|
|
65
|
+
if values.interval_start_minutes > values.interval_end_minutes:
|
|
66
|
+
raise ValueError('interval_start_minutes must be <= interval_end_minutes')
|
|
67
|
+
return values
|
|
68
|
+
|
|
69
|
+
@field_validator("interval_start_minutes")
|
|
70
|
+
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
71
|
+
if v % info.data["time_resolution_minutes"] != 0:
|
|
72
|
+
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
|
|
73
|
+
return v
|
|
74
|
+
|
|
75
|
+
@field_validator("interval_end_minutes")
|
|
76
|
+
def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
77
|
+
if v % info.data["time_resolution_minutes"] != 0:
|
|
78
|
+
raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
|
|
79
|
+
return v
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
43
83
|
# noinspection PyMethodParameters
|
|
44
84
|
class DropoutMixin(Base):
|
|
45
85
|
"""Mixin class, to add dropout minutes"""
|
|
@@ -76,54 +116,18 @@ class DropoutMixin(Base):
|
|
|
76
116
|
return self
|
|
77
117
|
|
|
78
118
|
|
|
79
|
-
# noinspection PyMethodParameters
|
|
80
|
-
class TimeWindowMixin(Base):
|
|
81
|
-
"""Time resolution mix in"""
|
|
82
|
-
|
|
83
|
-
time_resolution_minutes: int = Field(
|
|
84
|
-
...,
|
|
85
|
-
gt=0,
|
|
86
|
-
description="The temporal resolution of the data in minutes",
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
forecast_minutes: int = Field(
|
|
90
|
-
...,
|
|
91
|
-
ge=0,
|
|
92
|
-
description="how many minutes to forecast in the future",
|
|
93
|
-
)
|
|
94
|
-
history_minutes: int = Field(
|
|
95
|
-
...,
|
|
96
|
-
ge=0,
|
|
97
|
-
description="how many historic minutes to use",
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
@field_validator("forecast_minutes")
|
|
101
|
-
def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
|
|
102
|
-
if v % values.data["time_resolution_minutes"] != 0:
|
|
103
|
-
message = "Forecast duration must be divisible by time resolution"
|
|
104
|
-
logger.error(message)
|
|
105
|
-
raise Exception(message)
|
|
106
|
-
return v
|
|
107
|
-
|
|
108
|
-
@field_validator("history_minutes")
|
|
109
|
-
def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
|
|
110
|
-
if v % values.data["time_resolution_minutes"] != 0:
|
|
111
|
-
message = "History duration must be divisible by time resolution"
|
|
112
|
-
logger.error(message)
|
|
113
|
-
raise Exception(message)
|
|
114
|
-
return v
|
|
115
|
-
|
|
116
|
-
|
|
117
119
|
class SpatialWindowMixin(Base):
|
|
118
120
|
"""Mixin class, to add path and image size"""
|
|
119
121
|
|
|
120
122
|
image_size_pixels_height: int = Field(
|
|
121
123
|
...,
|
|
124
|
+
ge=0,
|
|
122
125
|
description="The number of pixels of the height of the region of interest",
|
|
123
126
|
)
|
|
124
127
|
|
|
125
128
|
image_size_pixels_width: int = Field(
|
|
126
129
|
...,
|
|
130
|
+
ge=0,
|
|
127
131
|
description="The number of pixels of the width of the region of interest",
|
|
128
132
|
)
|
|
129
133
|
|
|
@@ -140,10 +144,6 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
140
144
|
..., description="the satellite channels that are used"
|
|
141
145
|
)
|
|
142
146
|
|
|
143
|
-
live_delay_minutes: int = Field(
|
|
144
|
-
..., description="The expected delay in minutes of the satellite data"
|
|
145
|
-
)
|
|
146
|
-
|
|
147
147
|
|
|
148
148
|
# noinspection PyMethodParameters
|
|
149
149
|
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
@@ -169,6 +169,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
169
169
|
" the maximum forecast horizon of the NWP and the requested forecast length.",
|
|
170
170
|
)
|
|
171
171
|
|
|
172
|
+
|
|
172
173
|
@field_validator("provider")
|
|
173
174
|
def validate_provider(cls, v: str) -> str:
|
|
174
175
|
"""Validate 'provider'"""
|
|
@@ -227,11 +228,10 @@ class Site(TimeWindowMixin, DropoutMixin):
|
|
|
227
228
|
# TODO validate the csv for metadata
|
|
228
229
|
|
|
229
230
|
|
|
231
|
+
|
|
230
232
|
# noinspection PyPep8Naming
|
|
231
233
|
class InputData(Base):
|
|
232
|
-
"""
|
|
233
|
-
Input data model.
|
|
234
|
-
"""
|
|
234
|
+
"""Input data model"""
|
|
235
235
|
|
|
236
236
|
satellite: Optional[Satellite] = None
|
|
237
237
|
nwp: Optional[MultiNWP] = None
|
|
@@ -63,16 +63,16 @@ def find_contiguous_time_periods(
|
|
|
63
63
|
|
|
64
64
|
def trim_contiguous_time_periods(
|
|
65
65
|
contiguous_time_periods: pd.DataFrame,
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
interval_start: pd.Timedelta,
|
|
67
|
+
interval_end: pd.Timedelta,
|
|
68
68
|
) -> pd.DataFrame:
|
|
69
69
|
"""Trim the contiguous time periods to allow for history and forecast durations.
|
|
70
70
|
|
|
71
71
|
Args:
|
|
72
72
|
contiguous_time_periods: DataFrame where each row represents a single time period. The
|
|
73
73
|
DataFrame must have `start_dt` and `end_dt` columns.
|
|
74
|
-
|
|
75
|
-
|
|
74
|
+
interval_start: The start of the interval with respect to t0
|
|
75
|
+
interval_end: The end of the interval with respect to t0
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
Returns:
|
|
@@ -80,8 +80,8 @@ def trim_contiguous_time_periods(
|
|
|
80
80
|
"""
|
|
81
81
|
contiguous_time_periods = contiguous_time_periods.copy()
|
|
82
82
|
|
|
83
|
-
contiguous_time_periods["start_dt"]
|
|
84
|
-
contiguous_time_periods["end_dt"] -=
|
|
83
|
+
contiguous_time_periods["start_dt"] -= interval_start
|
|
84
|
+
contiguous_time_periods["end_dt"] -= interval_end
|
|
85
85
|
|
|
86
86
|
valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
|
|
87
87
|
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
|
|
@@ -92,16 +92,16 @@ def trim_contiguous_time_periods(
|
|
|
92
92
|
|
|
93
93
|
def find_contiguous_t0_periods(
|
|
94
94
|
datetimes: pd.DatetimeIndex,
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
interval_start: pd.Timedelta,
|
|
96
|
+
interval_end: pd.Timedelta,
|
|
97
97
|
sample_period_duration: pd.Timedelta,
|
|
98
98
|
) -> pd.DataFrame:
|
|
99
99
|
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
|
|
100
100
|
|
|
101
101
|
Args:
|
|
102
102
|
datetimes: pd.DatetimeIndex. Must be sorted.
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
interval_start: The start of the interval with respect to t0
|
|
104
|
+
interval_end: The end of the interval with respect to t0
|
|
105
105
|
sample_period_duration: The sample frequency of the timeseries
|
|
106
106
|
|
|
107
107
|
|
|
@@ -109,7 +109,7 @@ def find_contiguous_t0_periods(
|
|
|
109
109
|
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
110
110
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
111
111
|
"""
|
|
112
|
-
total_duration =
|
|
112
|
+
total_duration = interval_end - interval_start
|
|
113
113
|
|
|
114
114
|
contiguous_time_periods = find_contiguous_time_periods(
|
|
115
115
|
datetimes=datetimes,
|
|
@@ -119,8 +119,8 @@ def find_contiguous_t0_periods(
|
|
|
119
119
|
|
|
120
120
|
contiguous_t0_periods = trim_contiguous_time_periods(
|
|
121
121
|
contiguous_time_periods=contiguous_time_periods,
|
|
122
|
-
|
|
123
|
-
|
|
122
|
+
interval_start=interval_start,
|
|
123
|
+
interval_end=interval_end,
|
|
124
124
|
)
|
|
125
125
|
|
|
126
126
|
assert len(contiguous_t0_periods) > 0
|
|
@@ -128,92 +128,57 @@ def find_contiguous_t0_periods(
|
|
|
128
128
|
return contiguous_t0_periods
|
|
129
129
|
|
|
130
130
|
|
|
131
|
-
def _find_contiguous_t0_periods_nwp(
|
|
132
|
-
ds,
|
|
133
|
-
history_duration: pd.Timedelta,
|
|
134
|
-
forecast_duration: pd.Timedelta,
|
|
135
|
-
max_staleness: pd.Timedelta | None = None,
|
|
136
|
-
max_dropout: pd.Timedelta = pd.Timedelta(0),
|
|
137
|
-
time_dim: str = "init_time_utc",
|
|
138
|
-
end_buffer: pd.Timedelta = pd.Timedelta(0),
|
|
139
|
-
):
|
|
140
|
-
|
|
141
|
-
assert "step" in ds.coords
|
|
142
|
-
# It is possible to use up to this amount of max staleness for the dataset and slice
|
|
143
|
-
# required
|
|
144
|
-
possible_max_staleness = (
|
|
145
|
-
pd.Timedelta(ds["step"].max().item())
|
|
146
|
-
- forecast_duration
|
|
147
|
-
- end_buffer
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
# If max_staleness is set to None we set it based on the max step ahead of the input
|
|
151
|
-
# forecast data
|
|
152
|
-
if max_staleness is None:
|
|
153
|
-
max_staleness = possible_max_staleness
|
|
154
|
-
else:
|
|
155
|
-
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
156
|
-
assert max_staleness <= possible_max_staleness
|
|
157
|
-
max_staleness = max_staleness
|
|
158
|
-
|
|
159
|
-
contiguous_time_periods = find_contiguous_t0_periods_nwp(
|
|
160
|
-
datetimes=pd.DatetimeIndex(ds[time_dim]),
|
|
161
|
-
history_duration=history_duration,
|
|
162
|
-
max_staleness=max_staleness,
|
|
163
|
-
max_dropout=max_dropout,
|
|
164
|
-
)
|
|
165
|
-
return contiguous_time_periods
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
131
|
def find_contiguous_t0_periods_nwp(
|
|
170
|
-
|
|
171
|
-
|
|
132
|
+
init_times: pd.DatetimeIndex,
|
|
133
|
+
interval_start: pd.Timedelta,
|
|
172
134
|
max_staleness: pd.Timedelta,
|
|
173
135
|
max_dropout: pd.Timedelta = pd.Timedelta(0),
|
|
136
|
+
first_forecast_step: pd.Timedelta = pd.Timedelta(0),
|
|
137
|
+
|
|
174
138
|
) -> pd.DataFrame:
|
|
175
139
|
"""Get all time periods from the NWP init times which are valid as t0 datetimes.
|
|
176
140
|
|
|
177
141
|
Args:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
max_staleness: Up to how long after an
|
|
181
|
-
|
|
182
|
-
valid time.
|
|
142
|
+
init_times: The initialisation times of the available forecasts
|
|
143
|
+
interval_start: The start of the desired data interval with respect to t0
|
|
144
|
+
max_staleness: Up to how long after an init time are we willing to use the forecast. Each
|
|
145
|
+
init time will only be used up to this t0 time regardless of the forecast valid time.
|
|
183
146
|
max_dropout: What is the maximum amount of dropout that will be used. This must be <=
|
|
184
147
|
max_staleness.
|
|
148
|
+
first_forecast_step: The timedelta of the first step of the forecast. By default we assume
|
|
149
|
+
the first valid time of the forecast is the same as its init time.
|
|
185
150
|
|
|
186
151
|
Returns:
|
|
187
152
|
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
188
153
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
189
154
|
"""
|
|
190
155
|
# Sanity checks.
|
|
191
|
-
assert len(
|
|
192
|
-
assert
|
|
193
|
-
assert
|
|
194
|
-
assert history_duration >= pd.Timedelta(0)
|
|
156
|
+
assert len(init_times) > 0
|
|
157
|
+
assert init_times.is_monotonic_increasing
|
|
158
|
+
assert init_times.is_unique
|
|
195
159
|
assert max_staleness >= pd.Timedelta(0)
|
|
196
|
-
assert max_dropout <= max_staleness
|
|
160
|
+
assert pd.Timedelta(0) <= max_dropout <= max_staleness
|
|
197
161
|
|
|
198
|
-
hist_drop_buffer = max(
|
|
162
|
+
hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
|
|
199
163
|
|
|
200
164
|
# Store contiguous periods
|
|
201
165
|
contiguous_periods = []
|
|
202
166
|
|
|
203
|
-
#
|
|
204
|
-
|
|
167
|
+
# Begin the first period allowing for the time to the first_forecast_step, the length of the
|
|
168
|
+
# interval sampled from before t0, and the dropout
|
|
169
|
+
start_this_period = init_times[0] + hist_drop_buffer
|
|
205
170
|
|
|
206
171
|
# The first forecast is valid up to the max staleness
|
|
207
|
-
end_this_period =
|
|
208
|
-
|
|
209
|
-
for dt_init in
|
|
210
|
-
# If the previous init time becomes stale before the next init becomes valid whilst also
|
|
211
|
-
# considering dropout
|
|
212
|
-
#
|
|
213
|
-
|
|
172
|
+
end_this_period = init_times[0] + max_staleness
|
|
173
|
+
|
|
174
|
+
for dt_init in init_times[1:]:
|
|
175
|
+
# If the previous init time becomes stale before the next init becomes valid (whilst also
|
|
176
|
+
# considering dropout) then the contiguous period breaks
|
|
177
|
+
# Else if the previous init time becomes stale before the fist step of the next forecast
|
|
178
|
+
# then this also causes a break in the contiguous period
|
|
179
|
+
if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
|
|
214
180
|
contiguous_periods.append([start_this_period, end_this_period])
|
|
215
|
-
|
|
216
|
-
# And start a new period
|
|
181
|
+
# The new period begins with the same conditions as the first period
|
|
217
182
|
start_this_period = dt_init + hist_drop_buffer
|
|
218
183
|
end_this_period = dt_init + max_staleness
|
|
219
184
|
|
|
@@ -39,23 +39,14 @@ def _sel_fillinterp(
|
|
|
39
39
|
def select_time_slice(
|
|
40
40
|
ds: xr.DataArray,
|
|
41
41
|
t0: pd.Timestamp,
|
|
42
|
+
interval_start: pd.Timedelta,
|
|
43
|
+
interval_end: pd.Timedelta,
|
|
42
44
|
sample_period_duration: pd.Timedelta,
|
|
43
|
-
history_duration: pd.Timedelta | None = None,
|
|
44
|
-
forecast_duration: pd.Timedelta | None = None,
|
|
45
|
-
interval_start: pd.Timedelta | None = None,
|
|
46
|
-
interval_end: pd.Timedelta | None = None,
|
|
47
45
|
fill_selection: bool = False,
|
|
48
46
|
max_steps_gap: int = 0,
|
|
49
47
|
):
|
|
50
48
|
"""Select a time slice from a Dataset or DataArray."""
|
|
51
|
-
used_duration = history_duration is not None and forecast_duration is not None
|
|
52
|
-
used_intervals = interval_start is not None and interval_end is not None
|
|
53
|
-
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
|
|
54
49
|
assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
|
|
55
|
-
|
|
56
|
-
if used_duration:
|
|
57
|
-
interval_start = - history_duration
|
|
58
|
-
interval_end = forecast_duration
|
|
59
50
|
|
|
60
51
|
if fill_selection and max_steps_gap == 0:
|
|
61
52
|
_sel = _sel_fillnan
|
|
@@ -75,11 +66,11 @@ def select_time_slice(
|
|
|
75
66
|
|
|
76
67
|
|
|
77
68
|
def select_time_slice_nwp(
|
|
78
|
-
|
|
69
|
+
da: xr.DataArray,
|
|
79
70
|
t0: pd.Timestamp,
|
|
71
|
+
interval_start: pd.Timedelta,
|
|
72
|
+
interval_end: pd.Timedelta,
|
|
80
73
|
sample_period_duration: pd.Timedelta,
|
|
81
|
-
history_duration: pd.Timedelta,
|
|
82
|
-
forecast_duration: pd.Timedelta,
|
|
83
74
|
dropout_timedeltas: list[pd.Timedelta] | None = None,
|
|
84
75
|
dropout_frac: float | None = 0,
|
|
85
76
|
accum_channels: list[str] = [],
|
|
@@ -92,31 +83,31 @@ def select_time_slice_nwp(
|
|
|
92
83
|
), "dropout timedeltas must be negative"
|
|
93
84
|
assert len(dropout_timedeltas) >= 1
|
|
94
85
|
assert 0 <= dropout_frac <= 1
|
|
95
|
-
|
|
86
|
+
consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
|
|
96
87
|
|
|
97
88
|
|
|
98
89
|
# The accumatation and non-accumulation channels
|
|
99
90
|
accum_channels = np.intersect1d(
|
|
100
|
-
|
|
91
|
+
da[channel_dim_name].values, accum_channels
|
|
101
92
|
)
|
|
102
93
|
non_accum_channels = np.setdiff1d(
|
|
103
|
-
|
|
94
|
+
da[channel_dim_name].values, accum_channels
|
|
104
95
|
)
|
|
105
96
|
|
|
106
|
-
start_dt = (t0
|
|
107
|
-
end_dt = (t0 +
|
|
97
|
+
start_dt = (t0 + interval_start).ceil(sample_period_duration)
|
|
98
|
+
end_dt = (t0 + interval_end).ceil(sample_period_duration)
|
|
108
99
|
|
|
109
100
|
target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
|
|
110
101
|
|
|
111
102
|
# Maybe apply NWP dropout
|
|
112
|
-
if
|
|
103
|
+
if consider_dropout and (np.random.uniform() < dropout_frac):
|
|
113
104
|
dt = np.random.choice(dropout_timedeltas)
|
|
114
105
|
t0_available = t0 + dt
|
|
115
106
|
else:
|
|
116
107
|
t0_available = t0
|
|
117
108
|
|
|
118
109
|
# Forecasts made up to and including t0
|
|
119
|
-
available_init_times =
|
|
110
|
+
available_init_times = da.init_time_utc.sel(
|
|
120
111
|
init_time_utc=slice(None, t0_available)
|
|
121
112
|
)
|
|
122
113
|
|
|
@@ -139,7 +130,7 @@ def select_time_slice_nwp(
|
|
|
139
130
|
step_indexer = xr.DataArray(steps, coords=coords)
|
|
140
131
|
|
|
141
132
|
if len(accum_channels) == 0:
|
|
142
|
-
|
|
133
|
+
da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
143
134
|
|
|
144
135
|
else:
|
|
145
136
|
# First minimise the size of the dataset we are diffing
|
|
@@ -149,7 +140,7 @@ def select_time_slice_nwp(
|
|
|
149
140
|
min_step = min(steps)
|
|
150
141
|
max_step = max(steps) + sample_period_duration
|
|
151
142
|
|
|
152
|
-
|
|
143
|
+
da_min = da.sel(
|
|
153
144
|
{
|
|
154
145
|
"init_time_utc": unique_init_times,
|
|
155
146
|
"step": slice(min_step, max_step),
|
|
@@ -157,28 +148,28 @@ def select_time_slice_nwp(
|
|
|
157
148
|
)
|
|
158
149
|
|
|
159
150
|
# Slice out the data which does not need to be diffed
|
|
160
|
-
|
|
161
|
-
|
|
151
|
+
da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
|
|
152
|
+
da_sel_non_accum = da_non_accum.sel(
|
|
162
153
|
step=step_indexer, init_time_utc=init_time_indexer
|
|
163
154
|
)
|
|
164
155
|
|
|
165
156
|
# Slice out the channels which need to be diffed
|
|
166
|
-
|
|
157
|
+
da_accum = da_min.sel({channel_dim_name: accum_channels})
|
|
167
158
|
|
|
168
159
|
# Take the diff and slice requested data
|
|
169
|
-
|
|
170
|
-
|
|
160
|
+
da_accum = da_accum.diff(dim="step", label="lower")
|
|
161
|
+
da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
171
162
|
|
|
172
163
|
# Join diffed and non-diffed variables
|
|
173
|
-
|
|
164
|
+
da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
|
|
174
165
|
|
|
175
166
|
# Reorder the variable back to the original order
|
|
176
|
-
|
|
167
|
+
da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
|
|
177
168
|
|
|
178
169
|
# Rename the diffed channels
|
|
179
|
-
|
|
170
|
+
da_sel[channel_dim_name] = [
|
|
180
171
|
f"diff_{v}" if v in accum_channels else v
|
|
181
|
-
for v in
|
|
172
|
+
for v in da_sel[channel_dim_name].values
|
|
182
173
|
]
|
|
183
174
|
|
|
184
|
-
return
|
|
175
|
+
return da_sel
|
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
from ocf_data_sampler.config import Configuration
|
|
5
5
|
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
6
6
|
from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
|
|
7
|
-
from ocf_data_sampler.
|
|
7
|
+
from ocf_data_sampler.utils import minutes
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def slice_datasets_by_time(
|
|
@@ -23,19 +23,19 @@ def slice_datasets_by_time(
|
|
|
23
23
|
sliced_datasets_dict = {}
|
|
24
24
|
|
|
25
25
|
if "nwp" in datasets_dict:
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
sliced_datasets_dict["nwp"] = {}
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
nwp_config = config.input_data.nwp[nwp_key]
|
|
32
32
|
|
|
33
33
|
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
34
34
|
da_nwp,
|
|
35
35
|
t0,
|
|
36
36
|
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
37
|
-
|
|
38
|
-
|
|
37
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
38
|
+
interval_end=minutes(nwp_config.interval_end_minutes),
|
|
39
39
|
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
40
40
|
dropout_frac=nwp_config.dropout_fraction,
|
|
41
41
|
accum_channels=nwp_config.accum_channels,
|
|
@@ -49,8 +49,8 @@ def slice_datasets_by_time(
|
|
|
49
49
|
datasets_dict["sat"],
|
|
50
50
|
t0,
|
|
51
51
|
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
52
|
-
interval_start=minutes(
|
|
53
|
-
interval_end=minutes(
|
|
52
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
53
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
54
54
|
max_steps_gap=2,
|
|
55
55
|
)
|
|
56
56
|
|
|
@@ -74,15 +74,15 @@ def slice_datasets_by_time(
|
|
|
74
74
|
datasets_dict["gsp"],
|
|
75
75
|
t0,
|
|
76
76
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
77
|
-
interval_start=minutes(
|
|
78
|
-
interval_end=minutes(gsp_config.
|
|
77
|
+
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
78
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
79
79
|
)
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
82
82
|
datasets_dict["gsp"],
|
|
83
83
|
t0,
|
|
84
84
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
-
interval_start
|
|
85
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
86
86
|
interval_end=minutes(0),
|
|
87
87
|
)
|
|
88
88
|
|
|
@@ -94,9 +94,10 @@ def slice_datasets_by_time(
|
|
|
94
94
|
)
|
|
95
95
|
|
|
96
96
|
sliced_datasets_dict["gsp"] = apply_dropout_time(
|
|
97
|
-
sliced_datasets_dict["gsp"],
|
|
97
|
+
sliced_datasets_dict["gsp"],
|
|
98
|
+
gsp_dropout_time
|
|
98
99
|
)
|
|
99
|
-
|
|
100
|
+
|
|
100
101
|
if "site" in datasets_dict:
|
|
101
102
|
site_config = config.input_data.site
|
|
102
103
|
|
|
@@ -104,8 +105,8 @@ def slice_datasets_by_time(
|
|
|
104
105
|
datasets_dict["site"],
|
|
105
106
|
t0,
|
|
106
107
|
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
107
|
-
interval_start
|
|
108
|
-
interval_end=minutes(site_config.
|
|
108
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
109
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
# Randomly sample dropout
|
|
@@ -15,7 +15,7 @@ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
|
|
|
15
15
|
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
|
|
16
16
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
17
|
from ocf_data_sampler.select.location import Location
|
|
18
|
-
from ocf_data_sampler.
|
|
18
|
+
from ocf_data_sampler.utils import minutes
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def process_and_combine_datasets(
|
|
@@ -23,7 +23,7 @@ def process_and_combine_datasets(
|
|
|
23
23
|
config: Configuration,
|
|
24
24
|
t0: pd.Timestamp,
|
|
25
25
|
location: Location,
|
|
26
|
-
|
|
26
|
+
target_key: str = 'gsp'
|
|
27
27
|
) -> dict:
|
|
28
28
|
"""Normalize and convert data to numpy arrays"""
|
|
29
29
|
|
|
@@ -58,7 +58,8 @@ def process_and_combine_datasets(
|
|
|
58
58
|
|
|
59
59
|
numpy_modalities.append(
|
|
60
60
|
convert_gsp_to_numpy_batch(
|
|
61
|
-
da_gsp,
|
|
61
|
+
da_gsp,
|
|
62
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
62
63
|
)
|
|
63
64
|
)
|
|
64
65
|
|
|
@@ -80,34 +81,32 @@ def process_and_combine_datasets(
|
|
|
80
81
|
|
|
81
82
|
numpy_modalities.append(
|
|
82
83
|
convert_site_to_numpy_batch(
|
|
83
|
-
da_sites, t0_idx
|
|
84
|
+
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
|
|
84
85
|
)
|
|
85
86
|
)
|
|
86
87
|
|
|
87
|
-
if
|
|
88
|
+
if target_key == 'gsp':
|
|
88
89
|
# Make sun coords NumpyBatch
|
|
89
90
|
datetimes = pd.date_range(
|
|
90
|
-
t0
|
|
91
|
-
t0
|
|
91
|
+
t0+minutes(gsp_config.interval_start_minutes),
|
|
92
|
+
t0+minutes(gsp_config.interval_end_minutes),
|
|
92
93
|
freq=minutes(gsp_config.time_resolution_minutes),
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
96
|
-
key_prefix = "gsp"
|
|
97
97
|
|
|
98
|
-
elif
|
|
98
|
+
elif target_key == 'site':
|
|
99
99
|
# Make sun coords NumpyBatch
|
|
100
100
|
datetimes = pd.date_range(
|
|
101
|
-
t0
|
|
102
|
-
t0
|
|
101
|
+
t0+minutes(site_config.interval_start_minutes),
|
|
102
|
+
t0+minutes(site_config.interval_end_minutes),
|
|
103
103
|
freq=minutes(site_config.time_resolution_minutes),
|
|
104
104
|
)
|
|
105
105
|
|
|
106
106
|
lon, lat = location.x, location.y
|
|
107
|
-
key_prefix = "site"
|
|
108
107
|
|
|
109
108
|
numpy_modalities.append(
|
|
110
|
-
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=
|
|
109
|
+
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
# Combine all the modalities and fill NaNs
|
|
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset
|
|
|
9
9
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
10
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
11
|
from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
|
|
12
|
-
from ocf_data_sampler.
|
|
12
|
+
from ocf_data_sampler.utils import minutes
|
|
13
13
|
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
|
|
14
14
|
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
15
15
|
|
|
@@ -14,7 +14,7 @@ from ocf_data_sampler.select import (
|
|
|
14
14
|
intersection_of_multiple_dataframes_of_periods,
|
|
15
15
|
slice_datasets_by_time, slice_datasets_by_space
|
|
16
16
|
)
|
|
17
|
-
from ocf_data_sampler.
|
|
17
|
+
from ocf_data_sampler.utils import minutes
|
|
18
18
|
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
|
|
19
19
|
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
20
20
|
|
|
@@ -22,8 +22,8 @@ xr.set_options(keep_attrs=True)
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def find_valid_t0_and_site_ids(
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
datasets_dict: dict,
|
|
26
|
+
config: Configuration,
|
|
27
27
|
) -> pd.DataFrame:
|
|
28
28
|
"""Find the t0 times where all of the requested input data is available
|
|
29
29
|
|
|
@@ -57,8 +57,8 @@ def find_valid_t0_and_site_ids(
|
|
|
57
57
|
time_periods = find_contiguous_t0_periods(
|
|
58
58
|
pd.DatetimeIndex(site["time_utc"]),
|
|
59
59
|
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
61
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
62
62
|
)
|
|
63
63
|
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
64
64
|
[valid_time_periods, time_periods]
|
|
@@ -100,10 +100,10 @@ def get_locations(site_xr: xr.Dataset):
|
|
|
100
100
|
|
|
101
101
|
class SitesDataset(Dataset):
|
|
102
102
|
def __init__(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
103
|
+
self,
|
|
104
|
+
config_filename: str,
|
|
105
|
+
start_time: str | None = None,
|
|
106
|
+
end_time: str | None = None,
|
|
107
107
|
):
|
|
108
108
|
"""A torch Dataset for creating PVNet Site samples
|
|
109
109
|
|
|
@@ -154,7 +154,7 @@ class SitesDataset(Dataset):
|
|
|
154
154
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
155
155
|
sample_dict = compute(sample_dict)
|
|
156
156
|
|
|
157
|
-
sample = process_and_combine_datasets(sample_dict, self.config, t0, location,
|
|
157
|
+
sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')
|
|
158
158
|
|
|
159
159
|
return sample
|
|
160
160
|
|
|
@@ -2,9 +2,13 @@ import numpy as np
|
|
|
2
2
|
import pandas as pd
|
|
3
3
|
|
|
4
4
|
from ocf_data_sampler.config import Configuration
|
|
5
|
-
from ocf_data_sampler.select.find_contiguous_time_periods import
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
6
|
+
find_contiguous_t0_periods_nwp,
|
|
7
|
+
find_contiguous_t0_periods,
|
|
8
|
+
intersection_of_multiple_dataframes_of_periods,
|
|
9
|
+
)
|
|
10
|
+
from ocf_data_sampler.utils import minutes
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
def find_valid_time_periods(
|
|
@@ -46,7 +50,7 @@ def find_valid_time_periods(
|
|
|
46
50
|
# This is the max staleness we can use considering the max step of the input data
|
|
47
51
|
max_possible_staleness = (
|
|
48
52
|
pd.Timedelta(da["step"].max().item())
|
|
49
|
-
- minutes(nwp_config.
|
|
53
|
+
- minutes(nwp_config.interval_end_minutes)
|
|
50
54
|
- end_buffer
|
|
51
55
|
)
|
|
52
56
|
|
|
@@ -56,12 +60,16 @@ def find_valid_time_periods(
|
|
|
56
60
|
else:
|
|
57
61
|
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
58
62
|
assert max_staleness <= max_possible_staleness
|
|
63
|
+
|
|
64
|
+
# Find the first forecast step
|
|
65
|
+
first_forecast_step = pd.Timedelta(da["step"].min().item())
|
|
59
66
|
|
|
60
67
|
time_periods = find_contiguous_t0_periods_nwp(
|
|
61
|
-
|
|
62
|
-
|
|
68
|
+
init_times=pd.DatetimeIndex(da["init_time_utc"]),
|
|
69
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
63
70
|
max_staleness=max_staleness,
|
|
64
71
|
max_dropout=max_dropout,
|
|
72
|
+
first_forecast_step = first_forecast_step,
|
|
65
73
|
)
|
|
66
74
|
|
|
67
75
|
contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
|
|
@@ -72,8 +80,8 @@ def find_valid_time_periods(
|
|
|
72
80
|
time_periods = find_contiguous_t0_periods(
|
|
73
81
|
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
74
82
|
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
75
|
-
|
|
76
|
-
|
|
83
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
84
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
77
85
|
)
|
|
78
86
|
|
|
79
87
|
contiguous_time_periods['sat'] = time_periods
|
|
@@ -84,8 +92,8 @@ def find_valid_time_periods(
|
|
|
84
92
|
time_periods = find_contiguous_t0_periods(
|
|
85
93
|
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
86
94
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
87
|
-
|
|
88
|
-
|
|
95
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
96
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
89
97
|
)
|
|
90
98
|
|
|
91
99
|
contiguous_time_periods['gsp'] = time_periods
|
|
@@ -105,4 +113,4 @@ def find_valid_time_periods(
|
|
|
105
113
|
if len(valid_time_periods) == 0:
|
|
106
114
|
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
107
115
|
|
|
108
|
-
return valid_time_periods
|
|
116
|
+
return valid_time_periods
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
2
|
ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
|
|
3
|
-
ocf_data_sampler/
|
|
3
|
+
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
|
|
6
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=sXmh7IadwXDT-7lxEl5_b3vjovZgZYR77EXy4GHaf4w,7276
|
|
7
7
|
ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
|
|
@@ -27,22 +27,22 @@ ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLH
|
|
|
27
27
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
28
28
|
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
29
29
|
ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
|
|
30
|
-
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=
|
|
30
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=q7IaNfX95A3z9XHqbhgtkZ4Js1gn5K9Qyp6DVLbsL-Q,11093
|
|
31
31
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
32
32
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
33
33
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
34
|
-
ocf_data_sampler/select/select_time_slice.py,sha256=
|
|
34
|
+
ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_VmDn42s5leMzHt0k,6122
|
|
35
35
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
36
|
-
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=
|
|
36
|
+
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
|
|
37
37
|
ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
38
|
-
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=
|
|
39
|
-
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=
|
|
40
|
-
ocf_data_sampler/torch_datasets/site.py,sha256=
|
|
41
|
-
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=
|
|
38
|
+
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=4k6f6PlMqrg3luMwGw3764iOyfuUNUePKyoikYGaRMI,4953
|
|
39
|
+
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=QRFqbdfNchVWj4y70n-rJdFvFGvQj-WpZLdFqWjnOTw,5543
|
|
40
|
+
ocf_data_sampler/torch_datasets/site.py,sha256=lo2ULurfWNu9vzBC6H4pdKMMpUMIT8_FWC1l_1mgIOM,6596
|
|
41
|
+
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
42
42
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
43
43
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
tests/conftest.py,sha256=
|
|
45
|
-
tests/config/test_config.py,sha256=
|
|
44
|
+
tests/conftest.py,sha256=N-_XgXpWeTRhkwP_NVh2mBORt2LKkM4mbkm-O62RN5I,7363
|
|
45
|
+
tests/config/test_config.py,sha256=eaye_F7-el4tTP4n2vRME8qlV0b2jaKUX4HhgOUpa7E,5203
|
|
46
46
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
47
47
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
48
48
|
tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
|
|
@@ -53,14 +53,14 @@ tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0
|
|
|
53
53
|
tests/numpy_batch/test_sun_position.py,sha256=FYQ7KtlN0V5LlEjgI-cKjTMtGHUCxiMvxkRYTdMAgEE,2485
|
|
54
54
|
tests/select/test_dropout.py,sha256=kiycl7RxAQYMCZJlokmx6Da5h_oBpSs8Is8pmSW4gOU,2413
|
|
55
55
|
tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
|
|
56
|
-
tests/select/test_find_contiguous_time_periods.py,sha256=
|
|
56
|
+
tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
|
|
57
57
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
58
58
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
59
|
-
tests/select/test_select_time_slice.py,sha256=
|
|
59
|
+
tests/select/test_select_time_slice.py,sha256=QOhoR3qsr7RBGze4yohcViZ-ad1zYQzIKzxlnf0ymnU,9603
|
|
60
60
|
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
|
|
61
61
|
tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
|
|
62
|
-
ocf_data_sampler-0.0.
|
|
63
|
-
ocf_data_sampler-0.0.
|
|
64
|
-
ocf_data_sampler-0.0.
|
|
65
|
-
ocf_data_sampler-0.0.
|
|
66
|
-
ocf_data_sampler-0.0.
|
|
62
|
+
ocf_data_sampler-0.0.28.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
63
|
+
ocf_data_sampler-0.0.28.dist-info/METADATA,sha256=N0tSasiSNQVsvz3iAIi6_zoggS0FHmdo0YepfKCdjv4,5269
|
|
64
|
+
ocf_data_sampler-0.0.28.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
|
65
|
+
ocf_data_sampler-0.0.28.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
66
|
+
ocf_data_sampler-0.0.28.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -68,27 +68,33 @@ def test_extra_field_error():
|
|
|
68
68
|
_ = Configuration(**configuration_dict)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def
|
|
71
|
+
def test_incorrect_interval_start_minutes(test_config_filename):
|
|
72
72
|
"""
|
|
73
|
-
Check a
|
|
73
|
+
Check a history length not divisible by time resolution causes error
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
76
|
configuration = load_yaml_configuration(test_config_filename)
|
|
77
77
|
|
|
78
|
-
configuration.input_data.nwp['ukv'].
|
|
79
|
-
with pytest.raises(
|
|
78
|
+
configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
|
|
79
|
+
with pytest.raises(
|
|
80
|
+
ValueError,
|
|
81
|
+
match="interval_start_minutes must be divisible by time_resolution_minutes"
|
|
82
|
+
):
|
|
80
83
|
_ = Configuration(**configuration.model_dump())
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
def
|
|
86
|
+
def test_incorrect_interval_end_minutes(test_config_filename):
|
|
84
87
|
"""
|
|
85
|
-
Check a
|
|
88
|
+
Check a forecast length not divisible by time resolution causes error
|
|
86
89
|
"""
|
|
87
90
|
|
|
88
91
|
configuration = load_yaml_configuration(test_config_filename)
|
|
89
92
|
|
|
90
|
-
configuration.input_data.nwp['ukv'].
|
|
91
|
-
with pytest.raises(
|
|
93
|
+
configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
|
|
94
|
+
with pytest.raises(
|
|
95
|
+
ValueError,
|
|
96
|
+
match="interval_end_minutes must be divisible by time_resolution_minutes"
|
|
97
|
+
):
|
|
92
98
|
_ = Configuration(**configuration.model_dump())
|
|
93
99
|
|
|
94
100
|
|
tests/conftest.py
CHANGED
|
@@ -250,11 +250,13 @@ def data_sites() -> Site:
|
|
|
250
250
|
generation.to_netcdf(filename)
|
|
251
251
|
meta_df.to_csv(filename_csv)
|
|
252
252
|
|
|
253
|
-
site = Site(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
253
|
+
site = Site(
|
|
254
|
+
file_path=filename,
|
|
255
|
+
metadata_file_path=filename_csv,
|
|
256
|
+
interval_start_minutes=-30,
|
|
257
|
+
interval_end_minutes=60,
|
|
258
|
+
time_resolution_minutes=30,
|
|
259
|
+
)
|
|
258
260
|
|
|
259
261
|
yield site
|
|
260
262
|
|
|
@@ -11,8 +11,8 @@ def test_find_contiguous_t0_periods():
|
|
|
11
11
|
|
|
12
12
|
# Create 5-minutely data timestamps
|
|
13
13
|
freq = pd.Timedelta(5, "min")
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
interval_start = pd.Timedelta(-60, "min")
|
|
15
|
+
interval_end = pd.Timedelta(15, "min")
|
|
16
16
|
|
|
17
17
|
datetimes = (
|
|
18
18
|
pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq)
|
|
@@ -21,8 +21,8 @@ def test_find_contiguous_t0_periods():
|
|
|
21
21
|
|
|
22
22
|
periods = find_contiguous_t0_periods(
|
|
23
23
|
datetimes=datetimes,
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
interval_start=interval_start,
|
|
25
|
+
interval_end=interval_end,
|
|
26
26
|
sample_period_duration=freq,
|
|
27
27
|
)
|
|
28
28
|
|
|
@@ -135,7 +135,7 @@ def test_find_contiguous_t0_periods_nwp():
|
|
|
135
135
|
# Create 3-hourly init times with a few time stamps missing
|
|
136
136
|
freq = pd.Timedelta(3, "h")
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
init_times = (
|
|
139
139
|
pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq)
|
|
140
140
|
.delete([1, 4, 5, 6, 7, 9, 10])
|
|
141
141
|
)
|
|
@@ -146,13 +146,13 @@ def test_find_contiguous_t0_periods_nwp():
|
|
|
146
146
|
max_dropouts_hr = [0, 0, 0, 0, 3]
|
|
147
147
|
|
|
148
148
|
for i in range(len(expected_results)):
|
|
149
|
-
|
|
149
|
+
interval_start = pd.Timedelta(-history_durations_hr[i], "h")
|
|
150
150
|
max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h")
|
|
151
151
|
max_dropout = pd.Timedelta(max_dropouts_hr[i], "h")
|
|
152
152
|
|
|
153
153
|
time_periods = find_contiguous_t0_periods_nwp(
|
|
154
|
-
|
|
155
|
-
|
|
154
|
+
init_times=init_times,
|
|
155
|
+
interval_start=interval_start,
|
|
156
156
|
max_staleness=max_staleness,
|
|
157
157
|
max_dropout=max_dropout,
|
|
158
158
|
)
|
|
@@ -55,31 +55,19 @@ def test_select_time_slice(da_sat_like, t0_str):
|
|
|
55
55
|
|
|
56
56
|
# Slice parameters
|
|
57
57
|
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
interval_start = pd.Timedelta(-0, "min")
|
|
59
|
+
interval_end = pd.Timedelta(60, "min")
|
|
60
60
|
freq = pd.Timedelta("5min")
|
|
61
61
|
|
|
62
62
|
# Expect to return these timestamps from the selection
|
|
63
|
-
expected_datetimes = pd.date_range(t0
|
|
63
|
+
expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
|
|
64
64
|
|
|
65
|
-
# Make the selection
|
|
65
|
+
# Make the selection
|
|
66
66
|
sat_sample = select_time_slice(
|
|
67
|
-
|
|
67
|
+
da_sat_like,
|
|
68
68
|
t0=t0,
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
sample_period_duration=freq,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# Check the returned times are as expected
|
|
75
|
-
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
76
|
-
|
|
77
|
-
# Make the selection using the `interval_[x]` parameters
|
|
78
|
-
sat_sample = select_time_slice(
|
|
79
|
-
ds=da_sat_like,
|
|
80
|
-
t0=t0,
|
|
81
|
-
interval_start=-history_duration,
|
|
82
|
-
interval_end=forecast_duration,
|
|
69
|
+
interval_start=interval_start,
|
|
70
|
+
interval_end=interval_end,
|
|
83
71
|
sample_period_duration=freq,
|
|
84
72
|
)
|
|
85
73
|
|
|
@@ -93,8 +81,8 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
|
|
|
93
81
|
|
|
94
82
|
# Slice parameters
|
|
95
83
|
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
96
|
-
|
|
97
|
-
|
|
84
|
+
interval_start = pd.Timedelta(-30, "min")
|
|
85
|
+
interval_end = pd.Timedelta(60, "min")
|
|
98
86
|
freq = pd.Timedelta("5min")
|
|
99
87
|
|
|
100
88
|
# The data is available between these times
|
|
@@ -102,14 +90,14 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
|
|
|
102
90
|
max_time = da_sat_like.time_utc.max()
|
|
103
91
|
|
|
104
92
|
# Expect to return these timestamps from the selection
|
|
105
|
-
expected_datetimes = pd.date_range(t0
|
|
93
|
+
expected_datetimes = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
106
94
|
|
|
107
95
|
# Make the partially out of bounds selection
|
|
108
96
|
sat_sample = select_time_slice(
|
|
109
|
-
|
|
97
|
+
da_sat_like,
|
|
110
98
|
t0=t0,
|
|
111
|
-
|
|
112
|
-
|
|
99
|
+
interval_start=interval_start,
|
|
100
|
+
interval_end=interval_end,
|
|
113
101
|
sample_period_duration=freq,
|
|
114
102
|
fill_selection=True
|
|
115
103
|
)
|
|
@@ -138,8 +126,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
|
138
126
|
|
|
139
127
|
# Slice parameters
|
|
140
128
|
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
141
|
-
|
|
142
|
-
|
|
129
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
130
|
+
interval_end = pd.Timedelta(3, "h")
|
|
143
131
|
freq = pd.Timedelta("1h")
|
|
144
132
|
|
|
145
133
|
# Make the selection
|
|
@@ -147,8 +135,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
|
147
135
|
da_nwp_like,
|
|
148
136
|
t0,
|
|
149
137
|
sample_period_duration=freq,
|
|
150
|
-
|
|
151
|
-
|
|
138
|
+
interval_start=interval_start,
|
|
139
|
+
interval_end=interval_end,
|
|
152
140
|
dropout_timedeltas = None,
|
|
153
141
|
dropout_frac = 0,
|
|
154
142
|
accum_channels = [],
|
|
@@ -156,7 +144,7 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
|
156
144
|
)
|
|
157
145
|
|
|
158
146
|
# Check the target-times are as expected
|
|
159
|
-
expected_target_times = pd.date_range(t0
|
|
147
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
160
148
|
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
161
149
|
|
|
162
150
|
# Check the init-times are as expected
|
|
@@ -172,8 +160,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
|
172
160
|
"""Test the functionality of select_time_slice_nwp with dropout"""
|
|
173
161
|
|
|
174
162
|
t0 = pd.Timestamp("2024-01-02 12:00")
|
|
175
|
-
|
|
176
|
-
|
|
163
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
164
|
+
interval_end = pd.Timedelta(3, "h")
|
|
177
165
|
freq = pd.Timedelta("1h")
|
|
178
166
|
dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
|
|
179
167
|
|
|
@@ -181,8 +169,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
|
181
169
|
da_nwp_like,
|
|
182
170
|
t0,
|
|
183
171
|
sample_period_duration=freq,
|
|
184
|
-
|
|
185
|
-
|
|
172
|
+
interval_start=interval_start,
|
|
173
|
+
interval_end=interval_end,
|
|
186
174
|
dropout_timedeltas = [dropout_timedelta],
|
|
187
175
|
dropout_frac = 1,
|
|
188
176
|
accum_channels = [],
|
|
@@ -190,7 +178,7 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
|
190
178
|
)
|
|
191
179
|
|
|
192
180
|
# Check the target-times are as expected
|
|
193
|
-
expected_target_times = pd.date_range(t0
|
|
181
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
194
182
|
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
195
183
|
|
|
196
184
|
# Check the init-times are as expected considering the delay
|
|
@@ -207,9 +195,9 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
|
207
195
|
|
|
208
196
|
# Slice parameters
|
|
209
197
|
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
freq = pd.Timedelta("
|
|
198
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
199
|
+
interval_end = pd.Timedelta(3, "h")
|
|
200
|
+
freq = pd.Timedelta("1H")
|
|
213
201
|
dropout_timedelta = pd.Timedelta("-2h")
|
|
214
202
|
|
|
215
203
|
t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
|
|
@@ -218,8 +206,8 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
|
218
206
|
da_nwp_like,
|
|
219
207
|
t0,
|
|
220
208
|
sample_period_duration=freq,
|
|
221
|
-
|
|
222
|
-
|
|
209
|
+
interval_start=interval_start,
|
|
210
|
+
interval_end=interval_end,
|
|
223
211
|
dropout_timedeltas=[dropout_timedelta],
|
|
224
212
|
dropout_frac=1,
|
|
225
213
|
accum_channels=["dswrf"],
|
|
@@ -227,7 +215,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
|
227
215
|
)
|
|
228
216
|
|
|
229
217
|
# Check the target-times are as expected
|
|
230
|
-
expected_target_times = pd.date_range(t0
|
|
218
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
231
219
|
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
232
220
|
|
|
233
221
|
# Check the init-times are as expected considering the delay
|
|
@@ -254,7 +242,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
|
254
242
|
init_time_utc=t0_delayed,
|
|
255
243
|
channel="dswrf",
|
|
256
244
|
).diff(dim="step", label="lower")
|
|
257
|
-
.sel(step=slice(t0-t0_delayed
|
|
245
|
+
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
258
246
|
)
|
|
259
247
|
|
|
260
248
|
# Check the values are the same
|
|
@@ -275,7 +263,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
|
275
263
|
init_time_utc=t0_delayed,
|
|
276
264
|
channel="t",
|
|
277
265
|
)
|
|
278
|
-
.sel(step=slice(t0-t0_delayed
|
|
266
|
+
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
279
267
|
)
|
|
280
268
|
|
|
281
269
|
# Check the values are the same
|
|
File without changes
|
|
File without changes
|