ocf-data-sampler 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -214,6 +214,17 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConsta
214
214
  )
215
215
  public: bool = Field(False, description="Whether the NWP data is public or private")
216
216
 
217
+ @model_validator(mode="after")
218
+ def validate_accum_channels_subset(self) -> "NWP":
219
+ """Validate accum_channels is subset of channels."""
220
+ invalid_channels = set(self.accum_channels) - set(self.channels)
221
+ if invalid_channels:
222
+ raise ValueError(
223
+ f"NWP provider '{self.provider}': all values in 'accum_channels' should "
224
+ f"be present in 'channels'. Extra values found: {invalid_channels}",
225
+ )
226
+ return self
227
+
217
228
  @field_validator("provider")
218
229
  def validate_provider(cls, v: str) -> str:
219
230
  """Validator for 'provider'."""
@@ -2,6 +2,7 @@
2
2
 
3
3
  from importlib.resources import files
4
4
 
5
+ import numpy as np
5
6
  import pandas as pd
6
7
  import xarray as xr
7
8
 
@@ -26,11 +27,12 @@ def get_gsp_boundaries(version: str) -> pd.DataFrame:
26
27
  )
27
28
 
28
29
 
29
- def open_gsp(zarr_path: str,
30
- boundaries_version: str = "20220314",
31
- public: bool = False,
32
- ) -> xr.DataArray:
33
- """Open the GSP data.
30
+ def open_gsp(
31
+ zarr_path: str,
32
+ boundaries_version: str = "20220314",
33
+ public: bool = False,
34
+ ) -> xr.DataArray:
35
+ """Open the GSP data and validates its data types.
34
36
 
35
37
  Args:
36
38
  zarr_path: Path to the GSP zarr data
@@ -44,18 +46,16 @@ def open_gsp(zarr_path: str,
44
46
  # Load UK GSP locations
45
47
  df_gsp_loc = get_gsp_boundaries(boundaries_version)
46
48
 
47
- backend_kwargs ={}
49
+ backend_kwargs = {}
48
50
  # Open the GSP generation data
49
51
  if public:
50
- backend_kwargs ={"storage_options":{"anon": True}}
52
+ backend_kwargs = {"storage_options": {"anon": True}}
51
53
  # Currently only compatible with S3 bucket.
52
54
 
53
- ds = (
54
- xr.open_dataset(zarr_path,engine="zarr",backend_kwargs=backend_kwargs)
55
- .rename({"datetime_gmt": "time_utc"})
55
+ ds = xr.open_dataset(zarr_path, engine="zarr", backend_kwargs=backend_kwargs).rename(
56
+ {"datetime_gmt": "time_utc"},
56
57
  )
57
58
 
58
-
59
59
  if not (ds.gsp_id.isin(df_gsp_loc.index)).all():
60
60
  raise ValueError(
61
61
  "Some GSP IDs in the GSP generation data are not available in the locations file.",
@@ -72,4 +72,24 @@ def open_gsp(zarr_path: str,
72
72
  effective_capacity_mwp=ds.capacity_mwp,
73
73
  )
74
74
 
75
- return ds.generation_mw
75
+ gsp_da = ds.generation_mw
76
+
77
+ # Validate data types directly in loading function
78
+ if not np.issubdtype(gsp_da.dtype, np.floating):
79
+ raise TypeError(f"generation_mw should be floating, not {gsp_da.dtype}")
80
+
81
+ coord_dtypes = {
82
+ "time_utc": np.datetime64,
83
+ "gsp_id": np.integer,
84
+ "nominal_capacity_mwp": np.floating,
85
+ "effective_capacity_mwp": np.floating,
86
+ "x_osgb": np.floating,
87
+ "y_osgb": np.floating,
88
+ }
89
+
90
+ for coord, expected_dtype in coord_dtypes.items():
91
+ if not np.issubdtype(gsp_da.coords[coord].dtype, expected_dtype):
92
+ dtype = gsp_da.coords[coord].dtype
93
+ raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
94
+
95
+ return gsp_da
@@ -1,5 +1,6 @@
1
1
  """Module for opening NWP data."""
2
2
 
3
+ import numpy as np
3
4
  import xarray as xr
4
5
 
5
6
  from ocf_data_sampler.load.nwp.providers.cloudcasting import open_cloudcasting
@@ -9,23 +10,85 @@ from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
9
10
  from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
10
11
 
11
12
 
12
- def open_nwp(zarr_path: str | list[str], provider: str, public: bool = False) -> xr.DataArray:
13
- """Opens NWP zarr.
13
+ def _validate_nwp_data(data_array: xr.DataArray, provider: str) -> None:
14
+ """Validates the structure and data types of a loaded NWP DataArray.
15
+
16
+ This helper function is extracted to keep the main `open_nwp` function clean.
17
+
18
+ Args:
19
+ data_array: The xarray.DataArray to validate.
20
+ provider: The NWP provider name.
21
+
22
+ Raises:
23
+ TypeError: If the data or any coordinate has an unexpected dtype.
24
+ ValueError: If a required coordinate is missing.
25
+ """
26
+ if not np.issubdtype(data_array.dtype, np.number):
27
+ raise TypeError(f"NWP data for {provider} should be numeric, not {data_array.dtype}")
28
+
29
+ common_expected_dtypes = {
30
+ "init_time_utc": np.datetime64,
31
+ "step": np.timedelta64,
32
+ "channel": np.str_,
33
+ }
34
+
35
+ geographic_spatial_dtypes = {
36
+ "latitude": np.floating,
37
+ "longitude": np.floating,
38
+ }
39
+
40
+ provider_specific_spatial_dtypes = {
41
+ "ecmwf": geographic_spatial_dtypes,
42
+ "icon-eu": geographic_spatial_dtypes,
43
+ "gfs": geographic_spatial_dtypes,
44
+ "mo_global": geographic_spatial_dtypes,
45
+ "ukv": {
46
+ "x_osgb": np.floating,
47
+ "y_osgb": np.floating,
48
+ },
49
+ "cloudcasting": {
50
+ "x_geostationary": np.floating,
51
+ "y_geostationary": np.floating,
52
+ },
53
+ }
54
+
55
+ expected_dtypes = {
56
+ **common_expected_dtypes,
57
+ **provider_specific_spatial_dtypes.get(provider, {}),
58
+ }
59
+
60
+ if not expected_dtypes:
61
+ raise ValueError(f"Unknown provider: {provider}")
62
+
63
+ for coord, expected_dtype in expected_dtypes.items():
64
+ if coord not in data_array.coords:
65
+ raise ValueError(f"Coordinate '{coord}' missing for provider '{provider}'")
66
+ if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype):
67
+ actual_dtype = data_array.coords[coord].dtype
68
+ err_msg = (
69
+ f"'{coord}' for {provider} should be {expected_dtype.__name__}, "
70
+ f"not {actual_dtype}"
71
+ )
72
+ raise TypeError(err_msg)
73
+
74
+
75
+ def open_nwp(
76
+ zarr_path: str | list[str],
77
+ provider: str,
78
+ public: bool = False,
79
+ ) -> xr.DataArray:
80
+ """Opens NWP zarr and validates its structure and data types.
14
81
 
15
82
  Args:
16
83
  zarr_path: path to the zarr file
17
84
  provider: NWP provider
18
85
  public: Whether the data is public or private (only for GFS)
19
-
20
- Returns:
21
- Xarray DataArray of the NWP data
22
86
  """
23
87
  provider = provider.lower()
24
88
 
25
89
  kwargs = {
26
90
  "zarr_path": zarr_path,
27
91
  }
28
-
29
92
  if provider == "ukv":
30
93
  _open_nwp = open_ukv
31
94
  elif provider in ["ecmwf", "mo_global"]:
@@ -34,14 +97,15 @@ def open_nwp(zarr_path: str | list[str], provider: str, public: bool = False) ->
34
97
  _open_nwp = open_icon_eu
35
98
  elif provider == "gfs":
36
99
  _open_nwp = open_gfs
37
-
38
100
  # GFS has a public/private flag
39
101
  if public:
40
102
  kwargs["public"] = True
41
-
42
103
  elif provider == "cloudcasting":
43
104
  _open_nwp = open_cloudcasting
44
105
  else:
45
106
  raise ValueError(f"Unknown provider: {provider}")
46
107
 
47
- return _open_nwp(**kwargs)
108
+ data_array = _open_nwp(**kwargs)
109
+ _validate_nwp_data(data_array, provider)
110
+
111
+ return data_array
@@ -6,25 +6,12 @@ from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
6
6
  from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
7
7
 
8
8
 
9
- def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
10
- """Removes the isobaric levels from the coordinates of the NWP data.
11
-
12
- Args:
13
- nwp: NWP data
14
-
15
- Returns:
16
- NWP data without isobaric levels in the coordinates
17
- """
18
- variables_to_drop = [var for var in nwp.data_vars if "isobaricInhPa" in nwp[var].dims]
19
- return nwp.drop_vars(["isobaricInhPa", *variables_to_drop])
20
-
21
-
22
- def open_icon_eu(zarr_path: str | list[str]) -> xr.Dataset:
9
+ def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray:
23
10
  """Opens the ICON data.
24
11
 
25
- ICON EU Data is on a regular lat/lon grid
26
- It has data on multiple pressure levels, as well as the surface
27
- Each of the variables is its own data variable
12
+ ICON EU Data is now expected to be on a regular lat/lon grid,
13
+ with a 'channel' dimension directly available (as per the updated fixture).
14
+ The 'isobaricInhPa' dimension is expected to be already handled.
28
15
 
29
16
  Args:
30
17
  zarr_path: Path to the zarr(s) to open
@@ -32,15 +19,19 @@ def open_icon_eu(zarr_path: str | list[str]) -> xr.Dataset:
32
19
  Returns:
33
20
  Xarray DataArray of the NWP data
34
21
  """
35
- # Open the data
36
- nwp = open_zarr_paths(zarr_path, time_dim="time")
37
- nwp = nwp.rename({"time": "init_time_utc"})
38
- # Sanity checks.
22
+ # Open and check initially
23
+ ds = open_zarr_paths(zarr_path, time_dim="init_time_utc")
24
+
25
+ if "icon_eu_data" in ds.data_vars:
26
+ nwp = ds["icon_eu_data"]
27
+ else:
28
+ raise ValueError("Could not find 'icon_eu_data' DataArray in the ICON-EU Zarr file.")
29
+
39
30
  check_time_unique_increasing(nwp.init_time_utc)
31
+
40
32
  # 0-78 one hour steps, rest 3 hour steps
41
33
  nwp = nwp.isel(step=slice(0, 78))
42
- nwp = remove_isobaric_lelvels_from_coords(nwp)
43
- nwp = nwp.to_array().rename({"variable": "channel"})
44
34
  nwp = nwp.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
45
35
  nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
36
+
46
37
  return nwp
@@ -1,5 +1,5 @@
1
1
  """Satellite loader."""
2
-
2
+ import numpy as np
3
3
  import xarray as xr
4
4
 
5
5
  from ocf_data_sampler.load.utils import (
@@ -44,7 +44,7 @@ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
44
44
 
45
45
 
46
46
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
47
- """Lazily opens the zarr store.
47
+ """Lazily opens the zarr store and validates data types.
48
48
 
49
49
  Args:
50
50
  zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
@@ -72,5 +72,22 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
72
72
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
73
73
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
74
74
 
75
- # TODO: should we control the dtype of the DataArray?
76
- return get_xr_data_array_from_xr_dataset(ds)
75
+ data_array = get_xr_data_array_from_xr_dataset(ds)
76
+
77
+ # Validate data types directly loading function
78
+ if not np.issubdtype(data_array.dtype, np.number):
79
+ raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}")
80
+
81
+ coord_dtypes = {
82
+ "time_utc": np.datetime64,
83
+ "channel": np.str_,
84
+ "x_geostationary": np.floating,
85
+ "y_geostationary": np.floating,
86
+ }
87
+
88
+ for coord, expected_dtype in coord_dtypes.items():
89
+ if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype):
90
+ dtype = data_array.coords[coord].dtype
91
+ raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
92
+
93
+ return data_array
@@ -16,7 +16,6 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
16
16
  xr.DataArray: The opened site generation data
17
17
  """
18
18
  generation_ds = xr.open_dataset(generation_file_path)
19
-
20
19
  metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
21
20
 
22
21
  if not metadata_df.index.is_unique:
@@ -38,4 +37,23 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
38
37
  if not (generation_ds.capacity_kwp.values > 0).all():
39
38
  raise ValueError("capacity_kwp contains non-positive values")
40
39
 
41
- return generation_ds.generation_kw
40
+ site_da = generation_ds.generation_kw
41
+
42
+ # Validate data types directly in loading function
43
+ if not np.issubdtype(site_da.dtype, np.floating):
44
+ raise TypeError(f"Generation data should be float, not {site_da.dtype}")
45
+
46
+ coord_dtypes = {
47
+ "time_utc": np.datetime64,
48
+ "site_id": np.integer,
49
+ "capacity_kwp": np.floating,
50
+ "latitude": np.floating,
51
+ "longitude": np.floating,
52
+ }
53
+
54
+ for coord, expected_dtype in coord_dtypes.items():
55
+ if not np.issubdtype(site_da.coords[coord].dtype, expected_dtype):
56
+ dtype = site_da.coords[coord].dtype
57
+ raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
58
+
59
+ return site_da
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.32
3
+ Version: 0.2.34
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -2,23 +2,23 @@ ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,
2
2
  ocf_data_sampler/utils.py,sha256=DjuneGGisl08ENvPZV_lrcX4b2NCKJC1ZpXgIpxuQi4,290
3
3
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
4
4
  ocf_data_sampler/config/load.py,sha256=LL-7wemI8o4KPkx35j-wQ3HjsMvDgqXr7G46IcASfnU,632
5
- ocf_data_sampler/config/model.py,sha256=UwVQOjRBthbwhAWR5Rcs5cSXG3imLZ5pnd8vBeFseLE,10623
5
+ ocf_data_sampler/config/model.py,sha256=xX2PPywEFGYDsx_j9DX1GlwMRq3ovJR-mhmysMt_mO0,11116
6
6
  ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdRaI,1064
7
7
  ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
8
  ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
9
9
  ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
10
- ocf_data_sampler/load/gsp.py,sha256=winSW3ibFbpsOr0ZRIjYUlqSW5C6SUb0dxkRZm3E8GI,2195
10
+ ocf_data_sampler/load/gsp.py,sha256=IrTA6z9quN08imKGHJLf8gRktarxn1-utNMNFD0zWQs,2944
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=WjB3DvHbDQQYYnPmDFOWg_TQPgARZ5pu8fiRZSGtIg0,2099
12
- ocf_data_sampler/load/satellite.py,sha256=E7Ln7Y60Qr1RTV-_R71YoxXQM-Ca7Y1faIo3oKB2eFk,2292
13
- ocf_data_sampler/load/site.py,sha256=zOzlWk6pYZBB5daqG8URGksmDXWKrkutUvN8uALAIh8,1468
12
+ ocf_data_sampler/load/satellite.py,sha256=Gsc3oyPydEZLy6slUDtIpBCYLxWy9P3pD1VyI4W9-2w,2944
13
+ ocf_data_sampler/load/site.py,sha256=WtOy20VMHJIY0IwEemCdcecSDUGcVaLUown-4ixJw90,2147
14
14
  ocf_data_sampler/load/utils.py,sha256=sZ0-zzconcLkVQwAkCYrqKDo98Hrh5ChdiQJv5Bh91g,2040
15
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
- ocf_data_sampler/load/nwp/nwp.py,sha256=S0wsxQHKa-OZchmvhTAOSG1HLtAVRTo_ElXAIwz1pXo,1332
16
+ ocf_data_sampler/load/nwp/nwp.py,sha256=1GCoIX0_KmEmIB99e_C9JDo6wlrvsF13oC5r3lrqYBo,3495
17
17
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  ocf_data_sampler/load/nwp/providers/cloudcasting.py,sha256=fozXpB3a2rNqQgnpRDC7xunxffh7Wwmc0kkCiYmDVJ4,1521
19
19
  ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=an-gXsZwkPQvRXeza1U_4MNU5yEnVm0_8tn03rxTudI,997
20
20
  ocf_data_sampler/load/nwp/providers/gfs.py,sha256=glBbo2kXtcTjQv_VNqA32lsdCCGB114Ovm-cibRWxTA,1088
21
- ocf_data_sampler/load/nwp/providers/icon.py,sha256=6MkOfUk5dmv0XJZLrKMy1e8xipj2fHCTkYXuff7MgUY,1584
21
+ ocf_data_sampler/load/nwp/providers/icon.py,sha256=BnY3vAa5pHn1cyrImj0ymaRFKHanNtfD9_JO-4p2IZY,1241
22
22
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=Ka1KFZcJYPwr5vuxo-xWGVQC0pudheqGBonUnbyJCMg,1016
23
23
  ocf_data_sampler/load/nwp/providers/utils.py,sha256=NrzE3JAtoc6oEywJHxPUdi_I4UJgJ_l5dxLZ4DLKvcg,1124
24
24
  ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
@@ -55,7 +55,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
55
55
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
56
56
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
57
57
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
58
- ocf_data_sampler-0.2.32.dist-info/METADATA,sha256=NZ29xNAjkgJgfpfnvHqjh3cn7aU8vlzOVBjjab9Ksj4,12184
59
- ocf_data_sampler-0.2.32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
- ocf_data_sampler-0.2.32.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
61
- ocf_data_sampler-0.2.32.dist-info/RECORD,,
58
+ ocf_data_sampler-0.2.34.dist-info/METADATA,sha256=C6ux4M-XRmQsnG-yLwRYVdkKW6qPxP47KLGACzJRIss,12184
59
+ ocf_data_sampler-0.2.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
+ ocf_data_sampler-0.2.34.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
61
+ ocf_data_sampler-0.2.34.dist-info/RECORD,,