ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +146 -64
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
- ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
- ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +63 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler/constants.py +0 -222
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
ocf_data_sampler/config/load.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
"""Load configuration from a yaml file"""
|
|
1
|
+
"""Load configuration from a yaml file."""
|
|
2
2
|
|
|
3
3
|
import fsspec
|
|
4
4
|
from pyaml_env import parse_config
|
|
5
|
+
|
|
5
6
|
from ocf_data_sampler.config import Configuration
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def load_yaml_configuration(filename: str) -> Configuration:
|
|
9
|
-
"""
|
|
10
|
-
Load a yaml file which has a configuration in it
|
|
10
|
+
"""Load a yaml file which has a configuration in it.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
filename: the yaml file name that you want to load. Will load from local, AWS, or GCP
|
ocf_data_sampler/config/model.py
CHANGED
|
@@ -1,45 +1,50 @@
|
|
|
1
1
|
"""Configuration model for the dataset.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
Absolute or relative zarr filepath(s).
|
|
4
|
+
Prefix with a protocol like s3:// to read from alternative filesystems.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
|
-
from
|
|
9
|
-
from typing_extensions import Self
|
|
7
|
+
from collections.abc import Iterator
|
|
10
8
|
|
|
11
|
-
from pydantic import BaseModel, Field, RootModel, field_validator,
|
|
9
|
+
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
|
10
|
+
from typing_extensions import override
|
|
12
11
|
|
|
13
|
-
|
|
12
|
+
NWP_PROVIDERS = [
|
|
13
|
+
"ukv",
|
|
14
|
+
"ecmwf",
|
|
15
|
+
"gfs",
|
|
16
|
+
"icon_eu",
|
|
17
|
+
]
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
class Base(BaseModel):
|
|
17
|
-
"""Pydantic Base model where no extras can be added"""
|
|
21
|
+
"""Pydantic Base model where no extras can be added."""
|
|
18
22
|
|
|
19
23
|
class Config:
|
|
20
|
-
"""
|
|
24
|
+
"""Config class."""
|
|
21
25
|
|
|
22
26
|
extra = "forbid" # forbid use of extra kwargs
|
|
23
27
|
|
|
24
28
|
|
|
25
29
|
class General(Base):
|
|
26
|
-
"""General pydantic model"""
|
|
30
|
+
"""General pydantic model."""
|
|
27
31
|
|
|
28
32
|
name: str = Field("example", description="The name of this configuration file")
|
|
29
33
|
description: str = Field(
|
|
30
|
-
"example configuration",
|
|
34
|
+
"example configuration",
|
|
35
|
+
description="Description of this configuration file",
|
|
31
36
|
)
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
class TimeWindowMixin(Base):
|
|
35
|
-
"""Mixin class, to add interval start, end and resolution minutes"""
|
|
40
|
+
"""Mixin class, to add interval start, end and resolution minutes."""
|
|
36
41
|
|
|
37
42
|
time_resolution_minutes: int = Field(
|
|
38
43
|
...,
|
|
39
44
|
gt=0,
|
|
40
45
|
description="The temporal resolution of the data in minutes",
|
|
41
46
|
)
|
|
42
|
-
|
|
47
|
+
|
|
43
48
|
interval_start_minutes: int = Field(
|
|
44
49
|
...,
|
|
45
50
|
description="Data interval starts at `t0 + interval_start_minutes`",
|
|
@@ -50,32 +55,33 @@ class TimeWindowMixin(Base):
|
|
|
50
55
|
description="Data interval ends at `t0 + interval_end_minutes`",
|
|
51
56
|
)
|
|
52
57
|
|
|
53
|
-
@model_validator(mode=
|
|
54
|
-
def validate_intervals(
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
@model_validator(mode="after")
|
|
59
|
+
def validate_intervals(self) -> "TimeWindowMixin":
|
|
60
|
+
"""Validator for time interval fields."""
|
|
61
|
+
start = self.interval_start_minutes
|
|
62
|
+
end = self.interval_end_minutes
|
|
63
|
+
resolution = self.time_resolution_minutes
|
|
58
64
|
if start > end:
|
|
59
65
|
raise ValueError(
|
|
60
|
-
f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
|
|
66
|
+
f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})",
|
|
61
67
|
)
|
|
62
|
-
if
|
|
68
|
+
if start % resolution != 0:
|
|
63
69
|
raise ValueError(
|
|
64
70
|
f"interval_start_minutes ({start}) must be divisible "
|
|
65
|
-
f"by time_resolution_minutes ({resolution})"
|
|
71
|
+
f"by time_resolution_minutes ({resolution})",
|
|
66
72
|
)
|
|
67
|
-
if
|
|
73
|
+
if end % resolution != 0:
|
|
68
74
|
raise ValueError(
|
|
69
75
|
f"interval_end_minutes ({end}) must be divisible "
|
|
70
|
-
f"by time_resolution_minutes ({resolution})"
|
|
76
|
+
f"by time_resolution_minutes ({resolution})",
|
|
71
77
|
)
|
|
72
|
-
return
|
|
78
|
+
return self
|
|
73
79
|
|
|
74
80
|
|
|
75
81
|
class DropoutMixin(Base):
|
|
76
|
-
"""Mixin class, to add dropout minutes"""
|
|
82
|
+
"""Mixin class, to add dropout minutes."""
|
|
77
83
|
|
|
78
|
-
dropout_timedeltas_minutes:
|
|
84
|
+
dropout_timedeltas_minutes: list[int] = Field(
|
|
79
85
|
default=[],
|
|
80
86
|
description="List of possible minutes before t0 where data availability may start. Must be "
|
|
81
87
|
"negative or zero.",
|
|
@@ -89,14 +95,16 @@ class DropoutMixin(Base):
|
|
|
89
95
|
)
|
|
90
96
|
|
|
91
97
|
@field_validator("dropout_timedeltas_minutes")
|
|
92
|
-
def dropout_timedeltas_minutes_negative(cls, v:
|
|
93
|
-
"""Validate 'dropout_timedeltas_minutes'"""
|
|
98
|
+
def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
|
|
99
|
+
"""Validate 'dropout_timedeltas_minutes'."""
|
|
94
100
|
for m in v:
|
|
95
|
-
|
|
101
|
+
if m > 0:
|
|
102
|
+
raise ValueError("Dropout timedeltas must be negative")
|
|
96
103
|
return v
|
|
97
104
|
|
|
98
105
|
@model_validator(mode="after")
|
|
99
|
-
def dropout_instructions_consistent(self) ->
|
|
106
|
+
def dropout_instructions_consistent(self) -> "DropoutMixin":
|
|
107
|
+
"""Validator for dropout instructions."""
|
|
100
108
|
if self.dropout_fraction == 0:
|
|
101
109
|
if self.dropout_timedeltas_minutes != []:
|
|
102
110
|
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
|
|
@@ -107,7 +115,7 @@ class DropoutMixin(Base):
|
|
|
107
115
|
|
|
108
116
|
|
|
109
117
|
class SpatialWindowMixin(Base):
|
|
110
|
-
"""Mixin class, to add path and image size"""
|
|
118
|
+
"""Mixin class, to add path and image size."""
|
|
111
119
|
|
|
112
120
|
image_size_pixels_height: int = Field(
|
|
113
121
|
...,
|
|
@@ -122,9 +130,37 @@ class SpatialWindowMixin(Base):
|
|
|
122
130
|
)
|
|
123
131
|
|
|
124
132
|
|
|
125
|
-
class
|
|
126
|
-
"""
|
|
127
|
-
|
|
133
|
+
class NormalisationValues(Base):
|
|
134
|
+
"""Normalisation mean and standard deviation."""
|
|
135
|
+
mean: float = Field(..., description="Mean value for normalization")
|
|
136
|
+
std: float = Field(..., gt=0, description="Standard deviation (must be positive)")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class NormalisationConstantsMixin(Base):
|
|
140
|
+
"""Normalisation constants for multiple channels."""
|
|
141
|
+
normalisation_constants: dict[str, NormalisationValues]
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def channel_means(self) -> dict[str, float]:
|
|
145
|
+
"""Return the channel means."""
|
|
146
|
+
return {
|
|
147
|
+
channel: norm_values.mean
|
|
148
|
+
for channel, norm_values in self.normalisation_constants.items()
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def channel_stds(self) -> dict[str, float]:
|
|
154
|
+
"""Return the channel standard deviations."""
|
|
155
|
+
return {
|
|
156
|
+
channel: norm_values.std
|
|
157
|
+
for channel, norm_values in self.normalisation_constants.items()
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
|
|
162
|
+
"""Satellite configuration model."""
|
|
163
|
+
|
|
128
164
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
129
165
|
...,
|
|
130
166
|
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
@@ -132,82 +168,123 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
|
132
168
|
)
|
|
133
169
|
|
|
134
170
|
channels: list[str] = Field(
|
|
135
|
-
...,
|
|
171
|
+
...,
|
|
172
|
+
description="the satellite channels that are used",
|
|
136
173
|
)
|
|
137
174
|
|
|
175
|
+
@model_validator(mode="after")
|
|
176
|
+
def check_all_channel_have_normalisation_constants(self) -> "Satellite":
|
|
177
|
+
"""Check that all the channels have normalisation constants."""
|
|
178
|
+
normalisation_channels = set(self.normalisation_constants.keys())
|
|
179
|
+
missing_norm_values = set(self.channels) - set(normalisation_channels)
|
|
180
|
+
if len(missing_norm_values)>0:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"Normalsation constants must be provided for all channels. Missing values for "
|
|
183
|
+
f"channels: {missing_norm_values}",
|
|
184
|
+
)
|
|
185
|
+
return self
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
|
|
189
|
+
"""NWP configuration model."""
|
|
138
190
|
|
|
139
|
-
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
140
|
-
"""NWP configuration model"""
|
|
141
|
-
|
|
142
191
|
zarr_path: str | tuple[str] | list[str] = Field(
|
|
143
192
|
...,
|
|
144
193
|
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
|
|
145
194
|
"to read from alternative filesystems.",
|
|
146
195
|
)
|
|
147
|
-
|
|
196
|
+
|
|
148
197
|
channels: list[str] = Field(
|
|
149
|
-
...,
|
|
198
|
+
...,
|
|
199
|
+
description="the channels used in the nwp data",
|
|
150
200
|
)
|
|
151
201
|
|
|
152
202
|
provider: str = Field(..., description="The provider of the NWP data")
|
|
153
203
|
|
|
154
204
|
accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
|
|
155
205
|
|
|
156
|
-
max_staleness_minutes:
|
|
206
|
+
max_staleness_minutes: int | None = Field(
|
|
157
207
|
None,
|
|
158
208
|
description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
|
|
159
209
|
" used to construct an example. If set to None, then the max staleness is set according to"
|
|
160
210
|
" the maximum forecast horizon of the NWP and the requested forecast length.",
|
|
161
211
|
)
|
|
162
212
|
|
|
163
|
-
|
|
164
213
|
@field_validator("provider")
|
|
165
214
|
def validate_provider(cls, v: str) -> str:
|
|
166
|
-
"""
|
|
215
|
+
"""Validator for 'provider'."""
|
|
167
216
|
if v.lower() not in NWP_PROVIDERS:
|
|
168
|
-
|
|
169
|
-
raise Exception(message)
|
|
217
|
+
raise OSError(f"NWP provider {v} is not in {NWP_PROVIDERS}")
|
|
170
218
|
return v
|
|
171
219
|
|
|
172
220
|
|
|
221
|
+
@model_validator(mode="after")
|
|
222
|
+
def check_all_channel_have_normalisation_constants(self) -> "NWP":
|
|
223
|
+
"""Check that all the channels have normalisation constants."""
|
|
224
|
+
normalisation_channels = set(self.normalisation_constants.keys())
|
|
225
|
+
non_accum_channels = [c for c in self.channels if c not in self.accum_channels]
|
|
226
|
+
accum_channel_names = [f"diff_{c}" for c in self.accum_channels]
|
|
227
|
+
|
|
228
|
+
missing_norm_values = set(non_accum_channels) - set(normalisation_channels)
|
|
229
|
+
if len(missing_norm_values)>0:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
"Normalsation constants must be provided for all channels. Missing values for "
|
|
232
|
+
f"channels: {missing_norm_values}",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
missing_norm_values = set(accum_channel_names) - set(normalisation_channels)
|
|
236
|
+
if len(missing_norm_values)>0:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Normalsation constants must be provided for all channels. Accumulated "
|
|
239
|
+
"channels which will be diffed require normalisation constant names which "
|
|
240
|
+
"start with the prefix 'diff_'. The following channels were missing: "
|
|
241
|
+
f"{missing_norm_values}.",
|
|
242
|
+
)
|
|
243
|
+
return self
|
|
244
|
+
|
|
245
|
+
|
|
173
246
|
class MultiNWP(RootModel):
|
|
174
|
-
"""Configuration for multiple NWPs"""
|
|
247
|
+
"""Configuration for multiple NWPs."""
|
|
175
248
|
|
|
176
|
-
root:
|
|
249
|
+
root: dict[str, NWP]
|
|
177
250
|
|
|
178
|
-
|
|
251
|
+
@override
|
|
252
|
+
def __getattr__(self, item: str) -> NWP:
|
|
179
253
|
return self.root[item]
|
|
180
254
|
|
|
181
|
-
|
|
255
|
+
@override
|
|
256
|
+
def __getitem__(self, item: str) -> NWP:
|
|
182
257
|
return self.root[item]
|
|
183
258
|
|
|
184
|
-
|
|
259
|
+
@override
|
|
260
|
+
def __len__(self) -> int:
|
|
185
261
|
return len(self.root)
|
|
186
262
|
|
|
187
|
-
|
|
263
|
+
@override
|
|
264
|
+
def __iter__(self) -> Iterator:
|
|
188
265
|
return iter(self.root)
|
|
189
266
|
|
|
190
|
-
def keys(self):
|
|
191
|
-
"""Returns dictionary-like keys"""
|
|
267
|
+
def keys(self) -> Iterator[str]:
|
|
268
|
+
"""Returns dictionary-like keys."""
|
|
192
269
|
return self.root.keys()
|
|
193
270
|
|
|
194
|
-
def items(self):
|
|
195
|
-
"""Returns dictionary-like items"""
|
|
271
|
+
def items(self) -> Iterator[tuple[str, NWP]]:
|
|
272
|
+
"""Returns dictionary-like items."""
|
|
196
273
|
return self.root.items()
|
|
197
274
|
|
|
198
275
|
|
|
199
276
|
class GSP(TimeWindowMixin, DropoutMixin):
|
|
200
|
-
"""GSP configuration model"""
|
|
277
|
+
"""GSP configuration model."""
|
|
201
278
|
|
|
202
279
|
zarr_path: str = Field(
|
|
203
|
-
...,
|
|
280
|
+
...,
|
|
204
281
|
description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
|
|
205
282
|
"to read from alternative filesystems.",
|
|
206
283
|
)
|
|
207
284
|
|
|
208
285
|
|
|
209
286
|
class Site(TimeWindowMixin, DropoutMixin):
|
|
210
|
-
"""Site configuration model"""
|
|
287
|
+
"""Site configuration model."""
|
|
211
288
|
|
|
212
289
|
file_path: str = Field(
|
|
213
290
|
...,
|
|
@@ -222,17 +299,22 @@ class Site(TimeWindowMixin, DropoutMixin):
|
|
|
222
299
|
# TODO validate the csv for metadata
|
|
223
300
|
|
|
224
301
|
|
|
302
|
+
class SolarPosition(TimeWindowMixin):
|
|
303
|
+
"""Solar position configuration model."""
|
|
304
|
+
|
|
305
|
+
|
|
225
306
|
class InputData(Base):
|
|
226
|
-
"""Input data model"""
|
|
307
|
+
"""Input data model."""
|
|
227
308
|
|
|
228
|
-
satellite:
|
|
229
|
-
nwp:
|
|
230
|
-
gsp:
|
|
231
|
-
site:
|
|
309
|
+
satellite: Satellite | None = None
|
|
310
|
+
nwp: MultiNWP | None = None
|
|
311
|
+
gsp: GSP | None = None
|
|
312
|
+
site: Site | None = None
|
|
313
|
+
solar_position: SolarPosition | None = None
|
|
232
314
|
|
|
233
315
|
|
|
234
316
|
class Configuration(Base):
|
|
235
|
-
"""Configuration model for the dataset"""
|
|
317
|
+
"""Configuration model for the dataset."""
|
|
236
318
|
|
|
237
319
|
general: General = General()
|
|
238
320
|
input_data: InputData = InputData()
|
ocf_data_sampler/config/save.py
CHANGED
|
@@ -5,12 +5,14 @@ supporting local and cloud storage locations.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
import os
|
|
9
|
+
|
|
8
10
|
import fsspec
|
|
9
11
|
import yaml
|
|
10
|
-
import os
|
|
11
12
|
|
|
12
13
|
from ocf_data_sampler.config import Configuration
|
|
13
14
|
|
|
15
|
+
|
|
14
16
|
def save_yaml_configuration(configuration: Configuration, filename: str) -> None:
|
|
15
17
|
"""Save a configuration object to a YAML file.
|
|
16
18
|
|
|
@@ -20,12 +22,11 @@ def save_yaml_configuration(configuration: Configuration, filename: str) -> None
|
|
|
20
22
|
cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
|
|
21
23
|
absolute paths are recommended.
|
|
22
24
|
"""
|
|
23
|
-
|
|
24
25
|
if os.path.exists(filename):
|
|
25
26
|
raise FileExistsError(f"File already exists: {filename}")
|
|
26
27
|
|
|
27
28
|
# Serialize configuration to JSON-compatible dictionary
|
|
28
29
|
config_dict = json.loads(configuration.model_dump_json())
|
|
29
30
|
|
|
30
|
-
with fsspec.open(filename, mode=
|
|
31
|
-
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
|
31
|
+
with fsspec.open(filename, mode="w") as yaml_file:
|
|
32
|
+
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
ocf_data_sampler/load/gsp.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
"""Functions for loading GSP data."""
|
|
2
|
+
|
|
3
|
+
from importlib.resources import files
|
|
2
4
|
|
|
3
5
|
import pandas as pd
|
|
4
6
|
import xarray as xr
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
def open_gsp(zarr_path: str) -> xr.DataArray:
|
|
8
|
-
"""Open the GSP data
|
|
9
|
-
|
|
10
|
+
"""Open the GSP data.
|
|
11
|
+
|
|
10
12
|
Args:
|
|
11
13
|
zarr_path: Path to the GSP zarr data
|
|
12
14
|
|
|
13
15
|
Returns:
|
|
14
16
|
xr.DataArray: The opened GSP data
|
|
15
17
|
"""
|
|
16
|
-
|
|
17
18
|
ds = xr.open_zarr(zarr_path)
|
|
18
19
|
|
|
19
20
|
ds = ds.rename({"datetime_gmt": "time_utc"})
|
|
20
21
|
|
|
21
22
|
# Load UK GSP locations
|
|
22
23
|
df_gsp_loc = pd.read_csv(
|
|
23
|
-
|
|
24
|
+
files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
|
|
24
25
|
index_col="gsp_id",
|
|
25
26
|
)
|
|
26
27
|
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Loads all data sources."""
|
|
2
|
+
|
|
2
3
|
import xarray as xr
|
|
3
4
|
|
|
4
5
|
from ocf_data_sampler.config import InputData
|
|
5
|
-
from ocf_data_sampler.load import
|
|
6
|
+
from ocf_data_sampler.load import open_gsp, open_nwp, open_sat_data, open_site
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] | xr.DataArray]:
|
|
9
|
-
"""Construct dictionary of all of the input data sources
|
|
10
|
+
"""Construct dictionary of all of the input data sources.
|
|
10
11
|
|
|
11
12
|
Args:
|
|
12
13
|
input_config: InputData configuration object
|
|
13
14
|
"""
|
|
14
|
-
|
|
15
15
|
datasets_dict = {}
|
|
16
16
|
|
|
17
17
|
# Load GSP data unless the path is None
|
|
@@ -23,10 +23,8 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
|
|
|
23
23
|
|
|
24
24
|
# Load NWP data if in config
|
|
25
25
|
if input_config.nwp:
|
|
26
|
-
|
|
27
26
|
datasets_dict["nwp"] = {}
|
|
28
27
|
for nwp_source, nwp_config in input_config.nwp.items():
|
|
29
|
-
|
|
30
28
|
da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
|
|
31
29
|
|
|
32
30
|
da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
|
|
@@ -48,6 +46,7 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
|
|
|
48
46
|
generation_file_path=input_config.site.file_path,
|
|
49
47
|
metadata_file_path=input_config.site.metadata_file_path,
|
|
50
48
|
)
|
|
49
|
+
|
|
51
50
|
datasets_dict["site"] = da_sites
|
|
52
51
|
|
|
53
52
|
return datasets_dict
|
ocf_data_sampler/load/nwp/nwp.py
CHANGED
|
@@ -1,22 +1,34 @@
|
|
|
1
|
+
"""Module for opening NWP data."""
|
|
2
|
+
|
|
1
3
|
import xarray as xr
|
|
2
4
|
|
|
3
|
-
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
4
5
|
from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs
|
|
6
|
+
from ocf_data_sampler.load.nwp.providers.gfs import open_gfs
|
|
7
|
+
from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
|
|
8
|
+
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
|
|
8
|
-
"""Opens NWP zarr
|
|
12
|
+
"""Opens NWP zarr.
|
|
9
13
|
|
|
10
14
|
Args:
|
|
11
15
|
zarr_path: path to the zarr file
|
|
12
16
|
provider: NWP provider
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Xarray DataArray of the NWP data
|
|
13
20
|
"""
|
|
21
|
+
provider = provider.lower()
|
|
14
22
|
|
|
15
|
-
if provider
|
|
23
|
+
if provider == "ukv":
|
|
16
24
|
_open_nwp = open_ukv
|
|
17
|
-
elif provider
|
|
25
|
+
elif provider == "ecmwf":
|
|
18
26
|
_open_nwp = open_ifs
|
|
27
|
+
elif provider == "icon-eu":
|
|
28
|
+
_open_nwp = open_icon_eu
|
|
29
|
+
elif provider == "gfs":
|
|
30
|
+
_open_nwp = open_gfs
|
|
19
31
|
else:
|
|
20
32
|
raise ValueError(f"Unknown provider: {provider}")
|
|
21
|
-
return _open_nwp(zarr_path)
|
|
22
33
|
|
|
34
|
+
return _open_nwp(zarr_path)
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
"""ECMWF provider loaders"""
|
|
1
|
+
"""ECMWF provider loaders."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
|
+
|
|
4
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
5
6
|
from ocf_data_sampler.load.utils import (
|
|
6
7
|
check_time_unique_increasing,
|
|
8
|
+
get_xr_data_array_from_xr_dataset,
|
|
7
9
|
make_spatial_coords_increasing,
|
|
8
|
-
get_xr_data_array_from_xr_dataset
|
|
9
10
|
)
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
13
|
-
"""
|
|
14
|
-
Opens the ECMWF IFS NWP data
|
|
14
|
+
"""Opens the ECMWF IFS NWP data.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
zarr_path: Path to the zarr to open
|
|
@@ -19,9 +19,8 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
19
19
|
Returns:
|
|
20
20
|
Xarray DataArray of the NWP data
|
|
21
21
|
"""
|
|
22
|
-
|
|
23
22
|
ds = open_zarr_paths(zarr_path)
|
|
24
|
-
|
|
23
|
+
|
|
25
24
|
# LEGACY SUPPORT - rename variable to channel if it exists
|
|
26
25
|
ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
|
|
27
26
|
|
|
@@ -30,6 +29,6 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
30
29
|
ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
|
|
31
30
|
|
|
32
31
|
ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
|
|
33
|
-
|
|
32
|
+
|
|
34
33
|
# TODO: should we control the dtype of the DataArray?
|
|
35
34
|
return get_xr_data_array_from_xr_dataset(ds)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Open GFS Forecast data."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
8
|
+
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
|
|
9
|
+
|
|
10
|
+
_log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
|
|
14
|
+
"""Opens the GFS data.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
zarr_path: Path to the zarr to open
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Xarray DataArray of the NWP data
|
|
21
|
+
"""
|
|
22
|
+
_log.info("Loading NWP GFS data")
|
|
23
|
+
|
|
24
|
+
# Open data
|
|
25
|
+
gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc")
|
|
26
|
+
nwp: xr.DataArray = gfs.to_array()
|
|
27
|
+
|
|
28
|
+
del gfs
|
|
29
|
+
|
|
30
|
+
nwp = nwp.rename({"variable": "channel","init_time": "init_time_utc"})
|
|
31
|
+
check_time_unique_increasing(nwp.init_time_utc)
|
|
32
|
+
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
|
|
33
|
+
|
|
34
|
+
nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
|
|
35
|
+
|
|
36
|
+
return nwp
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""DWD ICON Loading."""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
6
|
+
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
|
|
10
|
+
"""Removes the isobaric levels from the coordinates of the NWP data.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
nwp: NWP data
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
NWP data without isobaric levels in the coordinates
|
|
17
|
+
"""
|
|
18
|
+
variables_to_drop = [var for var in nwp.data_vars if "isobaricInhPa" in nwp[var].dims]
|
|
19
|
+
return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def open_icon_eu(zarr_path: str) -> xr.Dataset:
|
|
23
|
+
"""Opens the ICON data.
|
|
24
|
+
|
|
25
|
+
ICON EU Data is on a regular lat/lon grid
|
|
26
|
+
It has data on multiple pressure levels, as well as the surface
|
|
27
|
+
Each of the variables is its own data variable
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
zarr_path: Path to the zarr to open
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Xarray DataArray of the NWP data
|
|
34
|
+
"""
|
|
35
|
+
# Open the data
|
|
36
|
+
nwp = open_zarr_paths(zarr_path, time_dim="time")
|
|
37
|
+
nwp = nwp.rename({"time": "init_time_utc"})
|
|
38
|
+
# Sanity checks.
|
|
39
|
+
check_time_unique_increasing(nwp.init_time_utc)
|
|
40
|
+
# 0-78 one hour steps, rest 3 hour steps
|
|
41
|
+
nwp = nwp.isel(step=slice(0, 78))
|
|
42
|
+
nwp = remove_isobaric_lelvels_from_coords(nwp)
|
|
43
|
+
nwp = nwp.to_array().rename({"variable": "channel"})
|
|
44
|
+
nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
|
|
45
|
+
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
|
|
46
|
+
return nwp
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
"""UKV provider loaders"""
|
|
1
|
+
"""UKV provider loaders."""
|
|
2
2
|
|
|
3
3
|
import xarray as xr
|
|
4
4
|
|
|
5
5
|
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
|
|
6
6
|
from ocf_data_sampler.load.utils import (
|
|
7
7
|
check_time_unique_increasing,
|
|
8
|
+
get_xr_data_array_from_xr_dataset,
|
|
8
9
|
make_spatial_coords_increasing,
|
|
9
|
-
get_xr_data_array_from_xr_dataset
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
|
|
14
|
-
"""
|
|
15
|
-
Opens the NWP data
|
|
14
|
+
"""Opens the NWP data.
|
|
16
15
|
|
|
17
16
|
Args:
|
|
18
17
|
zarr_path: Path to the zarr to open
|
|
@@ -28,7 +27,7 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
28
27
|
"variable": "channel",
|
|
29
28
|
"x": "x_osgb",
|
|
30
29
|
"y": "y_osgb",
|
|
31
|
-
}
|
|
30
|
+
},
|
|
32
31
|
)
|
|
33
32
|
|
|
34
33
|
check_time_unique_increasing(ds.init_time_utc)
|