ocf-data-sampler 0.0.18__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +77 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +19 -427
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.42.dist-info/METADATA +153 -0
  32. ocf_data_sampler-0.0.42.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.18.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.18.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/LICENSE +0 -0
@@ -0,0 +1,5 @@
1
+ """Configuration model"""
2
+
3
+ from ocf_data_sampler.config.model import Configuration
4
+ from ocf_data_sampler.config.save import save_yaml_configuration
5
+ from ocf_data_sampler.config.load import load_yaml_configuration
@@ -0,0 +1,33 @@
1
+ """Loading configuration functions.
2
+
3
+ Example:
4
+
5
+ from ocf_data_sampler.config import load_yaml_configuration
6
+ configuration = load_yaml_configuration(filename)
7
+ """
8
+
9
+ import fsspec
10
+ from pathy import Pathy
11
+ from pyaml_env import parse_config
12
+
13
+ from ocf_data_sampler.config import Configuration
14
+
15
+
16
+ def load_yaml_configuration(filename: str | Pathy) -> Configuration:
17
+ """
18
+ Load a yaml file which has a configuration in it
19
+
20
+ Args:
21
+ filename: the file name that you want to load. Will load from local, AWS, or GCP
22
+ depending on the protocol suffix (e.g. 's3://bucket/config.yaml').
23
+
24
+ Returns:pydantic class
25
+
26
+ """
27
+ # load the file to a dictionary
28
+ with fsspec.open(filename, mode="r") as stream:
29
+ configuration = parse_config(data=stream)
30
+ # this means we can load ENVs in the yaml file
31
+ # turn into pydantic class
32
+ configuration = Configuration(**configuration)
33
+ return configuration
@@ -0,0 +1,246 @@
1
+ """Configuration model for the dataset.
2
+
3
+ All paths must include the protocol prefix. For local files,
4
+ it's sufficient to just start with a '/'. For aws, start with 's3://',
5
+ for gcp start with 'gs://'.
6
+
7
+ Example:
8
+
9
+ from ocf_data_sampler.config import Configuration
10
+ config = Configuration(**config_dict)
11
+ """
12
+
13
+ import logging
14
+ from typing import Dict, List, Optional
15
+ from typing_extensions import Self
16
+
17
+ from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
18
+
19
+ from ocf_data_sampler.constants import NWP_PROVIDERS
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ providers = ["pvoutput.org", "solar_sheffield_passiv"]
24
+
25
+
26
+ class Base(BaseModel):
27
+ """Pydantic Base model where no extras can be added"""
28
+
29
+ class Config:
30
+ """config class"""
31
+
32
+ extra = "forbid" # forbid use of extra kwargs
33
+
34
+
35
+ class General(Base):
36
+ """General pydantic model"""
37
+
38
+ name: str = Field("example", description="The name of this configuration file")
39
+ description: str = Field(
40
+ "example configuration", description="Description of this configuration file"
41
+ )
42
+
43
+
44
+ class TimeWindowMixin(Base):
45
+ """Mixin class, to add interval start, end and resolution minutes"""
46
+
47
+ time_resolution_minutes: int = Field(
48
+ ...,
49
+ gt=0,
50
+ description="The temporal resolution of the data in minutes",
51
+ )
52
+
53
+ interval_start_minutes: int = Field(
54
+ ...,
55
+ description="Data interval starts at `t0 + interval_start_minutes`",
56
+ )
57
+
58
+ interval_end_minutes: int = Field(
59
+ ...,
60
+ description="Data interval ends at `t0 + interval_end_minutes`",
61
+ )
62
+
63
+ @model_validator(mode='after')
64
+ def check_interval_range(cls, values):
65
+ if values.interval_start_minutes > values.interval_end_minutes:
66
+ raise ValueError('interval_start_minutes must be <= interval_end_minutes')
67
+ return values
68
+
69
+ @field_validator("interval_start_minutes")
70
+ def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
71
+ if v % info.data["time_resolution_minutes"] != 0:
72
+ raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
73
+ return v
74
+
75
+ @field_validator("interval_end_minutes")
76
+ def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
77
+ if v % info.data["time_resolution_minutes"] != 0:
78
+ raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
79
+ return v
80
+
81
+
82
+
83
+ # noinspection PyMethodParameters
84
+ class DropoutMixin(Base):
85
+ """Mixin class, to add dropout minutes"""
86
+
87
+ dropout_timedeltas_minutes: Optional[List[int]] = Field(
88
+ default=None,
89
+ description="List of possible minutes before t0 where data availability may start. Must be "
90
+ "negative or zero.",
91
+ )
92
+
93
+ dropout_fraction: float = Field(
94
+ default=0,
95
+ description="Chance of dropout being applied to each sample",
96
+ ge=0,
97
+ le=1,
98
+ )
99
+
100
+ @field_validator("dropout_timedeltas_minutes")
101
+ def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
102
+ """Validate 'dropout_timedeltas_minutes'"""
103
+ if v is not None:
104
+ for m in v:
105
+ assert m <= 0, "Dropout timedeltas must be negative"
106
+ return v
107
+
108
+ @model_validator(mode="after")
109
+ def dropout_instructions_consistent(self) -> Self:
110
+ if self.dropout_fraction == 0:
111
+ if self.dropout_timedeltas_minutes is not None:
112
+ raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
113
+ else:
114
+ if self.dropout_timedeltas_minutes is None:
115
+ raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
116
+ return self
117
+
118
+
119
+ class SpatialWindowMixin(Base):
120
+ """Mixin class, to add path and image size"""
121
+
122
+ image_size_pixels_height: int = Field(
123
+ ...,
124
+ ge=0,
125
+ description="The number of pixels of the height of the region of interest",
126
+ )
127
+
128
+ image_size_pixels_width: int = Field(
129
+ ...,
130
+ ge=0,
131
+ description="The number of pixels of the width of the region of interest",
132
+ )
133
+
134
+
135
+ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
136
+ """Satellite configuration model"""
137
+
138
+ zarr_path: str | tuple[str] | list[str] = Field(
139
+ ...,
140
+ description="The path or list of paths which hold the data zarr",
141
+ )
142
+
143
+ channels: list[str] = Field(
144
+ ..., description="the satellite channels that are used"
145
+ )
146
+
147
+
148
+ # noinspection PyMethodParameters
149
+ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
150
+ """NWP configuration model"""
151
+
152
+ zarr_path: str | tuple[str] | list[str] = Field(
153
+ ...,
154
+ description="The path or list of paths which hold the data zarr",
155
+ )
156
+
157
+ channels: list[str] = Field(
158
+ ..., description="the channels used in the nwp data"
159
+ )
160
+
161
+ provider: str = Field(..., description="The provider of the NWP data")
162
+
163
+ accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
164
+
165
+ max_staleness_minutes: Optional[int] = Field(
166
+ None,
167
+ description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
168
+ " used to construct an example. If set to None, then the max staleness is set according to"
169
+ " the maximum forecast horizon of the NWP and the requested forecast length.",
170
+ )
171
+
172
+
173
+ @field_validator("provider")
174
+ def validate_provider(cls, v: str) -> str:
175
+ """Validate 'provider'"""
176
+ if v.lower() not in NWP_PROVIDERS:
177
+ message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
178
+ logger.warning(message)
179
+ raise Exception(message)
180
+ return v
181
+
182
+
183
+ class MultiNWP(RootModel):
184
+ """Configuration for multiple NWPs"""
185
+
186
+ root: Dict[str, NWP]
187
+
188
+ def __getattr__(self, item):
189
+ return self.root[item]
190
+
191
+ def __getitem__(self, item):
192
+ return self.root[item]
193
+
194
+ def __len__(self):
195
+ return len(self.root)
196
+
197
+ def __iter__(self):
198
+ return iter(self.root)
199
+
200
+ def keys(self):
201
+ """Returns dictionary-like keys"""
202
+ return self.root.keys()
203
+
204
+ def items(self):
205
+ """Returns dictionary-like items"""
206
+ return self.root.items()
207
+
208
+
209
+ class GSP(TimeWindowMixin, DropoutMixin):
210
+ """GSP configuration model"""
211
+
212
+ zarr_path: str = Field(..., description="The path which holds the GSP zarr")
213
+
214
+
215
+ class Site(TimeWindowMixin, DropoutMixin):
216
+ """Site configuration model"""
217
+
218
+ file_path: str = Field(
219
+ ...,
220
+ description="The NetCDF files holding the power timeseries.",
221
+ )
222
+ metadata_file_path: str = Field(
223
+ ...,
224
+ description="The CSV files describing power system",
225
+ )
226
+
227
+ # TODO validate the netcdf for sites
228
+ # TODO validate the csv for metadata
229
+
230
+
231
+
232
+ # noinspection PyPep8Naming
233
+ class InputData(Base):
234
+ """Input data model"""
235
+
236
+ satellite: Optional[Satellite] = None
237
+ nwp: Optional[MultiNWP] = None
238
+ gsp: Optional[GSP] = None
239
+ site: Optional[Site] = None
240
+
241
+
242
+ class Configuration(Base):
243
+ """Configuration model for the dataset"""
244
+
245
+ general: General = General()
246
+ input_data: InputData = InputData()
@@ -0,0 +1,73 @@
1
+ """Save functions for the configuration model.
2
+
3
+ This module provides functionality to save configuration objects to YAML files,
4
+ supporting local and cloud storage locations.
5
+
6
+ Example:
7
+ from ocf_data_sampler.config import save_yaml_configuration
8
+ saved_path = save_yaml_configuration(config, "config.yaml")
9
+ """
10
+
11
+ import json
12
+
13
+ from pathlib import Path
14
+ from typing import Union
15
+
16
+ import fsspec
17
+ import yaml
18
+
19
+ from ocf_data_sampler.config import Configuration
20
+
21
+
22
+ def save_yaml_configuration(
23
+ configuration: Configuration,
24
+ filename: Union[str, Path],
25
+ ) -> Path:
26
+ """Save a configuration object to a YAML file.
27
+
28
+ Args:
29
+ configuration: Configuration object containing the settings to save
30
+ filename: Destination path for the YAML file. Can be a local path or
31
+ cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
32
+ absolute paths are recommended.
33
+
34
+ Returns:
35
+ Path: The path where the configuration was saved
36
+
37
+ Raises:
38
+ ValueError: If filename is None or if writing to the specified path fails
39
+ TypeError: If the configuration cannot be serialized
40
+ """
41
+ if filename is None:
42
+ raise ValueError("filename cannot be None")
43
+
44
+ try:
45
+ # Convert to absolute path if it's a relative path
46
+ if isinstance(filename, (str, Path)) and not any(
47
+ str(filename).startswith(prefix) for prefix in ('gs://', 's3://', '/')
48
+ ):
49
+ filename = Path.cwd() / filename
50
+
51
+ filepath = Path(filename)
52
+
53
+ # For local files, check if directory exists before proceeding
54
+ if filepath.is_absolute():
55
+ directory = filepath.parent
56
+ if not directory.exists():
57
+ raise ValueError("Directory does not exist")
58
+
59
+ # Serialize configuration to JSON-compatible dictionary
60
+ config_dict = json.loads(configuration.model_dump_json())
61
+
62
+ # Save to YAML file using fsspec
63
+ with fsspec.open(str(filepath), mode='w') as yaml_file:
64
+ yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
65
+
66
+ return filepath
67
+
68
+ except json.JSONDecodeError as e:
69
+ raise TypeError(f"Failed to serialize configuration: {str(e)}") from e
70
+ except PermissionError as e:
71
+ raise ValueError(f"Permission denied when writing to {filename}") from e
72
+ except (IOError, OSError) as e:
73
+ raise ValueError(f"Failed to write configuration to {filename}: {str(e)}") from e
@@ -0,0 +1,173 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+
4
+
5
+ NWP_PROVIDERS = [
6
+ "ukv",
7
+ "ecmwf",
8
+ ]
9
+
10
+
11
+ def _to_data_array(d):
12
+ return xr.DataArray(
13
+ [d[k] for k in d.keys()],
14
+ coords={"channel": [k for k in d.keys()]},
15
+ ).astype(np.float32)
16
+
17
+
18
+ class NWPStatDict(dict):
19
+ """Custom dictionary class to hold NWP normalization stats"""
20
+
21
+ def __getitem__(self, key):
22
+ if key not in NWP_PROVIDERS:
23
+ raise KeyError(f"{key} is not a supported NWP provider - {NWP_PROVIDERS}")
24
+ elif key in self.keys():
25
+ return super().__getitem__(key)
26
+ else:
27
+ raise KeyError(
28
+ f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
29
+ )
30
+
31
+
32
+ # ------ UKV
33
+ # Means and std computed WITH version_7 and higher, MetOffice values
34
+ UKV_STD = {
35
+ "cdcb": 2126.99350113,
36
+ "lcc": 39.33210726,
37
+ "mcc": 41.91144559,
38
+ "hcc": 38.07184418,
39
+ "sde": 0.1029753,
40
+ "hcct": 18382.63958991,
41
+ "dswrf": 190.47216887,
42
+ "dlwrf": 39.45988077,
43
+ "h": 1075.77812282,
44
+ "t": 4.38818501,
45
+ "r": 11.45012499,
46
+ "dpt": 4.57250482,
47
+ "vis": 21578.97975625,
48
+ "si10": 3.94718813,
49
+ "wdir10": 94.08407495,
50
+ "prmsl": 1252.71790539,
51
+ "prate": 0.00021497,
52
+ }
53
+
54
+ UKV_MEAN = {
55
+ "cdcb": 1412.26599062,
56
+ "lcc": 50.08362643,
57
+ "mcc": 40.88984494,
58
+ "hcc": 29.11949682,
59
+ "sde": 0.00289545,
60
+ "hcct": -18345.97478167,
61
+ "dswrf": 111.28265039,
62
+ "dlwrf": 325.03130139,
63
+ "h": 2096.51991356,
64
+ "t": 283.64913206,
65
+ "r": 81.79229501,
66
+ "dpt": 280.54379901,
67
+ "vis": 32262.03285118,
68
+ "si10": 6.88348448,
69
+ "wdir10": 199.41891636,
70
+ "prmsl": 101321.61574029,
71
+ "prate": 3.45793433e-05,
72
+ }
73
+
74
+ UKV_STD = _to_data_array(UKV_STD)
75
+ UKV_MEAN = _to_data_array(UKV_MEAN)
76
+
77
+ # ------ ECMWF
78
+ # These were calculated from 100 random init times of UK data from 2020-2023
79
+ ECMWF_STD = {
80
+ "dlwrf": 15855867.0,
81
+ "dswrf": 13025427.0,
82
+ "duvrs": 1445635.25,
83
+ "hcc": 0.42244860529899597,
84
+ "lcc": 0.3791404366493225,
85
+ "mcc": 0.38039860129356384,
86
+ "prate": 9.81039775069803e-05,
87
+ "sde": 0.000913831521756947,
88
+ "sr": 16294988.0,
89
+ "t2m": 3.692270040512085,
90
+ "tcc": 0.37487083673477173,
91
+ "u10": 5.531515598297119,
92
+ "u100": 7.2320556640625,
93
+ "u200": 8.049470901489258,
94
+ "v10": 5.411230564117432,
95
+ "v100": 6.944501876831055,
96
+ "v200": 7.561611652374268,
97
+ "diff_dlwrf": 131942.03125,
98
+ "diff_dswrf": 715366.3125,
99
+ "diff_duvrs": 81605.25,
100
+ "diff_sr": 818950.6875,
101
+ }
102
+
103
+ ECMWF_MEAN = {
104
+ "dlwrf": 27187026.0,
105
+ "dswrf": 11458988.0,
106
+ "duvrs": 1305651.25,
107
+ "hcc": 0.3961029052734375,
108
+ "lcc": 0.44901806116104126,
109
+ "mcc": 0.3288780450820923,
110
+ "prate": 3.108070450252853e-05,
111
+ "sde": 8.107526082312688e-05,
112
+ "sr": 12905302.0,
113
+ "t2m": 283.48333740234375,
114
+ "tcc": 0.7049227356910706,
115
+ "u10": 1.7677178382873535,
116
+ "u100": 2.393547296524048,
117
+ "u200": 2.7963004112243652,
118
+ "v10": 0.985887885093689,
119
+ "v100": 1.4244288206100464,
120
+ "v200": 1.6010299921035767,
121
+ "diff_dlwrf": 1136464.0,
122
+ "diff_dswrf": 420584.6875,
123
+ "diff_duvrs": 48265.4765625,
124
+ "diff_sr": 469169.5,
125
+ }
126
+
127
+ ECMWF_STD = _to_data_array(ECMWF_STD)
128
+ ECMWF_MEAN = _to_data_array(ECMWF_MEAN)
129
+
130
+ NWP_STDS = NWPStatDict(
131
+ ukv=UKV_STD,
132
+ ecmwf=ECMWF_STD,
133
+ )
134
+ NWP_MEANS = NWPStatDict(
135
+ ukv=UKV_MEAN,
136
+ ecmwf=ECMWF_MEAN,
137
+ )
138
+
139
+ # ------ Satellite
140
+ # RSS Mean and std values from randomised 20% of 2020 imagery
141
+
142
+ RSS_STD = {
143
+ "HRV": 0.11405209,
144
+ "IR_016": 0.21462157,
145
+ "IR_039": 0.04618041,
146
+ "IR_087": 0.06687243,
147
+ "IR_097": 0.0468558,
148
+ "IR_108": 0.17482725,
149
+ "IR_120": 0.06115861,
150
+ "IR_134": 0.04492306,
151
+ "VIS006": 0.12184761,
152
+ "VIS008": 0.13090034,
153
+ "WV_062": 0.16111417,
154
+ "WV_073": 0.12924142,
155
+ }
156
+
157
+ RSS_MEAN = {
158
+ "HRV": 0.09298719,
159
+ "IR_016": 0.17594202,
160
+ "IR_039": 0.86167645,
161
+ "IR_087": 0.7719318,
162
+ "IR_097": 0.8014212,
163
+ "IR_108": 0.71254843,
164
+ "IR_120": 0.89058584,
165
+ "IR_134": 0.944365,
166
+ "VIS006": 0.09633306,
167
+ "VIS008": 0.11426069,
168
+ "WV_062": 0.7359355,
169
+ "WV_073": 0.62479186,
170
+ }
171
+
172
+ RSS_STD = _to_data_array(RSS_STD)
173
+ RSS_MEAN = _to_data_array(RSS_MEAN)
@@ -0,0 +1,55 @@
1
+ """ Loads all data sources """
2
+ import xarray as xr
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.load.gsp import open_gsp
6
+ from ocf_data_sampler.load.nwp import open_nwp
7
+ from ocf_data_sampler.load.satellite import open_sat_data
8
+ from ocf_data_sampler.load.site import open_site
9
+
10
+
11
+ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
12
+ """Construct dictionary of all of the input data sources
13
+
14
+ Args:
15
+ config: Configuration file
16
+ """
17
+
18
+ in_config = config.input_data
19
+
20
+ datasets_dict = {}
21
+
22
+ # Load GSP data unless the path is None
23
+ if in_config.gsp and in_config.gsp.zarr_path:
24
+ da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()
25
+
26
+ # Remove national GSP
27
+ datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
28
+
29
+ # Load NWP data if in config
30
+ if in_config.nwp:
31
+
32
+ datasets_dict["nwp"] = {}
33
+ for nwp_source, nwp_config in in_config.nwp.items():
34
+
35
+ da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
36
+
37
+ da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
38
+
39
+ datasets_dict["nwp"][nwp_source] = da_nwp
40
+
41
+ # Load satellite data if in config
42
+ if in_config.satellite:
43
+ sat_config = config.input_data.satellite
44
+
45
+ da_sat = open_sat_data(sat_config.zarr_path)
46
+
47
+ da_sat = da_sat.sel(channel=list(sat_config.channels))
48
+
49
+ datasets_dict["sat"] = da_sat
50
+
51
+ if in_config.site:
52
+ da_sites = open_site(in_config.site)
53
+ datasets_dict["site"] = da_sites
54
+
55
+ return datasets_dict
@@ -9,7 +9,6 @@ from ocf_data_sampler.load.utils import (
9
9
  )
10
10
 
11
11
 
12
-
13
12
  def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
14
13
  """
15
14
  Opens the ECMWF IFS NWP data
@@ -27,10 +26,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
27
26
  ds = ds.rename(
28
27
  {
29
28
  "init_time": "init_time_utc",
30
- "variable": "channel",
31
29
  }
32
30
  )
33
31
 
32
+ # LEGACY SUPPORT
33
+ # rename variable to channel if it exists
34
+ if "variable" in ds:
35
+ ds = ds.rename({"variable": "channel"})
36
+
34
37
  # Check the timestamps are unique and increasing
35
38
  check_time_unique_increasing(ds.init_time_utc)
36
39
 
@@ -0,0 +1,30 @@
1
+ import pandas as pd
2
+ import xarray as xr
3
+ import numpy as np
4
+
5
+ from ocf_data_sampler.config.model import Site
6
+
7
+
8
+ def open_site(sites_config: Site) -> xr.DataArray:
9
+
10
+ # Load site generation xr.Dataset
11
+ site_generation_ds = xr.open_dataset(sites_config.file_path)
12
+
13
+ # Load site generation data
14
+ metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")
15
+
16
+ # Ensure metadata aligns with the site_id dimension in data_ds
17
+ metadata_df = metadata_df.reindex(site_generation_ds.site_id.values)
18
+
19
+ # Assign coordinates to the Dataset using the aligned metadata
20
+ site_generation_ds = site_generation_ds.assign_coords(
21
+ latitude=("site_id", metadata_df["latitude"].values),
22
+ longitude=("site_id", metadata_df["longitude"].values),
23
+ capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
24
+ )
25
+
26
+ # Sanity checks
27
+ assert np.isfinite(site_generation_ds.capacity_kwp.values).all()
28
+ assert (site_generation_ds.capacity_kwp.values > 0).all()
29
+ assert metadata_df.index.is_unique
30
+ return site_generation_ds.generation_kw
@@ -0,0 +1,8 @@
1
+ """Conversion from Xarray to NumpySample"""
2
+
3
+ from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
4
+ from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
5
+ from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
6
+ from .sun_position import make_sun_position_numpy_sample
7
+ from .site import convert_site_to_numpy_sample
8
+