ocf-data-sampler 0.1.9__tar.gz → 0.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.1.9/ocf_data_sampler.egg-info → ocf_data_sampler-0.1.11}/PKG-INFO +1 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/model.py +25 -23
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/satellite.py +21 -29
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/site.py +1 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/gsp.py +6 -2
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/nwp.py +7 -13
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/satellite.py +11 -8
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/site.py +6 -2
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/sun_position.py +9 -10
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler-0.1.11/ocf_data_sampler/sample/base.py +79 -0
- ocf_data_sampler-0.1.11/ocf_data_sampler/sample/site.py +44 -0
- ocf_data_sampler-0.1.11/ocf_data_sampler/sample/uk_regional.py +75 -0
- ocf_data_sampler-0.1.11/ocf_data_sampler/select/dropout.py +52 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/fill_time_periods.py +3 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/pyproject.toml +1 -1
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_config.py +3 -3
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/conftest.py +33 -0
- ocf_data_sampler-0.1.11/tests/numpy_sample/test_nwp.py +13 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_dropout.py +7 -13
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_site_sample.py +5 -35
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_uk_regional_sample.py +8 -35
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_pvnet_uk.py +6 -19
- ocf_data_sampler-0.1.9/ocf_data_sampler/sample/base.py +0 -98
- ocf_data_sampler-0.1.9/ocf_data_sampler/sample/site.py +0 -81
- ocf_data_sampler-0.1.9/ocf_data_sampler/sample/uk_regional.py +0 -120
- ocf_data_sampler-0.1.9/ocf_data_sampler/select/dropout.py +0 -39
- ocf_data_sampler-0.1.9/tests/numpy_sample/test_nwp.py +0 -52
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/LICENSE +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/MANIFEST.in +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/README.md +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/constants.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/collate.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/utils.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/setup.cfg +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/__init__.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_load.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/config/test_save.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_gsp.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_nwp.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_satellite.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/load/test_load_sites.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_collate.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_datetime_features.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_gsp.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_satellite.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/numpy_sample/test_sun_position.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_fill_time_periods.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_location.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/select/test_select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/test_sample/test_base.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_site.py +0 -0
- {ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/tests/torch_datasets/test_validate_channels_utils.py +0 -0
|
@@ -49,31 +49,34 @@ class TimeWindowMixin(Base):
|
|
|
49
49
|
...,
|
|
50
50
|
description="Data interval ends at `t0 + interval_end_minutes`",
|
|
51
51
|
)
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
@model_validator(mode='after')
|
|
54
|
-
def
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
def validate_intervals(cls, values):
|
|
55
|
+
start = values.interval_start_minutes
|
|
56
|
+
end = values.interval_end_minutes
|
|
57
|
+
resolution = values.time_resolution_minutes
|
|
58
|
+
if start > end:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
|
|
61
|
+
)
|
|
62
|
+
if (start % resolution != 0):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"interval_start_minutes ({start}) must be divisible "
|
|
65
|
+
f"by time_resolution_minutes ({resolution})"
|
|
66
|
+
)
|
|
67
|
+
if (end % resolution != 0):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"interval_end_minutes ({end}) must be divisible "
|
|
70
|
+
f"by time_resolution_minutes ({resolution})"
|
|
71
|
+
)
|
|
57
72
|
return values
|
|
58
73
|
|
|
59
|
-
@field_validator("interval_start_minutes")
|
|
60
|
-
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
61
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
62
|
-
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
|
|
63
|
-
return v
|
|
64
|
-
|
|
65
|
-
@field_validator("interval_end_minutes")
|
|
66
|
-
def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
67
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
68
|
-
raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
|
|
69
|
-
return v
|
|
70
|
-
|
|
71
74
|
|
|
72
75
|
class DropoutMixin(Base):
|
|
73
76
|
"""Mixin class, to add dropout minutes"""
|
|
74
77
|
|
|
75
|
-
dropout_timedeltas_minutes:
|
|
76
|
-
default=
|
|
78
|
+
dropout_timedeltas_minutes: List[int] = Field(
|
|
79
|
+
default=[],
|
|
77
80
|
description="List of possible minutes before t0 where data availability may start. Must be "
|
|
78
81
|
"negative or zero.",
|
|
79
82
|
)
|
|
@@ -88,18 +91,17 @@ class DropoutMixin(Base):
|
|
|
88
91
|
@field_validator("dropout_timedeltas_minutes")
|
|
89
92
|
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
|
|
90
93
|
"""Validate 'dropout_timedeltas_minutes'"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
assert m <= 0, "Dropout timedeltas must be negative"
|
|
94
|
+
for m in v:
|
|
95
|
+
assert m <= 0, "Dropout timedeltas must be negative"
|
|
94
96
|
return v
|
|
95
97
|
|
|
96
98
|
@model_validator(mode="after")
|
|
97
99
|
def dropout_instructions_consistent(self) -> Self:
|
|
98
100
|
if self.dropout_fraction == 0:
|
|
99
|
-
if self.dropout_timedeltas_minutes
|
|
101
|
+
if self.dropout_timedeltas_minutes != []:
|
|
100
102
|
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
|
|
101
103
|
else:
|
|
102
|
-
if self.dropout_timedeltas_minutes
|
|
104
|
+
if self.dropout_timedeltas_minutes == []:
|
|
103
105
|
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
|
|
104
106
|
return self
|
|
105
107
|
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Satellite loader"""
|
|
2
2
|
|
|
3
|
-
import subprocess
|
|
4
|
-
|
|
5
3
|
import xarray as xr
|
|
6
4
|
from ocf_data_sampler.load.utils import (
|
|
7
5
|
check_time_unique_increasing,
|
|
@@ -10,63 +8,59 @@ from ocf_data_sampler.load.utils import (
|
|
|
10
8
|
)
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
def
|
|
14
|
-
"""Helper function to open a
|
|
11
|
+
def get_single_sat_data(zarr_path: str) -> xr.Dataset:
|
|
12
|
+
"""Helper function to open a zarr from either a local or GCP path
|
|
15
13
|
|
|
16
14
|
Args:
|
|
17
|
-
zarr_path:
|
|
18
|
-
GCS paths (gs://)
|
|
15
|
+
zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
|
|
16
|
+
GCS paths (gs://) do not support wildcards
|
|
19
17
|
|
|
20
18
|
Returns:
|
|
21
|
-
An xarray Dataset containing satellite data
|
|
19
|
+
An xarray Dataset containing satellite data
|
|
22
20
|
|
|
23
21
|
Raises:
|
|
24
|
-
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
22
|
+
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
|
-
# These kwargs are used if the path contains "*"
|
|
28
|
-
openmf_kwargs = dict(
|
|
29
|
-
engine="zarr",
|
|
30
|
-
concat_dim="time",
|
|
31
|
-
combine="nested",
|
|
32
|
-
chunks="auto",
|
|
33
|
-
join="override",
|
|
34
|
-
)
|
|
35
|
-
|
|
36
25
|
# Raise an error if a wildcard is used in a GCP path
|
|
37
26
|
if "gs://" in str(zarr_path) and "*" in str(zarr_path):
|
|
38
|
-
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs
|
|
27
|
+
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
|
|
39
28
|
|
|
40
29
|
# Handle multi-file dataset for local paths
|
|
41
30
|
if "*" in str(zarr_path):
|
|
42
|
-
ds = xr.open_mfdataset(
|
|
31
|
+
ds = xr.open_mfdataset(
|
|
32
|
+
zarr_path,
|
|
33
|
+
engine="zarr",
|
|
34
|
+
concat_dim="time",
|
|
35
|
+
combine="nested",
|
|
36
|
+
chunks="auto",
|
|
37
|
+
join="override",
|
|
38
|
+
)
|
|
39
|
+
check_time_unique_increasing(ds.time)
|
|
43
40
|
else:
|
|
44
41
|
ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
|
|
45
42
|
|
|
46
|
-
# Ensure time is unique and sorted
|
|
47
|
-
ds = ds.drop_duplicates("time").sortby("time")
|
|
48
|
-
|
|
49
43
|
return ds
|
|
50
44
|
|
|
51
45
|
|
|
52
46
|
def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
53
|
-
"""Lazily opens the
|
|
47
|
+
"""Lazily opens the zarr store
|
|
54
48
|
|
|
55
49
|
Args:
|
|
56
|
-
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
57
|
-
|
|
50
|
+
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
51
|
+
it must start with 'gs://'
|
|
58
52
|
"""
|
|
59
53
|
|
|
60
54
|
# Open the data
|
|
61
55
|
if isinstance(zarr_path, (list, tuple)):
|
|
62
56
|
ds = xr.combine_nested(
|
|
63
|
-
[
|
|
57
|
+
[get_single_sat_data(path) for path in zarr_path],
|
|
64
58
|
concat_dim="time",
|
|
65
59
|
combine_attrs="override",
|
|
66
60
|
join="override",
|
|
67
61
|
)
|
|
68
62
|
else:
|
|
69
|
-
ds =
|
|
63
|
+
ds = get_single_sat_data(zarr_path)
|
|
70
64
|
|
|
71
65
|
ds = ds.rename(
|
|
72
66
|
{
|
|
@@ -76,9 +70,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
76
70
|
)
|
|
77
71
|
|
|
78
72
|
check_time_unique_increasing(ds.time_utc)
|
|
79
|
-
|
|
80
73
|
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
|
|
81
|
-
|
|
82
74
|
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
83
75
|
|
|
84
76
|
# TODO: should we control the dtype of the DataArray?
|
|
@@ -20,7 +20,7 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
|
|
|
20
20
|
|
|
21
21
|
assert metadata_df.index.is_unique
|
|
22
22
|
|
|
23
|
-
# Ensure metadata aligns with the site_id dimension in
|
|
23
|
+
# Ensure metadata aligns with the site_id dimension in generation_ds
|
|
24
24
|
metadata_df = metadata_df.reindex(generation_ds.site_id.values)
|
|
25
25
|
|
|
26
26
|
# Assign coordinates to the Dataset using the aligned metadata
|
|
@@ -18,9 +18,13 @@ class GSPSampleKey:
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NumpySample
|
|
21
|
+
"""Convert from Xarray to NumpySample
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
da: Xarray DataArray containing GSP data
|
|
25
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
|
-
# Extract values from the DataArray
|
|
24
28
|
sample = {
|
|
25
29
|
GSPSampleKey.gsp: da.values,
|
|
26
30
|
GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
|
|
@@ -12,30 +12,24 @@ class NWPSampleKey:
|
|
|
12
12
|
step = 'nwp_step'
|
|
13
13
|
target_time_utc = 'nwp_target_time_utc'
|
|
14
14
|
t0_idx = 'nwp_t0_idx'
|
|
15
|
-
y_osgb = 'nwp_y_osgb'
|
|
16
|
-
x_osgb = 'nwp_x_osgb'
|
|
17
|
-
|
|
18
15
|
|
|
19
16
|
|
|
20
17
|
def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NWP NumpySample
|
|
18
|
+
"""Convert from Xarray to NWP NumpySample
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
da: Xarray DataArray containing NWP data
|
|
22
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the NWP
|
|
23
|
+
"""
|
|
22
24
|
|
|
23
|
-
# Create example and add t if available
|
|
24
25
|
sample = {
|
|
25
26
|
NWPSampleKey.nwp: da.values,
|
|
26
27
|
NWPSampleKey.channel_names: da.channel.values,
|
|
27
28
|
NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
|
|
28
29
|
NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
|
|
30
|
+
NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
|
|
29
31
|
}
|
|
30
32
|
|
|
31
|
-
if "target_time_utc" in da.coords:
|
|
32
|
-
sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
|
|
33
|
-
|
|
34
|
-
# TODO: Do we need this at all? Especially since it is only present in UKV data
|
|
35
|
-
for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
|
|
36
|
-
if dataset_key in da.coords:
|
|
37
|
-
sample[sample_key] = da[dataset_key].values
|
|
38
|
-
|
|
39
33
|
if t0_idx is not None:
|
|
40
34
|
sample[NWPSampleKey.t0_idx] = t0_idx
|
|
41
35
|
|
{ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/satellite.py
RENAMED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Convert Satellite to NumpySample"""
|
|
2
|
+
|
|
2
3
|
import xarray as xr
|
|
3
4
|
|
|
4
5
|
|
|
@@ -12,19 +13,21 @@ class SatelliteSampleKey:
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
15
|
-
"""Convert from Xarray to NumpySample
|
|
16
|
+
"""Convert from Xarray to NumpySample
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
da: xarray DataArray containing satellite data
|
|
20
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
|
|
21
|
+
"""
|
|
16
22
|
sample = {
|
|
17
23
|
SatelliteSampleKey.satellite_actual: da.values,
|
|
18
24
|
SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
|
|
25
|
+
SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
|
|
26
|
+
SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
|
|
19
27
|
}
|
|
20
28
|
|
|
21
|
-
for sample_key, dataset_key in (
|
|
22
|
-
(SatelliteSampleKey.x_geostationary, "x_geostationary"),
|
|
23
|
-
(SatelliteSampleKey.y_geostationary, "y_geostationary"),
|
|
24
|
-
):
|
|
25
|
-
sample[sample_key] = da[dataset_key].values
|
|
26
|
-
|
|
27
29
|
if t0_idx is not None:
|
|
28
30
|
sample[SatelliteSampleKey.t0_idx] = t0_idx
|
|
29
31
|
|
|
30
|
-
return sample
|
|
32
|
+
return sample
|
|
33
|
+
|
|
@@ -18,9 +18,13 @@ class SiteSampleKey:
|
|
|
18
18
|
time_cos = "site_time_cos"
|
|
19
19
|
|
|
20
20
|
def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NumpySample
|
|
21
|
+
"""Convert from Xarray to NumpySample
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
da: xarray DataArray containing site data
|
|
25
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the site data
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
|
-
# Extract values from the DataArray
|
|
24
28
|
sample = {
|
|
25
29
|
SiteSampleKey.generation: da.values,
|
|
26
30
|
SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
|
{ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/numpy_sample/sun_position.py
RENAMED
|
@@ -27,16 +27,15 @@ def calculate_azimuth_and_elevation(
|
|
|
27
27
|
latitude=lat,
|
|
28
28
|
method='nrel_numpy'
|
|
29
29
|
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
return azimuth, elevation
|
|
30
|
+
|
|
31
|
+
return solpos["azimuth"].values, solpos["elevation"].values
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
def make_sun_position_numpy_sample(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
datetimes: pd.DatetimeIndex,
|
|
36
|
+
lon: float,
|
|
37
|
+
lat: float,
|
|
38
|
+
key_prefix: str = "gsp"
|
|
40
39
|
) -> dict:
|
|
41
40
|
"""Creates NumpySample with standardized solar coordinates
|
|
42
41
|
|
|
@@ -44,22 +43,22 @@ def make_sun_position_numpy_sample(
|
|
|
44
43
|
datetimes: The datetimes to calculate solar angles for
|
|
45
44
|
lon: The longitude
|
|
46
45
|
lat: The latitude
|
|
46
|
+
key_prefix: The prefix to add to the keys in the NumpySample
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
|
|
50
50
|
|
|
51
51
|
# Normalise
|
|
52
|
-
|
|
53
52
|
# Azimuth is in range [0, 360] degrees
|
|
54
53
|
azimuth = azimuth / 360
|
|
55
54
|
|
|
56
55
|
# Elevation is in range [-90, 90] degrees
|
|
57
56
|
elevation = elevation / 180 + 0.5
|
|
58
|
-
|
|
57
|
+
|
|
59
58
|
# Make NumpySample
|
|
60
59
|
sun_numpy_sample = {
|
|
61
60
|
key_prefix + "_solar_azimuth": azimuth,
|
|
62
61
|
key_prefix + "_solar_elevation": elevation,
|
|
63
62
|
}
|
|
64
63
|
|
|
65
|
-
return sun_numpy_sample
|
|
64
|
+
return sun_numpy_sample
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
""" Base class for handling flat/nested data structures with NWP consideration """
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from typing import TypeAlias
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
11
|
+
NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
12
|
+
TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SampleBase(ABC):
|
|
16
|
+
"""
|
|
17
|
+
Abstract base class for all sample types
|
|
18
|
+
Provides core data storage functionality
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
23
|
+
"""Convert sample data to numpy format"""
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def plot(self) -> None:
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def save(self, path: str) -> None:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def load(cls, path: str) -> 'SampleBase':
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
41
|
+
"""
|
|
42
|
+
Recursively converts numpy arrays in nested dict to torch tensors
|
|
43
|
+
Args:
|
|
44
|
+
batch: NumpyBatch with data in numpy arrays
|
|
45
|
+
Returns:
|
|
46
|
+
TensorBatch with data in torch tensors
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
for k, v in batch.items():
|
|
50
|
+
if isinstance(v, dict):
|
|
51
|
+
batch[k] = batch_to_tensor(v)
|
|
52
|
+
elif isinstance(v, np.ndarray):
|
|
53
|
+
if v.dtype == np.bool_:
|
|
54
|
+
batch[k] = torch.tensor(v, dtype=torch.bool)
|
|
55
|
+
elif np.issubdtype(v.dtype, np.number):
|
|
56
|
+
batch[k] = torch.as_tensor(v)
|
|
57
|
+
return batch
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
|
|
61
|
+
"""Recursively copies tensors in nested dict to specified device
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
batch: Nested dict with tensors to move
|
|
65
|
+
device: Device to move tensors to
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A dict with tensors moved to the new device
|
|
69
|
+
"""
|
|
70
|
+
batch_copy = {}
|
|
71
|
+
|
|
72
|
+
for k, v in batch.items():
|
|
73
|
+
if isinstance(v, dict):
|
|
74
|
+
batch_copy[k] = copy_batch_to_device(v, device)
|
|
75
|
+
elif isinstance(v, torch.Tensor):
|
|
76
|
+
batch_copy[k] = v.to(device)
|
|
77
|
+
else:
|
|
78
|
+
batch_copy[k] = v
|
|
79
|
+
return batch_copy
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""PVNet Site sample implementation for netCDF data handling and conversion"""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
8
|
+
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SiteSample(SampleBase):
|
|
12
|
+
"""Handles PVNet site specific netCDF operations"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, data: xr.Dataset):
|
|
15
|
+
|
|
16
|
+
if not isinstance(data, xr.Dataset):
|
|
17
|
+
raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
|
|
18
|
+
|
|
19
|
+
self._data = data
|
|
20
|
+
|
|
21
|
+
@override
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
23
|
+
return convert_netcdf_to_numpy_sample(self._data)
|
|
24
|
+
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Save site sample data as netCDF
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
path: Path to save the netCDF file
|
|
30
|
+
"""
|
|
31
|
+
self._data.to_netcdf(path, mode="w", engine="h5netcdf")
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def load(cls, path: str) -> 'SiteSample':
|
|
35
|
+
"""Load site sample data from netCDF
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
path: Path to load the netCDF file from
|
|
39
|
+
"""
|
|
40
|
+
return cls(xr.open_dataset(path))
|
|
41
|
+
|
|
42
|
+
# TODO - placeholder for now
|
|
43
|
+
def plot(self) -> None:
|
|
44
|
+
raise NotImplementedError("Plotting not yet implemented for SiteSample")
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""PVNet UK Regional sample implementation for dataset handling and visualisation"""
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
8
|
+
from ocf_data_sampler.numpy_sample import (
|
|
9
|
+
NWPSampleKey,
|
|
10
|
+
GSPSampleKey,
|
|
11
|
+
SatelliteSampleKey
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UKRegionalSample(SampleBase):
|
|
16
|
+
"""Handles UK Regional PVNet data operations"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, data: NumpySample):
|
|
19
|
+
self._data = data
|
|
20
|
+
|
|
21
|
+
@override
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
23
|
+
return self._data
|
|
24
|
+
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Save PVNet sample as pickle format using torch.save
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
path: Path to save the sample data to
|
|
30
|
+
"""
|
|
31
|
+
torch.save(self._data, path)
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def load(cls, path: str) -> 'UKRegionalSample':
|
|
35
|
+
"""Load PVNet sample data from .pt format
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
path: Path to load the sample data from
|
|
39
|
+
"""
|
|
40
|
+
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
41
|
+
return cls(torch.load(path, weights_only=False))
|
|
42
|
+
|
|
43
|
+
def plot(self) -> None:
|
|
44
|
+
"""Creates visualisations for NWP, GSP, solar position, and satellite data"""
|
|
45
|
+
from matplotlib import pyplot as plt
|
|
46
|
+
|
|
47
|
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
48
|
+
|
|
49
|
+
if NWPSampleKey.nwp in self._data:
|
|
50
|
+
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
51
|
+
if 'nwp' in first_nwp:
|
|
52
|
+
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
53
|
+
title = 'NWP (First Channel)'
|
|
54
|
+
if NWPSampleKey.channel_names in first_nwp:
|
|
55
|
+
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
56
|
+
if channel_names:
|
|
57
|
+
title = f'NWP: {channel_names[0]}'
|
|
58
|
+
axes[0, 1].set_title(title)
|
|
59
|
+
|
|
60
|
+
if GSPSampleKey.gsp in self._data:
|
|
61
|
+
axes[0, 0].plot(self._data[GSPSampleKey.gsp])
|
|
62
|
+
axes[0, 0].set_title('GSP Generation')
|
|
63
|
+
|
|
64
|
+
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
65
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
66
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
67
|
+
axes[1, 1].set_title('Solar Position')
|
|
68
|
+
axes[1, 1].legend()
|
|
69
|
+
|
|
70
|
+
if SatelliteSampleKey.satellite_actual in self._data:
|
|
71
|
+
axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
|
|
72
|
+
axes[1, 0].set_title('Satellite Data')
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
plt.show()
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Functions for simulating dropout in time series data
|
|
2
|
+
|
|
3
|
+
This is used for the following types of data: GSP, Satellite and Site
|
|
4
|
+
This is not used for NWP
|
|
5
|
+
"""
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import xarray as xr
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def draw_dropout_time(
|
|
12
|
+
t0: pd.Timestamp,
|
|
13
|
+
dropout_timedeltas: list[pd.Timedelta],
|
|
14
|
+
dropout_frac: float,
|
|
15
|
+
) -> pd.Timestamp:
|
|
16
|
+
"""Randomly pick a dropout time from a list of timedeltas
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
t0: The forecast init-time
|
|
20
|
+
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
21
|
+
dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
|
|
22
|
+
inclusive
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if dropout_frac>0:
|
|
26
|
+
assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
|
|
27
|
+
|
|
28
|
+
for t in dropout_timedeltas:
|
|
29
|
+
assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
|
|
30
|
+
|
|
31
|
+
assert 0 <= dropout_frac <= 1
|
|
32
|
+
|
|
33
|
+
if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
|
|
34
|
+
dropout_time = t0
|
|
35
|
+
else:
|
|
36
|
+
dropout_time = t0 + np.random.choice(dropout_timedeltas)
|
|
37
|
+
|
|
38
|
+
return dropout_time
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def apply_dropout_time(
|
|
42
|
+
ds: xr.DataArray,
|
|
43
|
+
dropout_time: pd.Timestamp,
|
|
44
|
+
) -> xr.DataArray:
|
|
45
|
+
"""Apply dropout time to the data
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
ds: Xarray DataArray with 'time_utc' coordiante
|
|
49
|
+
dropout_time: Time after which data is set to NaN
|
|
50
|
+
"""
|
|
51
|
+
# This replaces the times after the dropout with NaNs
|
|
52
|
+
return ds.where(ds.time_utc <= dropout_time)
|
{ocf_data_sampler-0.1.9 → ocf_data_sampler-0.1.11}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Fill time periods between start and end dates at specified frequency"""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
|
+
"""Generate DatetimeIndex for all timestamps between start and end dates"""
|
|
9
|
+
|
|
8
10
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
9
11
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
10
12
|
date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
|
|
@@ -186,9 +186,8 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
186
186
|
gsp_ids: List of GSP IDs to create samples for. Defaults to all
|
|
187
187
|
"""
|
|
188
188
|
|
|
189
|
-
config = load_yaml_configuration(config_filename)
|
|
190
|
-
|
|
191
|
-
# Validate channels for NWP and satellite data
|
|
189
|
+
# config = load_yaml_configuration(config_filename)
|
|
190
|
+
config: Configuration = load_yaml_configuration(config_filename)
|
|
192
191
|
validate_nwp_channels(config)
|
|
193
192
|
validate_satellite_channels(config)
|
|
194
193
|
|