ocf-data-sampler 0.2.36__tar.gz → 0.2.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/PKG-INFO +1 -1
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/gfs.py +1 -2
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/find_contiguous_time_periods.py +9 -3
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +1 -11
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/datasets/site.py +271 -61
- ocf_data_sampler-0.2.38/ocf_data_sampler/utils.py +21 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/PKG-INFO +1 -1
- ocf_data_sampler-0.2.36/ocf_data_sampler/utils.py +0 -12
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/LICENSE +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/README.md +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/config/model.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/satellite.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/site.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/collate.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/site.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/dropout.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/fill_time_periods.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/base.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/site.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/__init__.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/pyproject.toml +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/scripts/download_gsp_location_data.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/setup.cfg +0 -0
- {ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/utils/compute_icon_mean_stddev.py +0 -0
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/gfs.py
RENAMED
|
@@ -24,8 +24,7 @@ def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray:
|
|
|
24
24
|
|
|
25
25
|
# Open data
|
|
26
26
|
gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc", public=public)
|
|
27
|
-
nwp: xr.DataArray = gfs.to_array()
|
|
28
|
-
nwp = nwp.rename({"variable": "channel"}) # `variable` appears when using `to_array`
|
|
27
|
+
nwp: xr.DataArray = gfs.to_array(dim="channel")
|
|
29
28
|
|
|
30
29
|
del gfs
|
|
31
30
|
|
|
@@ -242,6 +242,11 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
242
242
|
if a.empty or b.empty:
|
|
243
243
|
return pd.DataFrame(columns=["start_dt", "end_dt"])
|
|
244
244
|
|
|
245
|
+
# Maybe switch these for efficiency in the next section. We will do the native python loop over
|
|
246
|
+
# the shorter dataframe
|
|
247
|
+
if len(a) > len(b):
|
|
248
|
+
a, b = b, a
|
|
249
|
+
|
|
245
250
|
all_intersecting_periods = []
|
|
246
251
|
for a_period in a.itertuples():
|
|
247
252
|
# Five ways in which two periods may overlap:
|
|
@@ -250,12 +255,12 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
250
255
|
# In all five, `a` must always start before (or equal to) where `b` ends,
|
|
251
256
|
# and `a` must always end after (or equal to) where `b` starts.
|
|
252
257
|
|
|
253
|
-
overlapping_periods = b[(a_period.start_dt <= b.end_dt) & (a_period.end_dt >= b.start_dt)]
|
|
254
|
-
|
|
255
258
|
# There are two ways in which two periods may *not* overlap:
|
|
256
259
|
# a: |---| or |---|
|
|
257
260
|
# b: |---| |---|
|
|
258
|
-
# `
|
|
261
|
+
# `overlapping_periods` will not include periods which do *not* overlap.
|
|
262
|
+
|
|
263
|
+
overlapping_periods = b[(a_period.start_dt <= b.end_dt) & (a_period.end_dt >= b.start_dt)]
|
|
259
264
|
|
|
260
265
|
# Now find the intersection of each period in `overlapping_periods` with
|
|
261
266
|
# the period from `a` that starts at `a_start_dt` and ends at `a_end_dt`.
|
|
@@ -269,5 +274,6 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
269
274
|
|
|
270
275
|
all_intersecting_periods.append(intersection)
|
|
271
276
|
|
|
277
|
+
|
|
272
278
|
all_intersecting_periods = pd.concat(all_intersecting_periods)
|
|
273
279
|
return all_intersecting_periods.sort_values(by="start_dt").reset_index(drop=True)
|
|
@@ -30,21 +30,11 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
|
30
30
|
fill_nans_in_arrays,
|
|
31
31
|
merge_dicts,
|
|
32
32
|
)
|
|
33
|
-
from ocf_data_sampler.utils import minutes
|
|
33
|
+
from ocf_data_sampler.utils import compute, minutes
|
|
34
34
|
|
|
35
35
|
xr.set_options(keep_attrs=True)
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
def compute(xarray_dict: dict) -> dict:
|
|
39
|
-
"""Eagerly load a nested dictionary of xarray DataArrays."""
|
|
40
|
-
for k, v in xarray_dict.items():
|
|
41
|
-
if isinstance(v, dict):
|
|
42
|
-
xarray_dict[k] = compute(v)
|
|
43
|
-
else:
|
|
44
|
-
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
45
|
-
return xarray_dict
|
|
46
|
-
|
|
47
|
-
|
|
48
38
|
def get_gsp_locations(
|
|
49
39
|
gsp_ids: list[int] | None = None,
|
|
50
40
|
version: str = "20220314",
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/datasets/site.py
RENAMED
|
@@ -6,7 +6,7 @@ import xarray as xr
|
|
|
6
6
|
from torch.utils.data import Dataset
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from ocf_data_sampler.config import load_yaml_configuration
|
|
9
|
+
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
10
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
11
|
from ocf_data_sampler.numpy_sample import (
|
|
12
12
|
NWPSampleKey,
|
|
@@ -16,6 +16,7 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
16
16
|
make_datetime_numpy_dict,
|
|
17
17
|
make_sun_position_numpy_sample,
|
|
18
18
|
)
|
|
19
|
+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
19
20
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample
|
|
20
21
|
from ocf_data_sampler.select import (
|
|
21
22
|
Location,
|
|
@@ -33,11 +34,31 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
|
33
34
|
fill_nans_in_arrays,
|
|
34
35
|
merge_dicts,
|
|
35
36
|
)
|
|
36
|
-
from ocf_data_sampler.utils import minutes
|
|
37
|
+
from ocf_data_sampler.utils import compute, minutes
|
|
37
38
|
|
|
38
39
|
xr.set_options(keep_attrs=True)
|
|
39
40
|
|
|
40
41
|
|
|
42
|
+
def get_locations(site_xr: xr.Dataset) -> list[Location]:
|
|
43
|
+
"""Get list of locations of all sites.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
site_xr: xarray Dataset of site data
|
|
47
|
+
"""
|
|
48
|
+
locations = []
|
|
49
|
+
for site_id in site_xr.site_id.values:
|
|
50
|
+
site = site_xr.sel(site_id=site_id)
|
|
51
|
+
location = Location(
|
|
52
|
+
id=site_id,
|
|
53
|
+
x=site.longitude.values,
|
|
54
|
+
y=site.latitude.values,
|
|
55
|
+
coordinate_system="lon_lat",
|
|
56
|
+
)
|
|
57
|
+
locations.append(location)
|
|
58
|
+
|
|
59
|
+
return locations
|
|
60
|
+
|
|
61
|
+
|
|
41
62
|
class SitesDataset(Dataset):
|
|
42
63
|
"""A torch Dataset for creating PVNet Site samples."""
|
|
43
64
|
|
|
@@ -62,7 +83,7 @@ class SitesDataset(Dataset):
|
|
|
62
83
|
self.config = config
|
|
63
84
|
|
|
64
85
|
# get all locations
|
|
65
|
-
self.locations =
|
|
86
|
+
self.locations = get_locations(datasets_dict["site"])
|
|
66
87
|
self.location_lookup = {loc.id: loc for loc in self.locations}
|
|
67
88
|
|
|
68
89
|
# Get t0 times where all input data is available
|
|
@@ -82,48 +103,6 @@ class SitesDataset(Dataset):
|
|
|
82
103
|
# Assign coords and indices to self
|
|
83
104
|
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
84
105
|
|
|
85
|
-
@override
|
|
86
|
-
def __len__(self) -> int:
|
|
87
|
-
return len(self.valid_t0_and_site_ids)
|
|
88
|
-
|
|
89
|
-
@override
|
|
90
|
-
def __getitem__(self, idx: int) -> dict:
|
|
91
|
-
# Get the coordinates of the sample
|
|
92
|
-
t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
|
|
93
|
-
|
|
94
|
-
# get location from site id
|
|
95
|
-
location = self.location_lookup[site_id]
|
|
96
|
-
|
|
97
|
-
# Generate the sample
|
|
98
|
-
return self._get_sample(t0, location)
|
|
99
|
-
|
|
100
|
-
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
101
|
-
"""Generate the PVNet sample for given coordinates.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
t0: init-time for sample
|
|
105
|
-
location: location for sample
|
|
106
|
-
"""
|
|
107
|
-
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
108
|
-
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
109
|
-
|
|
110
|
-
sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
|
|
111
|
-
return sample.compute()
|
|
112
|
-
|
|
113
|
-
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
114
|
-
"""Generate a sample for a given site id and t0.
|
|
115
|
-
|
|
116
|
-
Useful for users to generate samples by t0 and site id
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
t0: init-time for sample
|
|
120
|
-
site_id: site id as int
|
|
121
|
-
"""
|
|
122
|
-
location = self.location_lookup[site_id]
|
|
123
|
-
|
|
124
|
-
return self._get_sample(t0, location)
|
|
125
|
-
|
|
126
|
-
|
|
127
106
|
def find_valid_t0_and_site_ids(
|
|
128
107
|
self,
|
|
129
108
|
datasets_dict: dict,
|
|
@@ -177,25 +156,46 @@ class SitesDataset(Dataset):
|
|
|
177
156
|
valid_t0_and_site_ids.index.name = "t0"
|
|
178
157
|
return valid_t0_and_site_ids.reset_index()
|
|
179
158
|
|
|
159
|
+
@override
|
|
160
|
+
def __len__(self) -> int:
|
|
161
|
+
return len(self.valid_t0_and_site_ids)
|
|
162
|
+
|
|
163
|
+
@override
|
|
164
|
+
def __getitem__(self, idx: int) -> dict:
|
|
165
|
+
# Get the coordinates of the sample
|
|
166
|
+
t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
|
|
180
167
|
|
|
181
|
-
|
|
182
|
-
|
|
168
|
+
# get location from site id
|
|
169
|
+
location = self.location_lookup[site_id]
|
|
170
|
+
|
|
171
|
+
# Generate the sample
|
|
172
|
+
return self._get_sample(t0, location)
|
|
173
|
+
|
|
174
|
+
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
175
|
+
"""Generate the PVNet sample for given coordinates.
|
|
183
176
|
|
|
184
177
|
Args:
|
|
185
|
-
|
|
178
|
+
t0: init-time for sample
|
|
179
|
+
location: location for sample
|
|
186
180
|
"""
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
181
|
+
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
182
|
+
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
183
|
+
|
|
184
|
+
sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
|
|
185
|
+
return sample.compute()
|
|
186
|
+
|
|
187
|
+
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
188
|
+
"""Generate a sample for a given site id and t0.
|
|
189
|
+
|
|
190
|
+
Useful for users to generate samples by t0 and site id
|
|
197
191
|
|
|
198
|
-
|
|
192
|
+
Args:
|
|
193
|
+
t0: init-time for sample
|
|
194
|
+
site_id: site id as int
|
|
195
|
+
"""
|
|
196
|
+
location = self.location_lookup[site_id]
|
|
197
|
+
|
|
198
|
+
return self._get_sample(t0, location)
|
|
199
199
|
|
|
200
200
|
def process_and_combine_site_sample_dict(
|
|
201
201
|
self,
|
|
@@ -256,8 +256,8 @@ class SitesDataset(Dataset):
|
|
|
256
256
|
|
|
257
257
|
# Only add solar position if explicitly configured
|
|
258
258
|
has_solar_config = (
|
|
259
|
-
hasattr(self.config.input_data, "solar_position")
|
|
260
|
-
self.config.input_data.solar_position is not None
|
|
259
|
+
hasattr(self.config.input_data, "solar_position")
|
|
260
|
+
and self.config.input_data.solar_position is not None
|
|
261
261
|
)
|
|
262
262
|
|
|
263
263
|
if has_solar_config:
|
|
@@ -351,6 +351,216 @@ class SitesDataset(Dataset):
|
|
|
351
351
|
return combined_dataset
|
|
352
352
|
|
|
353
353
|
|
|
354
|
+
class SitesDatasetConcurrent(Dataset):
|
|
355
|
+
"""A torch Dataset for creating PVNet Site batches with samples for all sites."""
|
|
356
|
+
|
|
357
|
+
def __init__(
|
|
358
|
+
self,
|
|
359
|
+
config_filename: str,
|
|
360
|
+
start_time: str | None = None,
|
|
361
|
+
end_time: str | None = None,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""A torch Dataset for creating PVNet Site samples.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
config_filename: Path to the configuration file
|
|
367
|
+
start_time: Limit the init-times to be after this
|
|
368
|
+
end_time: Limit the init-times to be before this
|
|
369
|
+
"""
|
|
370
|
+
config = load_yaml_configuration(config_filename)
|
|
371
|
+
datasets_dict = get_dataset_dict(config.input_data)
|
|
372
|
+
|
|
373
|
+
# Assign config and input data to self
|
|
374
|
+
self.datasets_dict = datasets_dict
|
|
375
|
+
self.config = config
|
|
376
|
+
|
|
377
|
+
# get all locations
|
|
378
|
+
self.locations = get_locations(datasets_dict["site"])
|
|
379
|
+
|
|
380
|
+
# Get t0 times where all input data is available
|
|
381
|
+
valid_t0s = self.find_valid_t0s(datasets_dict)
|
|
382
|
+
|
|
383
|
+
# Filter t0 times to given range
|
|
384
|
+
if start_time is not None:
|
|
385
|
+
valid_t0s = valid_t0s[
|
|
386
|
+
valid_t0s >= pd.Timestamp(start_time)
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
if end_time is not None:
|
|
390
|
+
valid_t0s = valid_t0s[
|
|
391
|
+
valid_t0s <= pd.Timestamp(end_time)
|
|
392
|
+
]
|
|
393
|
+
|
|
394
|
+
# Assign coords and indices to self
|
|
395
|
+
self.valid_t0s = valid_t0s
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def process_and_combine_datasets(
|
|
399
|
+
dataset_dict: dict,
|
|
400
|
+
config: Configuration,
|
|
401
|
+
t0: pd.Timestamp,
|
|
402
|
+
) -> NumpySample:
|
|
403
|
+
"""Normalise and convert data to numpy arrays.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
dataset_dict: Dictionary of xarray datasets
|
|
407
|
+
config: Configuration object
|
|
408
|
+
t0: init-time for sample
|
|
409
|
+
"""
|
|
410
|
+
numpy_modalities = []
|
|
411
|
+
|
|
412
|
+
if "nwp" in dataset_dict:
|
|
413
|
+
nwp_numpy_modalities = {}
|
|
414
|
+
|
|
415
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
416
|
+
# Standardise and convert to NumpyBatch
|
|
417
|
+
|
|
418
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
419
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
420
|
+
)
|
|
421
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
422
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
426
|
+
|
|
427
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
428
|
+
|
|
429
|
+
# Combine the NWPs into NumpyBatch
|
|
430
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
431
|
+
|
|
432
|
+
if "sat" in dataset_dict:
|
|
433
|
+
da_sat = dataset_dict["sat"]
|
|
434
|
+
|
|
435
|
+
# Standardise and convert to NumpyBatch
|
|
436
|
+
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
437
|
+
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
438
|
+
|
|
439
|
+
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
440
|
+
|
|
441
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
442
|
+
|
|
443
|
+
if "site" in dataset_dict:
|
|
444
|
+
da_sites = dataset_dict["site"]
|
|
445
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
446
|
+
|
|
447
|
+
# Convert to NumpyBatch
|
|
448
|
+
numpy_modalities.append(
|
|
449
|
+
convert_site_to_numpy_sample(
|
|
450
|
+
da_sites,
|
|
451
|
+
),
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Only add solar position if explicitly configured
|
|
455
|
+
has_solar_config = (
|
|
456
|
+
hasattr(config.input_data, "solar_position")
|
|
457
|
+
and config.input_data.solar_position is not None
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if has_solar_config:
|
|
461
|
+
solar_config = config.input_data.solar_position
|
|
462
|
+
|
|
463
|
+
# Create datetime range for solar position calculation
|
|
464
|
+
datetimes = pd.date_range(
|
|
465
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
466
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
467
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# Calculate solar positions and add to modalities
|
|
471
|
+
numpy_modalities.append(
|
|
472
|
+
make_sun_position_numpy_sample(
|
|
473
|
+
datetimes, da_sites.longitude.values, da_sites.latitude.values,
|
|
474
|
+
),
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Combine all the modalities and fill NaNs
|
|
478
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
479
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
480
|
+
|
|
481
|
+
return combined_sample
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def find_valid_t0s(
|
|
485
|
+
self,
|
|
486
|
+
datasets_dict: dict,
|
|
487
|
+
) -> pd.DataFrame:
|
|
488
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
489
|
+
|
|
490
|
+
The idea is to
|
|
491
|
+
1. Get valid time period for nwp and satellite
|
|
492
|
+
2. For the first site location, find valid periods for that location
|
|
493
|
+
Note there is an assumption that all sites have the same t0 data available
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
datasets_dict: A dictionary of input datasets
|
|
497
|
+
"""
|
|
498
|
+
# Get valid time period for nwp and satellite
|
|
499
|
+
datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
|
|
500
|
+
valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
|
|
501
|
+
sites = datasets_dict["site"]
|
|
502
|
+
|
|
503
|
+
# Taking just the first site value, assume t0s the same for all of them
|
|
504
|
+
site_id = sites.site_id.values[0]
|
|
505
|
+
site_config = self.config.input_data.site
|
|
506
|
+
site = sites.sel(site_id=site_id)
|
|
507
|
+
# Drop NaN values
|
|
508
|
+
site = site.dropna(dim="time_utc")
|
|
509
|
+
|
|
510
|
+
# Obtain valid time periods for this location
|
|
511
|
+
time_periods = find_contiguous_t0_periods(
|
|
512
|
+
pd.DatetimeIndex(site["time_utc"]),
|
|
513
|
+
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
514
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
515
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
516
|
+
)
|
|
517
|
+
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
518
|
+
[valid_time_periods, time_periods],
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Fill out contiguous time periods to get t0 times
|
|
522
|
+
valid_t0_times_per_site = fill_time_periods(
|
|
523
|
+
valid_time_periods_per_site,
|
|
524
|
+
freq=minutes(site_config.time_resolution_minutes),
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
return valid_t0_times_per_site
|
|
528
|
+
|
|
529
|
+
@override
|
|
530
|
+
def __len__(self) -> int:
|
|
531
|
+
return len(self.valid_t0s)
|
|
532
|
+
|
|
533
|
+
@override
|
|
534
|
+
def __getitem__(self, idx: int) -> dict:
|
|
535
|
+
# Get the coordinates of the sample
|
|
536
|
+
t0 = self.valid_t0s[idx]
|
|
537
|
+
|
|
538
|
+
return self._get_batch(t0)
|
|
539
|
+
|
|
540
|
+
def _get_batch(self, t0: pd.Timestamp) -> dict:
|
|
541
|
+
"""Generate the PVNet batch for given coordinates.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
t0: init-time for sample
|
|
545
|
+
"""
|
|
546
|
+
# slice by time first as we want to keep all site id info
|
|
547
|
+
sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
|
|
548
|
+
sample_dict = compute(sample_dict)
|
|
549
|
+
|
|
550
|
+
site_samples = []
|
|
551
|
+
|
|
552
|
+
for location in self.locations:
|
|
553
|
+
site_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
554
|
+
site_numpy_sample = self.process_and_combine_datasets(
|
|
555
|
+
site_sample_dict,
|
|
556
|
+
self.config,
|
|
557
|
+
t0,
|
|
558
|
+
)
|
|
559
|
+
site_samples.append(site_numpy_sample)
|
|
560
|
+
|
|
561
|
+
return stack_np_samples_into_batch(site_samples)
|
|
562
|
+
|
|
563
|
+
|
|
354
564
|
# ----- functions to load presaved samples ------
|
|
355
565
|
|
|
356
566
|
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Miscellaneous helper functions."""
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
|
|
7
|
+
"""Timedelta minutes.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
minutes: the number of minutes, single value or list
|
|
11
|
+
"""
|
|
12
|
+
return pd.to_timedelta(minutes, unit="m")
|
|
13
|
+
|
|
14
|
+
def compute(xarray_dict: dict) -> dict:
|
|
15
|
+
"""Eagerly load a nested dictionary of xarray DataArrays."""
|
|
16
|
+
for k, v in xarray_dict.items():
|
|
17
|
+
if isinstance(v, dict):
|
|
18
|
+
xarray_dict[k] = compute(v)
|
|
19
|
+
else:
|
|
20
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
21
|
+
return xarray_dict
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
"""Miscellaneous helper functions."""
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
|
|
7
|
-
"""Timedelta minutes.
|
|
8
|
-
|
|
9
|
-
Args:
|
|
10
|
-
minutes: the number of minutes, single value or list
|
|
11
|
-
"""
|
|
12
|
-
return pd.to_timedelta(minutes, unit="m")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/ecmwf.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/icon.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/ukv.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/load/nwp/providers/utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/__init__.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/collate.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/common_types.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/satellite.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/numpy_sample/sun_position.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/select_spatial_slice.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/select/select_time_slice.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/base.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler/torch_datasets/sample/site.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.2.36 → ocf_data_sampler-0.2.38}/ocf_data_sampler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|