ocf-data-sampler 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/__init__.py +1 -0
- ocf_data_sampler/data/uk_gsp_locations.csv +319 -0
- ocf_data_sampler/numpy_batch/__init__.py +7 -0
- ocf_data_sampler/numpy_batch/gsp.py +23 -0
- ocf_data_sampler/numpy_batch/nwp.py +33 -0
- ocf_data_sampler/numpy_batch/satellite.py +23 -0
- ocf_data_sampler/numpy_batch/sun_position.py +66 -0
- ocf_data_sampler/select/__init__.py +1 -0
- ocf_data_sampler/select/dropout.py +38 -0
- ocf_data_sampler/select/fill_time_periods.py +11 -0
- ocf_data_sampler/select/find_contiguous_time_periods.py +301 -0
- ocf_data_sampler/select/select_spatial_slice.py +360 -0
- ocf_data_sampler/select/select_time_slice.py +184 -0
- ocf_data_sampler/torch_datasets/__init__.py +1 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +538 -0
- {ocf_data_sampler-0.0.7.dist-info → ocf_data_sampler-0.0.9.dist-info}/METADATA +1 -1
- ocf_data_sampler-0.0.9.dist-info/RECORD +22 -0
- ocf_data_sampler-0.0.9.dist-info/top_level.txt +2 -0
- ocf_data_sampler-0.0.7.dist-info/RECORD +0 -7
- ocf_data_sampler-0.0.7.dist-info/top_level.txt +0 -1
- {ocf_data_sampler-0.0.7.dist-info → ocf_data_sampler-0.0.9.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.7.dist-info → ocf_data_sampler-0.0.9.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
"""Torch dataset for PVNet"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import xarray as xr
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from ocf_data_sampler.load.gsp import open_gsp
|
|
10
|
+
from ocf_data_sampler.load.nwp import open_nwp
|
|
11
|
+
from ocf_data_sampler.load.satellite import open_sat_data
|
|
12
|
+
|
|
13
|
+
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
14
|
+
find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
|
|
15
|
+
intersection_of_multiple_dataframes_of_periods,
|
|
16
|
+
)
|
|
17
|
+
from ocf_data_sampler.select.fill_time_periods import fill_time_periods
|
|
18
|
+
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
19
|
+
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
20
|
+
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
|
|
21
|
+
|
|
22
|
+
from ocf_data_sampler.numpy_batch import (
|
|
23
|
+
convert_gsp_to_numpy_batch,
|
|
24
|
+
convert_nwp_to_numpy_batch,
|
|
25
|
+
convert_satellite_to_numpy_batch,
|
|
26
|
+
make_sun_position_numpy_batch,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
from ocf_datapipes.config.model import Configuration
|
|
31
|
+
from ocf_datapipes.config.load import load_yaml_configuration
|
|
32
|
+
from ocf_datapipes.batch import BatchKey, NumpyBatch
|
|
33
|
+
|
|
34
|
+
from ocf_datapipes.utils.location import Location
|
|
35
|
+
from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
|
|
36
|
+
|
|
37
|
+
from ocf_datapipes.utils.consts import (
|
|
38
|
+
NWP_MEANS,
|
|
39
|
+
NWP_STDS,
|
|
40
|
+
RSS_MEAN,
|
|
41
|
+
RSS_STD,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
from ocf_datapipes.training.common import concat_xr_time_utc, normalize_gsp
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
xr.set_options(keep_attrs=True)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def minutes(minutes: list[float]):
|
|
53
|
+
"""Timedelta minutes
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
m: minutes
|
|
57
|
+
"""
|
|
58
|
+
return pd.to_timedelta(minutes, unit="m")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]:
|
|
62
|
+
"""Construct dictionary of all of the input data sources
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
config: Configuration file
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
in_config = config.input_data
|
|
69
|
+
|
|
70
|
+
datasets_dict = {}
|
|
71
|
+
|
|
72
|
+
# We always assume GSP will be included
|
|
73
|
+
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)
|
|
74
|
+
|
|
75
|
+
# Remove national GSP
|
|
76
|
+
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
|
|
77
|
+
|
|
78
|
+
# Load NWP data if in config
|
|
79
|
+
if in_config.nwp:
|
|
80
|
+
|
|
81
|
+
datasets_dict["nwp"] = {}
|
|
82
|
+
for nwp_source, nwp_config in in_config.nwp.items():
|
|
83
|
+
|
|
84
|
+
da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
|
|
85
|
+
|
|
86
|
+
da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
|
|
87
|
+
|
|
88
|
+
datasets_dict["nwp"][nwp_source] = da_nwp
|
|
89
|
+
|
|
90
|
+
# Load satellite data if in config
|
|
91
|
+
if in_config.satellite:
|
|
92
|
+
sat_config = config.input_data.satellite
|
|
93
|
+
|
|
94
|
+
da_sat = open_sat_data(sat_config.satellite_zarr_path)
|
|
95
|
+
|
|
96
|
+
da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
|
|
97
|
+
|
|
98
|
+
datasets_dict["sat"] = da_sat
|
|
99
|
+
|
|
100
|
+
return datasets_dict
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def find_valid_t0_times(
|
|
105
|
+
datasets_dict: dict,
|
|
106
|
+
config: Configuration,
|
|
107
|
+
):
|
|
108
|
+
"""Find the t0 times where all of the requested input data is available
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
datasets_dict: A dictionary of input datasets
|
|
112
|
+
config: Configuration file
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
|
|
116
|
+
|
|
117
|
+
contiguous_time_periods = [] # Used to store contiguous time periods from each data source
|
|
118
|
+
|
|
119
|
+
if "nwp" in datasets_dict:
|
|
120
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
121
|
+
|
|
122
|
+
da = datasets_dict["nwp"][nwp_key]
|
|
123
|
+
|
|
124
|
+
if nwp_config.dropout_timedeltas_minutes is None:
|
|
125
|
+
max_dropout = minutes(0)
|
|
126
|
+
else:
|
|
127
|
+
max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
|
|
128
|
+
|
|
129
|
+
if nwp_config.max_staleness_minutes is None:
|
|
130
|
+
max_staleness = None
|
|
131
|
+
else:
|
|
132
|
+
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
133
|
+
|
|
134
|
+
# The last step of the forecast is lost if we have to diff channels
|
|
135
|
+
if len(nwp_config.nwp_accum_channels) > 0:
|
|
136
|
+
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
137
|
+
else:
|
|
138
|
+
end_buffer = minutes(0)
|
|
139
|
+
|
|
140
|
+
# This is the max staleness we can use considering the max step of the input data
|
|
141
|
+
max_possible_staleness = (
|
|
142
|
+
pd.Timedelta(da["step"].max().item())
|
|
143
|
+
- minutes(nwp_config.forecast_minutes)
|
|
144
|
+
- end_buffer
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Default to use max possible staleness unless specified in config
|
|
148
|
+
if max_staleness is None:
|
|
149
|
+
max_staleness = max_possible_staleness
|
|
150
|
+
else:
|
|
151
|
+
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
152
|
+
assert max_staleness <= max_possible_staleness
|
|
153
|
+
|
|
154
|
+
time_periods = find_contiguous_t0_periods_nwp(
|
|
155
|
+
datetimes=pd.DatetimeIndex(da["init_time_utc"]),
|
|
156
|
+
history_duration=minutes(nwp_config.history_minutes),
|
|
157
|
+
max_staleness=max_staleness,
|
|
158
|
+
max_dropout=max_dropout,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
contiguous_time_periods.append(time_periods)
|
|
162
|
+
|
|
163
|
+
if "sat" in datasets_dict:
|
|
164
|
+
sat_config = config.input_data.satellite
|
|
165
|
+
|
|
166
|
+
time_periods = find_contiguous_t0_periods(
|
|
167
|
+
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
168
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
169
|
+
history_duration=minutes(sat_config.history_minutes),
|
|
170
|
+
forecast_duration=minutes(sat_config.forecast_minutes),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
contiguous_time_periods.append(time_periods)
|
|
174
|
+
|
|
175
|
+
# GSP always assumed to be in data
|
|
176
|
+
gsp_config = config.input_data.gsp
|
|
177
|
+
|
|
178
|
+
time_periods = find_contiguous_t0_periods(
|
|
179
|
+
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
180
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
181
|
+
history_duration=minutes(gsp_config.history_minutes),
|
|
182
|
+
forecast_duration=minutes(gsp_config.forecast_minutes),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
contiguous_time_periods.append(time_periods)
|
|
186
|
+
|
|
187
|
+
# Find joint overlapping contiguous time periods
|
|
188
|
+
if len(contiguous_time_periods) > 1:
|
|
189
|
+
valid_time_periods = intersection_of_multiple_dataframes_of_periods(
|
|
190
|
+
contiguous_time_periods
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
valid_time_periods = contiguous_time_periods[0]
|
|
194
|
+
|
|
195
|
+
# Fill out the contiguous time periods to get the t0 times
|
|
196
|
+
valid_t0_times = fill_time_periods(
|
|
197
|
+
valid_time_periods,
|
|
198
|
+
freq=minutes(config.input_data.gsp.time_resolution_minutes)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return valid_t0_times
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def slice_datasets_by_space(
|
|
205
|
+
datasets_dict: dict,
|
|
206
|
+
location: Location,
|
|
207
|
+
config: Configuration,
|
|
208
|
+
) -> dict:
|
|
209
|
+
"""Slice a dictionaries of input data sources around a given location
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
datasets_dict: Dictionary of the input data sources
|
|
213
|
+
location: The location to sample around
|
|
214
|
+
config: Configuration object.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
|
|
218
|
+
|
|
219
|
+
sliced_datasets_dict = {}
|
|
220
|
+
|
|
221
|
+
if "nwp" in datasets_dict:
|
|
222
|
+
|
|
223
|
+
sliced_datasets_dict["nwp"] = {}
|
|
224
|
+
|
|
225
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
226
|
+
|
|
227
|
+
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
228
|
+
datasets_dict["nwp"][nwp_key],
|
|
229
|
+
location,
|
|
230
|
+
height_pixels=nwp_config.nwp_image_size_pixels_height,
|
|
231
|
+
width_pixels=nwp_config.nwp_image_size_pixels_width,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if "sat" in datasets_dict:
|
|
235
|
+
sat_config = config.input_data.satellite
|
|
236
|
+
|
|
237
|
+
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
|
|
238
|
+
datasets_dict["sat"],
|
|
239
|
+
location,
|
|
240
|
+
height_pixels=sat_config.satellite_image_size_pixels_height,
|
|
241
|
+
width_pixels=sat_config.satellite_image_size_pixels_width,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# GSP always assumed to be in data
|
|
245
|
+
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
|
|
246
|
+
|
|
247
|
+
return sliced_datasets_dict
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def slice_datasets_by_time(
|
|
251
|
+
datasets_dict: dict,
|
|
252
|
+
t0: pd.Timedelta,
|
|
253
|
+
config: Configuration,
|
|
254
|
+
) -> dict:
|
|
255
|
+
"""Slice a dictionaries of input data sources around a given t0 time
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
datasets_dict: Dictionary of the input data sources
|
|
259
|
+
t0: The init-time
|
|
260
|
+
config: Configuration object.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
sliced_datasets_dict = {}
|
|
264
|
+
|
|
265
|
+
if "nwp" in datasets_dict:
|
|
266
|
+
|
|
267
|
+
sliced_datasets_dict["nwp"] = {}
|
|
268
|
+
|
|
269
|
+
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
270
|
+
|
|
271
|
+
nwp_config = config.input_data.nwp[nwp_key]
|
|
272
|
+
|
|
273
|
+
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
274
|
+
da_nwp,
|
|
275
|
+
t0,
|
|
276
|
+
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
277
|
+
history_duration=minutes(nwp_config.history_minutes),
|
|
278
|
+
forecast_duration=minutes(nwp_config.forecast_minutes),
|
|
279
|
+
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
280
|
+
dropout_frac=nwp_config.dropout_fraction,
|
|
281
|
+
accum_channels=nwp_config.nwp_accum_channels,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if "sat" in datasets_dict:
|
|
285
|
+
|
|
286
|
+
sat_config = config.input_data.satellite
|
|
287
|
+
|
|
288
|
+
sliced_datasets_dict["sat"] = select_time_slice(
|
|
289
|
+
datasets_dict["sat"],
|
|
290
|
+
t0,
|
|
291
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
292
|
+
interval_start=minutes(-sat_config.history_minutes),
|
|
293
|
+
interval_end=minutes(-sat_config.live_delay_minutes),
|
|
294
|
+
max_steps_gap=2,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Randomly sample dropout
|
|
298
|
+
sat_dropout_time = draw_dropout_time(
|
|
299
|
+
t0,
|
|
300
|
+
dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
|
|
301
|
+
dropout_frac=sat_config.dropout_fraction,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Apply the dropout
|
|
305
|
+
sliced_datasets_dict["sat"] = apply_dropout_time(
|
|
306
|
+
sliced_datasets_dict["sat"],
|
|
307
|
+
sat_dropout_time,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# GSP always assumed to be included
|
|
311
|
+
gsp_config = config.input_data.gsp
|
|
312
|
+
|
|
313
|
+
sliced_datasets_dict["gsp_future"] = select_time_slice(
|
|
314
|
+
datasets_dict["gsp"],
|
|
315
|
+
t0,
|
|
316
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
317
|
+
interval_start=minutes(30),
|
|
318
|
+
interval_end=minutes(gsp_config.forecast_minutes),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
322
|
+
datasets_dict["gsp"],
|
|
323
|
+
t0,
|
|
324
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
325
|
+
interval_start=-minutes(gsp_config.history_minutes),
|
|
326
|
+
interval_end=minutes(0),
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Dropout on the GSP, but not the future GSP
|
|
330
|
+
gsp_dropout_time = draw_dropout_time(
|
|
331
|
+
t0,
|
|
332
|
+
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
333
|
+
dropout_frac=gsp_config.dropout_fraction,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
|
|
337
|
+
|
|
338
|
+
return sliced_datasets_dict
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
342
|
+
"""Merge a list of dictionaries into a single dictionary"""
|
|
343
|
+
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
344
|
+
combined_dict = {}
|
|
345
|
+
for d in list_of_dicts:
|
|
346
|
+
combined_dict.update(d)
|
|
347
|
+
return combined_dict
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def process_and_combine_datasets(
|
|
351
|
+
dataset_dict: dict,
|
|
352
|
+
config: Configuration,
|
|
353
|
+
t0: pd.Timedelta,
|
|
354
|
+
location: Location,
|
|
355
|
+
) -> NumpyBatch:
|
|
356
|
+
"""Normalize and convert data to numpy arrays"""
|
|
357
|
+
|
|
358
|
+
numpy_modalities = []
|
|
359
|
+
|
|
360
|
+
if "nwp" in dataset_dict:
|
|
361
|
+
|
|
362
|
+
nwp_numpy_modalities = dict()
|
|
363
|
+
|
|
364
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
365
|
+
# Standardise
|
|
366
|
+
provider = config.input_data.nwp[nwp_key].nwp_provider
|
|
367
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
368
|
+
# Convert to NumpyBatch
|
|
369
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
370
|
+
|
|
371
|
+
# Combine the NWPs into NumpyBatch
|
|
372
|
+
numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})
|
|
373
|
+
|
|
374
|
+
if "sat" in dataset_dict:
|
|
375
|
+
# Standardise
|
|
376
|
+
# TODO: Since satellite is in range 0-1 already, so we don't need to standardize
|
|
377
|
+
da_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD
|
|
378
|
+
# Convert to NumpyBatch
|
|
379
|
+
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
|
|
380
|
+
|
|
381
|
+
# GSP always assumed to be in data
|
|
382
|
+
gsp_config = config.input_data.gsp
|
|
383
|
+
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
|
|
384
|
+
da_gsp = normalize_gsp(da_gsp)
|
|
385
|
+
|
|
386
|
+
numpy_modalities.append(
|
|
387
|
+
convert_gsp_to_numpy_batch(
|
|
388
|
+
da_gsp,
|
|
389
|
+
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Make sun coords NumpyBatch
|
|
394
|
+
datetimes = pd.date_range(
|
|
395
|
+
t0-minutes(gsp_config.history_minutes),
|
|
396
|
+
t0+minutes(gsp_config.forecast_minutes),
|
|
397
|
+
freq=minutes(gsp_config.time_resolution_minutes),
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
401
|
+
|
|
402
|
+
numpy_modalities.append(make_sun_position_numpy_batch(datetimes, lon, lat))
|
|
403
|
+
|
|
404
|
+
# Combine all the modalities
|
|
405
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
406
|
+
|
|
407
|
+
return combined_sample
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def compute(xarray_dict: dict) -> dict:
|
|
411
|
+
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
412
|
+
for k, v in xarray_dict.items():
|
|
413
|
+
if isinstance(v, dict):
|
|
414
|
+
xarray_dict[k] = compute(v)
|
|
415
|
+
else:
|
|
416
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
417
|
+
return xarray_dict
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def get_locations(ga_gsp: xr.DataArray) -> list[Location]:
|
|
421
|
+
"""Get list of locations of GSP"""
|
|
422
|
+
locations = []
|
|
423
|
+
for gsp_id in ga_gsp.gsp_id.values:
|
|
424
|
+
da = ga_gsp.sel(gsp_id=gsp_id)
|
|
425
|
+
locations.append(
|
|
426
|
+
Location(
|
|
427
|
+
coordinate_system = "osgb",
|
|
428
|
+
x=da.x_osgb.item(),
|
|
429
|
+
y=da.y_osgb.item(),
|
|
430
|
+
id=gsp_id,
|
|
431
|
+
)
|
|
432
|
+
)
|
|
433
|
+
return locations
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class PVNetUKRegionalDataset(Dataset):
|
|
437
|
+
def __init__(
|
|
438
|
+
self,
|
|
439
|
+
config_filename: str,
|
|
440
|
+
start_time: str | None = None,
|
|
441
|
+
end_time: str| None = None,
|
|
442
|
+
):
|
|
443
|
+
"""A torch Dataset for creating PVNet UK GSP samples
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
config_filename: Path to the configuration file
|
|
447
|
+
start_time: Limit the init-times to be after this
|
|
448
|
+
end_time: Limit the init-times to be before this
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
config = load_yaml_configuration(config_filename)
|
|
452
|
+
|
|
453
|
+
datasets_dict = get_dataset_dict(config)
|
|
454
|
+
|
|
455
|
+
# Get t0 times where all input data is available
|
|
456
|
+
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
457
|
+
|
|
458
|
+
# Filter t0 times to given range
|
|
459
|
+
if start_time is not None:
|
|
460
|
+
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
461
|
+
|
|
462
|
+
if end_time is not None:
|
|
463
|
+
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
464
|
+
|
|
465
|
+
# Construct list of locations to sample from
|
|
466
|
+
locations = get_locations(datasets_dict["gsp"])
|
|
467
|
+
|
|
468
|
+
# Construct a lookup for locations - useful for users to construct sample by GSP ID
|
|
469
|
+
location_lookup = {loc.id: loc for loc in locations}
|
|
470
|
+
|
|
471
|
+
# Construct indices for sampling
|
|
472
|
+
t_index, loc_index = np.meshgrid(
|
|
473
|
+
np.arange(len(valid_t0_times)),
|
|
474
|
+
np.arange(len(locations)),
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
|
|
478
|
+
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
|
|
479
|
+
|
|
480
|
+
# Assign coords and indices to self
|
|
481
|
+
self.valid_t0_times = valid_t0_times
|
|
482
|
+
self.locations = locations
|
|
483
|
+
self.location_lookup = location_lookup
|
|
484
|
+
self.index_pairs = index_pairs
|
|
485
|
+
|
|
486
|
+
# Assign config and input data to self
|
|
487
|
+
self.datasets_dict = datasets_dict
|
|
488
|
+
self.config = config
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def __len__(self):
|
|
492
|
+
return len(self.index_pairs)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch:
|
|
496
|
+
"""Generate the PVNet sample for given coordinates
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
t0: init-time for sample
|
|
500
|
+
location: location for sample
|
|
501
|
+
"""
|
|
502
|
+
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
503
|
+
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
504
|
+
sample_dict = compute(sample_dict)
|
|
505
|
+
|
|
506
|
+
sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
|
|
507
|
+
|
|
508
|
+
return sample
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def __getitem__(self, idx):
|
|
512
|
+
|
|
513
|
+
# Get the coordinates of the sample
|
|
514
|
+
t_index, loc_index = self.index_pairs[idx]
|
|
515
|
+
location = self.locations[loc_index]
|
|
516
|
+
t0 = self.valid_t0_times[t_index]
|
|
517
|
+
|
|
518
|
+
# Generate the sample
|
|
519
|
+
return self._get_sample(t0, location)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
|
|
523
|
+
"""Generate a sample for the given coordinates.
|
|
524
|
+
|
|
525
|
+
Useful for users to generate samples by GSP ID.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
t0: init-time for sample
|
|
529
|
+
gsp_id: GSP ID
|
|
530
|
+
"""
|
|
531
|
+
# Check the user has asked for a sample which we have the data for
|
|
532
|
+
assert t0 in self.valid_t0_times
|
|
533
|
+
assert gsp_id in self.location_lookup
|
|
534
|
+
|
|
535
|
+
location = self.location_lookup[gsp_id]
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
return self._get_sample(t0, location)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
|
+
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
3
|
+
ocf_data_sampler/numpy_batch/__init__.py,sha256=mrtqwbGik5Zc9MYP5byfCTBm08wMtS2XnTsypC4fPMo,245
|
|
4
|
+
ocf_data_sampler/numpy_batch/gsp.py,sha256=EL0_cJJNyvkQQcOat9vFA61pF4lema3BP_vB4ZS788U,805
|
|
5
|
+
ocf_data_sampler/numpy_batch/nwp.py,sha256=Rv0yfDj902Z2oCwdlRjOs3Kh-F5Fgxjjylh99-lQ9ws,1105
|
|
6
|
+
ocf_data_sampler/numpy_batch/satellite.py,sha256=e6eoNmiiHtzZbDVtBolFzDuE3qwhHN6bL9H86emAUsk,732
|
|
7
|
+
ocf_data_sampler/numpy_batch/sun_position.py,sha256=UW6-WtjrKdCkcguolHUDSLhYFfarknQzzjlCX8YdEOM,1700
|
|
8
|
+
ocf_data_sampler/select/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
9
|
+
ocf_data_sampler/select/dropout.py,sha256=JYbjG5e8d48te7xj4I9pTWk43d6ksjGeyKFLSTuAOlY,1062
|
|
10
|
+
ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
|
|
11
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
|
|
12
|
+
ocf_data_sampler/select/select_spatial_slice.py,sha256=7BSzOFPMSBWpBWXSajWTfI8luUVsSgh4zN-rkr-AuUs,11470
|
|
13
|
+
ocf_data_sampler/select/select_time_slice.py,sha256=XuksC9N03c5rV9OeWtxjGuoGyeJJGy4JMJe3w7m6oaw,6654
|
|
14
|
+
ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
15
|
+
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=rVKFfoHqSfm4C-eOXiqi5GwBJdMewRMIikvpjEJXi1s,17477
|
|
16
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
tests/conftest.py,sha256=OcArgF60paroZQqoP7xExRBF34nEyMuXd7dS7hD6p3w,5393
|
|
18
|
+
ocf_data_sampler-0.0.9.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
19
|
+
ocf_data_sampler-0.0.9.dist-info/METADATA,sha256=ubu-StG7JD9xjbT5TuwkRjanPG0F7Vuf56Bo1AX2u2c,587
|
|
20
|
+
ocf_data_sampler-0.0.9.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
|
21
|
+
ocf_data_sampler-0.0.9.dist-info/top_level.txt,sha256=KaQn5qzkJGJP6hKWqsVAc9t0cMLjVvSTk8-kTrW79SA,23
|
|
22
|
+
ocf_data_sampler-0.0.9.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
tests/conftest.py,sha256=OcArgF60paroZQqoP7xExRBF34nEyMuXd7dS7hD6p3w,5393
|
|
3
|
-
ocf_data_sampler-0.0.7.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
4
|
-
ocf_data_sampler-0.0.7.dist-info/METADATA,sha256=ShgwFYETp1LErATZC4cmsHtLsRSwy4eT60i5kSS0rSI,587
|
|
5
|
-
ocf_data_sampler-0.0.7.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
|
6
|
-
ocf_data_sampler-0.0.7.dist-info/top_level.txt,sha256=EdW7283x-lr_cbuivW8Ij7ANAP-ZJ9sLtILQseEPxXg,6
|
|
7
|
-
ocf_data_sampler-0.0.7.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
tests
|
|
File without changes
|
|
File without changes
|