ocf-data-sampler 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -1,7 +1,6 @@
1
1
  """Geospatial coordinate transformation functions.
2
2
 
3
3
  Provides utilities for working with different coordinate systems
4
- commonly used in geospatial applications, particularly for UK-based data.
5
4
 
6
5
  Supports conversions between:
7
6
  - OSGB36 (Ordnance Survey Great Britain, easting/northing in meters)
@@ -11,159 +10,198 @@ Supports conversions between:
11
10
 
12
11
  import numpy as np
13
12
  import pyproj
14
- import pyresample
15
13
  import xarray as xr
14
+ from pyresample.area_config import load_area_from_string
16
15
 
17
16
  # Coordinate Reference System (CRS) identifiers
18
- # OSGB36: UK Ordnance Survey National Grid (easting/northing in meters)
19
- # Refer to - https://epsg.io/27700
20
- OSGB36 = 27700
21
17
 
22
- # WGS84: World Geodetic System 1984 (latitude/longitude in degrees), used in GPS
18
+ # OSGB36: UK Ordnance Survey National Grid (easting/northing in meters) - https://epsg.io/27700
19
+ OSGB36 = 27700
20
+ # WGS84: World Geodetic System 1984 (latitude/longitude in degrees) - https://epsg.io/4326
23
21
  WGS84 = 4326
24
22
 
25
- # Pre-init Transformer
26
- _osgb_to_lon_lat = pyproj.Transformer.from_crs(
27
- crs_from=OSGB36,
28
- crs_to=WGS84,
29
- always_xy=True,
30
- ).transform
31
- _lon_lat_to_osgb = pyproj.Transformer.from_crs(
32
- crs_from=WGS84,
33
- crs_to=OSGB36,
34
- always_xy=True,
35
- ).transform
23
+ # Pre-inititiate coordinate Transformer objects
24
+ _osgb_to_lon_lat = pyproj.Transformer.from_crs(crs_from=OSGB36, crs_to=WGS84, always_xy=True)
25
+ _lon_lat_to_osgb = pyproj.Transformer.from_crs(crs_from=WGS84, crs_to=OSGB36, always_xy=True)
36
26
 
37
27
 
38
28
  def osgb_to_lon_lat(
39
29
  x: float | np.ndarray,
40
30
  y: float | np.ndarray,
41
31
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
42
- """Change OSGB coordinates to lon-lat.
32
+ """Convert OSGB coordinates to lon-lat.
43
33
 
44
34
  Args:
45
35
  x: osgb east-west
46
- y: osgb north-south
47
- Return: 2-tuple of longitude (east-west), latitude (north-south)
36
+ y: osgb south-north
37
+
38
+ Return: longitude, latitude
48
39
  """
49
- return _osgb_to_lon_lat(xx=x, yy=y)
40
+ return _osgb_to_lon_lat.transform(xx=x, yy=y)
50
41
 
51
42
 
52
43
  def lon_lat_to_osgb(
53
44
  x: float | np.ndarray,
54
45
  y: float | np.ndarray,
55
46
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
56
- """Change lon-lat coordinates to OSGB.
47
+ """Convert lon-lat coordinates to OSGB.
57
48
 
58
49
  Args:
59
50
  x: longitude east-west
60
- y: latitude north-south
51
+ y: latitude south-north
61
52
 
62
- Return: 2-tuple of OSGB x, y
53
+ Return: x_osgb, y_osgb
63
54
  """
64
- return _lon_lat_to_osgb(xx=x, yy=y)
55
+ return _lon_lat_to_osgb.transform(xx=x, yy=y)
56
+
57
+
58
+ def _get_geostationary_coord_transform(
59
+ crs_from: int,
60
+ area_string: str,
61
+ ) -> pyproj.transformer.Transformer:
62
+ """Loads geostationary area and transforms to geostationary coords.
63
+
64
+ Args:
65
+ x: osgb east-west, or latitude
66
+ y: osgb south-north, or longitude
67
+ crs_from: the cordiates system of x, y
68
+ area_string: String containing yaml geostationary area definition to convert to.
69
+
70
+ Returns: Coordinate Transformer
71
+ """
72
+ if crs_from not in [OSGB36, WGS84]:
73
+ raise ValueError(f"Unrecognized coordinate system: {crs_from}")
74
+
75
+ geostationary_crs = load_area_from_string(area_string).crs
76
+
77
+ return pyproj.Transformer.from_crs(
78
+ crs_from=crs_from,
79
+ crs_to=geostationary_crs,
80
+ always_xy=True,
81
+ )
65
82
 
66
83
 
67
84
  def lon_lat_to_geostationary_area_coords(
68
85
  longitude: float | np.ndarray,
69
86
  latitude: float | np.ndarray,
70
- xr_data: xr.DataArray,
87
+ area_string: str,
71
88
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
72
- """Loads geostationary area and transformation from lat-lon to geostationary coords.
89
+ """Convert from lon-lat to geostationary coords.
73
90
 
74
91
  Args:
75
92
  longitude: longitude
76
93
  latitude: latitude
77
- xr_data: xarray object with geostationary area
94
+ area_string: String containing yaml geostationary area definition to convert to.
78
95
 
79
- Returns:
80
- Geostationary coords: x, y
96
+ Returns: x_geostationary, y_geostationary
81
97
  """
82
- return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
98
+ coord_transformer = _get_geostationary_coord_transform(WGS84, area_string)
99
+ return coord_transformer.transform(xx=longitude, yy=latitude)
83
100
 
84
101
 
85
102
  def osgb_to_geostationary_area_coords(
86
103
  x: float | np.ndarray,
87
104
  y: float | np.ndarray,
88
- xr_data: xr.DataArray,
105
+ area_string: str,
89
106
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
90
- """Loads geostationary area and transformation from OSGB to geostationary coords.
107
+ """Convert from OSGB to geostationary coords.
91
108
 
92
109
  Args:
93
110
  x: osgb east-west
94
- y: osgb north-south
95
- xr_data: xarray object with geostationary area
111
+ y: osgb south-north
112
+ area_string: String containing yaml geostationary area definition to convert to.
113
+
114
+ Returns: x_geostationary, y_geostationary
115
+ """
116
+ coord_transformer = _get_geostationary_coord_transform(OSGB36, area_string)
117
+ return coord_transformer.transform(xx=x, yy=y)
118
+
119
+
120
+ def find_coord_system(da: xr.DataArray) -> tuple[str, str, str]:
121
+ """Searches the Xarray object to determine the spatial coordinate system.
122
+
123
+ Args:
124
+ da: Dataset with spatial coords
96
125
 
97
126
  Returns:
98
- Geostationary coords: x, y
127
+ Three strings with:
128
+ 1. The kind of the coordinate system
129
+ 2. Name of the x-coordinate
130
+ 3. Name of the y-coordinate
99
131
  """
100
- return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
132
+ # We only look at the dimensional coords. It is possible that other coordinate systems are
133
+ # included as non-dimensional coords
134
+ dimensional_coords = set(da.xindexes)
135
+
136
+ coord_systems = {
137
+ "lon_lat": ["longitude", "latitude"],
138
+ "geostationary": ["x_geostationary", "y_geostationary"],
139
+ "osgb": ["x_osgb", "y_osgb"],
140
+ }
141
+
142
+ coords_systems_found = []
143
+
144
+ for coord_name, coord_set in coord_systems.items():
145
+ if set(coord_set) <= dimensional_coords:
146
+ coords_systems_found.append(coord_name)
147
+
148
+ if len(coords_systems_found)==0:
149
+ raise ValueError(
150
+ f"Did not find any coordinate pairs in the dimensional coords: {dimensional_coords}",
151
+ )
152
+ elif len(coords_systems_found)>1:
153
+ raise ValueError(
154
+ f"Found >1 ({coords_systems_found}) coordinate pairs in the dimensional coords: "
155
+ f"{dimensional_coords}",
156
+ )
157
+ else:
158
+ coord_system_name = coords_systems_found[0]
159
+ return coord_system_name, *coord_systems[coord_system_name]
101
160
 
102
161
 
103
- def coordinates_to_geostationary_area_coords(
162
+ def convert_coordinates(
104
163
  x: float | np.ndarray,
105
164
  y: float | np.ndarray,
106
- xr_data: xr.DataArray,
107
- crs_from: int,
165
+ from_coords: str,
166
+ target_coords: str,
167
+ area_string: str | None = None,
108
168
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
109
- """Loads geostationary area and transforms to geostationary coords.
169
+ """Convert x and y coordinates from one coordinate system to another.
110
170
 
111
171
  Args:
112
- x: osgb east-west, or latitude
113
- y: osgb north-south, or longitude
114
- xr_data: xarray object with geostationary area
115
- crs_from: the cordiates system of x,y
172
+ x: The x-coordinate to convert.
173
+ y: The y-coordinate to convert.
174
+ from_coords: The coordinate system to convert from.
175
+ target_coords: The coordinate system to convert to
176
+ area_string: Optional string containing yaml geostationary area definition. Only used if
177
+ from_coords or target_coords is "geostationary"
116
178
 
117
179
  Returns:
118
- Geostationary coords: x, y
180
+ The converted (x, y) coordinates.
119
181
  """
120
- if crs_from not in [OSGB36, WGS84]:
121
- raise ValueError(f"Unrecognized coordinate system: {crs_from}")
182
+ if from_coords==target_coords:
183
+ return x, y
122
184
 
123
- area_definition_yaml = xr_data.attrs["area"]
185
+ if "geostationary" in (from_coords, target_coords) and area_string is not None:
186
+ ValueError("If using geostationary coords the `area_string` must be provided")
124
187
 
125
- geostationary_area_definition = pyresample.area_config.load_area_from_string(
126
- area_definition_yaml,
127
- )
128
- geostationary_crs = geostationary_area_definition.crs
129
- osgb_to_geostationary = pyproj.Transformer.from_crs(
130
- crs_from=crs_from,
131
- crs_to=geostationary_crs,
132
- always_xy=True,
133
- ).transform
134
- return osgb_to_geostationary(xx=x, yy=y)
135
-
136
-
137
- def _coord_priority(available_coords: list[str]) -> tuple[str, str, str]:
138
- """Determines the coordinate system of spatial coordinates present."""
139
- if "longitude" in available_coords:
140
- return "lon_lat", "longitude", "latitude"
141
- elif "x_geostationary" in available_coords:
142
- return "geostationary", "x_geostationary", "y_geostationary"
143
- elif "x_osgb" in available_coords:
144
- return "osgb", "x_osgb", "y_osgb"
145
- else:
146
- raise ValueError(f"Unrecognized coordinate system: {available_coords}")
188
+ match (from_coords, target_coords):
147
189
 
190
+ case ("osgb", "geostationary"):
191
+ x, y = osgb_to_geostationary_area_coords(x, y, area_string)
148
192
 
149
- def spatial_coord_type(ds: xr.DataArray) -> tuple[str, str, str]:
150
- """Searches the data array to determine the kind of spatial coordinates present.
193
+ case ("lon_lat", "geostationary"):
194
+ x, y = lon_lat_to_geostationary_area_coords(x, y, area_string)
151
195
 
152
- This search has a preference for the dimension coordinates of the xarray object.
196
+ case ("osgb", "lon_lat"):
197
+ x, y = osgb_to_lon_lat(x, y)
153
198
 
154
- Args:
155
- ds: Dataset with spatial coords
156
-
157
- Returns:
158
- Three strings with:
159
- 1. The kind of the coordinate system
160
- 2. Name of the x-coordinate
161
- 3. Name of the y-coordinate
162
- """
163
- if isinstance(ds, xr.DataArray):
164
- # Search dimension coords of dataarray
165
- coords = _coord_priority(ds.xindexes)
166
- else:
167
- raise ValueError(f"Unrecognized input type: {type(ds)}")
199
+ case ("lon_lat", "osgb"):
200
+ x, y = lon_lat_to_osgb(x, y)
168
201
 
169
- return coords
202
+ case (_, _):
203
+ raise NotImplementedError(
204
+ f"Conversion from {from_coords} to "
205
+ f"{target_coords} is not supported",
206
+ )
207
+ return x, y
@@ -1,27 +1,63 @@
1
- """Location model with coordinate system validation."""
1
+ """Location object."""
2
2
 
3
- from pydantic import BaseModel, Field, model_validator
4
3
 
5
- allowed_coordinate_systems = ["osgb", "lon_lat", "geostationary", "idx"]
4
+ allowed_coord_systems = {"osgb", "lon_lat", "geostationary"}
6
5
 
7
6
 
8
- class Location(BaseModel):
9
- """Represent a spatial location."""
7
+ class Location:
8
+ """A spatial location."""
10
9
 
11
- coordinate_system: str = Field(...,
12
- description="Coordinate system for the location must be lon_lat, osgb, or geostationary",
13
- )
10
+ def __init__(self, x: float, y: float, coord_system: int, id: int | str | None = None) -> None:
11
+ """A spatial location.
14
12
 
15
- x: float = Field(..., description="x coordinate - i.e. east-west position")
16
- y: float = Field(..., description="y coordinate - i.e. north-south position")
17
- id: int | None = Field(None, description="ID of the location - e.g. GSP ID")
13
+ Args:
14
+ x: The east-west / left-right location
15
+ y: The south-north / down-up location
16
+ coord_system: The coordinate system
17
+ id: The location ID
18
+ """
19
+ self._check_valid_coord_system(coord_system)
20
+ self._projections: dict[str, tuple[float, float]] = {coord_system: (x, y)}
21
+ self.id = id
18
22
 
19
- @model_validator(mode="after")
20
- def validate_coordinate_system(self) -> "Location":
21
- """Validate 'coordinate_system'."""
22
- if self.coordinate_system not in allowed_coordinate_systems:
23
+ @staticmethod
24
+ def _check_valid_coord_system(coord_system: str) -> None:
25
+ if coord_system not in allowed_coord_systems:
26
+ raise ValueError(f"Coordinate {coord_system} is not supported")
27
+
28
+ def in_coord_system(self, coord_system: str) -> tuple[float, float]:
29
+ """Get the location in a specified coordinate system.
30
+
31
+ Args:
32
+ coord_system: The desired output coordinate system
33
+ """
34
+ self._check_valid_coord_system(coord_system)
35
+
36
+ if coord_system in self._projections:
37
+ return self._projections[coord_system]
38
+ else:
23
39
  raise ValueError(
24
- f"coordinate_system = {self.coordinate_system} "
25
- f"is not in {allowed_coordinate_systems}",
40
+ "Requested the coodinate in {coord_system}. This has not yet been added. "
41
+ "The current available coordinate systems are "
42
+ f"{list(self.self._projections.keys())}",
26
43
  )
27
- return self
44
+
45
+ def add_coord_system(self, x: float, y: float, coord_system: int) -> None:
46
+ """Add the equivalent location in a different coordinate system.
47
+
48
+ Args:
49
+ x: The east-west / left-right coordinate
50
+ y: The south-north / down-up coordinate
51
+ coord_system: The coordinate system name
52
+ """
53
+ self._check_valid_coord_system(coord_system)
54
+ if coord_system in self._projections:
55
+ if not (x, y)==self._projections[coord_system]:
56
+ raise ValueError(
57
+ f"Tried to re-add coordinate projection {coord_system}, but the supplied"
58
+ f"coodrinate values ({x}, {y}) do not match the already stored values "
59
+ f"{self._projections[coord_system]}",
60
+ )
61
+ else:
62
+ self._projections[coord_system] = (x, y)
63
+
@@ -1,63 +1,13 @@
1
1
  """Select spatial slices."""
2
2
 
3
- import logging
4
-
5
3
  import numpy as np
6
4
  import xarray as xr
7
5
 
8
- from ocf_data_sampler.select.geospatial import (
9
- lon_lat_to_geostationary_area_coords,
10
- lon_lat_to_osgb,
11
- osgb_to_geostationary_area_coords,
12
- osgb_to_lon_lat,
13
- spatial_coord_type,
14
- )
6
+ from ocf_data_sampler.select.geospatial import find_coord_system
15
7
  from ocf_data_sampler.select.location import Location
16
8
 
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- def convert_coordinates(
21
- from_coords: str,
22
- x: float | np.ndarray,
23
- y: float | np.ndarray,
24
- da: xr.DataArray,
25
- ) -> tuple[float | np.ndarray, float | np.ndarray]:
26
- """Convert x and y coordinates to coordinate system matching xarray data.
27
-
28
- Args:
29
- from_coords: The coordinate system to convert from.
30
- x: The x-coordinate to convert.
31
- y: The y-coordinate to convert.
32
- da: The xarray DataArray used for context (e.g., for geostationary conversion).
33
-
34
- Returns:
35
- The converted (x, y) coordinates.
36
- """
37
- target_coords, *_ = spatial_coord_type(da)
38
-
39
- match (from_coords, target_coords):
40
- case ("osgb", "geostationary"):
41
- x, y = osgb_to_geostationary_area_coords(x, y, da)
42
- case ("osgb", "lon_lat"):
43
- x, y = osgb_to_lon_lat(x, y)
44
- case ("osgb", "osgb"):
45
- pass
46
- case ("lon_lat", "osgb"):
47
- x, y = lon_lat_to_osgb(x, y)
48
- case ("lon_lat", "geostationary"):
49
- x, y = lon_lat_to_geostationary_area_coords(x, y, da)
50
- case ("lon_lat", "lon_lat"):
51
- pass
52
- case (_, _):
53
- raise NotImplementedError(
54
- f"Conversion from {from_coords} to "
55
- f"{target_coords} is not supported",
56
- )
57
- return x, y
58
-
59
9
 
60
- def _get_pixel_index_location(da: xr.DataArray, location: Location) -> Location:
10
+ def _get_pixel_index_location(da: xr.DataArray, location: Location) -> tuple[int, int]:
61
11
  """Find pixel index location closest to given Location.
62
12
 
63
13
  Args:
@@ -65,14 +15,14 @@ def _get_pixel_index_location(da: xr.DataArray, location: Location) -> Location:
65
15
  location: The Location object representing the point of interest.
66
16
 
67
17
  Returns:
68
- A Location object with x and y attributes representing the pixel indices.
18
+ The pixel indices.
69
19
 
70
20
  Raises:
71
21
  ValueError: If the location is outside the bounds of the DataArray.
72
22
  """
73
- xr_coords, x_dim, y_dim = spatial_coord_type(da)
23
+ target_coords, x_dim, y_dim = find_coord_system(da)
74
24
 
75
- x, y = convert_coordinates(location.coordinate_system, location.x, location.y, da)
25
+ x, y = location.in_coord_system(target_coords)
76
26
 
77
27
  # Check that requested point lies within the data
78
28
  if not (da[x_dim].min() < x < da[x_dim].max()):
@@ -89,7 +39,7 @@ def _get_pixel_index_location(da: xr.DataArray, location: Location) -> Location:
89
39
  closest_x = x_index.get_indexer([x], method="nearest")[0]
90
40
  closest_y = y_index.get_indexer([y], method="nearest")[0]
91
41
 
92
- return Location(x=closest_x, y=closest_y, coordinate_system="idx")
42
+ return closest_x, closest_y
93
43
 
94
44
 
95
45
  def _select_padded_slice(
@@ -213,16 +163,16 @@ def select_spatial_slice_pixels(
213
163
  if (height_pixels % 2) != 0:
214
164
  raise ValueError("Height must be an even number")
215
165
 
216
- _, x_dim, y_dim = spatial_coord_type(da)
217
- center_idx = _get_pixel_index_location(da, location)
166
+ _, x_dim, y_dim = find_coord_system(da)
167
+ center_idx_x, center_idx_y = _get_pixel_index_location(da, location)
218
168
 
219
169
  half_width = width_pixels // 2
220
170
  half_height = height_pixels // 2
221
171
 
222
- left_idx = int(center_idx.x - half_width)
223
- right_idx = int(center_idx.x + half_width)
224
- bottom_idx = int(center_idx.y - half_height)
225
- top_idx = int(center_idx.y + half_height)
172
+ left_idx = int(center_idx_x - half_width)
173
+ right_idx = int(center_idx_x + half_width)
174
+ bottom_idx = int(center_idx_y - half_height)
175
+ top_idx = int(center_idx_y + half_height)
226
176
 
227
177
  data_width_pixels = len(da[x_dim])
228
178
  data_height_pixels = len(da[y_dim])
@@ -19,17 +19,15 @@ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
19
19
  from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
20
20
  from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
21
21
  from ocf_data_sampler.select import Location, fill_time_periods
22
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
23
22
  from ocf_data_sampler.torch_datasets.utils import (
23
+ add_alterate_coordinate_projections,
24
24
  config_normalization_values_to_dicts,
25
+ fill_nans_in_arrays,
25
26
  find_valid_time_periods,
27
+ merge_dicts,
26
28
  slice_datasets_by_space,
27
29
  slice_datasets_by_time,
28
30
  )
29
- from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
30
- fill_nans_in_arrays,
31
- merge_dicts,
32
- )
33
31
  from ocf_data_sampler.utils import minutes, tensorstore_compute
34
32
 
35
33
  xr.set_options(keep_attrs=True)
@@ -59,10 +57,10 @@ def get_gsp_locations(
59
57
  for gsp_id in gsp_ids:
60
58
  locations.append(
61
59
  Location(
62
- coordinate_system="osgb",
63
60
  x=df_gsp_loc.loc[gsp_id].x_osgb,
64
61
  y=df_gsp_loc.loc[gsp_id].y_osgb,
65
- id=gsp_id,
62
+ coord_system="osgb",
63
+ id=int(gsp_id),
66
64
  ),
67
65
  )
68
66
  return locations
@@ -100,10 +98,13 @@ class AbstractPVNetUKDataset(Dataset):
100
98
  valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
101
99
 
102
100
  # Construct list of locations to sample from
103
- self.locations = get_gsp_locations(
104
- gsp_ids,
105
- version=config.input_data.gsp.boundaries_version,
101
+ locations = get_gsp_locations(gsp_ids, version=config.input_data.gsp.boundaries_version)
102
+ self.locations = add_alterate_coordinate_projections(
103
+ locations,
104
+ datasets_dict,
105
+ primary_coords="osgb",
106
106
  )
107
+
107
108
  self.valid_t0_times = valid_t0_times
108
109
 
109
110
  # Assign config and input data to self
@@ -171,11 +172,14 @@ class AbstractPVNetUKDataset(Dataset):
171
172
  )
172
173
 
173
174
  # Add GSP location data
175
+
176
+ osgb_x, osgb_y = location.in_coord_system("osgb")
177
+
174
178
  numpy_modalities.append(
175
179
  {
176
180
  GSPSampleKey.gsp_id: location.id,
177
- GSPSampleKey.x_osgb: location.x,
178
- GSPSampleKey.y_osgb: location.y,
181
+ GSPSampleKey.x_osgb: osgb_x,
182
+ GSPSampleKey.y_osgb: osgb_y,
179
183
  },
180
184
  )
181
185
 
@@ -191,7 +195,7 @@ class AbstractPVNetUKDataset(Dataset):
191
195
  )
192
196
 
193
197
  # Convert OSGB coordinates to lon/lat
194
- lon, lat = osgb_to_lon_lat(location.x, location.y)
198
+ lon, lat = location.in_coord_system("lon_lat")
195
199
 
196
200
  # Calculate solar positions and add to modalities
197
201
  numpy_modalities.append(make_sun_position_numpy_sample(datetimes, lon, lat))
@@ -25,15 +25,14 @@ from ocf_data_sampler.select import (
25
25
  intersection_of_multiple_dataframes_of_periods,
26
26
  )
27
27
  from ocf_data_sampler.torch_datasets.utils import (
28
+ add_alterate_coordinate_projections,
28
29
  config_normalization_values_to_dicts,
30
+ fill_nans_in_arrays,
29
31
  find_valid_time_periods,
32
+ merge_dicts,
30
33
  slice_datasets_by_space,
31
34
  slice_datasets_by_time,
32
35
  )
33
- from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
34
- fill_nans_in_arrays,
35
- merge_dicts,
36
- )
37
36
  from ocf_data_sampler.utils import minutes, tensorstore_compute
38
37
 
39
38
  xr.set_options(keep_attrs=True)
@@ -52,7 +51,7 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
52
51
  id=site_id,
53
52
  x=site.longitude.values,
54
53
  y=site.latitude.values,
55
- coordinate_system="lon_lat",
54
+ coord_system="lon_lat",
56
55
  )
57
56
  locations.append(location)
58
57
 
@@ -168,8 +167,14 @@ class SitesDataset(Dataset):
168
167
  self.datasets_dict = datasets_dict
169
168
  self.config = config
170
169
 
171
- # get all locations
172
- self.locations = get_locations(datasets_dict["site"])
170
+ # Construct list of locations to sample from
171
+ locations = get_locations(datasets_dict["site"])
172
+ self.locations = add_alterate_coordinate_projections(
173
+ locations,
174
+ datasets_dict,
175
+ primary_coords="lon_lat",
176
+ )
177
+
173
178
  self.location_lookup = {loc.id: loc for loc in self.locations}
174
179
 
175
180
  # Get t0 times where all input data is available
@@ -2,4 +2,5 @@ from .config_normalization_values_to_dicts import config_normalization_values_to
2
2
  from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
3
  from .valid_time_periods import find_valid_time_periods
4
4
  from .spatial_slice_for_dataset import slice_datasets_by_space
5
- from .time_slice_for_dataset import slice_datasets_by_time
5
+ from .time_slice_for_dataset import slice_datasets_by_time
6
+ from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
@@ -0,0 +1,77 @@
1
+ """"Function for adding more projections to location objects."""
2
+
3
+ import numpy as np
4
+
5
+ from ocf_data_sampler.select import Location
6
+ from ocf_data_sampler.select.geospatial import convert_coordinates, find_coord_system
7
+
8
+
9
+ def add_alterate_coordinate_projections(
10
+ locations: list[Location],
11
+ datasets_dict: dict,
12
+ primary_coords: str,
13
+ ) -> list[Location]:
14
+ """Add (in-place) coordinate projections for all dataset to a set of locations.
15
+
16
+ Args:
17
+ locations: A list of locations
18
+ datasets_dict: The dataset dict to add projections for
19
+ primary_coords: The primary coords of the locations
20
+
21
+ Returns:
22
+ List of locations with all coordinate projections added
23
+ """
24
+ if primary_coords not in ["osgb", "lon_lat"]:
25
+ raise ValueError("Only osbg and lon_lat are currently supported")
26
+
27
+ xs, ys = np.array([loc.in_coord_system(primary_coords) for loc in locations]).T
28
+
29
+ datasets_list = []
30
+ if "nwp" in datasets_dict:
31
+ datasets_list.extend(datasets_dict["nwp"].values())
32
+ if "sat" in datasets_dict:
33
+ datasets_list.append(datasets_dict["sat"])
34
+
35
+ computed_coord_systems = {primary_coords}
36
+
37
+ # Find all the coord systems required by all datasets
38
+ for da in datasets_list:
39
+
40
+ # Fid the dataset required by this dataset
41
+ coord_system, *_ = find_coord_system(da)
42
+
43
+ # Skip if the projections in this coord system have already been computed
44
+ if coord_system not in computed_coord_systems:
45
+
46
+ # If using geostationary coords we need to extract the area definition string
47
+ area_string = da.attrs["area"] if coord_system=="geostationary" else None
48
+
49
+ new_xs, new_ys = convert_coordinates(
50
+ x=xs,
51
+ y=ys,
52
+ from_coords=primary_coords,
53
+ target_coords=coord_system,
54
+ area_string=area_string,
55
+ )
56
+
57
+ # Add the projection to the locations objects
58
+ for x, y, loc in zip(new_xs, new_ys, locations, strict=True):
59
+ loc.add_coord_system(x, y, coord_system)
60
+
61
+ computed_coord_systems.add(coord_system)
62
+
63
+ # Add lon-lat to start since it is required to compute the solar coords
64
+ if "lon_lat" not in computed_coord_systems:
65
+ new_xs, new_ys = convert_coordinates(
66
+ x=xs,
67
+ y=ys,
68
+ from_coords=primary_coords,
69
+ target_coords="lon_lat",
70
+ area_string=None,
71
+ )
72
+
73
+ # Add the projection to the locations objects
74
+ for x, y, loc in zip(new_xs, new_ys, locations, strict=False):
75
+ loc.add_coord_system(x, y, "lon_lat")
76
+
77
+ return locations
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.9
3
+ Version: 0.5.10
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -35,18 +35,19 @@ ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTA
35
35
  ocf_data_sampler/select/dropout.py,sha256=BYpv8L771faPOyN7SdIJ5cwkpDve-ohClj95jjsHmjg,1973
36
36
  ocf_data_sampler/select/fill_time_periods.py,sha256=TlGxp1xiAqnhdWfLy0pv3FuZc00dtimjWdLzr4JoTGA,865
37
37
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=etkr6LuB7zxkfzWJ6SgHiULdRuFzFlq5bOUNd257Qx4,11545
38
- ocf_data_sampler/select/geospatial.py,sha256=CDExkl36eZOKmdJPzUr_K0Wn3axHqv5nYo-EkSiINcc,5032
39
- ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O5Deu0c,1037
40
- ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
38
+ ocf_data_sampler/select/geospatial.py,sha256=rvMy_e--3tm-KAy9pU6b9-UMBQqH2sXykr3N_4SHYy4,6528
39
+ ocf_data_sampler/select/location.py,sha256=Qp0di-Pgq8WLjN9IBcTVTaRM3lckhr4ZVzaDRcgVXHw,2352
40
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=0nwIRa1Wbmasxgz_PiDvXkNPUaYCdZNaUaOmkV4YIE0,7192
41
41
  ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
42
42
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
43
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=qbyvTOZZNcGioeH-DDoJmSf_KLRidiuBQRnrvZXD6ts,12046
44
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_FUV_KDe5k7acAmjE9Z2kYgxCFJZrLjziaZssIi1ipg,15465
43
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=1uwcbEKakXsqr_ZcSjpzf3esP5tA_YVlYgtj9vBjtdM,12115
44
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=Z7pyuR3VHEJf8egxnZvpsdgBv7fMg2pSc61ORAuqOwQ,15607
45
45
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
46
46
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
47
47
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
48
48
  ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
49
- ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=_UHLL_yRzhLJVHi6ROSaSe8TGw80CAhU325uCZj7XkY,331
49
+ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=TNSYuSSmFgjsvvJxtoDrH645Z64CHsNUUQ0iayTccP4,416
50
+ ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py,sha256=w6Q4TyxNyl7PKAbhqiXvqOpnqIjwmOUcGREIvPNGYlQ,2666
50
51
  ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=jS3DkAwOF1W3AQnvsdkBJ1C8Unm93kQbS8hgTCtFv2A,1743
51
52
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
52
53
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
@@ -55,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0Wrg
55
56
  ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul3l0EP73Ik002fStr_bhsZh9mQqEU,4735
56
57
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
57
58
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
58
- ocf_data_sampler-0.5.9.dist-info/METADATA,sha256=LUgQmrakbDwIEfeP_3IojePDYDdvm15iUtftl5o8Rps,12816
59
- ocf_data_sampler-0.5.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
- ocf_data_sampler-0.5.9.dist-info/top_level.txt,sha256=deUxqmsONNAGZDNbsntbXH7BRA1MqWaUeAJrCo6q_xA,25
61
- ocf_data_sampler-0.5.9.dist-info/RECORD,,
59
+ ocf_data_sampler-0.5.10.dist-info/METADATA,sha256=01gjKqxXj0DCTPUBsUULi_WQACkpZlS3wWM5pJ8qrmw,12817
60
+ ocf_data_sampler-0.5.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
+ ocf_data_sampler-0.5.10.dist-info/top_level.txt,sha256=deUxqmsONNAGZDNbsntbXH7BRA1MqWaUeAJrCo6q_xA,25
62
+ ocf_data_sampler-0.5.10.dist-info/RECORD,,