ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
ocf_data_sampler/config/load.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
"""Load configuration from a yaml file"""
|
|
1
|
+
"""Load configuration from a yaml file."""
|
|
2
2
|
|
|
3
3
|
import fsspec
|
|
4
4
|
from pyaml_env import parse_config
|
|
5
|
+
|
|
5
6
|
from ocf_data_sampler.config import Configuration
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def load_yaml_configuration(filename: str) -> Configuration:
|
|
9
|
-
"""
|
|
10
|
-
Load a yaml file which has a configuration in it
|
|
10
|
+
"""Load a yaml file which has a configuration in it.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
filename: the yaml file name that you want to load. Will load from local, AWS, or GCP
|
ocf_data_sampler/config/model.py
CHANGED
|
@@ -1,45 +1,45 @@
|
|
|
1
1
|
"""Configuration model for the dataset.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
Absolute or relative zarr filepath(s).
|
|
4
|
+
Prefix with a protocol like s3:// to read from alternative filesystems.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
|
-
from
|
|
9
|
-
from typing_extensions import Self
|
|
7
|
+
from collections.abc import Iterator
|
|
10
8
|
|
|
11
|
-
from pydantic import BaseModel, Field, RootModel, field_validator,
|
|
9
|
+
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
|
10
|
+
from typing_extensions import override
|
|
12
11
|
|
|
13
12
|
from ocf_data_sampler.constants import NWP_PROVIDERS
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class Base(BaseModel):
|
|
17
|
-
"""Pydantic Base model where no extras can be added"""
|
|
16
|
+
"""Pydantic Base model where no extras can be added."""
|
|
18
17
|
|
|
19
18
|
class Config:
|
|
20
|
-
"""
|
|
19
|
+
"""Config class."""
|
|
21
20
|
|
|
22
21
|
extra = "forbid" # forbid use of extra kwargs
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class General(Base):
|
|
26
|
-
"""General pydantic model"""
|
|
25
|
+
"""General pydantic model."""
|
|
27
26
|
|
|
28
27
|
name: str = Field("example", description="The name of this configuration file")
|
|
29
28
|
description: str = Field(
|
|
30
|
-
"example configuration",
|
|
29
|
+
"example configuration",
|
|
30
|
+
description="Description of this configuration file",
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class TimeWindowMixin(Base):
|
|
35
|
-
"""Mixin class, to add interval start, end and resolution minutes"""
|
|
35
|
+
"""Mixin class, to add interval start, end and resolution minutes."""
|
|
36
36
|
|
|
37
37
|
time_resolution_minutes: int = Field(
|
|
38
38
|
...,
|
|
39
39
|
gt=0,
|
|
40
40
|
description="The temporal resolution of the data in minutes",
|
|
41
41
|
)
|
|
42
|
-
|
|
42
|
+
|
|
43
43
|
interval_start_minutes: int = Field(
|
|
44
44
|
...,
|
|
45
45
|
description="Data interval starts at `t0 + interval_start_minutes`",
|
|
@@ -49,31 +49,35 @@ class TimeWindowMixin(Base):
|
|
|
49
49
|
...,
|
|
50
50
|
description="Data interval ends at `t0 + interval_end_minutes`",
|
|
51
51
|
)
|
|
52
|
-
|
|
53
|
-
@model_validator(mode='after')
|
|
54
|
-
def check_interval_range(cls, values):
|
|
55
|
-
if values.interval_start_minutes > values.interval_end_minutes:
|
|
56
|
-
raise ValueError('interval_start_minutes must be <= interval_end_minutes')
|
|
57
|
-
return values
|
|
58
|
-
|
|
59
|
-
@field_validator("interval_start_minutes")
|
|
60
|
-
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
61
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
62
|
-
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
|
|
63
|
-
return v
|
|
64
52
|
|
|
65
|
-
@
|
|
66
|
-
def
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
53
|
+
@model_validator(mode="after")
|
|
54
|
+
def validate_intervals(self) -> "TimeWindowMixin":
|
|
55
|
+
"""Validator for time interval fields."""
|
|
56
|
+
start = self.interval_start_minutes
|
|
57
|
+
end = self.interval_end_minutes
|
|
58
|
+
resolution = self.time_resolution_minutes
|
|
59
|
+
if start > end:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})",
|
|
62
|
+
)
|
|
63
|
+
if start % resolution != 0:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"interval_start_minutes ({start}) must be divisible "
|
|
66
|
+
f"by time_resolution_minutes ({resolution})",
|
|
67
|
+
)
|
|
68
|
+
if end % resolution != 0:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"interval_end_minutes ({end}) must be divisible "
|
|
71
|
+
f"by time_resolution_minutes ({resolution})",
|
|
72
|
+
)
|
|
73
|
+
return self
|
|
70
74
|
|
|
71
75
|
|
|
72
76
|
class DropoutMixin(Base):
|
|
73
|
-
"""Mixin class, to add dropout minutes"""
|
|
77
|
+
"""Mixin class, to add dropout minutes."""
|
|
74
78
|
|
|
75
|
-
dropout_timedeltas_minutes:
|
|
76
|
-
default=
|
|
79
|
+
dropout_timedeltas_minutes: list[int] = Field(
|
|
80
|
+
default=[],
|
|
77
81
|
description="List of possible minutes before t0 where data availability may start. Must be "
|
|
78
82
|
"negative or zero.",
|
|
79
83
|
)
|
|
@@ -86,26 +90,27 @@ class DropoutMixin(Base):
|
|
|
86
90
|
)
|
|
87
91
|
|
|
88
92
|
@field_validator("dropout_timedeltas_minutes")
|
|
89
|
-
def dropout_timedeltas_minutes_negative(cls, v:
|
|
90
|
-
"""Validate 'dropout_timedeltas_minutes'"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
|
|
94
|
+
"""Validate 'dropout_timedeltas_minutes'."""
|
|
95
|
+
for m in v:
|
|
96
|
+
if m > 0:
|
|
97
|
+
raise ValueError("Dropout timedeltas must be negative")
|
|
94
98
|
return v
|
|
95
99
|
|
|
96
100
|
@model_validator(mode="after")
|
|
97
|
-
def dropout_instructions_consistent(self) ->
|
|
101
|
+
def dropout_instructions_consistent(self) -> "DropoutMixin":
|
|
102
|
+
"""Validator for dropout instructions."""
|
|
98
103
|
if self.dropout_fraction == 0:
|
|
99
|
-
if self.dropout_timedeltas_minutes
|
|
104
|
+
if self.dropout_timedeltas_minutes != []:
|
|
100
105
|
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
|
|
101
106
|
else:
|
|
102
|
-
if self.dropout_timedeltas_minutes
|
|
107
|
+
if self.dropout_timedeltas_minutes == []:
|
|
103
108
|
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
|
|
104
109
|
return self
|
|
105
110
|
|
|
106
111
|
|
|
107
112
|
class SpatialWindowMixin(Base):
|
|
108
|
-
"""Mixin class, to add path and image size"""
|
|
113
|
+
"""Mixin class, to add path and image size."""
|
|
109
114
|
|
|
110
115
|
image_size_pixels_height: int = Field(
|
|
111
116
|
...,
|
|
@@ -121,8 +126,8 @@ class SpatialWindowMixin(Base):
|
|
|
121
126
|
|
|
122
127
|
|
|
123
128
|
class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
124
|
-
"""Satellite configuration model"""
|
|
125
|
-
|
|
129
|
+
"""Satellite configuration model."""
|
|
130
|
+
|
|
126
131
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
127
132
|
...,
|
|
128
133
|
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
@@ -130,82 +135,86 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
130
135
|
)
|
|
131
136
|
|
|
132
137
|
channels: list[str] = Field(
|
|
133
|
-
...,
|
|
138
|
+
...,
|
|
139
|
+
description="the satellite channels that are used",
|
|
134
140
|
)
|
|
135
141
|
|
|
136
142
|
|
|
137
143
|
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
138
|
-
"""NWP configuration model"""
|
|
139
|
-
|
|
144
|
+
"""NWP configuration model."""
|
|
145
|
+
|
|
140
146
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
141
147
|
...,
|
|
142
148
|
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
143
149
|
"to read from alternative filesystems.",
|
|
144
150
|
)
|
|
145
|
-
|
|
151
|
+
|
|
146
152
|
channels: list[str] = Field(
|
|
147
|
-
...,
|
|
153
|
+
...,
|
|
154
|
+
description="the channels used in the nwp data",
|
|
148
155
|
)
|
|
149
156
|
|
|
150
157
|
provider: str = Field(..., description="The provider of the NWP data")
|
|
151
158
|
|
|
152
159
|
accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
|
|
153
160
|
|
|
154
|
-
max_staleness_minutes:
|
|
161
|
+
max_staleness_minutes: int | None = Field(
|
|
155
162
|
None,
|
|
156
163
|
description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
|
|
157
164
|
" used to construct an example. If set to None, then the max staleness is set according to"
|
|
158
165
|
" the maximum forecast horizon of the NWP and the requested forecast length.",
|
|
159
166
|
)
|
|
160
167
|
|
|
161
|
-
|
|
162
168
|
@field_validator("provider")
|
|
163
169
|
def validate_provider(cls, v: str) -> str:
|
|
164
|
-
"""
|
|
170
|
+
"""Validator for 'provider'."""
|
|
165
171
|
if v.lower() not in NWP_PROVIDERS:
|
|
166
|
-
|
|
167
|
-
raise Exception(message)
|
|
172
|
+
raise OSError(f"NWP provider {v} is not in {NWP_PROVIDERS}")
|
|
168
173
|
return v
|
|
169
174
|
|
|
170
175
|
|
|
171
176
|
class MultiNWP(RootModel):
|
|
172
|
-
"""Configuration for multiple NWPs"""
|
|
177
|
+
"""Configuration for multiple NWPs."""
|
|
173
178
|
|
|
174
|
-
root:
|
|
179
|
+
root: dict[str, NWP]
|
|
175
180
|
|
|
176
|
-
|
|
181
|
+
@override
|
|
182
|
+
def __getattr__(self, item: str) -> NWP:
|
|
177
183
|
return self.root[item]
|
|
178
184
|
|
|
179
|
-
|
|
185
|
+
@override
|
|
186
|
+
def __getitem__(self, item: str) -> NWP:
|
|
180
187
|
return self.root[item]
|
|
181
188
|
|
|
182
|
-
|
|
189
|
+
@override
|
|
190
|
+
def __len__(self) -> int:
|
|
183
191
|
return len(self.root)
|
|
184
192
|
|
|
185
|
-
|
|
193
|
+
@override
|
|
194
|
+
def __iter__(self) -> Iterator:
|
|
186
195
|
return iter(self.root)
|
|
187
196
|
|
|
188
|
-
def keys(self):
|
|
189
|
-
"""Returns dictionary-like keys"""
|
|
197
|
+
def keys(self) -> Iterator[str]:
|
|
198
|
+
"""Returns dictionary-like keys."""
|
|
190
199
|
return self.root.keys()
|
|
191
200
|
|
|
192
|
-
def items(self):
|
|
193
|
-
"""Returns dictionary-like items"""
|
|
201
|
+
def items(self) -> Iterator[tuple[str, NWP]]:
|
|
202
|
+
"""Returns dictionary-like items."""
|
|
194
203
|
return self.root.items()
|
|
195
204
|
|
|
196
205
|
|
|
197
206
|
class GSP(TimeWindowMixin, DropoutMixin):
|
|
198
|
-
"""GSP configuration model"""
|
|
207
|
+
"""GSP configuration model."""
|
|
199
208
|
|
|
200
209
|
zarr_path: str = Field(
|
|
201
|
-
...,
|
|
210
|
+
...,
|
|
202
211
|
description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
|
|
203
212
|
"to read from alternative filesystems.",
|
|
204
213
|
)
|
|
205
214
|
|
|
206
215
|
|
|
207
216
|
class Site(TimeWindowMixin, DropoutMixin):
|
|
208
|
-
"""Site configuration model"""
|
|
217
|
+
"""Site configuration model."""
|
|
209
218
|
|
|
210
219
|
file_path: str = Field(
|
|
211
220
|
...,
|
|
@@ -220,17 +229,22 @@ class Site(TimeWindowMixin, DropoutMixin):
|
|
|
220
229
|
# TODO validate the csv for metadata
|
|
221
230
|
|
|
222
231
|
|
|
232
|
+
class SolarPosition(TimeWindowMixin):
|
|
233
|
+
"""Solar position configuration model."""
|
|
234
|
+
|
|
235
|
+
|
|
223
236
|
class InputData(Base):
|
|
224
|
-
"""Input data model"""
|
|
237
|
+
"""Input data model."""
|
|
225
238
|
|
|
226
|
-
satellite:
|
|
227
|
-
nwp:
|
|
228
|
-
gsp:
|
|
229
|
-
site:
|
|
239
|
+
satellite: Satellite | None = None
|
|
240
|
+
nwp: MultiNWP | None = None
|
|
241
|
+
gsp: GSP | None = None
|
|
242
|
+
site: Site | None = None
|
|
243
|
+
solar_position: SolarPosition | None = None
|
|
230
244
|
|
|
231
245
|
|
|
232
246
|
class Configuration(Base):
|
|
233
|
-
"""Configuration model for the dataset"""
|
|
247
|
+
"""Configuration model for the dataset."""
|
|
234
248
|
|
|
235
249
|
general: General = General()
|
|
236
250
|
input_data: InputData = InputData()
|
ocf_data_sampler/config/save.py
CHANGED
|
@@ -5,12 +5,14 @@ supporting local and cloud storage locations.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
import os
|
|
9
|
+
|
|
8
10
|
import fsspec
|
|
9
11
|
import yaml
|
|
10
|
-
import os
|
|
11
12
|
|
|
12
13
|
from ocf_data_sampler.config import Configuration
|
|
13
14
|
|
|
15
|
+
|
|
14
16
|
def save_yaml_configuration(configuration: Configuration, filename: str) -> None:
|
|
15
17
|
"""Save a configuration object to a YAML file.
|
|
16
18
|
|
|
@@ -20,12 +22,11 @@ def save_yaml_configuration(configuration: Configuration, filename: str) -> None
|
|
|
20
22
|
cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
|
|
21
23
|
absolute paths are recommended.
|
|
22
24
|
"""
|
|
23
|
-
|
|
24
25
|
if os.path.exists(filename):
|
|
25
26
|
raise FileExistsError(f"File already exists: {filename}")
|
|
26
27
|
|
|
27
28
|
# Serialize configuration to JSON-compatible dictionary
|
|
28
29
|
config_dict = json.loads(configuration.model_dump_json())
|
|
29
30
|
|
|
30
|
-
with fsspec.open(filename, mode=
|
|
31
|
-
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
|
31
|
+
with fsspec.open(filename, mode="w") as yaml_file:
|
|
32
|
+
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
ocf_data_sampler/constants.py
CHANGED
|
@@ -1,33 +1,37 @@
|
|
|
1
|
-
|
|
2
|
-
import numpy as np
|
|
1
|
+
"""Constants for the package."""
|
|
3
2
|
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
from typing_extensions import override
|
|
4
6
|
|
|
5
7
|
NWP_PROVIDERS = [
|
|
6
8
|
"ukv",
|
|
7
9
|
"ecmwf",
|
|
8
|
-
"gfs"
|
|
10
|
+
"gfs",
|
|
11
|
+
"icon_eu",
|
|
9
12
|
]
|
|
10
|
-
# TODO add ICON
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def _to_data_array(d):
|
|
15
|
+
def _to_data_array(d: dict) -> xr.DataArray:
|
|
16
|
+
"""Convert a dictionary to a DataArray."""
|
|
14
17
|
return xr.DataArray(
|
|
15
|
-
[d[k] for k in d
|
|
16
|
-
coords={"channel":
|
|
18
|
+
[d[k] for k in d],
|
|
19
|
+
coords={"channel": list(d.keys())},
|
|
17
20
|
).astype(np.float32)
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class NWPStatDict(dict):
|
|
21
|
-
"""Custom dictionary class to hold NWP normalization stats"""
|
|
24
|
+
"""Custom dictionary class to hold NWP normalization stats."""
|
|
22
25
|
|
|
23
|
-
|
|
26
|
+
@override
|
|
27
|
+
def __getitem__(self, key: str) -> xr.DataArray:
|
|
24
28
|
if key not in NWP_PROVIDERS:
|
|
25
29
|
raise KeyError(f"{key} is not a supported NWP provider - {NWP_PROVIDERS}")
|
|
26
30
|
elif key in self.keys():
|
|
27
31
|
return super().__getitem__(key)
|
|
28
32
|
else:
|
|
29
33
|
raise KeyError(
|
|
30
|
-
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
|
|
34
|
+
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}",
|
|
31
35
|
)
|
|
32
36
|
|
|
33
37
|
|
|
@@ -173,16 +177,140 @@ GFS_MEAN = {
|
|
|
173
177
|
GFS_STD = _to_data_array(GFS_STD)
|
|
174
178
|
GFS_MEAN = _to_data_array(GFS_MEAN)
|
|
175
179
|
|
|
180
|
+
# ------ ICON-EU
|
|
181
|
+
# Statistics for ICON-EU variables
|
|
182
|
+
ICON_EU_STD = {
|
|
183
|
+
"alb_rad": 13.7881,
|
|
184
|
+
"alhfl_s": 73.7198,
|
|
185
|
+
"ashfl_s": 54.8027,
|
|
186
|
+
"asob_s": 55.8319,
|
|
187
|
+
"asob_t": 74.9360,
|
|
188
|
+
"aswdifd_s": 21.4940,
|
|
189
|
+
"aswdifu_s": 18.7688,
|
|
190
|
+
"aswdir_s": 54.4683,
|
|
191
|
+
"athb_s": 34.8575,
|
|
192
|
+
"athb_t": 42.9108,
|
|
193
|
+
"aumfl_s": 0.1460,
|
|
194
|
+
"avmfl_s": 0.1892,
|
|
195
|
+
"cape_con": 32.2570,
|
|
196
|
+
"cape_ml": 106.3998,
|
|
197
|
+
"clch": 39.9324,
|
|
198
|
+
"clcl": 36.3961,
|
|
199
|
+
"clcm": 41.1690,
|
|
200
|
+
"clct": 34.7696,
|
|
201
|
+
"clct_mod": 0.4227,
|
|
202
|
+
"cldepth": 0.1739,
|
|
203
|
+
"h_snow": 0.9012,
|
|
204
|
+
"hbas_con": 1306.6632,
|
|
205
|
+
"htop_con": 1810.5665,
|
|
206
|
+
"htop_dc": 459.0422,
|
|
207
|
+
"hzerocl": 1144.6469,
|
|
208
|
+
"pmsl": 1103.3301,
|
|
209
|
+
"ps": 4761.3184,
|
|
210
|
+
"qv_2m": 0.0024,
|
|
211
|
+
"qv_s": 0.0038,
|
|
212
|
+
"rain_con": 1.7097,
|
|
213
|
+
"rain_gsp": 4.2654,
|
|
214
|
+
"relhum_2m": 15.3779,
|
|
215
|
+
"rho_snow": 120.2461,
|
|
216
|
+
"runoff_g": 0.7410,
|
|
217
|
+
"runoff_s": 2.1930,
|
|
218
|
+
"snow_con": 1.1432,
|
|
219
|
+
"snow_gsp": 1.8154,
|
|
220
|
+
"snowlmt": 656.0699,
|
|
221
|
+
"synmsg_bt_cl_ir10.8": 17.9438,
|
|
222
|
+
"t_2m": 7.7973,
|
|
223
|
+
"t_g": 8.7053,
|
|
224
|
+
"t_snow": 134.6874,
|
|
225
|
+
"tch": 0.0052,
|
|
226
|
+
"tcm": 0.0133,
|
|
227
|
+
"td_2m": 7.1460,
|
|
228
|
+
"tmax_2m": 7.8218,
|
|
229
|
+
"tmin_2m": 7.8346,
|
|
230
|
+
"tot_prec": 5.6312,
|
|
231
|
+
"tqc": 0.0976,
|
|
232
|
+
"tqi": 0.0247,
|
|
233
|
+
"u_10m": 3.8351,
|
|
234
|
+
"v_10m": 5.0083,
|
|
235
|
+
"vmax_10m": 5.5037,
|
|
236
|
+
"w_snow": 286.1510,
|
|
237
|
+
"ww": 27.2974,
|
|
238
|
+
"z0": 0.3901,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
ICON_EU_MEAN = {
|
|
242
|
+
"alb_rad": 15.4437,
|
|
243
|
+
"alhfl_s": -54.9398,
|
|
244
|
+
"ashfl_s": -19.4684,
|
|
245
|
+
"asob_s": 40.9305,
|
|
246
|
+
"asob_t": 61.9244,
|
|
247
|
+
"aswdifd_s": 19.7813,
|
|
248
|
+
"aswdifu_s": 8.8328,
|
|
249
|
+
"aswdir_s": 29.9820,
|
|
250
|
+
"athb_s": -53.9873,
|
|
251
|
+
"athb_t": -212.8088,
|
|
252
|
+
"aumfl_s": 0.0558,
|
|
253
|
+
"avmfl_s": 0.0078,
|
|
254
|
+
"cape_con": 16.7397,
|
|
255
|
+
"cape_ml": 21.2189,
|
|
256
|
+
"clch": 26.4262,
|
|
257
|
+
"clcl": 57.1591,
|
|
258
|
+
"clcm": 36.1702,
|
|
259
|
+
"clct": 72.9254,
|
|
260
|
+
"clct_mod": 0.5561,
|
|
261
|
+
"cldepth": 0.1356,
|
|
262
|
+
"h_snow": 0.0494,
|
|
263
|
+
"hbas_con": 108.4975,
|
|
264
|
+
"htop_con": 433.0623,
|
|
265
|
+
"htop_dc": 454.0859,
|
|
266
|
+
"hzerocl": 1696.6272,
|
|
267
|
+
"pmsl": 101778.8281,
|
|
268
|
+
"ps": 99114.4766,
|
|
269
|
+
"qv_2m": 0.0049,
|
|
270
|
+
"qv_s": 0.0065,
|
|
271
|
+
"rain_con": 0.4869,
|
|
272
|
+
"rain_gsp": 0.9783,
|
|
273
|
+
"relhum_2m": 78.2258,
|
|
274
|
+
"rho_snow": 62.5032,
|
|
275
|
+
"runoff_g": 0.1301,
|
|
276
|
+
"runoff_s": 0.4119,
|
|
277
|
+
"snow_con": 0.2188,
|
|
278
|
+
"snow_gsp": 0.4317,
|
|
279
|
+
"snowlmt": 1450.3241,
|
|
280
|
+
"synmsg_bt_cl_ir10.8": 265.0639,
|
|
281
|
+
"t_2m": 278.8212,
|
|
282
|
+
"t_g": 279.9216,
|
|
283
|
+
"t_snow": 162.5582,
|
|
284
|
+
"tch": 0.0047,
|
|
285
|
+
"tcm": 0.0091,
|
|
286
|
+
"td_2m": 274.9544,
|
|
287
|
+
"tmax_2m": 279.3550,
|
|
288
|
+
"tmin_2m": 278.2519,
|
|
289
|
+
"tot_prec": 2.1158,
|
|
290
|
+
"tqc": 0.0424,
|
|
291
|
+
"tqi": 0.0108,
|
|
292
|
+
"u_10m": 1.1902,
|
|
293
|
+
"v_10m": -0.4733,
|
|
294
|
+
"vmax_10m": 8.4152,
|
|
295
|
+
"w_snow": 14.5936,
|
|
296
|
+
"ww": 15.3570,
|
|
297
|
+
"z0": 0.2386,
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
ICON_EU_STD = _to_data_array(ICON_EU_STD)
|
|
301
|
+
ICON_EU_MEAN = _to_data_array(ICON_EU_MEAN)
|
|
176
302
|
|
|
177
303
|
NWP_STDS = NWPStatDict(
|
|
178
304
|
ukv=UKV_STD,
|
|
179
305
|
ecmwf=ECMWF_STD,
|
|
180
|
-
gfs=GFS_STD
|
|
306
|
+
gfs=GFS_STD,
|
|
307
|
+
icon_eu=ICON_EU_STD,
|
|
181
308
|
)
|
|
182
309
|
NWP_MEANS = NWPStatDict(
|
|
183
310
|
ukv=UKV_MEAN,
|
|
184
311
|
ecmwf=ECMWF_MEAN,
|
|
185
|
-
gfs=GFS_MEAN
|
|
312
|
+
gfs=GFS_MEAN,
|
|
313
|
+
icon_eu=ICON_EU_MEAN,
|
|
186
314
|
)
|
|
187
315
|
|
|
188
316
|
# ------ Satellite
|
ocf_data_sampler/load/gsp.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
"""Functions for loading GSP data."""
|
|
2
|
+
|
|
3
|
+
from importlib.resources import files
|
|
2
4
|
|
|
3
5
|
import pandas as pd
|
|
4
6
|
import xarray as xr
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
def open_gsp(zarr_path: str) -> xr.DataArray:
|
|
8
|
-
"""Open the GSP data
|
|
9
|
-
|
|
10
|
+
"""Open the GSP data.
|
|
11
|
+
|
|
10
12
|
Args:
|
|
11
13
|
zarr_path: Path to the GSP zarr data
|
|
12
14
|
|
|
13
15
|
Returns:
|
|
14
16
|
xr.DataArray: The opened GSP data
|
|
15
17
|
"""
|
|
16
|
-
|
|
17
18
|
ds = xr.open_zarr(zarr_path)
|
|
18
19
|
|
|
19
20
|
ds = ds.rename({"datetime_gmt": "time_utc"})
|
|
20
21
|
|
|
21
22
|
# Load UK GSP locations
|
|
22
23
|
df_gsp_loc = pd.read_csv(
|
|
23
|
-
|
|
24
|
+
files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
|
|
24
25
|
index_col="gsp_id",
|
|
25
26
|
)
|
|
26
27
|
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Loads all data sources."""
|
|
2
|
+
|
|
2
3
|
import xarray as xr
|
|
3
4
|
|
|
4
5
|
from ocf_data_sampler.config import InputData
|
|
5
|
-
from ocf_data_sampler.load import
|
|
6
|
+
from ocf_data_sampler.load import open_gsp, open_nwp, open_sat_data, open_site
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] | xr.DataArray]:
|
|
9
|
-
"""Construct dictionary of all of the input data sources
|
|
10
|
+
"""Construct dictionary of all of the input data sources.
|
|
10
11
|
|
|
11
12
|
Args:
|
|
12
13
|
input_config: InputData configuration object
|
|
13
14
|
"""
|
|
14
|
-
|
|
15
15
|
datasets_dict = {}
|
|
16
16
|
|
|
17
17
|
# Load GSP data unless the path is None
|
|
@@ -23,10 +23,8 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
|
|
|
23
23
|
|
|
24
24
|
# Load NWP data if in config
|
|
25
25
|
if input_config.nwp:
|
|
26
|
-
|
|
27
26
|
datasets_dict["nwp"] = {}
|
|
28
27
|
for nwp_source, nwp_config in input_config.nwp.items():
|
|
29
|
-
|
|
30
28
|
da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
|
|
31
29
|
|
|
32
30
|
da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
|
|
@@ -48,6 +46,7 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
|
|
|
48
46
|
generation_file_path=input_config.site.file_path,
|
|
49
47
|
metadata_file_path=input_config.site.metadata_file_path,
|
|
50
48
|
)
|
|
49
|
+
|
|
51
50
|
datasets_dict["site"] = da_sites
|
|
52
51
|
|
|
53
52
|
return datasets_dict
|
ocf_data_sampler/load/nwp/nwp.py
CHANGED
|
@@ -1,22 +1,34 @@
|
|
|
1
|
+
"""Module for opening NWP data."""
|
|
2
|
+
|
|
1
3
|
import xarray as xr
|
|
2
4
|
|
|
3
|
-
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
4
5
|
from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs
|
|
6
|
+
from ocf_data_sampler.load.nwp.providers.gfs import open_gfs
|
|
7
|
+
from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
|
|
8
|
+
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
|
|
8
|
-
"""Opens NWP zarr
|
|
12
|
+
"""Opens NWP zarr.
|
|
9
13
|
|
|
10
14
|
Args:
|
|
11
15
|
zarr_path: path to the zarr file
|
|
12
16
|
provider: NWP provider
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Xarray DataArray of the NWP data
|
|
13
20
|
"""
|
|
21
|
+
provider = provider.lower()
|
|
14
22
|
|
|
15
|
-
if provider
|
|
23
|
+
if provider == "ukv":
|
|
16
24
|
_open_nwp = open_ukv
|
|
17
|
-
elif provider
|
|
25
|
+
elif provider == "ecmwf":
|
|
18
26
|
_open_nwp = open_ifs
|
|
27
|
+
elif provider == "icon-eu":
|
|
28
|
+
_open_nwp = open_icon_eu
|
|
29
|
+
elif provider == "gfs":
|
|
30
|
+
_open_nwp = open_gfs
|
|
19
31
|
else:
|
|
20
32
|
raise ValueError(f"Unknown provider: {provider}")
|
|
21
|
-
return _open_nwp(zarr_path)
|
|
22
33
|
|
|
34
|
+
return _open_nwp(zarr_path)
|