ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (76) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +73 -61
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +9 -10
  14. ocf_data_sampler/load/site.py +10 -6
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  19. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  20. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  21. ocf_data_sampler/numpy_sample/site.py +5 -8
  22. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  23. ocf_data_sampler/sample/base.py +15 -17
  24. ocf_data_sampler/sample/site.py +13 -20
  25. ocf_data_sampler/sample/uk_regional.py +29 -35
  26. ocf_data_sampler/select/dropout.py +16 -14
  27. ocf_data_sampler/select/fill_time_periods.py +15 -5
  28. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  29. ocf_data_sampler/select/geospatial.py +63 -54
  30. ocf_data_sampler/select/location.py +16 -51
  31. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  32. ocf_data_sampler/select/select_time_slice.py +71 -58
  33. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  34. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  35. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  36. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  37. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  38. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  39. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +62 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  48. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  49. tests/__init__.py +0 -0
  50. tests/config/test_config.py +0 -113
  51. tests/config/test_load.py +0 -7
  52. tests/config/test_save.py +0 -28
  53. tests/conftest.py +0 -319
  54. tests/load/test_load_gsp.py +0 -15
  55. tests/load/test_load_nwp.py +0 -21
  56. tests/load/test_load_satellite.py +0 -17
  57. tests/load/test_load_sites.py +0 -14
  58. tests/numpy_sample/test_collate.py +0 -21
  59. tests/numpy_sample/test_datetime_features.py +0 -37
  60. tests/numpy_sample/test_gsp.py +0 -38
  61. tests/numpy_sample/test_nwp.py +0 -13
  62. tests/numpy_sample/test_satellite.py +0 -40
  63. tests/numpy_sample/test_sun_position.py +0 -81
  64. tests/select/test_dropout.py +0 -69
  65. tests/select/test_fill_time_periods.py +0 -28
  66. tests/select/test_find_contiguous_time_periods.py +0 -202
  67. tests/select/test_location.py +0 -67
  68. tests/select/test_select_spatial_slice.py +0 -154
  69. tests/select/test_select_time_slice.py +0 -275
  70. tests/test_sample/test_base.py +0 -164
  71. tests/test_sample/test_site_sample.py +0 -165
  72. tests/test_sample/test_uk_regional_sample.py +0 -136
  73. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  74. tests/torch_datasets/test_pvnet_uk.py +0 -154
  75. tests/torch_datasets/test_site.py +0 -226
  76. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,36 +1,45 @@
1
- """Geospatial functions"""
1
+ """Geospatial coordinate transformation functions.
2
2
 
3
- from numbers import Number
4
- from typing import Union
3
+ Provides utilities for working with different coordinate systems
4
+ commonly used in geospatial applications, particularly for UK-based data.
5
+
6
+ Supports conversions between:
7
+ - OSGB36 (Ordnance Survey Great Britain, easting/northing in meters)
8
+ - WGS84 (World Geodetic System, latitude/longitude in degrees)
9
+ - Geostationary satellite coordinate systems
10
+ """
5
11
 
6
12
  import numpy as np
7
13
  import pyproj
14
+ import pyresample
8
15
  import xarray as xr
9
16
 
10
- # OSGB is also called "OSGB 1936 / British National Grid -- United
11
- # Kingdom Ordnance Survey". OSGB is used in many UK electricity
12
- # system maps, and is used by the UK Met Office UKV model. OSGB is a
13
- # Transverse Mercator projection, using 'easting' and 'northing'
14
- # coordinates which are in meters. See https://epsg.io/27700
17
+ # Coordinate Reference System (CRS) identifiers
18
+ # OSGB36: UK Ordnance Survey National Grid (easting/northing in meters)
19
+ # Refer to - https://epsg.io/27700
15
20
  OSGB36 = 27700
16
21
 
17
- # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
18
- # latitude and longitude.
22
+ # WGS84: World Geodetic System 1984 (latitude/longitude in degrees), used in GPS
19
23
  WGS84 = 4326
20
24
 
21
-
25
+ # Pre-init Transformer
22
26
  _osgb_to_lon_lat = pyproj.Transformer.from_crs(
23
- crs_from=OSGB36, crs_to=WGS84, always_xy=True
27
+ crs_from=OSGB36,
28
+ crs_to=WGS84,
29
+ always_xy=True,
24
30
  ).transform
25
31
  _lon_lat_to_osgb = pyproj.Transformer.from_crs(
26
- crs_from=WGS84, crs_to=OSGB36, always_xy=True
32
+ crs_from=WGS84,
33
+ crs_to=OSGB36,
34
+ always_xy=True,
27
35
  ).transform
28
36
 
29
37
 
30
38
  def osgb_to_lon_lat(
31
- x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
32
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
33
- """Change OSGB coordinates to lon, lat.
39
+ x: float | np.ndarray,
40
+ y: float | np.ndarray,
41
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
42
+ """Change OSGB coordinates to lon-lat.
34
43
 
35
44
  Args:
36
45
  x: osgb east-west
@@ -41,9 +50,9 @@ def osgb_to_lon_lat(
41
50
 
42
51
 
43
52
  def lon_lat_to_osgb(
44
- x: Union[Number, np.ndarray],
45
- y: Union[Number, np.ndarray],
46
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
53
+ x: float | np.ndarray,
54
+ y: float | np.ndarray,
55
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
47
56
  """Change lon-lat coordinates to OSGB.
48
57
 
49
58
  Args:
@@ -56,11 +65,11 @@ def lon_lat_to_osgb(
56
65
 
57
66
 
58
67
  def lon_lat_to_geostationary_area_coords(
59
- longitude: Union[Number, np.ndarray],
60
- latitude: Union[Number, np.ndarray],
68
+ longitude: float | np.ndarray,
69
+ latitude: float | np.ndarray,
61
70
  xr_data: xr.DataArray,
62
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
- """Loads geostationary area and transformation from lat-lon to geostationary coords
71
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
72
+ """Loads geostationary area and transformation from lat-lon to geostationary coords.
64
73
 
65
74
  Args:
66
75
  longitude: longitude
@@ -72,12 +81,13 @@ def lon_lat_to_geostationary_area_coords(
72
81
  """
73
82
  return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
74
83
 
84
+
75
85
  def osgb_to_geostationary_area_coords(
76
- x: Union[Number, np.ndarray],
77
- y: Union[Number, np.ndarray],
86
+ x: float | np.ndarray,
87
+ y: float | np.ndarray,
78
88
  xr_data: xr.DataArray,
79
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
80
- """Loads geostationary area and transformation from OSGB to geostationary coords
89
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
90
+ """Loads geostationary area and transformation from OSGB to geostationary coords.
81
91
 
82
92
  Args:
83
93
  x: osgb east-west
@@ -87,47 +97,45 @@ def osgb_to_geostationary_area_coords(
87
97
  Returns:
88
98
  Geostationary coords: x, y
89
99
  """
90
-
91
100
  return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
92
101
 
93
102
 
94
-
95
103
  def coordinates_to_geostationary_area_coords(
96
- x: Union[Number, np.ndarray],
97
- y: Union[Number, np.ndarray],
104
+ x: float | np.ndarray,
105
+ y: float | np.ndarray,
98
106
  xr_data: xr.DataArray,
99
- crs_from: int
100
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
101
- """Loads geostationary area and transformation from respective coordiates to geostationary coords
102
-
103
- Args:
104
- x: osgb east-west, or latitude
105
- y: osgb north-south, or longitude
106
- xr_data: xarray object with geostationary area
107
- crs_from: the cordiates system of x,y
107
+ crs_from: int,
108
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
109
+ """Loads geostationary area and transforms to geostationary coords.
108
110
 
109
- Returns:
110
- Geostationary coords: x, y
111
- """
112
-
113
- assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}"
111
+ Args:
112
+ x: osgb east-west, or latitude
113
+ y: osgb north-south, or longitude
114
+ xr_data: xarray object with geostationary area
115
+ crs_from: the cordiates system of x,y
114
116
 
115
- # Only load these if using geostationary projection
116
- import pyresample
117
+ Returns:
118
+ Geostationary coords: x, y
119
+ """
120
+ if crs_from not in [OSGB36, WGS84]:
121
+ raise ValueError(f"Unrecognized coordinate system: {crs_from}")
117
122
 
118
123
  area_definition_yaml = xr_data.attrs["area"]
119
124
 
120
125
  geostationary_area_definition = pyresample.area_config.load_area_from_string(
121
- area_definition_yaml
126
+ area_definition_yaml,
122
127
  )
123
128
  geostationary_crs = geostationary_area_definition.crs
124
129
  osgb_to_geostationary = pyproj.Transformer.from_crs(
125
- crs_from=crs_from, crs_to=geostationary_crs, always_xy=True
130
+ crs_from=crs_from,
131
+ crs_to=geostationary_crs,
132
+ always_xy=True,
126
133
  ).transform
127
134
  return osgb_to_geostationary(xx=x, yy=y)
128
135
 
129
136
 
130
- def _coord_priority(available_coords):
137
+ def _coord_priority(available_coords: list[str]) -> tuple[str, str, str]:
138
+ """Determines the coordinate system of spatial coordinates present."""
131
139
  if "longitude" in available_coords:
132
140
  return "lon_lat", "longitude", "latitude"
133
141
  elif "x_geostationary" in available_coords:
@@ -138,7 +146,7 @@ def _coord_priority(available_coords):
138
146
  raise ValueError(f"Unrecognized coordinate system: {available_coords}")
139
147
 
140
148
 
141
- def spatial_coord_type(ds: xr.DataArray):
149
+ def spatial_coord_type(ds: xr.DataArray) -> tuple[str, str, str]:
142
150
  """Searches the data array to determine the kind of spatial coordinates present.
143
151
 
144
152
  This search has a preference for the dimension coordinates of the xarray object.
@@ -147,9 +155,10 @@ def spatial_coord_type(ds: xr.DataArray):
147
155
  ds: Dataset with spatial coords
148
156
 
149
157
  Returns:
150
- str: The kind of the coordinate system
151
- x_coord: Name of the x-coordinate
152
- y_coord: Name of the y-coordinate
158
+ Three strings with:
159
+ 1. The kind of the coordinate system
160
+ 2. Name of the x-coordinate
161
+ 3. Name of the y-coordinate
153
162
  """
154
163
  if isinstance(ds, xr.DataArray):
155
164
  # Search dimension coords of dataarray
@@ -1,62 +1,27 @@
1
- """location"""
1
+ """Location model with coordinate system validation."""
2
2
 
3
- from typing import Optional
4
-
5
- import numpy as np
6
3
  from pydantic import BaseModel, Field, model_validator
7
4
 
5
+ allowed_coordinate_systems = ["osgb", "lon_lat", "geostationary", "idx"]
8
6
 
9
- allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
10
7
 
11
8
  class Location(BaseModel):
12
9
  """Represent a spatial location."""
13
10
 
14
- coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
15
- x: float
16
- y: float
17
- id: Optional[int] = Field(None)
18
-
19
- @model_validator(mode='after')
20
- def validate_coordinate_system(self):
21
- """Validate 'coordinate_system'"""
22
- if self.coordinate_system not in allowed_coordinate_systems:
23
- raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
24
- return self
25
-
26
- @model_validator(mode='after')
27
- def validate_x(self):
28
- """Validate 'x'"""
29
- min_x: float
30
- max_x: float
31
-
32
- co = self.coordinate_system
33
- if co == "osgb":
34
- min_x, max_x = -103976.3, 652897.98
35
- if co == "lon_lat":
36
- min_x, max_x = -180, 180
37
- if co == "geostationary":
38
- min_x, max_x = -5568748.275756836, 5567248.074173927
39
- if co == "idx":
40
- min_x, max_x = 0, np.inf
41
- if self.x < min_x or self.x > max_x:
42
- raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
43
- return self
11
+ coordinate_system: str = Field(...,
12
+ description="Coordinate system for the location must be lon_lat, osgb, or geostationary",
13
+ )
44
14
 
45
- @model_validator(mode='after')
46
- def validate_y(self):
47
- """Validate 'y'"""
48
- min_y: float
49
- max_y: float
15
+ x: float = Field(..., description="x coordinate - i.e. east-west position")
16
+ y: float = Field(..., description="y coordinate - i.e. north-south position")
17
+ id: int | None = Field(None, description="ID of the location - e.g. GSP ID")
50
18
 
51
- co = self.coordinate_system
52
- if co == "osgb":
53
- min_y, max_y = -16703.87, 1199851.44
54
- if co == "lon_lat":
55
- min_y, max_y = -90, 90
56
- if co == "geostationary":
57
- min_y, max_y = 1393687.2151494026, 5570748.323202133
58
- if co == "idx":
59
- min_y, max_y = 0, np.inf
60
- if self.y < min_y or self.y > max_y:
61
- raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
19
+ @model_validator(mode="after")
20
+ def validate_coordinate_system(self) -> "Location":
21
+ """Validate 'coordinate_system'."""
22
+ if self.coordinate_system not in allowed_coordinate_systems:
23
+ raise ValueError(
24
+ f"coordinate_system = {self.coordinate_system} "
25
+ f"is not in {allowed_coordinate_systems}",
26
+ )
62
27
  return self
@@ -1,19 +1,18 @@
1
- """Select spatial slices"""
1
+ """Select spatial slices."""
2
2
 
3
3
  import logging
4
4
 
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from ocf_data_sampler.select.location import Location
9
8
  from ocf_data_sampler.select.geospatial import (
10
- lon_lat_to_osgb,
11
9
  lon_lat_to_geostationary_area_coords,
10
+ lon_lat_to_osgb,
12
11
  osgb_to_geostationary_area_coords,
13
12
  osgb_to_lon_lat,
14
13
  spatial_coord_type,
15
14
  )
16
-
15
+ from ocf_data_sampler.select.location import Location
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
@@ -22,12 +21,12 @@ logger = logging.getLogger(__name__)
22
21
 
23
22
 
24
23
  def convert_coords_to_match_xarray(
25
- x: float | np.ndarray,
26
- y: float | np.ndarray,
27
- from_coords: str,
28
- da: xr.DataArray
29
- ):
30
- """Convert x and y coords to cooridnate system matching xarray data
24
+ x: float | np.ndarray,
25
+ y: float | np.ndarray,
26
+ from_coords: str,
27
+ da: xr.DataArray,
28
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
29
+ """Convert x and y coords to cooridnate system matching xarray data.
31
30
 
32
31
  Args:
33
32
  x: Float or array-like
@@ -35,38 +34,42 @@ def convert_coords_to_match_xarray(
35
34
  from_coords: String describing coordinate system of x and y
36
35
  da: DataArray to which coordinates should be matched
37
36
  """
38
-
39
37
  target_coords, *_ = spatial_coord_type(da)
40
38
 
41
- assert from_coords in ["osgb", "lon_lat"]
42
- assert target_coords in ["geostationary", "osgb", "lon_lat"]
43
-
44
- if target_coords == "geostationary":
45
- if from_coords == "osgb":
39
+ match (from_coords, target_coords):
40
+ case ("osgb", "geostationary"):
46
41
  x, y = osgb_to_geostationary_area_coords(x, y, da)
47
42
 
48
- elif target_coords == "lon_lat":
49
- if from_coords == "osgb":
43
+ case ("osgb", "lon_lat"):
50
44
  x, y = osgb_to_lon_lat(x, y)
51
45
 
52
- # else the from_coords=="lon_lat" and we don't need to convert
46
+ case ("osgb", "osgb"):
47
+ pass
53
48
 
54
- elif target_coords == "osgb":
55
- if from_coords == "lon_lat":
49
+ case ("lon_lat", "osgb"):
56
50
  x, y = lon_lat_to_osgb(x, y)
57
51
 
58
- # else the from_coords=="osgb" and we don't need to convert
52
+ case ("lon_lat", "geostationary"):
53
+ x, y = lon_lat_to_geostationary_area_coords(x, y, da)
54
+
55
+ case ("lon_lat", "lon_lat"):
56
+ pass
57
+
58
+ case (_, _):
59
+ raise NotImplementedError(
60
+ f"Conversion from {from_coords} to {target_coords} is not supported",
61
+ )
59
62
 
60
63
  return x, y
61
64
 
62
- # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
65
+
66
+ # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
63
67
  # We should combine them, and consider making a Coord class to help with this
64
68
  def _get_idx_of_pixel_closest_to_poi(
65
69
  da: xr.DataArray,
66
70
  location: Location,
67
71
  ) -> Location:
68
- """
69
- Return x and y index location of pixel at center of region of interest.
72
+ """Return x and y index location of pixel at center of region of interest.
70
73
 
71
74
  Args:
72
75
  da: xarray DataArray
@@ -88,8 +91,14 @@ def _get_idx_of_pixel_closest_to_poi(
88
91
  )
89
92
 
90
93
  # Check that the requested point lies within the data
91
- assert da[x_dim].min() < x < da[x_dim].max()
92
- assert da[y_dim].min() < y < da[y_dim].max()
94
+ if not (da[x_dim].min() < x < da[x_dim].max()):
95
+ raise ValueError(
96
+ f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
97
+ )
98
+ if not (da[y_dim].min() < y < da[y_dim].max()):
99
+ raise ValueError(
100
+ f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
101
+ )
93
102
 
94
103
  x_index = da.get_index(x_dim)
95
104
  y_index = da.get_index(y_dim)
@@ -104,32 +113,38 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
104
113
  da: xr.DataArray,
105
114
  center: Location,
106
115
  ) -> Location:
107
- """
108
- Return x and y index location of pixel at center of region of interest.
116
+ """Return x and y index location of pixel at center of region of interest.
109
117
 
110
118
  Args:
111
119
  da: xarray DataArray
112
- center_osgb: Center in OSGB coordinates
120
+ center: Center in OSGB coordinates
113
121
 
114
122
  Returns:
115
123
  Location for the center pixel in geostationary coordinates
116
124
  """
117
-
118
125
  _, x_dim, y_dim = spatial_coord_type(da)
119
126
 
120
- if center.coordinate_system == 'osgb':
127
+ if center.coordinate_system == "osgb":
121
128
  x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
122
- elif center.coordinate_system == 'lon_lat':
123
- x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
129
+ elif center.coordinate_system == "lon_lat":
130
+ x, y = lon_lat_to_geostationary_area_coords(
131
+ longitude=center.x,
132
+ latitude=center.y,
133
+ xr_data=da,
134
+ )
124
135
  else:
125
- x,y = center.x, center.y
136
+ x, y = center.x, center.y
126
137
  center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
127
138
 
128
139
  # Check that the requested point lies within the data
129
- assert da[x_dim].min() < x < da[x_dim].max(), \
130
- f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}"
131
- assert da[y_dim].min() < y < da[y_dim].max(), \
132
- f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
140
+ if not (da[x_dim].min() < x < da[x_dim].max()):
141
+ raise ValueError(
142
+ f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
143
+ )
144
+ if not (da[y_dim].min() < y < da[y_dim].max()):
145
+ raise ValueError(
146
+ f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
147
+ )
133
148
 
134
149
  # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
135
150
  x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
@@ -142,24 +157,25 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
142
157
 
143
158
 
144
159
  def _select_partial_spatial_slice_pixels(
145
- da,
146
- left_idx,
147
- right_idx,
148
- bottom_idx,
149
- top_idx,
150
- left_pad_pixels,
151
- right_pad_pixels,
152
- bottom_pad_pixels,
153
- top_pad_pixels,
154
- x_dim,
155
- y_dim,
156
- ):
157
- """Return spatial window of given pixel size when window partially overlaps input data"""
158
-
159
- # We should never be padding on both sides of a window. This would mean our desired window is
160
+ da: xr.DataArray,
161
+ left_idx: int,
162
+ right_idx: int,
163
+ bottom_idx: int,
164
+ top_idx: int,
165
+ left_pad_pixels: int,
166
+ right_pad_pixels: int,
167
+ bottom_pad_pixels: int,
168
+ top_pad_pixels: int,
169
+ x_dim: str,
170
+ y_dim: str,
171
+ ) -> xr.DataArray:
172
+ """Return spatial window of given pixel size when window partially overlaps input data."""
173
+ # We should never be padding on both sides of a window. This would mean our desired window is
160
174
  # larger than the size of the input data
161
- assert left_pad_pixels==0 or right_pad_pixels==0
162
- assert bottom_pad_pixels==0 or top_pad_pixels==0
175
+ if (left_pad_pixels != 0 and right_pad_pixels != 0) or (
176
+ bottom_pad_pixels != 0 and top_pad_pixels != 0
177
+ ):
178
+ raise ValueError("Cannot pad both sides of the window")
163
179
 
164
180
  dx = np.median(np.diff(da[x_dim].values))
165
181
  dy = np.median(np.diff(da[y_dim].values))
@@ -170,7 +186,7 @@ def _select_partial_spatial_slice_pixels(
170
186
  [
171
187
  da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
172
188
  da[x_dim].values[0:right_idx],
173
- ]
189
+ ],
174
190
  )
175
191
  da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})
176
192
 
@@ -180,7 +196,7 @@ def _select_partial_spatial_slice_pixels(
180
196
  [
181
197
  da[x_dim].values[left_idx:],
182
198
  da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx,
183
- ]
199
+ ],
184
200
  )
185
201
  da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})
186
202
 
@@ -194,7 +210,7 @@ def _select_partial_spatial_slice_pixels(
194
210
  [
195
211
  da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
196
212
  da[y_dim].values[0:top_idx],
197
- ]
213
+ ],
198
214
  )
199
215
  da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})
200
216
 
@@ -204,7 +220,7 @@ def _select_partial_spatial_slice_pixels(
204
220
  [
205
221
  da[y_dim].values[bottom_idx:],
206
222
  da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
207
- ]
223
+ ],
208
224
  )
209
225
  da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
210
226
 
@@ -216,15 +232,15 @@ def _select_partial_spatial_slice_pixels(
216
232
 
217
233
 
218
234
  def _select_spatial_slice_pixels(
219
- da: xr.DataArray,
220
- center_idx: Location,
221
- width_pixels: int,
222
- height_pixels: int,
223
- x_dim: str,
224
- y_dim: str,
235
+ da: xr.DataArray,
236
+ center_idx: Location,
237
+ width_pixels: int,
238
+ height_pixels: int,
239
+ x_dim: str,
240
+ y_dim: str,
225
241
  allow_partial_slice: bool,
226
- ):
227
- """Select a spatial slice from an xarray object
242
+ ) -> xr.DataArray:
243
+ """Select a spatial slice from an xarray object.
228
244
 
229
245
  Args:
230
246
  da: xarray DataArray to slice from
@@ -235,11 +251,13 @@ def _select_spatial_slice_pixels(
235
251
  y_dim: Name of the y-dimension in `da`
236
252
  allow_partial_slice: Whether to allow a partially filled window
237
253
  """
238
-
239
- assert center_idx.coordinate_system == "idx"
254
+ if center_idx.coordinate_system != "idx":
255
+ raise ValueError(f"Expected center_idx to be in 'idx' coordinates, got '{center_idx}'")
240
256
  # TODO: It shouldn't take much effort to allow height and width to be odd
241
- assert (width_pixels % 2)==0, "Width must be an even number"
242
- assert (height_pixels % 2)==0, "Height must be an even number"
257
+ if (width_pixels % 2) != 0:
258
+ raise ValueError("Width must be an even number")
259
+ if (height_pixels % 2) != 0:
260
+ raise ValueError("Height must be an even number")
243
261
 
244
262
  half_width = width_pixels // 2
245
263
  half_height = height_pixels // 2
@@ -261,14 +279,12 @@ def _select_spatial_slice_pixels(
261
279
 
262
280
  if pad_required:
263
281
  if allow_partial_slice:
264
-
265
282
  left_pad_pixels = (-left_idx) if left_pad_required else 0
266
283
  right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
267
284
 
268
285
  bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
269
286
  top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
270
287
 
271
-
272
288
  da = _select_partial_spatial_slice_pixels(
273
289
  da,
274
290
  left_idx,
@@ -287,7 +303,7 @@ def _select_spatial_slice_pixels(
287
303
  f"Window for location {center_idx} not available. Missing (left, right, bottom, "
288
304
  f"top) pixels = ({left_pad_required}, {right_pad_required}, "
289
305
  f"{bottom_pad_required}, {top_pad_required}). "
290
- f"You may wish to set `allow_partial_slice=True`"
306
+ f"You may wish to set `allow_partial_slice=True`",
291
307
  )
292
308
 
293
309
  else:
@@ -295,17 +311,19 @@ def _select_spatial_slice_pixels(
295
311
  {
296
312
  x_dim: slice(left_idx, right_idx),
297
313
  y_dim: slice(bottom_idx, top_idx),
298
- }
314
+ },
299
315
  )
300
316
 
301
- assert len(da[x_dim]) == width_pixels, (
302
- f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
303
- f"for location {center_idx} for slice {left_idx}:{right_idx}"
304
- )
305
- assert len(da[y_dim]) == height_pixels, (
306
- f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
307
- f"for location {center_idx} for slice {bottom_idx}:{top_idx}"
308
- )
317
+ if len(da[x_dim]) != width_pixels:
318
+ raise ValueError(
319
+ f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
320
+ f"for location {center_idx} for slice {left_idx}:{right_idx}",
321
+ )
322
+ if len(da[y_dim]) != height_pixels:
323
+ raise ValueError(
324
+ f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
325
+ f"for location {center_idx} for slice {bottom_idx}:{top_idx}",
326
+ )
309
327
 
310
328
  return da
311
329
 
@@ -319,9 +337,8 @@ def select_spatial_slice_pixels(
319
337
  width_pixels: int,
320
338
  height_pixels: int,
321
339
  allow_partial_slice: bool = False,
322
- ):
323
- """
324
- Select spatial slice based off pixels from location point of interest
340
+ ) -> xr.DataArray:
341
+ """Select spatial slice based off pixels from location point of interest.
325
342
 
326
343
  If `allow_partial_slice` is set to True, then slices may be made which intersect the border
327
344
  of the input data. The additional x and y cordinates that would be required for this slice
@@ -336,7 +353,6 @@ def select_spatial_slice_pixels(
336
353
  width_pixels: Width of the slice in pixels
337
354
  allow_partial_slice: Whether to allow a partial slice.
338
355
  """
339
-
340
356
  xr_coords, x_dim, y_dim = spatial_coord_type(da)
341
357
 
342
358
  if xr_coords == "geostationary":
@@ -354,4 +370,4 @@ def select_spatial_slice_pixels(
354
370
  allow_partial_slice=allow_partial_slice,
355
371
  )
356
372
 
357
- return selected
373
+ return selected