ocf-data-sampler 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +1 -0
- ocf_data_sampler/load/load_dataset.py +5 -1
- ocf_data_sampler/load/nwp/nwp.py +12 -2
- ocf_data_sampler/load/nwp/providers/cloudcasting.py +2 -4
- ocf_data_sampler/load/nwp/providers/ecmwf.py +1 -1
- ocf_data_sampler/load/nwp/providers/gfs.py +4 -3
- ocf_data_sampler/load/nwp/providers/icon.py +2 -2
- ocf_data_sampler/load/nwp/providers/ukv.py +1 -1
- ocf_data_sampler/load/nwp/providers/utils.py +16 -7
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +11 -13
- {ocf_data_sampler-0.2.19.dist-info → ocf_data_sampler-0.2.21.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.2.19.dist-info → ocf_data_sampler-0.2.21.dist-info}/RECORD +14 -14
- {ocf_data_sampler-0.2.19.dist-info → ocf_data_sampler-0.2.21.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.2.19.dist-info → ocf_data_sampler-0.2.21.dist-info}/top_level.txt +0 -0
ocf_data_sampler/config/model.py
CHANGED
|
@@ -211,6 +211,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConsta
|
|
|
211
211
|
" used to construct an example. If set to None, then the max staleness is set according to"
|
|
212
212
|
" the maximum forecast horizon of the NWP and the requested forecast length.",
|
|
213
213
|
)
|
|
214
|
+
public: bool = Field(False, description="Whether the NWP data is public or private")
|
|
214
215
|
|
|
215
216
|
@field_validator("provider")
|
|
216
217
|
def validate_provider(cls, v: str) -> str:
|
|
@@ -38,7 +38,11 @@ def get_dataset_dict(
|
|
|
38
38
|
if input_config.nwp:
|
|
39
39
|
datasets_dict["nwp"] = {}
|
|
40
40
|
for nwp_source, nwp_config in input_config.nwp.items():
|
|
41
|
-
da_nwp = open_nwp(
|
|
41
|
+
da_nwp = open_nwp(
|
|
42
|
+
zarr_path=nwp_config.zarr_path,
|
|
43
|
+
provider=nwp_config.provider,
|
|
44
|
+
public=nwp_config.public,
|
|
45
|
+
)
|
|
42
46
|
|
|
43
47
|
da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
|
|
44
48
|
|
ocf_data_sampler/load/nwp/nwp.py
CHANGED
|
@@ -9,18 +9,23 @@ from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
|
|
|
9
9
|
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
|
|
12
|
+
def open_nwp(zarr_path: str | list[str], provider: str, public: bool = False) -> xr.DataArray:
|
|
13
13
|
"""Opens NWP zarr.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
zarr_path: path to the zarr file
|
|
17
17
|
provider: NWP provider
|
|
18
|
+
public: Whether the data is public or private (only for GFS)
|
|
18
19
|
|
|
19
20
|
Returns:
|
|
20
21
|
Xarray DataArray of the NWP data
|
|
21
22
|
"""
|
|
22
23
|
provider = provider.lower()
|
|
23
24
|
|
|
25
|
+
kwargs = {
|
|
26
|
+
"zarr_path": zarr_path,
|
|
27
|
+
}
|
|
28
|
+
|
|
24
29
|
if provider == "ukv":
|
|
25
30
|
_open_nwp = open_ukv
|
|
26
31
|
elif provider == "ecmwf":
|
|
@@ -29,9 +34,14 @@ def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
|
|
|
29
34
|
_open_nwp = open_icon_eu
|
|
30
35
|
elif provider == "gfs":
|
|
31
36
|
_open_nwp = open_gfs
|
|
37
|
+
|
|
38
|
+
# GFS has a public/private flag
|
|
39
|
+
if public:
|
|
40
|
+
kwargs["public"] = True
|
|
41
|
+
|
|
32
42
|
elif provider == "cloudcasting":
|
|
33
43
|
_open_nwp = open_cloudcasting
|
|
34
44
|
else:
|
|
35
45
|
raise ValueError(f"Unknown provider: {provider}")
|
|
36
46
|
|
|
37
|
-
return _open_nwp(
|
|
47
|
+
return _open_nwp(**kwargs)
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Cloudcasting provider loader."""
|
|
2
2
|
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
|
|
5
3
|
import xarray as xr
|
|
6
4
|
|
|
7
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
@@ -12,14 +10,14 @@ from ocf_data_sampler.load.utils import (
|
|
|
12
10
|
)
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
def open_cloudcasting(zarr_path:
|
|
13
|
+
def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray:
|
|
16
14
|
"""Opens the satellite predictions from cloudcasting.
|
|
17
15
|
|
|
18
16
|
Cloudcasting is a OCF forecast product. We forecast future satellite images from recent
|
|
19
17
|
satellite images. More information can be found in the references below.
|
|
20
18
|
|
|
21
19
|
Args:
|
|
22
|
-
zarr_path: Path to the zarr to open
|
|
20
|
+
zarr_path: Path to the zarr(s) to open
|
|
23
21
|
|
|
24
22
|
Returns:
|
|
25
23
|
Xarray DataArray of the cloudcasting data
|
|
@@ -10,11 +10,12 @@ from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spati
|
|
|
10
10
|
_log = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
13
|
+
def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray:
|
|
14
14
|
"""Opens the GFS data.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
|
-
zarr_path: Path to the zarr to open
|
|
17
|
+
zarr_path: Path to the zarr(s) to open
|
|
18
|
+
public: Whether the data is public or private
|
|
18
19
|
|
|
19
20
|
Returns:
|
|
20
21
|
Xarray DataArray of the NWP data
|
|
@@ -22,7 +23,7 @@ def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
22
23
|
_log.info("Loading NWP GFS data")
|
|
23
24
|
|
|
24
25
|
# Open data
|
|
25
|
-
gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc")
|
|
26
|
+
gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc", public=public)
|
|
26
27
|
nwp: xr.DataArray = gfs.to_array()
|
|
27
28
|
nwp = nwp.rename({"variable": "channel"}) # `variable` appears when using `to_array`
|
|
28
29
|
|
|
@@ -19,7 +19,7 @@ def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
|
|
|
19
19
|
return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def open_icon_eu(zarr_path: str) -> xr.Dataset:
|
|
22
|
+
def open_icon_eu(zarr_path: str | list[str]) -> xr.Dataset:
|
|
23
23
|
"""Opens the ICON data.
|
|
24
24
|
|
|
25
25
|
ICON EU Data is on a regular lat/lon grid
|
|
@@ -27,7 +27,7 @@ def open_icon_eu(zarr_path: str) -> xr.Dataset:
|
|
|
27
27
|
Each of the variables is its own data variable
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
|
-
zarr_path: Path to the zarr to open
|
|
30
|
+
zarr_path: Path to the zarr(s) to open
|
|
31
31
|
|
|
32
32
|
Returns:
|
|
33
33
|
Xarray DataArray of the NWP data
|
|
@@ -3,32 +3,41 @@
|
|
|
3
3
|
import xarray as xr
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def open_zarr_paths(
|
|
6
|
+
def open_zarr_paths(
|
|
7
|
+
zarr_path: str | list[str], time_dim: str = "init_time", public: bool = False,
|
|
8
|
+
) -> xr.Dataset:
|
|
7
9
|
"""Opens the NWP data.
|
|
8
10
|
|
|
9
11
|
Args:
|
|
10
12
|
zarr_path: Path to the zarr(s) to open
|
|
11
13
|
time_dim: Name of the time dimension
|
|
14
|
+
public: Whether the data is public or private
|
|
12
15
|
|
|
13
16
|
Returns:
|
|
14
17
|
The opened Xarray Dataset
|
|
15
18
|
"""
|
|
19
|
+
general_kwargs = {
|
|
20
|
+
"engine": "zarr",
|
|
21
|
+
"chunks": "auto",
|
|
22
|
+
"decode_timedelta": True,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
if public:
|
|
26
|
+
# note this only works for s3 zarr paths at the moment
|
|
27
|
+
general_kwargs["storage_options"] = {"anon": True}
|
|
28
|
+
|
|
16
29
|
if type(zarr_path) in [list, tuple] or "*" in str(zarr_path): # Multi-file dataset
|
|
17
30
|
ds = xr.open_mfdataset(
|
|
18
31
|
zarr_path,
|
|
19
|
-
engine="zarr",
|
|
20
32
|
concat_dim=time_dim,
|
|
21
33
|
combine="nested",
|
|
22
|
-
|
|
23
|
-
decode_timedelta=True,
|
|
34
|
+
**general_kwargs,
|
|
24
35
|
).sortby(time_dim)
|
|
25
36
|
else:
|
|
26
37
|
ds = xr.open_dataset(
|
|
27
38
|
zarr_path,
|
|
28
|
-
engine="zarr",
|
|
29
39
|
consolidated=True,
|
|
30
40
|
mode="r",
|
|
31
|
-
|
|
32
|
-
decode_timedelta=True,
|
|
41
|
+
**general_kwargs,
|
|
33
42
|
)
|
|
34
43
|
return ds
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Torch dataset for UK PVNet."""
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
3
|
import pandas as pd
|
|
5
4
|
import xarray as xr
|
|
6
5
|
from torch.utils.data import Dataset
|
|
@@ -257,22 +256,12 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
257
256
|
# Construct a lookup for locations - useful for users to construct sample by GSP ID
|
|
258
257
|
location_lookup = {loc.id: loc for loc in self.locations}
|
|
259
258
|
|
|
260
|
-
# Construct indices for sampling
|
|
261
|
-
t_index, loc_index = np.meshgrid(
|
|
262
|
-
np.arange(len(self.valid_t0_times)),
|
|
263
|
-
np.arange(len(self.locations)),
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
|
|
267
|
-
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
|
|
268
|
-
|
|
269
259
|
# Assign coords and indices to self
|
|
270
260
|
self.location_lookup = location_lookup
|
|
271
|
-
self.index_pairs = index_pairs
|
|
272
261
|
|
|
273
262
|
@override
|
|
274
263
|
def __len__(self) -> int:
|
|
275
|
-
return len(self.
|
|
264
|
+
return len(self.locations)*len(self.valid_t0_times)
|
|
276
265
|
|
|
277
266
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpySample:
|
|
278
267
|
"""Generate the PVNet sample for given coordinates.
|
|
@@ -290,7 +279,16 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
290
279
|
@override
|
|
291
280
|
def __getitem__(self, idx: int) -> NumpySample:
|
|
292
281
|
# Get the coordinates of the sample
|
|
293
|
-
|
|
282
|
+
|
|
283
|
+
if idx >= len(self):
|
|
284
|
+
raise ValueError(f"Index {idx} out of range for dataset of length {len(self)}")
|
|
285
|
+
|
|
286
|
+
# t_index will be between 0 and len(self.valid_t0_times)-1
|
|
287
|
+
t_index = idx % len(self.valid_t0_times)
|
|
288
|
+
|
|
289
|
+
# For each location, there are len(self.valid_t0_times) possible samples
|
|
290
|
+
loc_index = idx // len(self.valid_t0_times)
|
|
291
|
+
|
|
294
292
|
location = self.locations[loc_index]
|
|
295
293
|
t0 = self.valid_t0_times[t_index]
|
|
296
294
|
|
|
@@ -2,25 +2,25 @@ ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,
|
|
|
2
2
|
ocf_data_sampler/utils.py,sha256=DjuneGGisl08ENvPZV_lrcX4b2NCKJC1ZpXgIpxuQi4,290
|
|
3
3
|
ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
|
|
4
4
|
ocf_data_sampler/config/load.py,sha256=LL-7wemI8o4KPkx35j-wQ3HjsMvDgqXr7G46IcASfnU,632
|
|
5
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
5
|
+
ocf_data_sampler/config/model.py,sha256=iqffLs_VDqw9jOTLWchVFK4c6FWxHAzCngSfkjLUyCY,10516
|
|
6
6
|
ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdRaI,1064
|
|
7
7
|
ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=UfPxwHw2Dw2xYSO5Al28oTamgnEM_n_4bYXsqGwY5Tc,1884
|
|
11
|
-
ocf_data_sampler/load/load_dataset.py,sha256=
|
|
11
|
+
ocf_data_sampler/load/load_dataset.py,sha256=wSXPUQKgGRM6HC-yBXQ2IcDBQDckOSllmbGnhqikFMQ,2055
|
|
12
12
|
ocf_data_sampler/load/satellite.py,sha256=E7Ln7Y60Qr1RTV-_R71YoxXQM-Ca7Y1faIo3oKB2eFk,2292
|
|
13
13
|
ocf_data_sampler/load/site.py,sha256=zOzlWk6pYZBB5daqG8URGksmDXWKrkutUvN8uALAIh8,1468
|
|
14
14
|
ocf_data_sampler/load/utils.py,sha256=sZ0-zzconcLkVQwAkCYrqKDo98Hrh5ChdiQJv5Bh91g,2040
|
|
15
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
16
|
-
ocf_data_sampler/load/nwp/nwp.py,sha256=
|
|
16
|
+
ocf_data_sampler/load/nwp/nwp.py,sha256=PNNYYREEGQT4sxGilNzuthKKOmVJdQL8R2r8bvzyEr0,1317
|
|
17
17
|
ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
ocf_data_sampler/load/nwp/providers/cloudcasting.py,sha256=
|
|
19
|
-
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=
|
|
20
|
-
ocf_data_sampler/load/nwp/providers/gfs.py,sha256=
|
|
21
|
-
ocf_data_sampler/load/nwp/providers/icon.py,sha256=
|
|
22
|
-
ocf_data_sampler/load/nwp/providers/ukv.py,sha256
|
|
23
|
-
ocf_data_sampler/load/nwp/providers/utils.py,sha256=
|
|
18
|
+
ocf_data_sampler/load/nwp/providers/cloudcasting.py,sha256=fozXpB3a2rNqQgnpRDC7xunxffh7Wwmc0kkCiYmDVJ4,1521
|
|
19
|
+
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=an-gXsZwkPQvRXeza1U_4MNU5yEnVm0_8tn03rxTudI,997
|
|
20
|
+
ocf_data_sampler/load/nwp/providers/gfs.py,sha256=glBbo2kXtcTjQv_VNqA32lsdCCGB114Ovm-cibRWxTA,1088
|
|
21
|
+
ocf_data_sampler/load/nwp/providers/icon.py,sha256=6MkOfUk5dmv0XJZLrKMy1e8xipj2fHCTkYXuff7MgUY,1584
|
|
22
|
+
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=Ka1KFZcJYPwr5vuxo-xWGVQC0pudheqGBonUnbyJCMg,1016
|
|
23
|
+
ocf_data_sampler/load/nwp/providers/utils.py,sha256=NrzE3JAtoc6oEywJHxPUdi_I4UJgJ_l5dxLZ4DLKvcg,1124
|
|
24
24
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
25
25
|
ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lHUgaKgHIoo80,2039
|
|
26
26
|
ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
|
|
@@ -39,7 +39,7 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
|
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=liAqIa-Amj58pOqx5r16i99HURj9oQ41j7gnPgRDQP4,8201
|
|
40
40
|
ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
|
|
41
41
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
|
|
42
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
42
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=cd4IyzYu8rMFgLHRXqYpnOIAZe4Yl21YdLmDQw45F7o,12545
|
|
43
43
|
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nRUlhXQQGVrTuBmE1QnwXAUsPTXz0dsezlQjwK71jIQ,17641
|
|
44
44
|
ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
|
|
45
45
|
ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
|
|
@@ -55,7 +55,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
|
|
|
55
55
|
scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
|
|
56
56
|
scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
|
|
57
57
|
utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
|
|
58
|
-
ocf_data_sampler-0.2.
|
|
59
|
-
ocf_data_sampler-0.2.
|
|
60
|
-
ocf_data_sampler-0.2.
|
|
61
|
-
ocf_data_sampler-0.2.
|
|
58
|
+
ocf_data_sampler-0.2.21.dist-info/METADATA,sha256=PKoq-iYwiK9VGZLHgAQCQJm-U97WMkb7HToAsrPNzlw,11581
|
|
59
|
+
ocf_data_sampler-0.2.21.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
60
|
+
ocf_data_sampler-0.2.21.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
|
|
61
|
+
ocf_data_sampler-0.2.21.dist-info/RECORD,,
|
|
File without changes
|