ocf-data-sampler 0.0.26__tar.gz → 0.0.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (72) hide show
  1. {ocf_data_sampler-0.0.26/ocf_data_sampler.egg-info → ocf_data_sampler-0.0.28}/PKG-INFO +1 -1
  2. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/config/model.py +46 -46
  3. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  4. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/select_time_slice.py +24 -33
  5. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  6. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/torch_datasets/process_and_combine.py +12 -13
  7. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +1 -1
  8. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/torch_datasets/site.py +10 -10
  9. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/torch_datasets/valid_time_periods.py +19 -11
  10. ocf_data_sampler-0.0.26/ocf_data_sampler/time_functions.py → ocf_data_sampler-0.0.28/ocf_data_sampler/utils.py +1 -2
  11. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
  12. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler.egg-info/SOURCES.txt +1 -1
  13. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/pyproject.toml +1 -1
  14. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/config/test_config.py +14 -8
  15. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/conftest.py +7 -5
  16. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_find_contiguous_time_periods.py +8 -8
  17. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_select_time_slice.py +31 -43
  18. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/LICENSE +0 -0
  19. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/MANIFEST.in +0 -0
  20. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/README.md +0 -0
  21. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/__init__.py +0 -0
  22. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/config/__init__.py +0 -0
  23. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/config/load.py +0 -0
  24. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/config/save.py +0 -0
  25. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/constants.py +0 -0
  26. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  27. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/__init__.py +0 -0
  28. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/gsp.py +0 -0
  29. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/load_dataset.py +0 -0
  30. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  31. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  32. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  33. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  34. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  35. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  36. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/satellite.py +0 -0
  37. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/site.py +0 -0
  38. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/load/utils.py +0 -0
  39. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/__init__.py +0 -0
  40. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/gsp.py +0 -0
  41. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/nwp.py +0 -0
  42. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/satellite.py +0 -0
  43. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/site.py +0 -0
  44. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/numpy_batch/sun_position.py +0 -0
  45. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/__init__.py +0 -0
  46. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/dropout.py +0 -0
  47. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  48. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/geospatial.py +0 -0
  49. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/location.py +0 -0
  50. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  51. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  52. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler/torch_datasets/__init__.py +0 -0
  53. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  54. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler.egg-info/requires.txt +0 -0
  55. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  56. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/scripts/refactor_site.py +0 -0
  57. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/setup.cfg +0 -0
  58. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/__init__.py +0 -0
  59. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/load/test_load_gsp.py +0 -0
  60. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/load/test_load_nwp.py +0 -0
  61. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/load/test_load_satellite.py +0 -0
  62. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/load/test_load_sites.py +0 -0
  63. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/numpy_batch/test_gsp.py +0 -0
  64. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/numpy_batch/test_nwp.py +0 -0
  65. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/numpy_batch/test_satellite.py +0 -0
  66. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/numpy_batch/test_sun_position.py +0 -0
  67. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_dropout.py +0 -0
  68. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_fill_time_periods.py +0 -0
  69. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_location.py +0 -0
  70. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/select/test_select_spatial_slice.py +0 -0
  71. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/torch_datasets/test_pvnet_uk_regional.py +0 -0
  72. {ocf_data_sampler-0.0.26 → ocf_data_sampler-0.0.28}/tests/torch_datasets/test_site.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.26
3
+ Version: 0.0.28
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -14,7 +14,8 @@ import logging
14
14
  from typing import Dict, List, Optional
15
15
  from typing_extensions import Self
16
16
 
17
- from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
17
+ from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
18
+
18
19
  from ocf_data_sampler.constants import NWP_PROVIDERS
19
20
 
20
21
  logger = logging.getLogger(__name__)
@@ -40,6 +41,45 @@ class General(Base):
40
41
  )
41
42
 
42
43
 
44
+ class TimeWindowMixin(Base):
45
+ """Mixin class, to add interval start, end and resolution minutes"""
46
+
47
+ time_resolution_minutes: int = Field(
48
+ ...,
49
+ gt=0,
50
+ description="The temporal resolution of the data in minutes",
51
+ )
52
+
53
+ interval_start_minutes: int = Field(
54
+ ...,
55
+ description="Data interval starts at `t0 + interval_start_minutes`",
56
+ )
57
+
58
+ interval_end_minutes: int = Field(
59
+ ...,
60
+ description="Data interval ends at `t0 + interval_end_minutes`",
61
+ )
62
+
63
+ @model_validator(mode='after')
64
+ def check_interval_range(cls, values):
65
+ if values.interval_start_minutes > values.interval_end_minutes:
66
+ raise ValueError('interval_start_minutes must be <= interval_end_minutes')
67
+ return values
68
+
69
+ @field_validator("interval_start_minutes")
70
+ def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
71
+ if v % info.data["time_resolution_minutes"] != 0:
72
+ raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
73
+ return v
74
+
75
+ @field_validator("interval_end_minutes")
76
+ def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
77
+ if v % info.data["time_resolution_minutes"] != 0:
78
+ raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
79
+ return v
80
+
81
+
82
+
43
83
  # noinspection PyMethodParameters
44
84
  class DropoutMixin(Base):
45
85
  """Mixin class, to add dropout minutes"""
@@ -76,54 +116,18 @@ class DropoutMixin(Base):
76
116
  return self
77
117
 
78
118
 
79
- # noinspection PyMethodParameters
80
- class TimeWindowMixin(Base):
81
- """Time resolution mix in"""
82
-
83
- time_resolution_minutes: int = Field(
84
- ...,
85
- gt=0,
86
- description="The temporal resolution of the data in minutes",
87
- )
88
-
89
- forecast_minutes: int = Field(
90
- ...,
91
- ge=0,
92
- description="how many minutes to forecast in the future",
93
- )
94
- history_minutes: int = Field(
95
- ...,
96
- ge=0,
97
- description="how many historic minutes to use",
98
- )
99
-
100
- @field_validator("forecast_minutes")
101
- def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
102
- if v % values.data["time_resolution_minutes"] != 0:
103
- message = "Forecast duration must be divisible by time resolution"
104
- logger.error(message)
105
- raise Exception(message)
106
- return v
107
-
108
- @field_validator("history_minutes")
109
- def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
110
- if v % values.data["time_resolution_minutes"] != 0:
111
- message = "History duration must be divisible by time resolution"
112
- logger.error(message)
113
- raise Exception(message)
114
- return v
115
-
116
-
117
119
  class SpatialWindowMixin(Base):
118
120
  """Mixin class, to add path and image size"""
119
121
 
120
122
  image_size_pixels_height: int = Field(
121
123
  ...,
124
+ ge=0,
122
125
  description="The number of pixels of the height of the region of interest",
123
126
  )
124
127
 
125
128
  image_size_pixels_width: int = Field(
126
129
  ...,
130
+ ge=0,
127
131
  description="The number of pixels of the width of the region of interest",
128
132
  )
129
133
 
@@ -140,10 +144,6 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
140
144
  ..., description="the satellite channels that are used"
141
145
  )
142
146
 
143
- live_delay_minutes: int = Field(
144
- ..., description="The expected delay in minutes of the satellite data"
145
- )
146
-
147
147
 
148
148
  # noinspection PyMethodParameters
149
149
  class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
@@ -169,6 +169,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
169
169
  " the maximum forecast horizon of the NWP and the requested forecast length.",
170
170
  )
171
171
 
172
+
172
173
  @field_validator("provider")
173
174
  def validate_provider(cls, v: str) -> str:
174
175
  """Validate 'provider'"""
@@ -227,11 +228,10 @@ class Site(TimeWindowMixin, DropoutMixin):
227
228
  # TODO validate the csv for metadata
228
229
 
229
230
 
231
+
230
232
  # noinspection PyPep8Naming
231
233
  class InputData(Base):
232
- """
233
- Input data model.
234
- """
234
+ """Input data model"""
235
235
 
236
236
  satellite: Optional[Satellite] = None
237
237
  nwp: Optional[MultiNWP] = None
@@ -63,16 +63,16 @@ def find_contiguous_time_periods(
63
63
 
64
64
  def trim_contiguous_time_periods(
65
65
  contiguous_time_periods: pd.DataFrame,
66
- history_duration: pd.Timedelta,
67
- forecast_duration: pd.Timedelta,
66
+ interval_start: pd.Timedelta,
67
+ interval_end: pd.Timedelta,
68
68
  ) -> pd.DataFrame:
69
69
  """Trim the contiguous time periods to allow for history and forecast durations.
70
70
 
71
71
  Args:
72
72
  contiguous_time_periods: DataFrame where each row represents a single time period. The
73
73
  DataFrame must have `start_dt` and `end_dt` columns.
74
- history_duration: Length of the historical slice used for a sample
75
- forecast_duration: Length of the forecast slice used for a sample
74
+ interval_start: The start of the interval with respect to t0
75
+ interval_end: The end of the interval with respect to t0
76
76
 
77
77
 
78
78
  Returns:
@@ -80,8 +80,8 @@ def trim_contiguous_time_periods(
80
80
  """
81
81
  contiguous_time_periods = contiguous_time_periods.copy()
82
82
 
83
- contiguous_time_periods["start_dt"] += history_duration
84
- contiguous_time_periods["end_dt"] -= forecast_duration
83
+ contiguous_time_periods["start_dt"] -= interval_start
84
+ contiguous_time_periods["end_dt"] -= interval_end
85
85
 
86
86
  valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
87
87
  contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
@@ -92,16 +92,16 @@ def trim_contiguous_time_periods(
92
92
 
93
93
  def find_contiguous_t0_periods(
94
94
  datetimes: pd.DatetimeIndex,
95
- history_duration: pd.Timedelta,
96
- forecast_duration: pd.Timedelta,
95
+ interval_start: pd.Timedelta,
96
+ interval_end: pd.Timedelta,
97
97
  sample_period_duration: pd.Timedelta,
98
98
  ) -> pd.DataFrame:
99
99
  """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
100
100
 
101
101
  Args:
102
102
  datetimes: pd.DatetimeIndex. Must be sorted.
103
- history_duration: Length of the historical slice used for each sample
104
- forecast_duration: Length of the forecast slice used for each sample
103
+ interval_start: The start of the interval with respect to t0
104
+ interval_end: The end of the interval with respect to t0
105
105
  sample_period_duration: The sample frequency of the timeseries
106
106
 
107
107
 
@@ -109,7 +109,7 @@ def find_contiguous_t0_periods(
109
109
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
110
110
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
111
111
  """
112
- total_duration = history_duration + forecast_duration
112
+ total_duration = interval_end - interval_start
113
113
 
114
114
  contiguous_time_periods = find_contiguous_time_periods(
115
115
  datetimes=datetimes,
@@ -119,8 +119,8 @@ def find_contiguous_t0_periods(
119
119
 
120
120
  contiguous_t0_periods = trim_contiguous_time_periods(
121
121
  contiguous_time_periods=contiguous_time_periods,
122
- history_duration=history_duration,
123
- forecast_duration=forecast_duration,
122
+ interval_start=interval_start,
123
+ interval_end=interval_end,
124
124
  )
125
125
 
126
126
  assert len(contiguous_t0_periods) > 0
@@ -128,92 +128,57 @@ def find_contiguous_t0_periods(
128
128
  return contiguous_t0_periods
129
129
 
130
130
 
131
- def _find_contiguous_t0_periods_nwp(
132
- ds,
133
- history_duration: pd.Timedelta,
134
- forecast_duration: pd.Timedelta,
135
- max_staleness: pd.Timedelta | None = None,
136
- max_dropout: pd.Timedelta = pd.Timedelta(0),
137
- time_dim: str = "init_time_utc",
138
- end_buffer: pd.Timedelta = pd.Timedelta(0),
139
- ):
140
-
141
- assert "step" in ds.coords
142
- # It is possible to use up to this amount of max staleness for the dataset and slice
143
- # required
144
- possible_max_staleness = (
145
- pd.Timedelta(ds["step"].max().item())
146
- - forecast_duration
147
- - end_buffer
148
- )
149
-
150
- # If max_staleness is set to None we set it based on the max step ahead of the input
151
- # forecast data
152
- if max_staleness is None:
153
- max_staleness = possible_max_staleness
154
- else:
155
- # Make sure the max acceptable staleness isn't longer than the max possible
156
- assert max_staleness <= possible_max_staleness
157
- max_staleness = max_staleness
158
-
159
- contiguous_time_periods = find_contiguous_t0_periods_nwp(
160
- datetimes=pd.DatetimeIndex(ds[time_dim]),
161
- history_duration=history_duration,
162
- max_staleness=max_staleness,
163
- max_dropout=max_dropout,
164
- )
165
- return contiguous_time_periods
166
-
167
-
168
-
169
131
  def find_contiguous_t0_periods_nwp(
170
- datetimes: pd.DatetimeIndex,
171
- history_duration: pd.Timedelta,
132
+ init_times: pd.DatetimeIndex,
133
+ interval_start: pd.Timedelta,
172
134
  max_staleness: pd.Timedelta,
173
135
  max_dropout: pd.Timedelta = pd.Timedelta(0),
136
+ first_forecast_step: pd.Timedelta = pd.Timedelta(0),
137
+
174
138
  ) -> pd.DataFrame:
175
139
  """Get all time periods from the NWP init times which are valid as t0 datetimes.
176
140
 
177
141
  Args:
178
- datetimes: Sorted pd.DatetimeIndex
179
- history_duration: Length of the historical slice used for a sample
180
- max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
181
- forecast. Each init time will only be used up to this t0 time regardless of the forecast
182
- valid time.
142
+ init_times: The initialisation times of the available forecasts
143
+ interval_start: The start of the desired data interval with respect to t0
144
+ max_staleness: Up to how long after an init time are we willing to use the forecast. Each
145
+ init time will only be used up to this t0 time regardless of the forecast valid time.
183
146
  max_dropout: What is the maximum amount of dropout that will be used. This must be <=
184
147
  max_staleness.
148
+ first_forecast_step: The timedelta of the first step of the forecast. By default we assume
149
+ the first valid time of the forecast is the same as its init time.
185
150
 
186
151
  Returns:
187
152
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
188
153
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
189
154
  """
190
155
  # Sanity checks.
191
- assert len(datetimes) > 0
192
- assert datetimes.is_monotonic_increasing
193
- assert datetimes.is_unique
194
- assert history_duration >= pd.Timedelta(0)
156
+ assert len(init_times) > 0
157
+ assert init_times.is_monotonic_increasing
158
+ assert init_times.is_unique
195
159
  assert max_staleness >= pd.Timedelta(0)
196
- assert max_dropout <= max_staleness
160
+ assert pd.Timedelta(0) <= max_dropout <= max_staleness
197
161
 
198
- hist_drop_buffer = max(history_duration, max_dropout)
162
+ hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
199
163
 
200
164
  # Store contiguous periods
201
165
  contiguous_periods = []
202
166
 
203
- # Start first period allowing for history slice and max dropout
204
- start_this_period = datetimes[0] + hist_drop_buffer
167
+ # Begin the first period allowing for the time to the first_forecast_step, the length of the
168
+ # interval sampled from before t0, and the dropout
169
+ start_this_period = init_times[0] + hist_drop_buffer
205
170
 
206
171
  # The first forecast is valid up to the max staleness
207
- end_this_period = datetimes[0] + max_staleness
208
-
209
- for dt_init in datetimes[1:]:
210
- # If the previous init time becomes stale before the next init becomes valid whilst also
211
- # considering dropout - then the contiguous period breaks, and new starts with considering
212
- # dropout and history duration
213
- if end_this_period < dt_init + max_dropout:
172
+ end_this_period = init_times[0] + max_staleness
173
+
174
+ for dt_init in init_times[1:]:
175
+ # If the previous init time becomes stale before the next init becomes valid (whilst also
176
+ # considering dropout) then the contiguous period breaks
177
+ # Else if the previous init time becomes stale before the fist step of the next forecast
178
+ # then this also causes a break in the contiguous period
179
+ if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
214
180
  contiguous_periods.append([start_this_period, end_this_period])
215
-
216
- # And start a new period
181
+ # The new period begins with the same conditions as the first period
217
182
  start_this_period = dt_init + hist_drop_buffer
218
183
  end_this_period = dt_init + max_staleness
219
184
 
@@ -39,23 +39,14 @@ def _sel_fillinterp(
39
39
  def select_time_slice(
40
40
  ds: xr.DataArray,
41
41
  t0: pd.Timestamp,
42
+ interval_start: pd.Timedelta,
43
+ interval_end: pd.Timedelta,
42
44
  sample_period_duration: pd.Timedelta,
43
- history_duration: pd.Timedelta | None = None,
44
- forecast_duration: pd.Timedelta | None = None,
45
- interval_start: pd.Timedelta | None = None,
46
- interval_end: pd.Timedelta | None = None,
47
45
  fill_selection: bool = False,
48
46
  max_steps_gap: int = 0,
49
47
  ):
50
48
  """Select a time slice from a Dataset or DataArray."""
51
- used_duration = history_duration is not None and forecast_duration is not None
52
- used_intervals = interval_start is not None and interval_end is not None
53
- assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
54
49
  assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
55
-
56
- if used_duration:
57
- interval_start = - history_duration
58
- interval_end = forecast_duration
59
50
 
60
51
  if fill_selection and max_steps_gap == 0:
61
52
  _sel = _sel_fillnan
@@ -75,11 +66,11 @@ def select_time_slice(
75
66
 
76
67
 
77
68
  def select_time_slice_nwp(
78
- ds: xr.DataArray,
69
+ da: xr.DataArray,
79
70
  t0: pd.Timestamp,
71
+ interval_start: pd.Timedelta,
72
+ interval_end: pd.Timedelta,
80
73
  sample_period_duration: pd.Timedelta,
81
- history_duration: pd.Timedelta,
82
- forecast_duration: pd.Timedelta,
83
74
  dropout_timedeltas: list[pd.Timedelta] | None = None,
84
75
  dropout_frac: float | None = 0,
85
76
  accum_channels: list[str] = [],
@@ -92,31 +83,31 @@ def select_time_slice_nwp(
92
83
  ), "dropout timedeltas must be negative"
93
84
  assert len(dropout_timedeltas) >= 1
94
85
  assert 0 <= dropout_frac <= 1
95
- _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
86
+ consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
96
87
 
97
88
 
98
89
  # The accumatation and non-accumulation channels
99
90
  accum_channels = np.intersect1d(
100
- ds[channel_dim_name].values, accum_channels
91
+ da[channel_dim_name].values, accum_channels
101
92
  )
102
93
  non_accum_channels = np.setdiff1d(
103
- ds[channel_dim_name].values, accum_channels
94
+ da[channel_dim_name].values, accum_channels
104
95
  )
105
96
 
106
- start_dt = (t0 - history_duration).ceil(sample_period_duration)
107
- end_dt = (t0 + forecast_duration).ceil(sample_period_duration)
97
+ start_dt = (t0 + interval_start).ceil(sample_period_duration)
98
+ end_dt = (t0 + interval_end).ceil(sample_period_duration)
108
99
 
109
100
  target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
110
101
 
111
102
  # Maybe apply NWP dropout
112
- if _consider_dropout and (np.random.uniform() < dropout_frac):
103
+ if consider_dropout and (np.random.uniform() < dropout_frac):
113
104
  dt = np.random.choice(dropout_timedeltas)
114
105
  t0_available = t0 + dt
115
106
  else:
116
107
  t0_available = t0
117
108
 
118
109
  # Forecasts made up to and including t0
119
- available_init_times = ds.init_time_utc.sel(
110
+ available_init_times = da.init_time_utc.sel(
120
111
  init_time_utc=slice(None, t0_available)
121
112
  )
122
113
 
@@ -139,7 +130,7 @@ def select_time_slice_nwp(
139
130
  step_indexer = xr.DataArray(steps, coords=coords)
140
131
 
141
132
  if len(accum_channels) == 0:
142
- xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer)
133
+ da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
143
134
 
144
135
  else:
145
136
  # First minimise the size of the dataset we are diffing
@@ -149,7 +140,7 @@ def select_time_slice_nwp(
149
140
  min_step = min(steps)
150
141
  max_step = max(steps) + sample_period_duration
151
142
 
152
- xr_min = ds.sel(
143
+ da_min = da.sel(
153
144
  {
154
145
  "init_time_utc": unique_init_times,
155
146
  "step": slice(min_step, max_step),
@@ -157,28 +148,28 @@ def select_time_slice_nwp(
157
148
  )
158
149
 
159
150
  # Slice out the data which does not need to be diffed
160
- xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels})
161
- xr_sel_non_accum = xr_non_accum.sel(
151
+ da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
152
+ da_sel_non_accum = da_non_accum.sel(
162
153
  step=step_indexer, init_time_utc=init_time_indexer
163
154
  )
164
155
 
165
156
  # Slice out the channels which need to be diffed
166
- xr_accum = xr_min.sel({channel_dim_name: accum_channels})
157
+ da_accum = da_min.sel({channel_dim_name: accum_channels})
167
158
 
168
159
  # Take the diff and slice requested data
169
- xr_accum = xr_accum.diff(dim="step", label="lower")
170
- xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
160
+ da_accum = da_accum.diff(dim="step", label="lower")
161
+ da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
171
162
 
172
163
  # Join diffed and non-diffed variables
173
- xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name)
164
+ da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
174
165
 
175
166
  # Reorder the variable back to the original order
176
- xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values})
167
+ da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
177
168
 
178
169
  # Rename the diffed channels
179
- xr_sel[channel_dim_name] = [
170
+ da_sel[channel_dim_name] = [
180
171
  f"diff_{v}" if v in accum_channels else v
181
- for v in xr_sel[channel_dim_name].values
172
+ for v in da_sel[channel_dim_name].values
182
173
  ]
183
174
 
184
- return xr_sel
175
+ return da_sel
@@ -4,7 +4,7 @@ import pandas as pd
4
4
  from ocf_data_sampler.config import Configuration
5
5
  from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
6
6
  from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
7
- from ocf_data_sampler.time_functions import minutes
7
+ from ocf_data_sampler.utils import minutes
8
8
 
9
9
 
10
10
  def slice_datasets_by_time(
@@ -23,19 +23,19 @@ def slice_datasets_by_time(
23
23
  sliced_datasets_dict = {}
24
24
 
25
25
  if "nwp" in datasets_dict:
26
-
26
+
27
27
  sliced_datasets_dict["nwp"] = {}
28
-
28
+
29
29
  for nwp_key, da_nwp in datasets_dict["nwp"].items():
30
-
30
+
31
31
  nwp_config = config.input_data.nwp[nwp_key]
32
32
 
33
33
  sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
34
34
  da_nwp,
35
35
  t0,
36
36
  sample_period_duration=minutes(nwp_config.time_resolution_minutes),
37
- history_duration=minutes(nwp_config.history_minutes),
38
- forecast_duration=minutes(nwp_config.forecast_minutes),
37
+ interval_start=minutes(nwp_config.interval_start_minutes),
38
+ interval_end=minutes(nwp_config.interval_end_minutes),
39
39
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
40
40
  dropout_frac=nwp_config.dropout_fraction,
41
41
  accum_channels=nwp_config.accum_channels,
@@ -49,8 +49,8 @@ def slice_datasets_by_time(
49
49
  datasets_dict["sat"],
50
50
  t0,
51
51
  sample_period_duration=minutes(sat_config.time_resolution_minutes),
52
- interval_start=minutes(-sat_config.history_minutes),
53
- interval_end=minutes(-sat_config.live_delay_minutes),
52
+ interval_start=minutes(sat_config.interval_start_minutes),
53
+ interval_end=minutes(sat_config.interval_end_minutes),
54
54
  max_steps_gap=2,
55
55
  )
56
56
 
@@ -74,15 +74,15 @@ def slice_datasets_by_time(
74
74
  datasets_dict["gsp"],
75
75
  t0,
76
76
  sample_period_duration=minutes(gsp_config.time_resolution_minutes),
77
- interval_start=minutes(30),
78
- interval_end=minutes(gsp_config.forecast_minutes),
77
+ interval_start=minutes(gsp_config.time_resolution_minutes),
78
+ interval_end=minutes(gsp_config.interval_end_minutes),
79
79
  )
80
-
80
+
81
81
  sliced_datasets_dict["gsp"] = select_time_slice(
82
82
  datasets_dict["gsp"],
83
83
  t0,
84
84
  sample_period_duration=minutes(gsp_config.time_resolution_minutes),
85
- interval_start=-minutes(gsp_config.history_minutes),
85
+ interval_start=minutes(gsp_config.interval_start_minutes),
86
86
  interval_end=minutes(0),
87
87
  )
88
88
 
@@ -94,9 +94,10 @@ def slice_datasets_by_time(
94
94
  )
95
95
 
96
96
  sliced_datasets_dict["gsp"] = apply_dropout_time(
97
- sliced_datasets_dict["gsp"], gsp_dropout_time
97
+ sliced_datasets_dict["gsp"],
98
+ gsp_dropout_time
98
99
  )
99
-
100
+
100
101
  if "site" in datasets_dict:
101
102
  site_config = config.input_data.site
102
103
 
@@ -104,8 +105,8 @@ def slice_datasets_by_time(
104
105
  datasets_dict["site"],
105
106
  t0,
106
107
  sample_period_duration=minutes(site_config.time_resolution_minutes),
107
- interval_start=-minutes(site_config.history_minutes),
108
- interval_end=minutes(site_config.forecast_minutes),
108
+ interval_start=minutes(site_config.interval_start_minutes),
109
+ interval_end=minutes(site_config.interval_end_minutes),
109
110
  )
110
111
 
111
112
  # Randomly sample dropout
@@ -15,7 +15,7 @@ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
15
15
  from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
16
16
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
17
17
  from ocf_data_sampler.select.location import Location
18
- from ocf_data_sampler.time_functions import minutes
18
+ from ocf_data_sampler.utils import minutes
19
19
 
20
20
 
21
21
  def process_and_combine_datasets(
@@ -23,7 +23,7 @@ def process_and_combine_datasets(
23
23
  config: Configuration,
24
24
  t0: pd.Timestamp,
25
25
  location: Location,
26
- sun_position_key: str = 'gsp'
26
+ target_key: str = 'gsp'
27
27
  ) -> dict:
28
28
  """Normalize and convert data to numpy arrays"""
29
29
 
@@ -58,7 +58,8 @@ def process_and_combine_datasets(
58
58
 
59
59
  numpy_modalities.append(
60
60
  convert_gsp_to_numpy_batch(
61
- da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes
61
+ da_gsp,
62
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
62
63
  )
63
64
  )
64
65
 
@@ -80,34 +81,32 @@ def process_and_combine_datasets(
80
81
 
81
82
  numpy_modalities.append(
82
83
  convert_site_to_numpy_batch(
83
- da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes
84
+ da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
84
85
  )
85
86
  )
86
87
 
87
- if sun_position_key == 'gsp':
88
+ if target_key == 'gsp':
88
89
  # Make sun coords NumpyBatch
89
90
  datetimes = pd.date_range(
90
- t0 - minutes(gsp_config.history_minutes),
91
- t0 + minutes(gsp_config.forecast_minutes),
91
+ t0+minutes(gsp_config.interval_start_minutes),
92
+ t0+minutes(gsp_config.interval_end_minutes),
92
93
  freq=minutes(gsp_config.time_resolution_minutes),
93
94
  )
94
95
 
95
96
  lon, lat = osgb_to_lon_lat(location.x, location.y)
96
- key_prefix = "gsp"
97
97
 
98
- elif sun_position_key == 'site':
98
+ elif target_key == 'site':
99
99
  # Make sun coords NumpyBatch
100
100
  datetimes = pd.date_range(
101
- t0 - minutes(site_config.history_minutes),
102
- t0 + minutes(site_config.forecast_minutes),
101
+ t0+minutes(site_config.interval_start_minutes),
102
+ t0+minutes(site_config.interval_end_minutes),
103
103
  freq=minutes(site_config.time_resolution_minutes),
104
104
  )
105
105
 
106
106
  lon, lat = location.x, location.y
107
- key_prefix = "site"
108
107
 
109
108
  numpy_modalities.append(
110
- make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix)
109
+ make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
111
110
  )
112
111
 
113
112
  # Combine all the modalities and fill NaNs
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset
9
9
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
10
10
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
11
  from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
12
- from ocf_data_sampler.time_functions import minutes
12
+ from ocf_data_sampler.utils import minutes
13
13
  from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
14
14
  from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
15
15