ocf-data-sampler 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/numpy_sample/__init__.py +2 -3
- ocf_data_sampler/numpy_sample/datetime_features.py +16 -25
- ocf_data_sampler/numpy_sample/site.py +1 -8
- ocf_data_sampler/torch_datasets/datasets/__init__.py +1 -5
- ocf_data_sampler/torch_datasets/datasets/site.py +95 -366
- ocf_data_sampler/torch_datasets/sample/site.py +21 -12
- {ocf_data_sampler-0.3.1.dist-info → ocf_data_sampler-0.5.0.dist-info}/METADATA +3 -2
- {ocf_data_sampler-0.3.1.dist-info → ocf_data_sampler-0.5.0.dist-info}/RECORD +10 -10
- {ocf_data_sampler-0.3.1.dist-info → ocf_data_sampler-0.5.0.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.3.1.dist-info → ocf_data_sampler-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""Conversion from Xarray to NumpySample"""
|
|
2
2
|
|
|
3
|
-
from .datetime_features import
|
|
3
|
+
from .datetime_features import encode_datetimes
|
|
4
4
|
from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
|
|
5
5
|
from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
|
|
6
6
|
from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
|
|
7
7
|
from .sun_position import make_sun_position_numpy_sample
|
|
8
|
-
from .site import convert_site_to_numpy_sample
|
|
9
|
-
|
|
8
|
+
from .site import convert_site_to_numpy_sample, SiteSampleKey
|
|
@@ -6,33 +6,24 @@ import pandas as pd
|
|
|
6
6
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
10
|
-
"""
|
|
9
|
+
def encode_datetimes(datetimes: pd.DatetimeIndex) -> NumpySample:
|
|
10
|
+
"""Creates dictionary of sin and cos datetime embeddings.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
|
-
|
|
13
|
+
datetimes: DatetimeIndex to create radian embeddings for
|
|
14
14
|
|
|
15
15
|
Returns:
|
|
16
|
-
|
|
16
|
+
Dictionary of datetime encodings
|
|
17
17
|
"""
|
|
18
|
-
day_of_year =
|
|
19
|
-
minute_of_day =
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
return
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
time_numpy_sample = {}
|
|
32
|
-
|
|
33
|
-
time_numpy_sample[key_prefix + "_date_sin"] = np.sin(date_in_pi)
|
|
34
|
-
time_numpy_sample[key_prefix + "_date_cos"] = np.cos(date_in_pi)
|
|
35
|
-
time_numpy_sample[key_prefix + "_time_sin"] = np.sin(time_in_pi)
|
|
36
|
-
time_numpy_sample[key_prefix + "_time_cos"] = np.cos(time_in_pi)
|
|
37
|
-
|
|
38
|
-
return time_numpy_sample
|
|
18
|
+
day_of_year = datetimes.dayofyear
|
|
19
|
+
minute_of_day = datetimes.minute + datetimes.hour * 60
|
|
20
|
+
|
|
21
|
+
time_in_radians = (2 * np.pi) * (minute_of_day / (24 * 60))
|
|
22
|
+
date_in_radians = (2 * np.pi) * (day_of_year / 365)
|
|
23
|
+
|
|
24
|
+
return {
|
|
25
|
+
"date_sin": np.sin(date_in_radians),
|
|
26
|
+
"date_cos": np.cos(date_in_radians),
|
|
27
|
+
"time_sin": np.sin(time_in_radians),
|
|
28
|
+
"time_cos": np.cos(time_in_radians),
|
|
29
|
+
}
|
|
@@ -13,10 +13,7 @@ class SiteSampleKey:
|
|
|
13
13
|
time_utc = "site_time_utc"
|
|
14
14
|
t0_idx = "site_t0_idx"
|
|
15
15
|
id = "site_id"
|
|
16
|
-
|
|
17
|
-
date_cos = "site_date_cos"
|
|
18
|
-
time_sin = "site_time_sin"
|
|
19
|
-
time_cos = "site_time_cos"
|
|
16
|
+
|
|
20
17
|
|
|
21
18
|
|
|
22
19
|
def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> NumpySample:
|
|
@@ -31,10 +28,6 @@ def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
|
|
|
31
28
|
SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
|
|
32
29
|
SiteSampleKey.time_utc: da["time_utc"].values.astype(float),
|
|
33
30
|
SiteSampleKey.id: da["site_id"].values,
|
|
34
|
-
SiteSampleKey.date_sin: da["date_sin"].values,
|
|
35
|
-
SiteSampleKey.date_cos: da["date_cos"].values,
|
|
36
|
-
SiteSampleKey.time_sin: da["time_sin"].values,
|
|
37
|
-
SiteSampleKey.time_cos: da["time_cos"].values,
|
|
38
31
|
}
|
|
39
32
|
|
|
40
33
|
if t0_idx is not None:
|
|
@@ -13,7 +13,7 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
13
13
|
convert_nwp_to_numpy_sample,
|
|
14
14
|
convert_satellite_to_numpy_sample,
|
|
15
15
|
convert_site_to_numpy_sample,
|
|
16
|
-
|
|
16
|
+
encode_datetimes,
|
|
17
17
|
make_sun_position_numpy_sample,
|
|
18
18
|
)
|
|
19
19
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
@@ -58,6 +58,96 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
|
|
|
58
58
|
|
|
59
59
|
return locations
|
|
60
60
|
|
|
61
|
+
def process_and_combine_datasets(
|
|
62
|
+
dataset_dict: dict,
|
|
63
|
+
config: Configuration,
|
|
64
|
+
t0: pd.Timestamp,
|
|
65
|
+
) -> NumpySample:
|
|
66
|
+
"""Normalise and convert data to numpy arrays.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dataset_dict: Dictionary of xarray datasets
|
|
70
|
+
config: Configuration object
|
|
71
|
+
t0: init-time for sample
|
|
72
|
+
"""
|
|
73
|
+
numpy_modalities = []
|
|
74
|
+
|
|
75
|
+
if "nwp" in dataset_dict:
|
|
76
|
+
nwp_numpy_modalities = {}
|
|
77
|
+
|
|
78
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
79
|
+
|
|
80
|
+
# Standardise and convert to NumpyBatch
|
|
81
|
+
|
|
82
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
83
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
84
|
+
)
|
|
85
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
86
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
90
|
+
|
|
91
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
92
|
+
|
|
93
|
+
# Combine the NWPs into NumpyBatch
|
|
94
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
95
|
+
|
|
96
|
+
if "sat" in dataset_dict:
|
|
97
|
+
da_sat = dataset_dict["sat"]
|
|
98
|
+
|
|
99
|
+
# Standardise and convert to NumpyBatch
|
|
100
|
+
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
101
|
+
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
102
|
+
|
|
103
|
+
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
104
|
+
|
|
105
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
106
|
+
|
|
107
|
+
if "site" in dataset_dict:
|
|
108
|
+
da_sites = dataset_dict["site"]
|
|
109
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
110
|
+
|
|
111
|
+
# Convert to NumpyBatch
|
|
112
|
+
numpy_modalities.append(
|
|
113
|
+
convert_site_to_numpy_sample(
|
|
114
|
+
da_sites,
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# add datetime features
|
|
119
|
+
datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
|
|
120
|
+
datetime_features = encode_datetimes(datetimes=datetimes)
|
|
121
|
+
|
|
122
|
+
numpy_modalities.append(datetime_features)
|
|
123
|
+
|
|
124
|
+
# Only add solar position if explicitly configured
|
|
125
|
+
if config.input_data.solar_position is not None:
|
|
126
|
+
solar_config = config.input_data.solar_position
|
|
127
|
+
|
|
128
|
+
# Create datetime range for solar position calculation
|
|
129
|
+
datetimes = pd.date_range(
|
|
130
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
131
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
132
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Calculate solar positions and add to modalities
|
|
137
|
+
numpy_modalities.append(
|
|
138
|
+
make_sun_position_numpy_sample(
|
|
139
|
+
datetimes,
|
|
140
|
+
da_sites.longitude.values,
|
|
141
|
+
da_sites.latitude.values,
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Combine all the modalities and fill NaNs
|
|
146
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
147
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
148
|
+
|
|
149
|
+
return combined_sample
|
|
150
|
+
|
|
61
151
|
|
|
62
152
|
class SitesDataset(Dataset):
|
|
63
153
|
"""A torch Dataset for creating PVNet Site samples."""
|
|
@@ -181,8 +271,9 @@ class SitesDataset(Dataset):
|
|
|
181
271
|
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
182
272
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
183
273
|
|
|
184
|
-
|
|
185
|
-
|
|
274
|
+
sample_dict = compute(sample_dict)
|
|
275
|
+
|
|
276
|
+
return process_and_combine_datasets(sample_dict, self.config, t0)
|
|
186
277
|
|
|
187
278
|
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
188
279
|
"""Generate a sample for a given site id and t0.
|
|
@@ -197,159 +288,6 @@ class SitesDataset(Dataset):
|
|
|
197
288
|
|
|
198
289
|
return self._get_sample(t0, location)
|
|
199
290
|
|
|
200
|
-
def process_and_combine_site_sample_dict(
|
|
201
|
-
self,
|
|
202
|
-
dataset_dict: dict,
|
|
203
|
-
t0: pd.Timestamp,
|
|
204
|
-
) -> xr.Dataset:
|
|
205
|
-
"""Normalize and combine data into a single xr Dataset.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
dataset_dict: dict containing sliced xr DataArrays
|
|
209
|
-
t0: The initial timestamp of the sample
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
xr.Dataset: A merged Dataset with nans filled in.
|
|
213
|
-
"""
|
|
214
|
-
data_arrays = []
|
|
215
|
-
|
|
216
|
-
if "nwp" in dataset_dict:
|
|
217
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
218
|
-
provider = self.config.input_data.nwp[nwp_key].provider
|
|
219
|
-
|
|
220
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
221
|
-
self.config.input_data.nwp[nwp_key].channel_means,
|
|
222
|
-
)
|
|
223
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
224
|
-
self.config.input_data.nwp[nwp_key].channel_stds,
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
228
|
-
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
229
|
-
|
|
230
|
-
if "sat" in dataset_dict:
|
|
231
|
-
da_sat = dataset_dict["sat"]
|
|
232
|
-
|
|
233
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
234
|
-
self.config.input_data.satellite.channel_means,
|
|
235
|
-
)
|
|
236
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
237
|
-
self.config.input_data.satellite.channel_stds,
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
241
|
-
data_arrays.append(("satellite", da_sat))
|
|
242
|
-
|
|
243
|
-
if "site" in dataset_dict:
|
|
244
|
-
da_sites = dataset_dict["site"]
|
|
245
|
-
da_sites = da_sites / da_sites.capacity_kwp
|
|
246
|
-
data_arrays.append(("site", da_sites))
|
|
247
|
-
|
|
248
|
-
combined_sample_dataset = self.merge_data_arrays(data_arrays)
|
|
249
|
-
|
|
250
|
-
# add datetime features
|
|
251
|
-
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
252
|
-
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
|
|
253
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
254
|
-
{k: ("site__time_utc", v) for k, v in datetime_features.items()},
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# Only add solar position if explicitly configured
|
|
258
|
-
has_solar_config = (
|
|
259
|
-
hasattr(self.config.input_data, "solar_position")
|
|
260
|
-
and self.config.input_data.solar_position is not None
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
if has_solar_config:
|
|
264
|
-
solar_config = self.config.input_data.solar_position
|
|
265
|
-
|
|
266
|
-
# Datetime range - solar config params
|
|
267
|
-
solar_datetimes = pd.date_range(
|
|
268
|
-
t0 + minutes(solar_config.interval_start_minutes),
|
|
269
|
-
t0 + minutes(solar_config.interval_end_minutes),
|
|
270
|
-
freq=minutes(solar_config.time_resolution_minutes),
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Calculate sun position features
|
|
274
|
-
sun_position_features = make_sun_position_numpy_sample(
|
|
275
|
-
datetimes=solar_datetimes,
|
|
276
|
-
lon=combined_sample_dataset.site__longitude.values,
|
|
277
|
-
lat=combined_sample_dataset.site__latitude.values,
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
# Use existing dimension for solar positions
|
|
281
|
-
# TODO decouple this as a separate data varaible
|
|
282
|
-
solar_dim_name = "site__time_utc"
|
|
283
|
-
|
|
284
|
-
# Assign solar position values
|
|
285
|
-
for key, values in sun_position_features.items():
|
|
286
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
287
|
-
{key: (solar_dim_name, values)},
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
# TODO include t0_index in xr dataset?
|
|
291
|
-
|
|
292
|
-
# Fill any nan values
|
|
293
|
-
return combined_sample_dataset.fillna(0.0)
|
|
294
|
-
|
|
295
|
-
def merge_data_arrays(
|
|
296
|
-
self,
|
|
297
|
-
normalised_data_arrays: list[tuple[str, xr.DataArray]],
|
|
298
|
-
) -> xr.Dataset:
|
|
299
|
-
"""Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
normalised_data_arrays: List of tuples where each tuple contains:
|
|
303
|
-
- A string (key name).
|
|
304
|
-
- An xarray.DataArray.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
|
|
308
|
-
"""
|
|
309
|
-
datasets = []
|
|
310
|
-
|
|
311
|
-
for key, data_array in normalised_data_arrays:
|
|
312
|
-
# Ensure all attributes are strings for consistency
|
|
313
|
-
data_array = data_array.assign_attrs(
|
|
314
|
-
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
# Convert DataArray to Dataset with the variable name as the key
|
|
318
|
-
dataset = data_array.to_dataset(name=key)
|
|
319
|
-
|
|
320
|
-
# Prepend key name to all dimension and coordinate names for uniqueness
|
|
321
|
-
dataset = dataset.rename(
|
|
322
|
-
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
|
|
323
|
-
)
|
|
324
|
-
dataset = dataset.rename(
|
|
325
|
-
{coord: f"{key}__{coord}" for coord in dataset.coords},
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
# Handle concatenation dimension if applicable
|
|
329
|
-
concat_dim = (
|
|
330
|
-
f"{key}__target_time_utc"
|
|
331
|
-
if f"{key}__target_time_utc" in dataset.coords
|
|
332
|
-
else f"{key}__time_utc"
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
if f"{key}__init_time_utc" in dataset.coords:
|
|
336
|
-
init_coord = f"{key}__init_time_utc"
|
|
337
|
-
if dataset[init_coord].ndim == 0: # Check if scalar
|
|
338
|
-
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
|
|
339
|
-
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})
|
|
340
|
-
|
|
341
|
-
datasets.append(dataset)
|
|
342
|
-
|
|
343
|
-
# Ensure all datasets are valid xarray.Dataset objects
|
|
344
|
-
for ds in datasets:
|
|
345
|
-
if not isinstance(ds, xr.Dataset):
|
|
346
|
-
raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
|
|
347
|
-
|
|
348
|
-
# Merge all prepared datasets
|
|
349
|
-
combined_dataset = xr.merge(datasets)
|
|
350
|
-
|
|
351
|
-
return combined_dataset
|
|
352
|
-
|
|
353
291
|
|
|
354
292
|
class SitesDatasetConcurrent(Dataset):
|
|
355
293
|
"""A torch Dataset for creating PVNet Site batches with samples for all sites."""
|
|
@@ -394,93 +332,6 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
394
332
|
# Assign coords and indices to self
|
|
395
333
|
self.valid_t0s = valid_t0s
|
|
396
334
|
|
|
397
|
-
@staticmethod
|
|
398
|
-
def process_and_combine_datasets(
|
|
399
|
-
dataset_dict: dict,
|
|
400
|
-
config: Configuration,
|
|
401
|
-
t0: pd.Timestamp,
|
|
402
|
-
) -> NumpySample:
|
|
403
|
-
"""Normalise and convert data to numpy arrays.
|
|
404
|
-
|
|
405
|
-
Args:
|
|
406
|
-
dataset_dict: Dictionary of xarray datasets
|
|
407
|
-
config: Configuration object
|
|
408
|
-
t0: init-time for sample
|
|
409
|
-
"""
|
|
410
|
-
numpy_modalities = []
|
|
411
|
-
|
|
412
|
-
if "nwp" in dataset_dict:
|
|
413
|
-
nwp_numpy_modalities = {}
|
|
414
|
-
|
|
415
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
416
|
-
# Standardise and convert to NumpyBatch
|
|
417
|
-
|
|
418
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
419
|
-
config.input_data.nwp[nwp_key].channel_means,
|
|
420
|
-
)
|
|
421
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
422
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
426
|
-
|
|
427
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
428
|
-
|
|
429
|
-
# Combine the NWPs into NumpyBatch
|
|
430
|
-
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
431
|
-
|
|
432
|
-
if "sat" in dataset_dict:
|
|
433
|
-
da_sat = dataset_dict["sat"]
|
|
434
|
-
|
|
435
|
-
# Standardise and convert to NumpyBatch
|
|
436
|
-
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
437
|
-
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
438
|
-
|
|
439
|
-
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
440
|
-
|
|
441
|
-
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
442
|
-
|
|
443
|
-
if "site" in dataset_dict:
|
|
444
|
-
da_sites = dataset_dict["site"]
|
|
445
|
-
da_sites = da_sites / da_sites.capacity_kwp
|
|
446
|
-
|
|
447
|
-
# Convert to NumpyBatch
|
|
448
|
-
numpy_modalities.append(
|
|
449
|
-
convert_site_to_numpy_sample(
|
|
450
|
-
da_sites,
|
|
451
|
-
),
|
|
452
|
-
)
|
|
453
|
-
|
|
454
|
-
# Only add solar position if explicitly configured
|
|
455
|
-
has_solar_config = (
|
|
456
|
-
hasattr(config.input_data, "solar_position")
|
|
457
|
-
and config.input_data.solar_position is not None
|
|
458
|
-
)
|
|
459
|
-
|
|
460
|
-
if has_solar_config:
|
|
461
|
-
solar_config = config.input_data.solar_position
|
|
462
|
-
|
|
463
|
-
# Create datetime range for solar position calculation
|
|
464
|
-
datetimes = pd.date_range(
|
|
465
|
-
t0 + minutes(solar_config.interval_start_minutes),
|
|
466
|
-
t0 + minutes(solar_config.interval_end_minutes),
|
|
467
|
-
freq=minutes(solar_config.time_resolution_minutes),
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
# Calculate solar positions and add to modalities
|
|
471
|
-
numpy_modalities.append(
|
|
472
|
-
make_sun_position_numpy_sample(
|
|
473
|
-
datetimes, da_sites.longitude.values, da_sites.latitude.values,
|
|
474
|
-
),
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# Combine all the modalities and fill NaNs
|
|
478
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
479
|
-
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
480
|
-
|
|
481
|
-
return combined_sample
|
|
482
|
-
|
|
483
|
-
|
|
484
335
|
def find_valid_t0s(
|
|
485
336
|
self,
|
|
486
337
|
datasets_dict: dict,
|
|
@@ -551,7 +402,7 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
551
402
|
|
|
552
403
|
for location in self.locations:
|
|
553
404
|
site_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
554
|
-
site_numpy_sample =
|
|
405
|
+
site_numpy_sample = process_and_combine_datasets(
|
|
555
406
|
site_sample_dict,
|
|
556
407
|
self.config,
|
|
557
408
|
t0,
|
|
@@ -561,128 +412,6 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
561
412
|
return stack_np_samples_into_batch(site_samples)
|
|
562
413
|
|
|
563
414
|
|
|
564
|
-
# ----- functions to load presaved samples ------
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
568
|
-
"""Convert a netcdf dataset to a numpy sample.
|
|
569
|
-
|
|
570
|
-
Args:
|
|
571
|
-
ds: xarray Dataset
|
|
572
|
-
"""
|
|
573
|
-
# convert the single dataset to a dict of arrays
|
|
574
|
-
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
575
|
-
|
|
576
|
-
if "satellite" in sample_dict:
|
|
577
|
-
# rename satellite to sat # TODO this could be improved
|
|
578
|
-
sample_dict["sat"] = sample_dict.pop("satellite")
|
|
579
|
-
|
|
580
|
-
# process and combine the datasets
|
|
581
|
-
sample = convert_to_numpy_and_combine(
|
|
582
|
-
dataset_dict=sample_dict,
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
# Extraction of solar position coords
|
|
586
|
-
solar_keys = ["solar_azimuth", "solar_elevation"]
|
|
587
|
-
for key in solar_keys:
|
|
588
|
-
if key in ds.coords:
|
|
589
|
-
sample[key] = ds.coords[key].values
|
|
590
|
-
|
|
591
|
-
# TODO think about normalization:
|
|
592
|
-
# * maybe its done not in sample creation, maybe its done afterwards,
|
|
593
|
-
# to allow it to be flexible
|
|
594
|
-
|
|
595
|
-
return sample
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
|
|
599
|
-
"""Convert a combined sample dataset to a dict of datasets for each input.
|
|
600
|
-
|
|
601
|
-
Args:
|
|
602
|
-
combined_dataset: The combined NetCDF dataset
|
|
603
|
-
|
|
604
|
-
Returns:
|
|
605
|
-
The uncombined datasets as a dict of xr.Datasets
|
|
606
|
-
"""
|
|
607
|
-
# Split into datasets by splitting by the prefix added in combine_to_netcdf
|
|
608
|
-
datasets: dict[str, xr.DataArray] = {}
|
|
609
|
-
|
|
610
|
-
# Go through each data variable and split it into a dataset
|
|
611
|
-
for key, dataset in combined_dataset.items():
|
|
612
|
-
# If 'key__' doesn't exist in a dim or coordinate, remove it
|
|
613
|
-
for dim in list(dataset.coords):
|
|
614
|
-
if f"{key}__" not in dim:
|
|
615
|
-
dataset = dataset.drop_vars(dim)
|
|
616
|
-
dataset = dataset.rename(
|
|
617
|
-
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
|
|
618
|
-
)
|
|
619
|
-
dataset = dataset.rename(
|
|
620
|
-
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
|
|
621
|
-
)
|
|
622
|
-
# Split the dataset by the prefix
|
|
623
|
-
datasets[key] = dataset
|
|
624
|
-
|
|
625
|
-
# Unflatten any NWP data
|
|
626
|
-
return nest_nwp_source_dict(datasets, sep="-")
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def nest_nwp_source_dict(
|
|
630
|
-
dataset_dict: dict[xr.Dataset],
|
|
631
|
-
sep: str = "-",
|
|
632
|
-
) -> dict[str, xr.Dataset | dict[xr.Dataset]]:
|
|
633
|
-
"""Re-nest a dictionary where the NWP values are nested under keys 'nwp-<key>'.
|
|
634
|
-
|
|
635
|
-
Args:
|
|
636
|
-
dataset_dict: Dictionary of datasets
|
|
637
|
-
sep: Separator to use to nest NWP keys
|
|
638
|
-
"""
|
|
639
|
-
nwp_prefix = f"nwp{sep}"
|
|
640
|
-
new_dict = {k: v for k, v in dataset_dict.items() if not k.startswith(nwp_prefix)}
|
|
641
|
-
nwp_keys = [k for k in dataset_dict if k.startswith(nwp_prefix)]
|
|
642
|
-
if len(nwp_keys) > 0:
|
|
643
|
-
nwp_subdict = {k.removeprefix(nwp_prefix): dataset_dict[k] for k in nwp_keys}
|
|
644
|
-
new_dict["nwp"] = nwp_subdict
|
|
645
|
-
return new_dict
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
def convert_to_numpy_and_combine(dataset_dict: dict[xr.Dataset]) -> NumpySample:
|
|
649
|
-
"""Convert input data in a dict to numpy arrays.
|
|
650
|
-
|
|
651
|
-
Args:
|
|
652
|
-
dataset_dict: Dictionary of xarray Datasets
|
|
653
|
-
"""
|
|
654
|
-
numpy_modalities = []
|
|
655
|
-
|
|
656
|
-
if "nwp" in dataset_dict:
|
|
657
|
-
nwp_numpy_modalities = {}
|
|
658
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
659
|
-
# Convert to NumpySample
|
|
660
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
661
|
-
|
|
662
|
-
# Combine the NWPs into NumpySample
|
|
663
|
-
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
664
|
-
|
|
665
|
-
if "sat" in dataset_dict:
|
|
666
|
-
# Satellite is already in the range [0-1] so no need to standardise
|
|
667
|
-
da_sat = dataset_dict["sat"]
|
|
668
|
-
|
|
669
|
-
# Convert to NumpySample
|
|
670
|
-
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
671
|
-
|
|
672
|
-
if "site" in dataset_dict:
|
|
673
|
-
da_sites = dataset_dict["site"]
|
|
674
|
-
|
|
675
|
-
numpy_modalities.append(
|
|
676
|
-
convert_site_to_numpy_sample(
|
|
677
|
-
da_sites,
|
|
678
|
-
),
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
# Combine all the modalities and fill NaNs
|
|
682
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
683
|
-
return fill_nans_in_arrays(combined_sample)
|
|
684
|
-
|
|
685
|
-
|
|
686
415
|
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
|
|
687
416
|
"""Coarsen the data to a specified resolution in degrees.
|
|
688
417
|
|
|
@@ -1,37 +1,46 @@
|
|
|
1
1
|
"""PVNet Site sample implementation for netCDF data handling and conversion."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import torch
|
|
4
4
|
from typing_extensions import override
|
|
5
5
|
|
|
6
6
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample
|
|
7
|
-
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
8
7
|
|
|
9
8
|
from .base import SampleBase
|
|
10
9
|
|
|
11
10
|
|
|
11
|
+
# TODO this is now similar to the UKRegionalSample
|
|
12
|
+
# We should consider just having one Sample class for all datasets
|
|
12
13
|
class SiteSample(SampleBase):
|
|
13
|
-
"""Handles
|
|
14
|
+
"""Handles SiteSample specific operations."""
|
|
14
15
|
|
|
15
|
-
def __init__(self, data:
|
|
16
|
-
"""Initializes the SiteSample object with the given
|
|
17
|
-
if not isinstance(data, xr.Dataset):
|
|
18
|
-
raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
|
|
16
|
+
def __init__(self, data: NumpySample) -> None:
|
|
17
|
+
"""Initializes the SiteSample object with the given NumpySample."""
|
|
19
18
|
self._data = data
|
|
20
19
|
|
|
21
20
|
@override
|
|
22
21
|
def to_numpy(self) -> NumpySample:
|
|
23
|
-
return
|
|
22
|
+
return self._data
|
|
24
23
|
|
|
25
24
|
@override
|
|
26
25
|
def save(self, path: str) -> None:
|
|
27
|
-
|
|
28
|
-
|
|
26
|
+
"""Saves sample to the specified path in pickle format."""
|
|
27
|
+
# Saves to pickle format
|
|
28
|
+
torch.save(self._data, path)
|
|
29
29
|
|
|
30
30
|
@classmethod
|
|
31
31
|
@override
|
|
32
32
|
def load(cls, path: str) -> "SiteSample":
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
"""Loads sample from the specified path.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
path: Path to the saved sample file.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A SiteSample instance with the loaded data.
|
|
40
|
+
"""
|
|
41
|
+
# Loads from .pt format
|
|
42
|
+
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
43
|
+
return cls(torch.load(path, weights_only=False))
|
|
35
44
|
|
|
36
45
|
@override
|
|
37
46
|
def plot(self) -> None:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ocf-data-sampler
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Author: James Fulton, Peter Dudfield
|
|
5
5
|
Author-email: Open Climate Fix team <info@openclimatefix.org>
|
|
6
6
|
License: MIT License
|
|
@@ -49,7 +49,7 @@ Requires-Dist: xarray-tensorstore==0.1.5
|
|
|
49
49
|
# ocf-data-sampler
|
|
50
50
|
|
|
51
51
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
52
|
-
[](#contributors-)
|
|
53
53
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
54
54
|
|
|
55
55
|
[](https://github.com/openclimatefix/ocf-data-sampler/tags)
|
|
@@ -128,6 +128,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
128
128
|
<td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
|
|
129
129
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/Sachin-G13"><img src="https://avatars.githubusercontent.com/u/190184500?v=4?s=100" width="100px;" alt="Sachin-G13"/><br /><sub><b>Sachin-G13</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Sachin-G13" title="Code">💻</a></td>
|
|
130
130
|
<td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
|
|
131
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
|
|
131
132
|
</tr>
|
|
132
133
|
</tbody>
|
|
133
134
|
</table>
|
|
@@ -22,14 +22,14 @@ ocf_data_sampler/load/nwp/providers/gfs.py,sha256=h6vm-Rfz1JGOE4P_fP1_XQJ3bugNbe
|
|
|
22
22
|
ocf_data_sampler/load/nwp/providers/icon.py,sha256=iVZwLKRr_D74_kAu5MHir6pRKEfbTmIxFRZAxzmiYdI,1257
|
|
23
23
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=2i32VM9gnmWUpbL0qBSp_AKzuyKucXZPS8yklbcGlbc,1039
|
|
24
24
|
ocf_data_sampler/load/nwp/providers/utils.py,sha256=cVwCiC8FqNpkZFSUGv1CRqIQlKdjx1sIsb2SIUlvWV8,2333
|
|
25
|
-
ocf_data_sampler/numpy_sample/__init__.py,sha256=
|
|
25
|
+
ocf_data_sampler/numpy_sample/__init__.py,sha256=5bdpzM8hMAEe0XRSZ9AZFQdqEeBsEPhaF79Y8bDx3GQ,407
|
|
26
26
|
ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lHUgaKgHIoo80,2039
|
|
27
27
|
ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
|
|
28
|
-
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=
|
|
28
|
+
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=ObHM42VnZB7_daQ5a42GeftoDWYtVMT-wDP8kRtY_84,857
|
|
29
29
|
ocf_data_sampler/numpy_sample/gsp.py,sha256=aUHDIUSu2LMsVmR7TsTriZxVfv495QNL-scaxyJFHgQ,1149
|
|
30
30
|
ocf_data_sampler/numpy_sample/nwp.py,sha256=lXqE2Il0xX5hzz76HHkiYmfDsXWWhmaA_6bSnmwbAXU,1078
|
|
31
31
|
ocf_data_sampler/numpy_sample/satellite.py,sha256=RaYzYIcB1AmDrKeiqSpn4QVfBH-QMe26F1P5t1az2Jg,1111
|
|
32
|
-
ocf_data_sampler/numpy_sample/site.py,sha256=
|
|
32
|
+
ocf_data_sampler/numpy_sample/site.py,sha256=4S19bzCN5lswVUrmWRfwpVsBPUE7bi0OIdxsD9wgvhU,982
|
|
33
33
|
ocf_data_sampler/numpy_sample/sun_position.py,sha256=5tt-zNm6aRuZMsxZPaAxyg7HeikswfZCeHWXTHuO2K0,1555
|
|
34
34
|
ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTAooNg_s,210
|
|
35
35
|
ocf_data_sampler/select/dropout.py,sha256=BYpv8L771faPOyN7SdIJ5cwkpDve-ohClj95jjsHmjg,1973
|
|
@@ -39,12 +39,12 @@ ocf_data_sampler/select/geospatial.py,sha256=CDExkl36eZOKmdJPzUr_K0Wn3axHqv5nYo-
|
|
|
39
39
|
ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O5Deu0c,1037
|
|
40
40
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
|
|
41
41
|
ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
|
|
42
|
-
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=
|
|
42
|
+
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
|
|
43
43
|
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=v63goKEMI6UgBPnQCnIbxhFFdwuP_sxgcPYY6iNfGkc,12257
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_0A2kRq8B5WL5zWjKxNY9snAl_GwptohUt7c6DDa2AA,14812
|
|
45
45
|
ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
|
|
46
46
|
ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
|
|
47
|
-
ocf_data_sampler/torch_datasets/sample/site.py,sha256=
|
|
47
|
+
ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
|
|
48
48
|
ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
|
|
49
49
|
ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_QuiYOuvIyulPrn6xEA4TY,309
|
|
50
50
|
ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
|
|
@@ -56,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
|
|
|
56
56
|
scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
|
|
57
57
|
scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
|
|
58
58
|
utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
|
|
59
|
-
ocf_data_sampler-0.
|
|
60
|
-
ocf_data_sampler-0.
|
|
61
|
-
ocf_data_sampler-0.
|
|
62
|
-
ocf_data_sampler-0.
|
|
59
|
+
ocf_data_sampler-0.5.0.dist-info/METADATA,sha256=DUHmN65X_SR-1E8bTNfsCShFPJKIEvR9DWfAQoNyAt4,12588
|
|
60
|
+
ocf_data_sampler-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
61
|
+
ocf_data_sampler-0.5.0.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
|
|
62
|
+
ocf_data_sampler-0.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|