ocf-data-sampler 0.4.0__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/PKG-INFO +1 -1
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/__init__.py +1 -2
- ocf_data_sampler-0.5.0/ocf_data_sampler/torch_datasets/datasets/__init__.py +2 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/datasets/site.py +94 -355
- ocf_data_sampler-0.5.0/ocf_data_sampler/torch_datasets/sample/site.py +48 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/PKG-INFO +1 -1
- ocf_data_sampler-0.4.0/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -6
- ocf_data_sampler-0.4.0/ocf_data_sampler/torch_datasets/sample/site.py +0 -39
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/LICENSE +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/README.md +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/model.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/open_tensorstore_zarrs.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/satellite.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/site.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/collate.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/site.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/dropout.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/fill_time_periods.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/base.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/__init__.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/utils.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/pyproject.toml +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/scripts/download_gsp_location_data.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/setup.cfg +0 -0
- {ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/utils/compute_icon_mean_stddev.py +0 -0
|
@@ -5,5 +5,4 @@ from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
|
|
|
5
5
|
from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
|
|
6
6
|
from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
|
|
7
7
|
from .sun_position import make_sun_position_numpy_sample
|
|
8
|
-
from .site import convert_site_to_numpy_sample
|
|
9
|
-
|
|
8
|
+
from .site import convert_site_to_numpy_sample, SiteSampleKey
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/datasets/site.py
RENAMED
|
@@ -58,6 +58,96 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
|
|
|
58
58
|
|
|
59
59
|
return locations
|
|
60
60
|
|
|
61
|
+
def process_and_combine_datasets(
|
|
62
|
+
dataset_dict: dict,
|
|
63
|
+
config: Configuration,
|
|
64
|
+
t0: pd.Timestamp,
|
|
65
|
+
) -> NumpySample:
|
|
66
|
+
"""Normalise and convert data to numpy arrays.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dataset_dict: Dictionary of xarray datasets
|
|
70
|
+
config: Configuration object
|
|
71
|
+
t0: init-time for sample
|
|
72
|
+
"""
|
|
73
|
+
numpy_modalities = []
|
|
74
|
+
|
|
75
|
+
if "nwp" in dataset_dict:
|
|
76
|
+
nwp_numpy_modalities = {}
|
|
77
|
+
|
|
78
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
79
|
+
|
|
80
|
+
# Standardise and convert to NumpyBatch
|
|
81
|
+
|
|
82
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
83
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
84
|
+
)
|
|
85
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
86
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
90
|
+
|
|
91
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
92
|
+
|
|
93
|
+
# Combine the NWPs into NumpyBatch
|
|
94
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
95
|
+
|
|
96
|
+
if "sat" in dataset_dict:
|
|
97
|
+
da_sat = dataset_dict["sat"]
|
|
98
|
+
|
|
99
|
+
# Standardise and convert to NumpyBatch
|
|
100
|
+
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
101
|
+
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
102
|
+
|
|
103
|
+
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
104
|
+
|
|
105
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
106
|
+
|
|
107
|
+
if "site" in dataset_dict:
|
|
108
|
+
da_sites = dataset_dict["site"]
|
|
109
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
110
|
+
|
|
111
|
+
# Convert to NumpyBatch
|
|
112
|
+
numpy_modalities.append(
|
|
113
|
+
convert_site_to_numpy_sample(
|
|
114
|
+
da_sites,
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# add datetime features
|
|
119
|
+
datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
|
|
120
|
+
datetime_features = encode_datetimes(datetimes=datetimes)
|
|
121
|
+
|
|
122
|
+
numpy_modalities.append(datetime_features)
|
|
123
|
+
|
|
124
|
+
# Only add solar position if explicitly configured
|
|
125
|
+
if config.input_data.solar_position is not None:
|
|
126
|
+
solar_config = config.input_data.solar_position
|
|
127
|
+
|
|
128
|
+
# Create datetime range for solar position calculation
|
|
129
|
+
datetimes = pd.date_range(
|
|
130
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
131
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
132
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Calculate solar positions and add to modalities
|
|
137
|
+
numpy_modalities.append(
|
|
138
|
+
make_sun_position_numpy_sample(
|
|
139
|
+
datetimes,
|
|
140
|
+
da_sites.longitude.values,
|
|
141
|
+
da_sites.latitude.values,
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Combine all the modalities and fill NaNs
|
|
146
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
147
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
148
|
+
|
|
149
|
+
return combined_sample
|
|
150
|
+
|
|
61
151
|
|
|
62
152
|
class SitesDataset(Dataset):
|
|
63
153
|
"""A torch Dataset for creating PVNet Site samples."""
|
|
@@ -181,8 +271,9 @@ class SitesDataset(Dataset):
|
|
|
181
271
|
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
182
272
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
183
273
|
|
|
184
|
-
|
|
185
|
-
|
|
274
|
+
sample_dict = compute(sample_dict)
|
|
275
|
+
|
|
276
|
+
return process_and_combine_datasets(sample_dict, self.config, t0)
|
|
186
277
|
|
|
187
278
|
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
188
279
|
"""Generate a sample for a given site id and t0.
|
|
@@ -197,159 +288,6 @@ class SitesDataset(Dataset):
|
|
|
197
288
|
|
|
198
289
|
return self._get_sample(t0, location)
|
|
199
290
|
|
|
200
|
-
def process_and_combine_site_sample_dict(
|
|
201
|
-
self,
|
|
202
|
-
dataset_dict: dict,
|
|
203
|
-
t0: pd.Timestamp,
|
|
204
|
-
) -> xr.Dataset:
|
|
205
|
-
"""Normalize and combine data into a single xr Dataset.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
dataset_dict: dict containing sliced xr DataArrays
|
|
209
|
-
t0: The initial timestamp of the sample
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
xr.Dataset: A merged Dataset with nans filled in.
|
|
213
|
-
"""
|
|
214
|
-
data_arrays = []
|
|
215
|
-
|
|
216
|
-
if "nwp" in dataset_dict:
|
|
217
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
218
|
-
provider = self.config.input_data.nwp[nwp_key].provider
|
|
219
|
-
|
|
220
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
221
|
-
self.config.input_data.nwp[nwp_key].channel_means,
|
|
222
|
-
)
|
|
223
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
224
|
-
self.config.input_data.nwp[nwp_key].channel_stds,
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
228
|
-
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
229
|
-
|
|
230
|
-
if "sat" in dataset_dict:
|
|
231
|
-
da_sat = dataset_dict["sat"]
|
|
232
|
-
|
|
233
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
234
|
-
self.config.input_data.satellite.channel_means,
|
|
235
|
-
)
|
|
236
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
237
|
-
self.config.input_data.satellite.channel_stds,
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
241
|
-
data_arrays.append(("satellite", da_sat))
|
|
242
|
-
|
|
243
|
-
if "site" in dataset_dict:
|
|
244
|
-
da_sites = dataset_dict["site"]
|
|
245
|
-
da_sites = da_sites / da_sites.capacity_kwp
|
|
246
|
-
data_arrays.append(("site", da_sites))
|
|
247
|
-
|
|
248
|
-
combined_sample_dataset = self.merge_data_arrays(data_arrays)
|
|
249
|
-
|
|
250
|
-
# add datetime features
|
|
251
|
-
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
252
|
-
datetime_features = encode_datetimes(datetimes=datetimes)
|
|
253
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
254
|
-
{k: ("site__time_utc", v) for k, v in datetime_features.items()},
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# Only add solar position if explicitly configured
|
|
258
|
-
has_solar_config = (
|
|
259
|
-
hasattr(self.config.input_data, "solar_position")
|
|
260
|
-
and self.config.input_data.solar_position is not None
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
if has_solar_config:
|
|
264
|
-
solar_config = self.config.input_data.solar_position
|
|
265
|
-
|
|
266
|
-
# Datetime range - solar config params
|
|
267
|
-
solar_datetimes = pd.date_range(
|
|
268
|
-
t0 + minutes(solar_config.interval_start_minutes),
|
|
269
|
-
t0 + minutes(solar_config.interval_end_minutes),
|
|
270
|
-
freq=minutes(solar_config.time_resolution_minutes),
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Calculate sun position features
|
|
274
|
-
sun_position_features = make_sun_position_numpy_sample(
|
|
275
|
-
datetimes=solar_datetimes,
|
|
276
|
-
lon=combined_sample_dataset.site__longitude.values,
|
|
277
|
-
lat=combined_sample_dataset.site__latitude.values,
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
# Use existing dimension for solar positions
|
|
281
|
-
# TODO decouple this as a separate data varaible
|
|
282
|
-
solar_dim_name = "site__time_utc"
|
|
283
|
-
|
|
284
|
-
# Assign solar position values
|
|
285
|
-
for key, values in sun_position_features.items():
|
|
286
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
287
|
-
{key: (solar_dim_name, values)},
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
# TODO include t0_index in xr dataset?
|
|
291
|
-
|
|
292
|
-
# Fill any nan values
|
|
293
|
-
return combined_sample_dataset.fillna(0.0)
|
|
294
|
-
|
|
295
|
-
def merge_data_arrays(
|
|
296
|
-
self,
|
|
297
|
-
normalised_data_arrays: list[tuple[str, xr.DataArray]],
|
|
298
|
-
) -> xr.Dataset:
|
|
299
|
-
"""Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
normalised_data_arrays: List of tuples where each tuple contains:
|
|
303
|
-
- A string (key name).
|
|
304
|
-
- An xarray.DataArray.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
|
|
308
|
-
"""
|
|
309
|
-
datasets = []
|
|
310
|
-
|
|
311
|
-
for key, data_array in normalised_data_arrays:
|
|
312
|
-
# Ensure all attributes are strings for consistency
|
|
313
|
-
data_array = data_array.assign_attrs(
|
|
314
|
-
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
# Convert DataArray to Dataset with the variable name as the key
|
|
318
|
-
dataset = data_array.to_dataset(name=key)
|
|
319
|
-
|
|
320
|
-
# Prepend key name to all dimension and coordinate names for uniqueness
|
|
321
|
-
dataset = dataset.rename(
|
|
322
|
-
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
|
|
323
|
-
)
|
|
324
|
-
dataset = dataset.rename(
|
|
325
|
-
{coord: f"{key}__{coord}" for coord in dataset.coords},
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
# Handle concatenation dimension if applicable
|
|
329
|
-
concat_dim = (
|
|
330
|
-
f"{key}__target_time_utc"
|
|
331
|
-
if f"{key}__target_time_utc" in dataset.coords
|
|
332
|
-
else f"{key}__time_utc"
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
if f"{key}__init_time_utc" in dataset.coords:
|
|
336
|
-
init_coord = f"{key}__init_time_utc"
|
|
337
|
-
if dataset[init_coord].ndim == 0: # Check if scalar
|
|
338
|
-
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
|
|
339
|
-
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})
|
|
340
|
-
|
|
341
|
-
datasets.append(dataset)
|
|
342
|
-
|
|
343
|
-
# Ensure all datasets are valid xarray.Dataset objects
|
|
344
|
-
for ds in datasets:
|
|
345
|
-
if not isinstance(ds, xr.Dataset):
|
|
346
|
-
raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
|
|
347
|
-
|
|
348
|
-
# Merge all prepared datasets
|
|
349
|
-
combined_dataset = xr.merge(datasets)
|
|
350
|
-
|
|
351
|
-
return combined_dataset
|
|
352
|
-
|
|
353
291
|
|
|
354
292
|
class SitesDatasetConcurrent(Dataset):
|
|
355
293
|
"""A torch Dataset for creating PVNet Site batches with samples for all sites."""
|
|
@@ -394,89 +332,6 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
394
332
|
# Assign coords and indices to self
|
|
395
333
|
self.valid_t0s = valid_t0s
|
|
396
334
|
|
|
397
|
-
@staticmethod
|
|
398
|
-
def process_and_combine_datasets(
|
|
399
|
-
dataset_dict: dict,
|
|
400
|
-
config: Configuration,
|
|
401
|
-
t0: pd.Timestamp,
|
|
402
|
-
) -> NumpySample:
|
|
403
|
-
"""Normalise and convert data to numpy arrays.
|
|
404
|
-
|
|
405
|
-
Args:
|
|
406
|
-
dataset_dict: Dictionary of xarray datasets
|
|
407
|
-
config: Configuration object
|
|
408
|
-
t0: init-time for sample
|
|
409
|
-
"""
|
|
410
|
-
numpy_modalities = []
|
|
411
|
-
|
|
412
|
-
if "nwp" in dataset_dict:
|
|
413
|
-
nwp_numpy_modalities = {}
|
|
414
|
-
|
|
415
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
416
|
-
# Standardise and convert to NumpyBatch
|
|
417
|
-
|
|
418
|
-
da_channel_means = channel_dict_to_dataarray(
|
|
419
|
-
config.input_data.nwp[nwp_key].channel_means,
|
|
420
|
-
)
|
|
421
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
422
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
426
|
-
|
|
427
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
428
|
-
|
|
429
|
-
# Combine the NWPs into NumpyBatch
|
|
430
|
-
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
431
|
-
|
|
432
|
-
if "sat" in dataset_dict:
|
|
433
|
-
da_sat = dataset_dict["sat"]
|
|
434
|
-
|
|
435
|
-
# Standardise and convert to NumpyBatch
|
|
436
|
-
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
437
|
-
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
438
|
-
|
|
439
|
-
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
440
|
-
|
|
441
|
-
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
442
|
-
|
|
443
|
-
if "site" in dataset_dict:
|
|
444
|
-
da_sites = dataset_dict["site"]
|
|
445
|
-
da_sites = da_sites / da_sites.capacity_kwp
|
|
446
|
-
|
|
447
|
-
# Convert to NumpyBatch
|
|
448
|
-
numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
|
|
449
|
-
|
|
450
|
-
# Only add solar position if explicitly configured
|
|
451
|
-
has_solar_config = (
|
|
452
|
-
hasattr(config.input_data, "solar_position")
|
|
453
|
-
and config.input_data.solar_position is not None
|
|
454
|
-
)
|
|
455
|
-
|
|
456
|
-
if has_solar_config:
|
|
457
|
-
solar_config = config.input_data.solar_position
|
|
458
|
-
|
|
459
|
-
# Create datetime range for solar position calculation
|
|
460
|
-
datetimes = pd.date_range(
|
|
461
|
-
t0 + minutes(solar_config.interval_start_minutes),
|
|
462
|
-
t0 + minutes(solar_config.interval_end_minutes),
|
|
463
|
-
freq=minutes(solar_config.time_resolution_minutes),
|
|
464
|
-
)
|
|
465
|
-
|
|
466
|
-
# Calculate solar positions and add to modalities
|
|
467
|
-
numpy_modalities.append(
|
|
468
|
-
make_sun_position_numpy_sample(
|
|
469
|
-
datetimes, da_sites.longitude.values, da_sites.latitude.values,
|
|
470
|
-
),
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
# Combine all the modalities and fill NaNs
|
|
474
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
475
|
-
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
476
|
-
|
|
477
|
-
return combined_sample
|
|
478
|
-
|
|
479
|
-
|
|
480
335
|
def find_valid_t0s(
|
|
481
336
|
self,
|
|
482
337
|
datasets_dict: dict,
|
|
@@ -547,7 +402,7 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
547
402
|
|
|
548
403
|
for location in self.locations:
|
|
549
404
|
site_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
550
|
-
site_numpy_sample =
|
|
405
|
+
site_numpy_sample = process_and_combine_datasets(
|
|
551
406
|
site_sample_dict,
|
|
552
407
|
self.config,
|
|
553
408
|
t0,
|
|
@@ -557,122 +412,6 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
557
412
|
return stack_np_samples_into_batch(site_samples)
|
|
558
413
|
|
|
559
414
|
|
|
560
|
-
# ----- functions to load presaved samples ------
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
564
|
-
"""Convert a netcdf dataset to a numpy sample.
|
|
565
|
-
|
|
566
|
-
Args:
|
|
567
|
-
ds: xarray Dataset
|
|
568
|
-
"""
|
|
569
|
-
# convert the single dataset to a dict of arrays
|
|
570
|
-
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
571
|
-
|
|
572
|
-
if "satellite" in sample_dict:
|
|
573
|
-
# rename satellite to sat # TODO this could be improved
|
|
574
|
-
sample_dict["sat"] = sample_dict.pop("satellite")
|
|
575
|
-
|
|
576
|
-
# process and combine the datasets
|
|
577
|
-
sample = convert_to_numpy_and_combine(dataset_dict=sample_dict)
|
|
578
|
-
|
|
579
|
-
# Add solar coord and datetime features
|
|
580
|
-
keys = ["solar_azimuth", "solar_elevation", "date_sin", "date_cos", "time_sin", "time_cos"]
|
|
581
|
-
for key in keys:
|
|
582
|
-
if key in ds.coords:
|
|
583
|
-
sample[key] = ds.coords[key].values
|
|
584
|
-
|
|
585
|
-
# TODO think about normalization:
|
|
586
|
-
# * maybe its done not in sample creation, maybe its done afterwards,
|
|
587
|
-
# to allow it to be flexible
|
|
588
|
-
|
|
589
|
-
return sample
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
|
|
593
|
-
"""Convert a combined sample dataset to a dict of datasets for each input.
|
|
594
|
-
|
|
595
|
-
Args:
|
|
596
|
-
combined_dataset: The combined NetCDF dataset
|
|
597
|
-
|
|
598
|
-
Returns:
|
|
599
|
-
The uncombined datasets as a dict of xr.Datasets
|
|
600
|
-
"""
|
|
601
|
-
# Split into datasets by splitting by the prefix added in combine_to_netcdf
|
|
602
|
-
datasets: dict[str, xr.DataArray] = {}
|
|
603
|
-
|
|
604
|
-
# Go through each data variable and split it into a dataset
|
|
605
|
-
for key, dataset in combined_dataset.items():
|
|
606
|
-
# If 'key__' doesn't exist in a dim or coordinate, remove it
|
|
607
|
-
for dim in list(dataset.coords):
|
|
608
|
-
if f"{key}__" not in dim:
|
|
609
|
-
dataset = dataset.drop_vars(dim)
|
|
610
|
-
dataset = dataset.rename(
|
|
611
|
-
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
|
|
612
|
-
)
|
|
613
|
-
dataset = dataset.rename(
|
|
614
|
-
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
|
|
615
|
-
)
|
|
616
|
-
# Split the dataset by the prefix
|
|
617
|
-
datasets[key] = dataset
|
|
618
|
-
|
|
619
|
-
# Unflatten any NWP data
|
|
620
|
-
return nest_nwp_source_dict(datasets, sep="-")
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
def nest_nwp_source_dict(
|
|
624
|
-
dataset_dict: dict[xr.Dataset],
|
|
625
|
-
sep: str = "-",
|
|
626
|
-
) -> dict[str, xr.Dataset | dict[xr.Dataset]]:
|
|
627
|
-
"""Re-nest a dictionary where the NWP values are nested under keys 'nwp-<key>'.
|
|
628
|
-
|
|
629
|
-
Args:
|
|
630
|
-
dataset_dict: Dictionary of datasets
|
|
631
|
-
sep: Separator to use to nest NWP keys
|
|
632
|
-
"""
|
|
633
|
-
nwp_prefix = f"nwp{sep}"
|
|
634
|
-
new_dict = {k: v for k, v in dataset_dict.items() if not k.startswith(nwp_prefix)}
|
|
635
|
-
nwp_keys = [k for k in dataset_dict if k.startswith(nwp_prefix)]
|
|
636
|
-
if len(nwp_keys) > 0:
|
|
637
|
-
nwp_subdict = {k.removeprefix(nwp_prefix): dataset_dict[k] for k in nwp_keys}
|
|
638
|
-
new_dict["nwp"] = nwp_subdict
|
|
639
|
-
return new_dict
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
def convert_to_numpy_and_combine(dataset_dict: dict[xr.Dataset]) -> NumpySample:
|
|
643
|
-
"""Convert input data in a dict to numpy arrays.
|
|
644
|
-
|
|
645
|
-
Args:
|
|
646
|
-
dataset_dict: Dictionary of xarray Datasets
|
|
647
|
-
"""
|
|
648
|
-
numpy_modalities = []
|
|
649
|
-
|
|
650
|
-
if "nwp" in dataset_dict:
|
|
651
|
-
nwp_numpy_modalities = {}
|
|
652
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
653
|
-
# Convert to NumpySample
|
|
654
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
655
|
-
|
|
656
|
-
# Combine the NWPs into NumpySample
|
|
657
|
-
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
658
|
-
|
|
659
|
-
if "sat" in dataset_dict:
|
|
660
|
-
# Satellite is already in the range [0-1] so no need to standardise
|
|
661
|
-
da_sat = dataset_dict["sat"]
|
|
662
|
-
|
|
663
|
-
# Convert to NumpySample
|
|
664
|
-
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
665
|
-
|
|
666
|
-
if "site" in dataset_dict:
|
|
667
|
-
da_sites = dataset_dict["site"]
|
|
668
|
-
|
|
669
|
-
numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
|
|
670
|
-
|
|
671
|
-
# Combine all the modalities and fill NaNs
|
|
672
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
673
|
-
return fill_nans_in_arrays(combined_sample)
|
|
674
|
-
|
|
675
|
-
|
|
676
415
|
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
|
|
677
416
|
"""Coarsen the data to a specified resolution in degrees.
|
|
678
417
|
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""PVNet Site sample implementation for netCDF data handling and conversion."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.numpy_sample.common_types import NumpySample
|
|
7
|
+
|
|
8
|
+
from .base import SampleBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO this is now similar to the UKRegionalSample
|
|
12
|
+
# We should consider just having one Sample class for all datasets
|
|
13
|
+
class SiteSample(SampleBase):
|
|
14
|
+
"""Handles SiteSample specific operations."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, data: NumpySample) -> None:
|
|
17
|
+
"""Initializes the SiteSample object with the given NumpySample."""
|
|
18
|
+
self._data = data
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def to_numpy(self) -> NumpySample:
|
|
22
|
+
return self._data
|
|
23
|
+
|
|
24
|
+
@override
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Saves sample to the specified path in pickle format."""
|
|
27
|
+
# Saves to pickle format
|
|
28
|
+
torch.save(self._data, path)
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
@override
|
|
32
|
+
def load(cls, path: str) -> "SiteSample":
|
|
33
|
+
"""Loads sample from the specified path.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
path: Path to the saved sample file.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A SiteSample instance with the loaded data.
|
|
40
|
+
"""
|
|
41
|
+
# Loads from .pt format
|
|
42
|
+
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
43
|
+
return cls(torch.load(path, weights_only=False))
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def plot(self) -> None:
|
|
47
|
+
# TODO - placeholder for now
|
|
48
|
+
raise NotImplementedError("Plotting not yet implemented for SiteSample")
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""PVNet Site sample implementation for netCDF data handling and conversion."""
|
|
2
|
-
|
|
3
|
-
import xarray as xr
|
|
4
|
-
from typing_extensions import override
|
|
5
|
-
|
|
6
|
-
from ocf_data_sampler.numpy_sample.common_types import NumpySample
|
|
7
|
-
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
8
|
-
|
|
9
|
-
from .base import SampleBase
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class SiteSample(SampleBase):
|
|
13
|
-
"""Handles PVNet site specific netCDF operations."""
|
|
14
|
-
|
|
15
|
-
def __init__(self, data: xr.Dataset) -> None:
|
|
16
|
-
"""Initializes the SiteSample object with the given xarray Dataset."""
|
|
17
|
-
if not isinstance(data, xr.Dataset):
|
|
18
|
-
raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
|
|
19
|
-
self._data = data
|
|
20
|
-
|
|
21
|
-
@override
|
|
22
|
-
def to_numpy(self) -> NumpySample:
|
|
23
|
-
return convert_netcdf_to_numpy_sample(self._data)
|
|
24
|
-
|
|
25
|
-
@override
|
|
26
|
-
def save(self, path: str) -> None:
|
|
27
|
-
# Saves as NetCDF
|
|
28
|
-
self._data.to_netcdf(path, mode="w", engine="h5netcdf")
|
|
29
|
-
|
|
30
|
-
@classmethod
|
|
31
|
-
@override
|
|
32
|
-
def load(cls, path: str) -> "SiteSample":
|
|
33
|
-
# Loads from NetCDF
|
|
34
|
-
return cls(xr.open_dataset(path, decode_timedelta=False))
|
|
35
|
-
|
|
36
|
-
@override
|
|
37
|
-
def plot(self) -> None:
|
|
38
|
-
# TODO - placeholder for now
|
|
39
|
-
raise NotImplementedError("Plotting not yet implemented for SiteSample")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ecmwf.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/gfs.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/icon.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ukv.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/utils.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/open_tensorstore_zarrs.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/common_types.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/datetime_features.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/satellite.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/sun_position.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_spatial_slice.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_time_slice.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/__init__.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/base.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.4.0 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|