ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,9 +1,11 @@
1
- """Get contiguous time periods for training"""
1
+ """Get contiguous time periods."""
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+
5
6
  from ocf_data_sampler.load.utils import check_time_unique_increasing
6
7
 
8
+ ZERO_TDELTA = pd.Timedelta(0)
7
9
 
8
10
 
9
11
  def find_contiguous_time_periods(
@@ -15,20 +17,20 @@ def find_contiguous_time_periods(
15
17
 
16
18
  Args:
17
19
  datetimes: pd.DatetimeIndex. Must be sorted.
18
- min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this
19
- would be set to the `total_seq_length` of each machine learning example.
20
+ min_seq_length: Sequences of min_seq_length or shorter will be discarded.
20
21
  max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
21
22
  apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
22
- sequences. Typically, `max_gap_duration` would be set to the sample period of
23
- the timeseries.
23
+ sequences.
24
24
 
25
25
  Returns:
26
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
27
- has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
26
+ pd.DataFrame where each row represents a single time period. The pd.DataFrame
27
+ has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
28
28
  """
29
29
  # Sanity checks.
30
- assert len(datetimes) > 0
31
- assert min_seq_length > 1
30
+ if len(datetimes) == 0:
31
+ raise ValueError("No datetimes to use")
32
+ if min_seq_length <= 1:
33
+ raise ValueError(f"{min_seq_length=} must be greater than 1")
32
34
  check_time_unique_increasing(datetimes)
33
35
 
34
36
  # Find indices of gaps larger than max_gap:
@@ -44,77 +46,75 @@ def find_contiguous_time_periods(
44
46
  # Capture the last segment of dt_index.
45
47
  segment_boundaries = np.concatenate((segment_boundaries, [len(datetimes)]))
46
48
 
47
- periods: list[dict[str, pd.Timestamp]] = []
49
+ periods: list[list[pd.Timestamp]] = []
48
50
  start_i = 0
49
51
  for next_start_i in segment_boundaries:
50
52
  n_timesteps = next_start_i - start_i
51
53
  if n_timesteps > min_seq_length:
52
54
  end_i = next_start_i - 1
53
- period = {"start_dt": datetimes[start_i], "end_dt": datetimes[end_i]}
54
- periods.append(period)
55
+ periods.append([datetimes[start_i], datetimes[end_i]])
55
56
  start_i = next_start_i
56
57
 
57
- assert len(periods) > 0, (
58
- f"Did not find an periods from {datetimes}. " f"{min_seq_length=} {max_gap_duration=}"
59
- )
58
+ if len(periods) == 0:
59
+ raise ValueError(
60
+ f"Did not find any periods from {datetimes}. {min_seq_length=} {max_gap_duration=}",
61
+ )
60
62
 
61
- return pd.DataFrame(periods)
63
+ return pd.DataFrame(periods, columns=["start_dt", "end_dt"])
62
64
 
63
65
 
64
66
  def trim_contiguous_time_periods(
65
- contiguous_time_periods: pd.DataFrame,
67
+ contiguous_time_periods: pd.DataFrame,
66
68
  interval_start: pd.Timedelta,
67
69
  interval_end: pd.Timedelta,
68
70
  ) -> pd.DataFrame:
69
- """Trim the contiguous time periods to allow for history and forecast durations.
71
+ """Trims contiguous time periods to account for history requirements and forecast horizons.
70
72
 
71
73
  Args:
72
- contiguous_time_periods: DataFrame where each row represents a single time period. The
73
- DataFrame must have `start_dt` and `end_dt` columns.
74
+ contiguous_time_periods: pd.DataFrame where each row represents a single time period.
75
+ The pd.DataFrame must have `start_dt` and `end_dt` columns.
74
76
  interval_start: The start of the interval with respect to t0
75
77
  interval_end: The end of the interval with respect to t0
76
78
 
77
-
78
79
  Returns:
79
- The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
80
+ The contiguous_time_periods pd.DataFrame with the `start_dt` and `end_dt` columns updated.
80
81
  """
81
- contiguous_time_periods = contiguous_time_periods.copy()
82
-
83
- contiguous_time_periods["start_dt"] -= interval_start
84
- contiguous_time_periods["end_dt"] -= interval_end
82
+ # Make a copy so the data is not edited in place.
83
+ trimmed_time_periods = contiguous_time_periods.copy()
84
+ trimmed_time_periods["start_dt"] -= interval_start
85
+ trimmed_time_periods["end_dt"] -= interval_end
85
86
 
86
- valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
87
- contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
88
-
89
- return contiguous_time_periods
87
+ valid_mask = trimmed_time_periods["start_dt"] <= trimmed_time_periods["end_dt"]
90
88
 
89
+ return trimmed_time_periods.loc[valid_mask]
91
90
 
92
91
 
93
92
  def find_contiguous_t0_periods(
94
- datetimes: pd.DatetimeIndex,
95
- interval_start: pd.Timedelta,
96
- interval_end: pd.Timedelta,
97
- sample_period_duration: pd.Timedelta,
98
- ) -> pd.DataFrame:
93
+ datetimes: pd.DatetimeIndex,
94
+ interval_start: pd.Timedelta,
95
+ interval_end: pd.Timedelta,
96
+ time_resolution: pd.Timedelta,
97
+ ) -> pd.DataFrame:
99
98
  """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
100
99
 
101
100
  Args:
102
- datetimes: pd.DatetimeIndex. Must be sorted.
101
+ datetimes: pd.DatetimeIndex
103
102
  interval_start: The start of the interval with respect to t0
104
103
  interval_end: The end of the interval with respect to t0
105
- sample_period_duration: The sample frequency of the timeseries
106
-
104
+ time_resolution: The sample frequency of the timeseries
107
105
 
108
106
  Returns:
109
107
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
110
108
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
111
109
  """
110
+ check_time_unique_increasing(datetimes)
111
+
112
112
  total_duration = interval_end - interval_start
113
-
113
+
114
114
  contiguous_time_periods = find_contiguous_time_periods(
115
115
  datetimes=datetimes,
116
- min_seq_length=int(total_duration / sample_period_duration) + 1,
117
- max_gap_duration=sample_period_duration,
116
+ min_seq_length=int(total_duration / time_resolution) + 1,
117
+ max_gap_duration=time_resolution,
118
118
  )
119
119
 
120
120
  contiguous_t0_periods = trim_contiguous_time_periods(
@@ -123,7 +123,11 @@ def find_contiguous_t0_periods(
123
123
  interval_end=interval_end,
124
124
  )
125
125
 
126
- assert len(contiguous_t0_periods) > 0
126
+ if len(contiguous_t0_periods) == 0:
127
+ raise ValueError(
128
+ f"No contiguous time periods found for {datetimes}. "
129
+ f"{interval_start=} {interval_end=} {time_resolution=}",
130
+ )
127
131
 
128
132
  return contiguous_t0_periods
129
133
 
@@ -132,54 +136,59 @@ def find_contiguous_t0_periods_nwp(
132
136
  init_times: pd.DatetimeIndex,
133
137
  interval_start: pd.Timedelta,
134
138
  max_staleness: pd.Timedelta,
135
- max_dropout: pd.Timedelta = pd.Timedelta(0),
136
- first_forecast_step: pd.Timedelta = pd.Timedelta(0),
137
-
139
+ max_dropout: pd.Timedelta = ZERO_TDELTA,
140
+ first_forecast_step: pd.Timedelta = ZERO_TDELTA,
138
141
  ) -> pd.DataFrame:
139
- """Get all time periods from the NWP init times which are valid as t0 datetimes.
142
+ """Get all time periods from the NWP init-times which are valid as t0 datetimes.
140
143
 
141
144
  Args:
142
145
  init_times: The initialisation times of the available forecasts
143
- interval_start: The start of the desired data interval with respect to t0
144
- max_staleness: Up to how long after an init time are we willing to use the forecast. Each
145
- init time will only be used up to this t0 time regardless of the forecast valid time.
146
- max_dropout: What is the maximum amount of dropout that will be used. This must be <=
147
- max_staleness.
148
- first_forecast_step: The timedelta of the first step of the forecast. By default we assume
149
- the first valid time of the forecast is the same as its init time.
146
+ interval_start: The start of the time interval with respect to t0
147
+ max_staleness: Up to how long after an init-time are we willing to use the forecast.
148
+ Each init-time will only be used up to this t0 time regardless of the forecast valid
149
+ time.
150
+ max_dropout: What is the maximum amount of dropout that will be used.
151
+ This must be <= max_staleness.
152
+ first_forecast_step: The timedelta of the first step of the forecast.
153
+ By default we assume the first valid time of the forecast
154
+ is the same as its init-time.
150
155
 
151
156
  Returns:
152
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
157
+ pd.DataFrame where each row represents a single time period. The pd.DataFrame
153
158
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
154
159
  """
155
160
  # Sanity checks.
156
- assert len(init_times) > 0
157
- assert init_times.is_monotonic_increasing
158
- assert init_times.is_unique
159
- assert max_staleness >= pd.Timedelta(0)
160
- assert pd.Timedelta(0) <= max_dropout <= max_staleness
161
+ if len(init_times) == 0:
162
+ raise ValueError("No init-times to use")
163
+
164
+ check_time_unique_increasing(init_times)
165
+
166
+ if max_staleness < pd.Timedelta(0):
167
+ raise ValueError("The max staleness must be positive")
168
+ if not (pd.Timedelta(0) <= max_dropout <= max_staleness):
169
+ raise ValueError("The max dropout must be between 0 and the max staleness")
161
170
 
162
- hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
171
+ history_drop_buffer = max(first_forecast_step - interval_start, max_dropout)
163
172
 
164
173
  # Store contiguous periods
165
- contiguous_periods = []
174
+ contiguous_periods: list[list[pd.Timestamp]] = []
166
175
 
167
- # Begin the first period allowing for the time to the first_forecast_step, the length of the
176
+ # Begin the first period allowing for the time to the first_forecast_step, the length of the
168
177
  # interval sampled from before t0, and the dropout
169
- start_this_period = init_times[0] + hist_drop_buffer
178
+ start_this_period = init_times[0] + history_drop_buffer
170
179
 
171
180
  # The first forecast is valid up to the max staleness
172
181
  end_this_period = init_times[0] + max_staleness
173
182
 
174
183
  for dt_init in init_times[1:]:
175
- # If the previous init time becomes stale before the next init becomes valid (whilst also
176
- # considering dropout) then the contiguous period breaks
177
- # Else if the previous init time becomes stale before the fist step of the next forecast
184
+ # If the previous init-time becomes stale before the next init-time becomes valid (whilst
185
+ # also considering dropout) then the contiguous period breaks
186
+ # Else if the previous init-time becomes stale before the fist step of the next forecast
178
187
  # then this also causes a break in the contiguous period
179
- if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
188
+ if end_this_period < dt_init + max(max_dropout, first_forecast_step):
180
189
  contiguous_periods.append([start_this_period, end_this_period])
181
190
  # The new period begins with the same conditions as the first period
182
- start_this_period = dt_init + hist_drop_buffer
191
+ start_this_period = dt_init + history_drop_buffer
183
192
  end_this_period = dt_init + max_staleness
184
193
 
185
194
  contiguous_periods.append([start_this_period, end_this_period])
@@ -190,11 +199,13 @@ def find_contiguous_t0_periods_nwp(
190
199
  def intersection_of_multiple_dataframes_of_periods(
191
200
  time_periods: list[pd.DataFrame],
192
201
  ) -> pd.DataFrame:
193
- """Find the intersection of a list of time periods.
202
+ """Find the intersection of list of time periods.
194
203
 
195
- See the docstring of intersection_of_2_dataframes_of_periods() for more details.
204
+ Consecutively updates intersection of time periods.
205
+ See the docstring of intersection_of_2_dataframes_of_periods() for further details.
196
206
  """
197
- assert len(time_periods) > 0
207
+ if len(time_periods) == 0:
208
+ raise ValueError("No time periods to intersect")
198
209
  intersection = time_periods[0]
199
210
  for time_period in time_periods[1:]:
200
211
  intersection = intersection_of_2_dataframes_of_periods(intersection, time_period)
@@ -210,7 +221,8 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
210
221
  A typical use-case is that each pd.DataFrame represents all the time periods where
211
222
  a `DataSource` has contiguous, valid data.
212
223
 
213
- Here's a graphical example of two pd.DataFrames of time periods and their intersection:
224
+ Graphical representation of two pd.DataFrames of time periods and their intersection,
225
+ as follows:
214
226
 
215
227
  ----------------------> TIME ->---------------------
216
228
  a: |-----| |----| |----------| |-----------|
@@ -218,9 +230,9 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
218
230
  intersection: |--| |-| |--| |---|
219
231
 
220
232
  Args:
221
- a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
233
+ a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
222
234
  two columns: start_dt and end_dt.
223
- b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
235
+ b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
224
236
  two columns: start_dt and end_dt.
225
237
 
226
238
  Returns:
@@ -239,7 +251,7 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
239
251
  # and `a` must always end after `b` starts:
240
252
 
241
253
  # TODO: <= and >= because we should allow overlap time periods of length 1. e.g.
242
- # a: |----| or |---|
254
+ # a: |----| or |---|
243
255
  # b: |--| |---|
244
256
  # These aren't allowed if we use < and >.
245
257
 
@@ -1,36 +1,45 @@
1
- """Geospatial functions"""
1
+ """Geospatial coordinate transformation functions.
2
2
 
3
- from numbers import Number
4
- from typing import Union
3
+ Provides utilities for working with different coordinate systems
4
+ commonly used in geospatial applications, particularly for UK-based data.
5
+
6
+ Supports conversions between:
7
+ - OSGB36 (Ordnance Survey Great Britain, easting/northing in meters)
8
+ - WGS84 (World Geodetic System, latitude/longitude in degrees)
9
+ - Geostationary satellite coordinate systems
10
+ """
5
11
 
6
12
  import numpy as np
7
13
  import pyproj
14
+ import pyresample
8
15
  import xarray as xr
9
16
 
10
- # OSGB is also called "OSGB 1936 / British National Grid -- United
11
- # Kingdom Ordnance Survey". OSGB is used in many UK electricity
12
- # system maps, and is used by the UK Met Office UKV model. OSGB is a
13
- # Transverse Mercator projection, using 'easting' and 'northing'
14
- # coordinates which are in meters. See https://epsg.io/27700
17
+ # Coordinate Reference System (CRS) identifiers
18
+ # OSGB36: UK Ordnance Survey National Grid (easting/northing in meters)
19
+ # Refer to - https://epsg.io/27700
15
20
  OSGB36 = 27700
16
21
 
17
- # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
18
- # latitude and longitude.
22
+ # WGS84: World Geodetic System 1984 (latitude/longitude in degrees), used in GPS
19
23
  WGS84 = 4326
20
24
 
21
-
25
+ # Pre-init Transformer
22
26
  _osgb_to_lon_lat = pyproj.Transformer.from_crs(
23
- crs_from=OSGB36, crs_to=WGS84, always_xy=True
27
+ crs_from=OSGB36,
28
+ crs_to=WGS84,
29
+ always_xy=True,
24
30
  ).transform
25
31
  _lon_lat_to_osgb = pyproj.Transformer.from_crs(
26
- crs_from=WGS84, crs_to=OSGB36, always_xy=True
32
+ crs_from=WGS84,
33
+ crs_to=OSGB36,
34
+ always_xy=True,
27
35
  ).transform
28
36
 
29
37
 
30
38
  def osgb_to_lon_lat(
31
- x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
32
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
33
- """Change OSGB coordinates to lon, lat.
39
+ x: float | np.ndarray,
40
+ y: float | np.ndarray,
41
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
42
+ """Change OSGB coordinates to lon-lat.
34
43
 
35
44
  Args:
36
45
  x: osgb east-west
@@ -41,9 +50,9 @@ def osgb_to_lon_lat(
41
50
 
42
51
 
43
52
  def lon_lat_to_osgb(
44
- x: Union[Number, np.ndarray],
45
- y: Union[Number, np.ndarray],
46
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
53
+ x: float | np.ndarray,
54
+ y: float | np.ndarray,
55
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
47
56
  """Change lon-lat coordinates to OSGB.
48
57
 
49
58
  Args:
@@ -56,11 +65,11 @@ def lon_lat_to_osgb(
56
65
 
57
66
 
58
67
  def lon_lat_to_geostationary_area_coords(
59
- longitude: Union[Number, np.ndarray],
60
- latitude: Union[Number, np.ndarray],
68
+ longitude: float | np.ndarray,
69
+ latitude: float | np.ndarray,
61
70
  xr_data: xr.DataArray,
62
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
- """Loads geostationary area and transformation from lat-lon to geostationary coords
71
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
72
+ """Loads geostationary area and transformation from lat-lon to geostationary coords.
64
73
 
65
74
  Args:
66
75
  longitude: longitude
@@ -72,12 +81,13 @@ def lon_lat_to_geostationary_area_coords(
72
81
  """
73
82
  return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
74
83
 
84
+
75
85
  def osgb_to_geostationary_area_coords(
76
- x: Union[Number, np.ndarray],
77
- y: Union[Number, np.ndarray],
86
+ x: float | np.ndarray,
87
+ y: float | np.ndarray,
78
88
  xr_data: xr.DataArray,
79
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
80
- """Loads geostationary area and transformation from OSGB to geostationary coords
89
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
90
+ """Loads geostationary area and transformation from OSGB to geostationary coords.
81
91
 
82
92
  Args:
83
93
  x: osgb east-west
@@ -87,47 +97,45 @@ def osgb_to_geostationary_area_coords(
87
97
  Returns:
88
98
  Geostationary coords: x, y
89
99
  """
90
-
91
100
  return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
92
101
 
93
102
 
94
-
95
103
  def coordinates_to_geostationary_area_coords(
96
- x: Union[Number, np.ndarray],
97
- y: Union[Number, np.ndarray],
104
+ x: float | np.ndarray,
105
+ y: float | np.ndarray,
98
106
  xr_data: xr.DataArray,
99
- crs_from: int
100
- ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
101
- """Loads geostationary area and transformation from respective coordiates to geostationary coords
102
-
103
- Args:
104
- x: osgb east-west, or latitude
105
- y: osgb north-south, or longitude
106
- xr_data: xarray object with geostationary area
107
- crs_from: the cordiates system of x,y
107
+ crs_from: int,
108
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
109
+ """Loads geostationary area and transforms to geostationary coords.
108
110
 
109
- Returns:
110
- Geostationary coords: x, y
111
- """
112
-
113
- assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}"
111
+ Args:
112
+ x: osgb east-west, or latitude
113
+ y: osgb north-south, or longitude
114
+ xr_data: xarray object with geostationary area
115
+ crs_from: the cordiates system of x,y
114
116
 
115
- # Only load these if using geostationary projection
116
- import pyresample
117
+ Returns:
118
+ Geostationary coords: x, y
119
+ """
120
+ if crs_from not in [OSGB36, WGS84]:
121
+ raise ValueError(f"Unrecognized coordinate system: {crs_from}")
117
122
 
118
123
  area_definition_yaml = xr_data.attrs["area"]
119
124
 
120
125
  geostationary_area_definition = pyresample.area_config.load_area_from_string(
121
- area_definition_yaml
126
+ area_definition_yaml,
122
127
  )
123
128
  geostationary_crs = geostationary_area_definition.crs
124
129
  osgb_to_geostationary = pyproj.Transformer.from_crs(
125
- crs_from=crs_from, crs_to=geostationary_crs, always_xy=True
130
+ crs_from=crs_from,
131
+ crs_to=geostationary_crs,
132
+ always_xy=True,
126
133
  ).transform
127
134
  return osgb_to_geostationary(xx=x, yy=y)
128
135
 
129
136
 
130
- def _coord_priority(available_coords):
137
+ def _coord_priority(available_coords: list[str]) -> tuple[str, str, str]:
138
+ """Determines the coordinate system of spatial coordinates present."""
131
139
  if "longitude" in available_coords:
132
140
  return "lon_lat", "longitude", "latitude"
133
141
  elif "x_geostationary" in available_coords:
@@ -138,7 +146,7 @@ def _coord_priority(available_coords):
138
146
  raise ValueError(f"Unrecognized coordinate system: {available_coords}")
139
147
 
140
148
 
141
- def spatial_coord_type(ds: xr.DataArray):
149
+ def spatial_coord_type(ds: xr.DataArray) -> tuple[str, str, str]:
142
150
  """Searches the data array to determine the kind of spatial coordinates present.
143
151
 
144
152
  This search has a preference for the dimension coordinates of the xarray object.
@@ -147,9 +155,10 @@ def spatial_coord_type(ds: xr.DataArray):
147
155
  ds: Dataset with spatial coords
148
156
 
149
157
  Returns:
150
- str: The kind of the coordinate system
151
- x_coord: Name of the x-coordinate
152
- y_coord: Name of the y-coordinate
158
+ Three strings with:
159
+ 1. The kind of the coordinate system
160
+ 2. Name of the x-coordinate
161
+ 3. Name of the y-coordinate
153
162
  """
154
163
  if isinstance(ds, xr.DataArray):
155
164
  # Search dimension coords of dataarray
@@ -1,62 +1,27 @@
1
- """location"""
1
+ """Location model with coordinate system validation."""
2
2
 
3
- from typing import Optional
4
-
5
- import numpy as np
6
3
  from pydantic import BaseModel, Field, model_validator
7
4
 
5
+ allowed_coordinate_systems = ["osgb", "lon_lat", "geostationary", "idx"]
8
6
 
9
- allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
10
7
 
11
8
  class Location(BaseModel):
12
9
  """Represent a spatial location."""
13
10
 
14
- coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
15
- x: float
16
- y: float
17
- id: Optional[int] = Field(None)
18
-
19
- @model_validator(mode='after')
20
- def validate_coordinate_system(self):
21
- """Validate 'coordinate_system'"""
22
- if self.coordinate_system not in allowed_coordinate_systems:
23
- raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
24
- return self
25
-
26
- @model_validator(mode='after')
27
- def validate_x(self):
28
- """Validate 'x'"""
29
- min_x: float
30
- max_x: float
31
-
32
- co = self.coordinate_system
33
- if co == "osgb":
34
- min_x, max_x = -103976.3, 652897.98
35
- if co == "lon_lat":
36
- min_x, max_x = -180, 180
37
- if co == "geostationary":
38
- min_x, max_x = -5568748.275756836, 5567248.074173927
39
- if co == "idx":
40
- min_x, max_x = 0, np.inf
41
- if self.x < min_x or self.x > max_x:
42
- raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
43
- return self
11
+ coordinate_system: str = Field(...,
12
+ description="Coordinate system for the location must be lon_lat, osgb, or geostationary",
13
+ )
44
14
 
45
- @model_validator(mode='after')
46
- def validate_y(self):
47
- """Validate 'y'"""
48
- min_y: float
49
- max_y: float
15
+ x: float = Field(..., description="x coordinate - i.e. east-west position")
16
+ y: float = Field(..., description="y coordinate - i.e. north-south position")
17
+ id: int | None = Field(None, description="ID of the location - e.g. GSP ID")
50
18
 
51
- co = self.coordinate_system
52
- if co == "osgb":
53
- min_y, max_y = -16703.87, 1199851.44
54
- if co == "lon_lat":
55
- min_y, max_y = -90, 90
56
- if co == "geostationary":
57
- min_y, max_y = 1393687.2151494026, 5570748.323202133
58
- if co == "idx":
59
- min_y, max_y = 0, np.inf
60
- if self.y < min_y or self.y > max_y:
61
- raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
19
+ @model_validator(mode="after")
20
+ def validate_coordinate_system(self) -> "Location":
21
+ """Validate 'coordinate_system'."""
22
+ if self.coordinate_system not in allowed_coordinate_systems:
23
+ raise ValueError(
24
+ f"coordinate_system = {self.coordinate_system} "
25
+ f"is not in {allowed_coordinate_systems}",
26
+ )
62
27
  return self