ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (78) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +146 -64
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/load/gsp.py +6 -5
  5. ocf_data_sampler/load/load_dataset.py +5 -6
  6. ocf_data_sampler/load/nwp/nwp.py +17 -5
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  8. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  9. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  10. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  11. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  12. ocf_data_sampler/load/satellite.py +9 -10
  13. ocf_data_sampler/load/site.py +10 -6
  14. ocf_data_sampler/load/utils.py +21 -16
  15. ocf_data_sampler/numpy_sample/collate.py +10 -9
  16. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  17. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  18. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  19. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  20. ocf_data_sampler/numpy_sample/site.py +5 -8
  21. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  22. ocf_data_sampler/sample/base.py +15 -17
  23. ocf_data_sampler/sample/site.py +13 -20
  24. ocf_data_sampler/sample/uk_regional.py +29 -35
  25. ocf_data_sampler/select/dropout.py +16 -14
  26. ocf_data_sampler/select/fill_time_periods.py +15 -5
  27. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  28. ocf_data_sampler/select/geospatial.py +63 -54
  29. ocf_data_sampler/select/location.py +16 -51
  30. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  31. ocf_data_sampler/select/select_time_slice.py +71 -58
  32. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  33. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  34. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
  35. ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
  36. ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  37. ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +63 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler/constants.py +0 -222
  48. ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
  49. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  50. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  51. tests/__init__.py +0 -0
  52. tests/config/test_config.py +0 -113
  53. tests/config/test_load.py +0 -7
  54. tests/config/test_save.py +0 -28
  55. tests/conftest.py +0 -319
  56. tests/load/test_load_gsp.py +0 -15
  57. tests/load/test_load_nwp.py +0 -21
  58. tests/load/test_load_satellite.py +0 -17
  59. tests/load/test_load_sites.py +0 -14
  60. tests/numpy_sample/test_collate.py +0 -21
  61. tests/numpy_sample/test_datetime_features.py +0 -37
  62. tests/numpy_sample/test_gsp.py +0 -38
  63. tests/numpy_sample/test_nwp.py +0 -13
  64. tests/numpy_sample/test_satellite.py +0 -40
  65. tests/numpy_sample/test_sun_position.py +0 -81
  66. tests/select/test_dropout.py +0 -69
  67. tests/select/test_fill_time_periods.py +0 -28
  68. tests/select/test_find_contiguous_time_periods.py +0 -202
  69. tests/select/test_location.py +0 -67
  70. tests/select/test_select_spatial_slice.py +0 -154
  71. tests/select/test_select_time_slice.py +0 -275
  72. tests/test_sample/test_base.py +0 -164
  73. tests/test_sample/test_site_sample.py +0 -165
  74. tests/test_sample/test_uk_regional_sample.py +0 -136
  75. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  76. tests/torch_datasets/test_pvnet_uk.py +0 -154
  77. tests/torch_datasets/test_site.py +0 -226
  78. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,13 +1,13 @@
1
- """Load configuration from a yaml file"""
1
+ """Load configuration from a yaml file."""
2
2
 
3
3
  import fsspec
4
4
  from pyaml_env import parse_config
5
+
5
6
  from ocf_data_sampler.config import Configuration
6
7
 
7
8
 
8
9
  def load_yaml_configuration(filename: str) -> Configuration:
9
- """
10
- Load a yaml file which has a configuration in it
10
+ """Load a yaml file which has a configuration in it.
11
11
 
12
12
  Args:
13
13
  filename: the yaml file name that you want to load. Will load from local, AWS, or GCP
@@ -1,45 +1,50 @@
1
1
  """Configuration model for the dataset.
2
2
 
3
-
4
- Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// to read from alternative filesystems.
5
-
3
+ Absolute or relative zarr filepath(s).
4
+ Prefix with a protocol like s3:// to read from alternative filesystems.
6
5
  """
7
6
 
8
- from typing import Dict, List, Optional
9
- from typing_extensions import Self
7
+ from collections.abc import Iterator
10
8
 
11
- from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
9
+ from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
10
+ from typing_extensions import override
12
11
 
13
- from ocf_data_sampler.constants import NWP_PROVIDERS
12
+ NWP_PROVIDERS = [
13
+ "ukv",
14
+ "ecmwf",
15
+ "gfs",
16
+ "icon_eu",
17
+ ]
14
18
 
15
19
 
16
20
  class Base(BaseModel):
17
- """Pydantic Base model where no extras can be added"""
21
+ """Pydantic Base model where no extras can be added."""
18
22
 
19
23
  class Config:
20
- """config class"""
24
+ """Config class."""
21
25
 
22
26
  extra = "forbid" # forbid use of extra kwargs
23
27
 
24
28
 
25
29
  class General(Base):
26
- """General pydantic model"""
30
+ """General pydantic model."""
27
31
 
28
32
  name: str = Field("example", description="The name of this configuration file")
29
33
  description: str = Field(
30
- "example configuration", description="Description of this configuration file"
34
+ "example configuration",
35
+ description="Description of this configuration file",
31
36
  )
32
37
 
33
38
 
34
39
  class TimeWindowMixin(Base):
35
- """Mixin class, to add interval start, end and resolution minutes"""
40
+ """Mixin class, to add interval start, end and resolution minutes."""
36
41
 
37
42
  time_resolution_minutes: int = Field(
38
43
  ...,
39
44
  gt=0,
40
45
  description="The temporal resolution of the data in minutes",
41
46
  )
42
-
47
+
43
48
  interval_start_minutes: int = Field(
44
49
  ...,
45
50
  description="Data interval starts at `t0 + interval_start_minutes`",
@@ -50,32 +55,33 @@ class TimeWindowMixin(Base):
50
55
  description="Data interval ends at `t0 + interval_end_minutes`",
51
56
  )
52
57
 
53
- @model_validator(mode='after')
54
- def validate_intervals(cls, values):
55
- start = values.interval_start_minutes
56
- end = values.interval_end_minutes
57
- resolution = values.time_resolution_minutes
58
+ @model_validator(mode="after")
59
+ def validate_intervals(self) -> "TimeWindowMixin":
60
+ """Validator for time interval fields."""
61
+ start = self.interval_start_minutes
62
+ end = self.interval_end_minutes
63
+ resolution = self.time_resolution_minutes
58
64
  if start > end:
59
65
  raise ValueError(
60
- f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
66
+ f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})",
61
67
  )
62
- if (start % resolution != 0):
68
+ if start % resolution != 0:
63
69
  raise ValueError(
64
70
  f"interval_start_minutes ({start}) must be divisible "
65
- f"by time_resolution_minutes ({resolution})"
71
+ f"by time_resolution_minutes ({resolution})",
66
72
  )
67
- if (end % resolution != 0):
73
+ if end % resolution != 0:
68
74
  raise ValueError(
69
75
  f"interval_end_minutes ({end}) must be divisible "
70
- f"by time_resolution_minutes ({resolution})"
76
+ f"by time_resolution_minutes ({resolution})",
71
77
  )
72
- return values
78
+ return self
73
79
 
74
80
 
75
81
  class DropoutMixin(Base):
76
- """Mixin class, to add dropout minutes"""
82
+ """Mixin class, to add dropout minutes."""
77
83
 
78
- dropout_timedeltas_minutes: List[int] = Field(
84
+ dropout_timedeltas_minutes: list[int] = Field(
79
85
  default=[],
80
86
  description="List of possible minutes before t0 where data availability may start. Must be "
81
87
  "negative or zero.",
@@ -89,14 +95,16 @@ class DropoutMixin(Base):
89
95
  )
90
96
 
91
97
  @field_validator("dropout_timedeltas_minutes")
92
- def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
93
- """Validate 'dropout_timedeltas_minutes'"""
98
+ def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
99
+ """Validate 'dropout_timedeltas_minutes'."""
94
100
  for m in v:
95
- assert m <= 0, "Dropout timedeltas must be negative"
101
+ if m > 0:
102
+ raise ValueError("Dropout timedeltas must be negative")
96
103
  return v
97
104
 
98
105
  @model_validator(mode="after")
99
- def dropout_instructions_consistent(self) -> Self:
106
+ def dropout_instructions_consistent(self) -> "DropoutMixin":
107
+ """Validator for dropout instructions."""
100
108
  if self.dropout_fraction == 0:
101
109
  if self.dropout_timedeltas_minutes != []:
102
110
  raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
@@ -107,7 +115,7 @@ class DropoutMixin(Base):
107
115
 
108
116
 
109
117
  class SpatialWindowMixin(Base):
110
- """Mixin class, to add path and image size"""
118
+ """Mixin class, to add path and image size."""
111
119
 
112
120
  image_size_pixels_height: int = Field(
113
121
  ...,
@@ -122,9 +130,37 @@ class SpatialWindowMixin(Base):
122
130
  )
123
131
 
124
132
 
125
- class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
126
- """Satellite configuration model"""
127
-
133
+ class NormalisationValues(Base):
134
+ """Normalisation mean and standard deviation."""
135
+ mean: float = Field(..., description="Mean value for normalization")
136
+ std: float = Field(..., gt=0, description="Standard deviation (must be positive)")
137
+
138
+
139
+ class NormalisationConstantsMixin(Base):
140
+ """Normalisation constants for multiple channels."""
141
+ normalisation_constants: dict[str, NormalisationValues]
142
+
143
+ @property
144
+ def channel_means(self) -> dict[str, float]:
145
+ """Return the channel means."""
146
+ return {
147
+ channel: norm_values.mean
148
+ for channel, norm_values in self.normalisation_constants.items()
149
+ }
150
+
151
+
152
+ @property
153
+ def channel_stds(self) -> dict[str, float]:
154
+ """Return the channel standard deviations."""
155
+ return {
156
+ channel: norm_values.std
157
+ for channel, norm_values in self.normalisation_constants.items()
158
+ }
159
+
160
+
161
+ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
162
+ """Satellite configuration model."""
163
+
128
164
  zarr_path: str | tuple[str] | list[str] = Field(
129
165
  ...,
130
166
  description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
@@ -132,82 +168,123 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
132
168
  )
133
169
 
134
170
  channels: list[str] = Field(
135
- ..., description="the satellite channels that are used"
171
+ ...,
172
+ description="the satellite channels that are used",
136
173
  )
137
174
 
175
+ @model_validator(mode="after")
176
+ def check_all_channel_have_normalisation_constants(self) -> "Satellite":
177
+ """Check that all the channels have normalisation constants."""
178
+ normalisation_channels = set(self.normalisation_constants.keys())
179
+ missing_norm_values = set(self.channels) - set(normalisation_channels)
180
+ if len(missing_norm_values)>0:
181
+ raise ValueError(
182
+ "Normalsation constants must be provided for all channels. Missing values for "
183
+ f"channels: {missing_norm_values}",
184
+ )
185
+ return self
186
+
187
+
188
+ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
189
+ """NWP configuration model."""
138
190
 
139
- class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
140
- """NWP configuration model"""
141
-
142
191
  zarr_path: str | tuple[str] | list[str] = Field(
143
192
  ...,
144
193
  description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
145
194
  "to read from alternative filesystems.",
146
195
  )
147
-
196
+
148
197
  channels: list[str] = Field(
149
- ..., description="the channels used in the nwp data"
198
+ ...,
199
+ description="the channels used in the nwp data",
150
200
  )
151
201
 
152
202
  provider: str = Field(..., description="The provider of the NWP data")
153
203
 
154
204
  accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
155
205
 
156
- max_staleness_minutes: Optional[int] = Field(
206
+ max_staleness_minutes: int | None = Field(
157
207
  None,
158
208
  description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
159
209
  " used to construct an example. If set to None, then the max staleness is set according to"
160
210
  " the maximum forecast horizon of the NWP and the requested forecast length.",
161
211
  )
162
212
 
163
-
164
213
  @field_validator("provider")
165
214
  def validate_provider(cls, v: str) -> str:
166
- """Validate 'provider'"""
215
+ """Validator for 'provider'."""
167
216
  if v.lower() not in NWP_PROVIDERS:
168
- message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
169
- raise Exception(message)
217
+ raise OSError(f"NWP provider {v} is not in {NWP_PROVIDERS}")
170
218
  return v
171
219
 
172
220
 
221
+ @model_validator(mode="after")
222
+ def check_all_channel_have_normalisation_constants(self) -> "NWP":
223
+ """Check that all the channels have normalisation constants."""
224
+ normalisation_channels = set(self.normalisation_constants.keys())
225
+ non_accum_channels = [c for c in self.channels if c not in self.accum_channels]
226
+ accum_channel_names = [f"diff_{c}" for c in self.accum_channels]
227
+
228
+ missing_norm_values = set(non_accum_channels) - set(normalisation_channels)
229
+ if len(missing_norm_values)>0:
230
+ raise ValueError(
231
+ "Normalsation constants must be provided for all channels. Missing values for "
232
+ f"channels: {missing_norm_values}",
233
+ )
234
+
235
+ missing_norm_values = set(accum_channel_names) - set(normalisation_channels)
236
+ if len(missing_norm_values)>0:
237
+ raise ValueError(
238
+ "Normalsation constants must be provided for all channels. Accumulated "
239
+ "channels which will be diffed require normalisation constant names which "
240
+ "start with the prefix 'diff_'. The following channels were missing: "
241
+ f"{missing_norm_values}.",
242
+ )
243
+ return self
244
+
245
+
173
246
  class MultiNWP(RootModel):
174
- """Configuration for multiple NWPs"""
247
+ """Configuration for multiple NWPs."""
175
248
 
176
- root: Dict[str, NWP]
249
+ root: dict[str, NWP]
177
250
 
178
- def __getattr__(self, item):
251
+ @override
252
+ def __getattr__(self, item: str) -> NWP:
179
253
  return self.root[item]
180
254
 
181
- def __getitem__(self, item):
255
+ @override
256
+ def __getitem__(self, item: str) -> NWP:
182
257
  return self.root[item]
183
258
 
184
- def __len__(self):
259
+ @override
260
+ def __len__(self) -> int:
185
261
  return len(self.root)
186
262
 
187
- def __iter__(self):
263
+ @override
264
+ def __iter__(self) -> Iterator:
188
265
  return iter(self.root)
189
266
 
190
- def keys(self):
191
- """Returns dictionary-like keys"""
267
+ def keys(self) -> Iterator[str]:
268
+ """Returns dictionary-like keys."""
192
269
  return self.root.keys()
193
270
 
194
- def items(self):
195
- """Returns dictionary-like items"""
271
+ def items(self) -> Iterator[tuple[str, NWP]]:
272
+ """Returns dictionary-like items."""
196
273
  return self.root.items()
197
274
 
198
275
 
199
276
  class GSP(TimeWindowMixin, DropoutMixin):
200
- """GSP configuration model"""
277
+ """GSP configuration model."""
201
278
 
202
279
  zarr_path: str = Field(
203
- ...,
280
+ ...,
204
281
  description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
205
282
  "to read from alternative filesystems.",
206
283
  )
207
284
 
208
285
 
209
286
  class Site(TimeWindowMixin, DropoutMixin):
210
- """Site configuration model"""
287
+ """Site configuration model."""
211
288
 
212
289
  file_path: str = Field(
213
290
  ...,
@@ -222,17 +299,22 @@ class Site(TimeWindowMixin, DropoutMixin):
222
299
  # TODO validate the csv for metadata
223
300
 
224
301
 
302
+ class SolarPosition(TimeWindowMixin):
303
+ """Solar position configuration model."""
304
+
305
+
225
306
  class InputData(Base):
226
- """Input data model"""
307
+ """Input data model."""
227
308
 
228
- satellite: Optional[Satellite] = None
229
- nwp: Optional[MultiNWP] = None
230
- gsp: Optional[GSP] = None
231
- site: Optional[Site] = None
309
+ satellite: Satellite | None = None
310
+ nwp: MultiNWP | None = None
311
+ gsp: GSP | None = None
312
+ site: Site | None = None
313
+ solar_position: SolarPosition | None = None
232
314
 
233
315
 
234
316
  class Configuration(Base):
235
- """Configuration model for the dataset"""
317
+ """Configuration model for the dataset."""
236
318
 
237
319
  general: General = General()
238
320
  input_data: InputData = InputData()
@@ -5,12 +5,14 @@ supporting local and cloud storage locations.
5
5
  """
6
6
 
7
7
  import json
8
+ import os
9
+
8
10
  import fsspec
9
11
  import yaml
10
- import os
11
12
 
12
13
  from ocf_data_sampler.config import Configuration
13
14
 
15
+
14
16
  def save_yaml_configuration(configuration: Configuration, filename: str) -> None:
15
17
  """Save a configuration object to a YAML file.
16
18
 
@@ -20,12 +22,11 @@ def save_yaml_configuration(configuration: Configuration, filename: str) -> None
20
22
  cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
21
23
  absolute paths are recommended.
22
24
  """
23
-
24
25
  if os.path.exists(filename):
25
26
  raise FileExistsError(f"File already exists: {filename}")
26
27
 
27
28
  # Serialize configuration to JSON-compatible dictionary
28
29
  config_dict = json.loads(configuration.model_dump_json())
29
30
 
30
- with fsspec.open(filename, mode='w') as yaml_file:
31
- yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
31
+ with fsspec.open(filename, mode="w") as yaml_file:
32
+ yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
@@ -1,26 +1,27 @@
1
- import pkg_resources
1
+ """Functions for loading GSP data."""
2
+
3
+ from importlib.resources import files
2
4
 
3
5
  import pandas as pd
4
6
  import xarray as xr
5
7
 
6
8
 
7
9
  def open_gsp(zarr_path: str) -> xr.DataArray:
8
- """Open the GSP data
9
-
10
+ """Open the GSP data.
11
+
10
12
  Args:
11
13
  zarr_path: Path to the GSP zarr data
12
14
 
13
15
  Returns:
14
16
  xr.DataArray: The opened GSP data
15
17
  """
16
-
17
18
  ds = xr.open_zarr(zarr_path)
18
19
 
19
20
  ds = ds.rename({"datetime_gmt": "time_utc"})
20
21
 
21
22
  # Load UK GSP locations
22
23
  df_gsp_loc = pd.read_csv(
23
- pkg_resources.resource_filename(__name__, "../data/uk_gsp_locations.csv"),
24
+ files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
24
25
  index_col="gsp_id",
25
26
  )
26
27
 
@@ -1,17 +1,17 @@
1
- """ Loads all data sources """
1
+ """Loads all data sources."""
2
+
2
3
  import xarray as xr
3
4
 
4
5
  from ocf_data_sampler.config import InputData
5
- from ocf_data_sampler.load import open_nwp, open_gsp, open_sat_data, open_site
6
+ from ocf_data_sampler.load import open_gsp, open_nwp, open_sat_data, open_site
6
7
 
7
8
 
8
9
  def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] | xr.DataArray]:
9
- """Construct dictionary of all of the input data sources
10
+ """Construct dictionary of all of the input data sources.
10
11
 
11
12
  Args:
12
13
  input_config: InputData configuration object
13
14
  """
14
-
15
15
  datasets_dict = {}
16
16
 
17
17
  # Load GSP data unless the path is None
@@ -23,10 +23,8 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
23
23
 
24
24
  # Load NWP data if in config
25
25
  if input_config.nwp:
26
-
27
26
  datasets_dict["nwp"] = {}
28
27
  for nwp_source, nwp_config in input_config.nwp.items():
29
-
30
28
  da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
31
29
 
32
30
  da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
@@ -48,6 +46,7 @@ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] |
48
46
  generation_file_path=input_config.site.file_path,
49
47
  metadata_file_path=input_config.site.metadata_file_path,
50
48
  )
49
+
51
50
  datasets_dict["site"] = da_sites
52
51
 
53
52
  return datasets_dict
@@ -1,22 +1,34 @@
1
+ """Module for opening NWP data."""
2
+
1
3
  import xarray as xr
2
4
 
3
- from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
4
5
  from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs
6
+ from ocf_data_sampler.load.nwp.providers.gfs import open_gfs
7
+ from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
8
+ from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
5
9
 
6
10
 
7
11
  def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
8
- """Opens NWP zarr
12
+ """Opens NWP zarr.
9
13
 
10
14
  Args:
11
15
  zarr_path: path to the zarr file
12
16
  provider: NWP provider
17
+
18
+ Returns:
19
+ Xarray DataArray of the NWP data
13
20
  """
21
+ provider = provider.lower()
14
22
 
15
- if provider.lower() == "ukv":
23
+ if provider == "ukv":
16
24
  _open_nwp = open_ukv
17
- elif provider.lower() == "ecmwf":
25
+ elif provider == "ecmwf":
18
26
  _open_nwp = open_ifs
27
+ elif provider == "icon-eu":
28
+ _open_nwp = open_icon_eu
29
+ elif provider == "gfs":
30
+ _open_nwp = open_gfs
19
31
  else:
20
32
  raise ValueError(f"Unknown provider: {provider}")
21
- return _open_nwp(zarr_path)
22
33
 
34
+ return _open_nwp(zarr_path)
@@ -1,17 +1,17 @@
1
- """ECMWF provider loaders"""
1
+ """ECMWF provider loaders."""
2
2
 
3
3
  import xarray as xr
4
+
4
5
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
5
6
  from ocf_data_sampler.load.utils import (
6
7
  check_time_unique_increasing,
8
+ get_xr_data_array_from_xr_dataset,
7
9
  make_spatial_coords_increasing,
8
- get_xr_data_array_from_xr_dataset
9
10
  )
10
11
 
11
12
 
12
13
  def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
13
- """
14
- Opens the ECMWF IFS NWP data
14
+ """Opens the ECMWF IFS NWP data.
15
15
 
16
16
  Args:
17
17
  zarr_path: Path to the zarr to open
@@ -19,9 +19,8 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
19
19
  Returns:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
-
23
22
  ds = open_zarr_paths(zarr_path)
24
-
23
+
25
24
  # LEGACY SUPPORT - rename variable to channel if it exists
26
25
  ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
27
26
 
@@ -30,6 +29,6 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
30
29
  ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
31
30
 
32
31
  ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
33
-
32
+
34
33
  # TODO: should we control the dtype of the DataArray?
35
34
  return get_xr_data_array_from_xr_dataset(ds)
@@ -0,0 +1,36 @@
1
+ """Open GFS Forecast data."""
2
+
3
+ import logging
4
+
5
+ import xarray as xr
6
+
7
+ from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
8
+ from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
9
+
10
+ _log = logging.getLogger(__name__)
11
+
12
+
13
+ def open_gfs(zarr_path: str | list[str]) -> xr.DataArray:
14
+ """Opens the GFS data.
15
+
16
+ Args:
17
+ zarr_path: Path to the zarr to open
18
+
19
+ Returns:
20
+ Xarray DataArray of the NWP data
21
+ """
22
+ _log.info("Loading NWP GFS data")
23
+
24
+ # Open data
25
+ gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc")
26
+ nwp: xr.DataArray = gfs.to_array()
27
+
28
+ del gfs
29
+
30
+ nwp = nwp.rename({"variable": "channel","init_time": "init_time_utc"})
31
+ check_time_unique_increasing(nwp.init_time_utc)
32
+ nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
33
+
34
+ nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
35
+
36
+ return nwp
@@ -0,0 +1,46 @@
1
+ """DWD ICON Loading."""
2
+
3
+ import xarray as xr
4
+
5
+ from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
6
+ from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
7
+
8
+
9
+ def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
10
+ """Removes the isobaric levels from the coordinates of the NWP data.
11
+
12
+ Args:
13
+ nwp: NWP data
14
+
15
+ Returns:
16
+ NWP data without isobaric levels in the coordinates
17
+ """
18
+ variables_to_drop = [var for var in nwp.data_vars if "isobaricInhPa" in nwp[var].dims]
19
+ return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
20
+
21
+
22
+ def open_icon_eu(zarr_path: str) -> xr.Dataset:
23
+ """Opens the ICON data.
24
+
25
+ ICON EU Data is on a regular lat/lon grid
26
+ It has data on multiple pressure levels, as well as the surface
27
+ Each of the variables is its own data variable
28
+
29
+ Args:
30
+ zarr_path: Path to the zarr to open
31
+
32
+ Returns:
33
+ Xarray DataArray of the NWP data
34
+ """
35
+ # Open the data
36
+ nwp = open_zarr_paths(zarr_path, time_dim="time")
37
+ nwp = nwp.rename({"time": "init_time_utc"})
38
+ # Sanity checks.
39
+ check_time_unique_increasing(nwp.init_time_utc)
40
+ # 0-78 one hour steps, rest 3 hour steps
41
+ nwp = nwp.isel(step=slice(0, 78))
42
+ nwp = remove_isobaric_lelvels_from_coords(nwp)
43
+ nwp = nwp.to_array().rename({"variable": "channel"})
44
+ nwp = nwp.transpose("init_time_utc", "step", "channel", "latitude", "longitude")
45
+ nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
46
+ return nwp
@@ -1,18 +1,17 @@
1
- """UKV provider loaders"""
1
+ """UKV provider loaders."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
6
6
  from ocf_data_sampler.load.utils import (
7
7
  check_time_unique_increasing,
8
+ get_xr_data_array_from_xr_dataset,
8
9
  make_spatial_coords_increasing,
9
- get_xr_data_array_from_xr_dataset
10
10
  )
11
11
 
12
12
 
13
13
  def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
14
- """
15
- Opens the NWP data
14
+ """Opens the NWP data.
16
15
 
17
16
  Args:
18
17
  zarr_path: Path to the zarr to open
@@ -28,7 +27,7 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
28
27
  "variable": "channel",
29
28
  "x": "x_osgb",
30
29
  "y": "y_osgb",
31
- }
30
+ },
32
31
  )
33
32
 
34
33
  check_time_unique_increasing(ds.init_time_utc)