ocf-data-sampler 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ from typing import Dict, List, Optional
15
15
  from typing_extensions import Self
16
16
 
17
17
  from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
18
+
18
19
  from ocf_data_sampler.constants import NWP_PROVIDERS
19
20
 
20
21
  logger = logging.getLogger(__name__)
@@ -34,26 +35,50 @@ class Base(BaseModel):
34
35
  class General(Base):
35
36
  """General pydantic model"""
36
37
 
37
- name: str = Field("example", description="The name of this configuration file.")
38
+ name: str = Field("example", description="The name of this configuration file")
38
39
  description: str = Field(
39
40
  "example configuration", description="Description of this configuration file"
40
41
  )
41
42
 
42
43
 
43
- class DataSourceMixin(Base):
44
- """Mixin class, to add forecast and history minutes"""
44
+ class TimeWindowMixin(Base):
45
+ """Mixin class, to add interval start, end and resolution minutes"""
45
46
 
46
- forecast_minutes: int = Field(
47
+ time_resolution_minutes: int = Field(
47
48
  ...,
48
- ge=0,
49
- description="how many minutes to forecast in the future. ",
49
+ gt=0,
50
+ description="The temporal resolution of the data in minutes",
50
51
  )
51
- history_minutes: int = Field(
52
+
53
+ interval_start_minutes: int = Field(
52
54
  ...,
53
- ge=0,
54
- description="how many historic minutes to use. ",
55
+ description="Data interval starts at `t0 + interval_start_minutes`",
55
56
  )
56
57
 
58
+ interval_end_minutes: int = Field(
59
+ ...,
60
+ description="Data interval ends at `t0 + interval_end_minutes`",
61
+ )
62
+
63
+ @model_validator(mode='after')
64
+ def check_interval_range(cls, values):
65
+ if values.interval_start_minutes > values.interval_end_minutes:
66
+ raise ValueError('interval_start_minutes must be <= interval_end_minutes')
67
+ return values
68
+
69
+ @field_validator("interval_start_minutes")
70
+ def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
71
+ if v % info.data["time_resolution_minutes"] != 0:
72
+ raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
73
+ return v
74
+
75
+ @field_validator("interval_end_minutes")
76
+ def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
77
+ if v % info.data["time_resolution_minutes"] != 0:
78
+ raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
79
+ return v
80
+
81
+
57
82
 
58
83
  # noinspection PyMethodParameters
59
84
  class DropoutMixin(Base):
@@ -65,7 +90,12 @@ class DropoutMixin(Base):
65
90
  "negative or zero.",
66
91
  )
67
92
 
68
- dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")
93
+ dropout_fraction: float = Field(
94
+ default=0,
95
+ description="Chance of dropout being applied to each sample",
96
+ ge=0,
97
+ le=1,
98
+ )
69
99
 
70
100
  @field_validator("dropout_timedeltas_minutes")
71
101
  def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
@@ -75,12 +105,6 @@ class DropoutMixin(Base):
75
105
  assert m <= 0, "Dropout timedeltas must be negative"
76
106
  return v
77
107
 
78
- @field_validator("dropout_fraction")
79
- def dropout_fraction_valid(cls, v: float) -> float:
80
- """Validate 'dropout_fraction'"""
81
- assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
82
- return v
83
-
84
108
  @model_validator(mode="after")
85
109
  def dropout_instructions_consistent(self) -> Self:
86
110
  if self.dropout_fraction == 0:
@@ -92,93 +116,51 @@ class DropoutMixin(Base):
92
116
  return self
93
117
 
94
118
 
95
- # noinspection PyMethodParameters
96
- class TimeResolutionMixin(Base):
97
- """Time resolution mix in"""
119
+ class SpatialWindowMixin(Base):
120
+ """Mixin class, to add path and image size"""
98
121
 
99
- time_resolution_minutes: int = Field(
122
+ image_size_pixels_height: int = Field(
100
123
  ...,
101
- description="The temporal resolution of the data in minutes",
124
+ ge=0,
125
+ description="The number of pixels of the height of the region of interest",
102
126
  )
103
127
 
104
-
105
- class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
106
- """Site configuration model"""
107
-
108
- file_path: str = Field(
128
+ image_size_pixels_width: int = Field(
109
129
  ...,
110
- description="The NetCDF files holding the power timeseries.",
111
- )
112
- metadata_file_path: str = Field(
113
- ...,
114
- description="The CSV files describing power system",
130
+ ge=0,
131
+ description="The number of pixels of the width of the region of interest",
115
132
  )
116
133
 
117
- @field_validator("forecast_minutes")
118
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
119
- """Check forecast length requested will give stable number of timesteps"""
120
- if v % info.data["time_resolution_minutes"] != 0:
121
- message = "Forecast duration must be divisible by time resolution"
122
- logger.error(message)
123
- raise Exception(message)
124
- return v
125
-
126
- @field_validator("history_minutes")
127
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
128
- """Check history length requested will give stable number of timesteps"""
129
- if v % info.data["time_resolution_minutes"] != 0:
130
- message = "History duration must be divisible by time resolution"
131
- logger.error(message)
132
- raise Exception(message)
133
- return v
134
-
135
- # TODO validate the netcdf for sites
136
- # TODO validate the csv for metadata
137
134
 
138
- class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
135
+ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
139
136
  """Satellite configuration model"""
140
-
141
- # Todo: remove 'satellite' from names
142
- satellite_zarr_path: str | tuple[str] | list[str] = Field(
143
- ...,
144
- description="The path or list of paths which hold the satellite zarr",
145
- )
146
- satellite_channels: list[str] = Field(
147
- ..., description="the satellite channels that are used"
148
- )
149
- satellite_image_size_pixels_height: int = Field(
137
+
138
+ zarr_path: str | tuple[str] | list[str] = Field(
150
139
  ...,
151
- description="The number of pixels of the height of the region of interest"
152
- " for non-HRV satellite channels.",
140
+ description="The path or list of paths which hold the data zarr",
153
141
  )
154
142
 
155
- satellite_image_size_pixels_width: int = Field(
156
- ...,
157
- description="The number of pixels of the width of the region "
158
- "of interest for non-HRV satellite channels.",
159
- )
160
-
161
- live_delay_minutes: int = Field(
162
- ..., description="The expected delay in minutes of the satellite data"
143
+ channels: list[str] = Field(
144
+ ..., description="the satellite channels that are used"
163
145
  )
164
146
 
165
147
 
166
148
  # noinspection PyMethodParameters
167
- class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
149
+ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
168
150
  """NWP configuration model"""
169
-
170
- nwp_zarr_path: str | tuple[str] | list[str] = Field(
151
+
152
+ zarr_path: str | tuple[str] | list[str] = Field(
171
153
  ...,
172
- description="The path which holds the NWP zarr",
154
+ description="The path or list of paths which hold the data zarr",
173
155
  )
174
- nwp_channels: list[str] = Field(
156
+
157
+ channels: list[str] = Field(
175
158
  ..., description="the channels used in the nwp data"
176
159
  )
177
- nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
178
- nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
179
- nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")
180
160
 
181
- nwp_provider: str = Field(..., description="The provider of the NWP data")
161
+ provider: str = Field(..., description="The provider of the NWP data")
162
+
163
+ accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
182
164
 
183
165
  max_staleness_minutes: Optional[int] = Field(
184
166
  None,
@@ -188,32 +170,15 @@ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
188
170
  )
189
171
 
190
172
 
191
- @field_validator("nwp_provider")
192
- def validate_nwp_provider(cls, v: str) -> str:
193
- """Validate 'nwp_provider'"""
173
+ @field_validator("provider")
174
+ def validate_provider(cls, v: str) -> str:
175
+ """Validate 'provider'"""
194
176
  if v.lower() not in NWP_PROVIDERS:
195
177
  message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
196
178
  logger.warning(message)
197
179
  raise Exception(message)
198
180
  return v
199
181
 
200
- # Todo: put into time mixin when moving intervals there
201
- @field_validator("forecast_minutes")
202
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
203
- if v % info.data["time_resolution_minutes"] != 0:
204
- message = "Forecast duration must be divisible by time resolution"
205
- logger.error(message)
206
- raise Exception(message)
207
- return v
208
-
209
- @field_validator("history_minutes")
210
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
211
- if v % info.data["time_resolution_minutes"] != 0:
212
- message = "History duration must be divisible by time resolution"
213
- logger.error(message)
214
- raise Exception(message)
215
- return v
216
-
217
182
 
218
183
  class MultiNWP(RootModel):
219
184
  """Configuration for multiple NWPs"""
@@ -241,34 +206,32 @@ class MultiNWP(RootModel):
241
206
  return self.root.items()
242
207
 
243
208
 
244
- # noinspection PyMethodParameters
245
- class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
209
+ class GSP(TimeWindowMixin, DropoutMixin):
246
210
  """GSP configuration model"""
247
211
 
248
- gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr")
212
+ zarr_path: str = Field(..., description="The path which holds the GSP zarr")
249
213
 
250
- @field_validator("forecast_minutes")
251
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
252
- if v % info.data["time_resolution_minutes"] != 0:
253
- message = "Forecast duration must be divisible by time resolution"
254
- logger.error(message)
255
- raise Exception(message)
256
- return v
257
214
 
258
- @field_validator("history_minutes")
259
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
260
- if v % info.data["time_resolution_minutes"] != 0:
261
- message = "History duration must be divisible by time resolution"
262
- logger.error(message)
263
- raise Exception(message)
264
- return v
215
+ class Site(TimeWindowMixin, DropoutMixin):
216
+ """Site configuration model"""
217
+
218
+ file_path: str = Field(
219
+ ...,
220
+ description="The NetCDF files holding the power timeseries.",
221
+ )
222
+ metadata_file_path: str = Field(
223
+ ...,
224
+ description="The CSV files describing power system",
225
+ )
226
+
227
+ # TODO validate the netcdf for sites
228
+ # TODO validate the csv for metadata
229
+
265
230
 
266
231
 
267
232
  # noinspection PyPep8Naming
268
233
  class InputData(Base):
269
- """
270
- Input data model.
271
- """
234
+ """Input data model"""
272
235
 
273
236
  satellite: Optional[Satellite] = None
274
237
  nwp: Optional[MultiNWP] = None
@@ -280,4 +243,4 @@ class Configuration(Base):
280
243
  """Configuration model for the dataset"""
281
244
 
282
245
  general: General = General()
283
- input_data: InputData = InputData()
246
+ input_data: InputData = InputData()
@@ -20,8 +20,8 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
20
20
  datasets_dict = {}
21
21
 
22
22
  # Load GSP data unless the path is None
23
- if in_config.gsp and in_config.gsp.gsp_zarr_path:
24
- da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
23
+ if in_config.gsp and in_config.gsp.zarr_path:
24
+ da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()
25
25
 
26
26
  # Remove national GSP
27
27
  datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
@@ -32,9 +32,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
32
32
  datasets_dict["nwp"] = {}
33
33
  for nwp_source, nwp_config in in_config.nwp.items():
34
34
 
35
- da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
35
+ da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
36
36
 
37
- da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
37
+ da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
38
38
 
39
39
  datasets_dict["nwp"][nwp_source] = da_nwp
40
40
 
@@ -42,9 +42,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
42
42
  if in_config.satellite:
43
43
  sat_config = config.input_data.satellite
44
44
 
45
- da_sat = open_sat_data(sat_config.satellite_zarr_path)
45
+ da_sat = open_sat_data(sat_config.zarr_path)
46
46
 
47
- da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
47
+ da_sat = da_sat.sel(channel=list(sat_config.channels))
48
48
 
49
49
  datasets_dict["sat"] = da_sat
50
50
 
@@ -63,16 +63,16 @@ def find_contiguous_time_periods(
63
63
 
64
64
  def trim_contiguous_time_periods(
65
65
  contiguous_time_periods: pd.DataFrame,
66
- history_duration: pd.Timedelta,
67
- forecast_duration: pd.Timedelta,
66
+ interval_start: pd.Timedelta,
67
+ interval_end: pd.Timedelta,
68
68
  ) -> pd.DataFrame:
69
69
  """Trim the contiguous time periods to allow for history and forecast durations.
70
70
 
71
71
  Args:
72
72
  contiguous_time_periods: DataFrame where each row represents a single time period. The
73
73
  DataFrame must have `start_dt` and `end_dt` columns.
74
- history_duration: Length of the historical slice used for a sample
75
- forecast_duration: Length of the forecast slice used for a sample
74
+ interval_start: The start of the interval with respect to t0
75
+ interval_end: The end of the interval with respect to t0
76
76
 
77
77
 
78
78
  Returns:
@@ -80,8 +80,8 @@ def trim_contiguous_time_periods(
80
80
  """
81
81
  contiguous_time_periods = contiguous_time_periods.copy()
82
82
 
83
- contiguous_time_periods["start_dt"] += history_duration
84
- contiguous_time_periods["end_dt"] -= forecast_duration
83
+ contiguous_time_periods["start_dt"] -= interval_start
84
+ contiguous_time_periods["end_dt"] -= interval_end
85
85
 
86
86
  valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
87
87
  contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
@@ -92,16 +92,16 @@ def trim_contiguous_time_periods(
92
92
 
93
93
  def find_contiguous_t0_periods(
94
94
  datetimes: pd.DatetimeIndex,
95
- history_duration: pd.Timedelta,
96
- forecast_duration: pd.Timedelta,
95
+ interval_start: pd.Timedelta,
96
+ interval_end: pd.Timedelta,
97
97
  sample_period_duration: pd.Timedelta,
98
98
  ) -> pd.DataFrame:
99
99
  """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
100
100
 
101
101
  Args:
102
102
  datetimes: pd.DatetimeIndex. Must be sorted.
103
- history_duration: Length of the historical slice used for each sample
104
- forecast_duration: Length of the forecast slice used for each sample
103
+ interval_start: The start of the interval with respect to t0
104
+ interval_end: The end of the interval with respect to t0
105
105
  sample_period_duration: The sample frequency of the timeseries
106
106
 
107
107
 
@@ -109,7 +109,7 @@ def find_contiguous_t0_periods(
109
109
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
110
110
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
111
111
  """
112
- total_duration = history_duration + forecast_duration
112
+ total_duration = interval_end - interval_start
113
113
 
114
114
  contiguous_time_periods = find_contiguous_time_periods(
115
115
  datetimes=datetimes,
@@ -119,8 +119,8 @@ def find_contiguous_t0_periods(
119
119
 
120
120
  contiguous_t0_periods = trim_contiguous_time_periods(
121
121
  contiguous_time_periods=contiguous_time_periods,
122
- history_duration=history_duration,
123
- forecast_duration=forecast_duration,
122
+ interval_start=interval_start,
123
+ interval_end=interval_end,
124
124
  )
125
125
 
126
126
  assert len(contiguous_t0_periods) > 0
@@ -128,92 +128,57 @@ def find_contiguous_t0_periods(
128
128
  return contiguous_t0_periods
129
129
 
130
130
 
131
- def _find_contiguous_t0_periods_nwp(
132
- ds,
133
- history_duration: pd.Timedelta,
134
- forecast_duration: pd.Timedelta,
135
- max_staleness: pd.Timedelta | None = None,
136
- max_dropout: pd.Timedelta = pd.Timedelta(0),
137
- time_dim: str = "init_time_utc",
138
- end_buffer: pd.Timedelta = pd.Timedelta(0),
139
- ):
140
-
141
- assert "step" in ds.coords
142
- # It is possible to use up to this amount of max staleness for the dataset and slice
143
- # required
144
- possible_max_staleness = (
145
- pd.Timedelta(ds["step"].max().item())
146
- - forecast_duration
147
- - end_buffer
148
- )
149
-
150
- # If max_staleness is set to None we set it based on the max step ahead of the input
151
- # forecast data
152
- if max_staleness is None:
153
- max_staleness = possible_max_staleness
154
- else:
155
- # Make sure the max acceptable staleness isn't longer than the max possible
156
- assert max_staleness <= possible_max_staleness
157
- max_staleness = max_staleness
158
-
159
- contiguous_time_periods = find_contiguous_t0_periods_nwp(
160
- datetimes=pd.DatetimeIndex(ds[time_dim]),
161
- history_duration=history_duration,
162
- max_staleness=max_staleness,
163
- max_dropout=max_dropout,
164
- )
165
- return contiguous_time_periods
166
-
167
-
168
-
169
131
  def find_contiguous_t0_periods_nwp(
170
- datetimes: pd.DatetimeIndex,
171
- history_duration: pd.Timedelta,
132
+ init_times: pd.DatetimeIndex,
133
+ interval_start: pd.Timedelta,
172
134
  max_staleness: pd.Timedelta,
173
135
  max_dropout: pd.Timedelta = pd.Timedelta(0),
136
+ first_forecast_step: pd.Timedelta = pd.Timedelta(0),
137
+
174
138
  ) -> pd.DataFrame:
175
139
  """Get all time periods from the NWP init times which are valid as t0 datetimes.
176
140
 
177
141
  Args:
178
- datetimes: Sorted pd.DatetimeIndex
179
- history_duration: Length of the historical slice used for a sample
180
- max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
181
- forecast. Each init time will only be used up to this t0 time regardless of the forecast
182
- valid time.
142
+ init_times: The initialisation times of the available forecasts
143
+ interval_start: The start of the desired data interval with respect to t0
144
+ max_staleness: Up to how long after an init time are we willing to use the forecast. Each
145
+ init time will only be used up to this t0 time regardless of the forecast valid time.
183
146
  max_dropout: What is the maximum amount of dropout that will be used. This must be <=
184
147
  max_staleness.
148
+ first_forecast_step: The timedelta of the first step of the forecast. By default we assume
149
+ the first valid time of the forecast is the same as its init time.
185
150
 
186
151
  Returns:
187
152
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
188
153
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
189
154
  """
190
155
  # Sanity checks.
191
- assert len(datetimes) > 0
192
- assert datetimes.is_monotonic_increasing
193
- assert datetimes.is_unique
194
- assert history_duration >= pd.Timedelta(0)
156
+ assert len(init_times) > 0
157
+ assert init_times.is_monotonic_increasing
158
+ assert init_times.is_unique
195
159
  assert max_staleness >= pd.Timedelta(0)
196
- assert max_dropout <= max_staleness
160
+ assert pd.Timedelta(0) <= max_dropout <= max_staleness
197
161
 
198
- hist_drop_buffer = max(history_duration, max_dropout)
162
+ hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
199
163
 
200
164
  # Store contiguous periods
201
165
  contiguous_periods = []
202
166
 
203
- # Start first period allowing for history slice and max dropout
204
- start_this_period = datetimes[0] + hist_drop_buffer
167
+ # Begin the first period allowing for the time to the first_forecast_step, the length of the
168
+ # interval sampled from before t0, and the dropout
169
+ start_this_period = init_times[0] + hist_drop_buffer
205
170
 
206
171
  # The first forecast is valid up to the max staleness
207
- end_this_period = datetimes[0] + max_staleness
208
-
209
- for dt_init in datetimes[1:]:
210
- # If the previous init time becomes stale before the next init becomes valid whilst also
211
- # considering dropout - then the contiguous period breaks, and new starts with considering
212
- # dropout and history duration
213
- if end_this_period < dt_init + max_dropout:
172
+ end_this_period = init_times[0] + max_staleness
173
+
174
+ for dt_init in init_times[1:]:
175
+ # If the previous init time becomes stale before the next init becomes valid (whilst also
176
+ # considering dropout) then the contiguous period breaks
177
+ # Else if the previous init time becomes stale before the fist step of the next forecast
178
+ # then this also causes a break in the contiguous period
179
+ if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
214
180
  contiguous_periods.append([start_this_period, end_this_period])
215
-
216
- # And start a new period
181
+ # The new period begins with the same conditions as the first period
217
182
  start_this_period = dt_init + hist_drop_buffer
218
183
  end_this_period = dt_init + max_staleness
219
184
 
@@ -39,23 +39,14 @@ def _sel_fillinterp(
39
39
  def select_time_slice(
40
40
  ds: xr.DataArray,
41
41
  t0: pd.Timestamp,
42
+ interval_start: pd.Timedelta,
43
+ interval_end: pd.Timedelta,
42
44
  sample_period_duration: pd.Timedelta,
43
- history_duration: pd.Timedelta | None = None,
44
- forecast_duration: pd.Timedelta | None = None,
45
- interval_start: pd.Timedelta | None = None,
46
- interval_end: pd.Timedelta | None = None,
47
45
  fill_selection: bool = False,
48
46
  max_steps_gap: int = 0,
49
47
  ):
50
48
  """Select a time slice from a Dataset or DataArray."""
51
- used_duration = history_duration is not None and forecast_duration is not None
52
- used_intervals = interval_start is not None and interval_end is not None
53
- assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
54
49
  assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
55
-
56
- if used_duration:
57
- interval_start = - history_duration
58
- interval_end = forecast_duration
59
50
 
60
51
  if fill_selection and max_steps_gap == 0:
61
52
  _sel = _sel_fillnan
@@ -75,11 +66,11 @@ def select_time_slice(
75
66
 
76
67
 
77
68
  def select_time_slice_nwp(
78
- ds: xr.DataArray,
69
+ da: xr.DataArray,
79
70
  t0: pd.Timestamp,
71
+ interval_start: pd.Timedelta,
72
+ interval_end: pd.Timedelta,
80
73
  sample_period_duration: pd.Timedelta,
81
- history_duration: pd.Timedelta,
82
- forecast_duration: pd.Timedelta,
83
74
  dropout_timedeltas: list[pd.Timedelta] | None = None,
84
75
  dropout_frac: float | None = 0,
85
76
  accum_channels: list[str] = [],
@@ -92,31 +83,31 @@ def select_time_slice_nwp(
92
83
  ), "dropout timedeltas must be negative"
93
84
  assert len(dropout_timedeltas) >= 1
94
85
  assert 0 <= dropout_frac <= 1
95
- _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
86
+ consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
96
87
 
97
88
 
98
89
  # The accumatation and non-accumulation channels
99
90
  accum_channels = np.intersect1d(
100
- ds[channel_dim_name].values, accum_channels
91
+ da[channel_dim_name].values, accum_channels
101
92
  )
102
93
  non_accum_channels = np.setdiff1d(
103
- ds[channel_dim_name].values, accum_channels
94
+ da[channel_dim_name].values, accum_channels
104
95
  )
105
96
 
106
- start_dt = (t0 - history_duration).ceil(sample_period_duration)
107
- end_dt = (t0 + forecast_duration).ceil(sample_period_duration)
97
+ start_dt = (t0 + interval_start).ceil(sample_period_duration)
98
+ end_dt = (t0 + interval_end).ceil(sample_period_duration)
108
99
 
109
100
  target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
110
101
 
111
102
  # Maybe apply NWP dropout
112
- if _consider_dropout and (np.random.uniform() < dropout_frac):
103
+ if consider_dropout and (np.random.uniform() < dropout_frac):
113
104
  dt = np.random.choice(dropout_timedeltas)
114
105
  t0_available = t0 + dt
115
106
  else:
116
107
  t0_available = t0
117
108
 
118
109
  # Forecasts made up to and including t0
119
- available_init_times = ds.init_time_utc.sel(
110
+ available_init_times = da.init_time_utc.sel(
120
111
  init_time_utc=slice(None, t0_available)
121
112
  )
122
113
 
@@ -139,7 +130,7 @@ def select_time_slice_nwp(
139
130
  step_indexer = xr.DataArray(steps, coords=coords)
140
131
 
141
132
  if len(accum_channels) == 0:
142
- xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer)
133
+ da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
143
134
 
144
135
  else:
145
136
  # First minimise the size of the dataset we are diffing
@@ -149,7 +140,7 @@ def select_time_slice_nwp(
149
140
  min_step = min(steps)
150
141
  max_step = max(steps) + sample_period_duration
151
142
 
152
- xr_min = ds.sel(
143
+ da_min = da.sel(
153
144
  {
154
145
  "init_time_utc": unique_init_times,
155
146
  "step": slice(min_step, max_step),
@@ -157,28 +148,28 @@ def select_time_slice_nwp(
157
148
  )
158
149
 
159
150
  # Slice out the data which does not need to be diffed
160
- xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels})
161
- xr_sel_non_accum = xr_non_accum.sel(
151
+ da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
152
+ da_sel_non_accum = da_non_accum.sel(
162
153
  step=step_indexer, init_time_utc=init_time_indexer
163
154
  )
164
155
 
165
156
  # Slice out the channels which need to be diffed
166
- xr_accum = xr_min.sel({channel_dim_name: accum_channels})
157
+ da_accum = da_min.sel({channel_dim_name: accum_channels})
167
158
 
168
159
  # Take the diff and slice requested data
169
- xr_accum = xr_accum.diff(dim="step", label="lower")
170
- xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
160
+ da_accum = da_accum.diff(dim="step", label="lower")
161
+ da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
171
162
 
172
163
  # Join diffed and non-diffed variables
173
- xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name)
164
+ da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
174
165
 
175
166
  # Reorder the variable back to the original order
176
- xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values})
167
+ da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
177
168
 
178
169
  # Rename the diffed channels
179
- xr_sel[channel_dim_name] = [
170
+ da_sel[channel_dim_name] = [
180
171
  f"diff_{v}" if v in accum_channels else v
181
- for v in xr_sel[channel_dim_name].values
172
+ for v in da_sel[channel_dim_name].values
182
173
  ]
183
174
 
184
- return xr_sel
175
+ return da_sel