ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
"""Get contiguous time periods
|
|
1
|
+
"""Get contiguous time periods."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
+
|
|
5
6
|
from ocf_data_sampler.load.utils import check_time_unique_increasing
|
|
6
7
|
|
|
8
|
+
ZERO_TDELTA = pd.Timedelta(0)
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def find_contiguous_time_periods(
|
|
@@ -15,20 +17,20 @@ def find_contiguous_time_periods(
|
|
|
15
17
|
|
|
16
18
|
Args:
|
|
17
19
|
datetimes: pd.DatetimeIndex. Must be sorted.
|
|
18
|
-
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
|
|
19
|
-
would be set to the `total_seq_length` of each machine learning example.
|
|
20
|
+
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
|
|
20
21
|
max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
|
|
21
22
|
apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
|
|
22
|
-
sequences.
|
|
23
|
-
the timeseries.
|
|
23
|
+
sequences.
|
|
24
24
|
|
|
25
25
|
Returns:
|
|
26
|
-
pd.DataFrame where each row represents a single time period.
|
|
27
|
-
|
|
26
|
+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
27
|
+
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
28
28
|
"""
|
|
29
29
|
# Sanity checks.
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
if len(datetimes) == 0:
|
|
31
|
+
raise ValueError("No datetimes to use")
|
|
32
|
+
if min_seq_length <= 1:
|
|
33
|
+
raise ValueError(f"{min_seq_length=} must be greater than 1")
|
|
32
34
|
check_time_unique_increasing(datetimes)
|
|
33
35
|
|
|
34
36
|
# Find indices of gaps larger than max_gap:
|
|
@@ -44,77 +46,75 @@ def find_contiguous_time_periods(
|
|
|
44
46
|
# Capture the last segment of dt_index.
|
|
45
47
|
segment_boundaries = np.concatenate((segment_boundaries, [len(datetimes)]))
|
|
46
48
|
|
|
47
|
-
periods: list[
|
|
49
|
+
periods: list[list[pd.Timestamp]] = []
|
|
48
50
|
start_i = 0
|
|
49
51
|
for next_start_i in segment_boundaries:
|
|
50
52
|
n_timesteps = next_start_i - start_i
|
|
51
53
|
if n_timesteps > min_seq_length:
|
|
52
54
|
end_i = next_start_i - 1
|
|
53
|
-
|
|
54
|
-
periods.append(period)
|
|
55
|
+
periods.append([datetimes[start_i], datetimes[end_i]])
|
|
55
56
|
start_i = next_start_i
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
if len(periods) == 0:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Did not find any periods from {datetimes}. {min_seq_length=} {max_gap_duration=}",
|
|
61
|
+
)
|
|
60
62
|
|
|
61
|
-
return pd.DataFrame(periods)
|
|
63
|
+
return pd.DataFrame(periods, columns=["start_dt", "end_dt"])
|
|
62
64
|
|
|
63
65
|
|
|
64
66
|
def trim_contiguous_time_periods(
|
|
65
|
-
contiguous_time_periods: pd.DataFrame,
|
|
67
|
+
contiguous_time_periods: pd.DataFrame,
|
|
66
68
|
interval_start: pd.Timedelta,
|
|
67
69
|
interval_end: pd.Timedelta,
|
|
68
70
|
) -> pd.DataFrame:
|
|
69
|
-
"""
|
|
71
|
+
"""Trims contiguous time periods to account for history requirements and forecast horizons.
|
|
70
72
|
|
|
71
73
|
Args:
|
|
72
|
-
contiguous_time_periods: DataFrame where each row represents a single time period.
|
|
73
|
-
DataFrame must have `start_dt` and `end_dt` columns.
|
|
74
|
+
contiguous_time_periods: pd.DataFrame where each row represents a single time period.
|
|
75
|
+
The pd.DataFrame must have `start_dt` and `end_dt` columns.
|
|
74
76
|
interval_start: The start of the interval with respect to t0
|
|
75
77
|
interval_end: The end of the interval with respect to t0
|
|
76
78
|
|
|
77
|
-
|
|
78
79
|
Returns:
|
|
79
|
-
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
|
|
80
|
+
The contiguous_time_periods pd.DataFrame with the `start_dt` and `end_dt` columns updated.
|
|
80
81
|
"""
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
# Make a copy so the data is not edited in place.
|
|
83
|
+
trimmed_time_periods = contiguous_time_periods.copy()
|
|
84
|
+
trimmed_time_periods["start_dt"] -= interval_start
|
|
85
|
+
trimmed_time_periods["end_dt"] -= interval_end
|
|
85
86
|
|
|
86
|
-
valid_mask =
|
|
87
|
-
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
|
|
88
|
-
|
|
89
|
-
return contiguous_time_periods
|
|
87
|
+
valid_mask = trimmed_time_periods["start_dt"] <= trimmed_time_periods["end_dt"]
|
|
90
88
|
|
|
89
|
+
return trimmed_time_periods.loc[valid_mask]
|
|
91
90
|
|
|
92
91
|
|
|
93
92
|
def find_contiguous_t0_periods(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
93
|
+
datetimes: pd.DatetimeIndex,
|
|
94
|
+
interval_start: pd.Timedelta,
|
|
95
|
+
interval_end: pd.Timedelta,
|
|
96
|
+
time_resolution: pd.Timedelta,
|
|
97
|
+
) -> pd.DataFrame:
|
|
99
98
|
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
|
|
100
99
|
|
|
101
100
|
Args:
|
|
102
|
-
datetimes: pd.DatetimeIndex
|
|
101
|
+
datetimes: pd.DatetimeIndex
|
|
103
102
|
interval_start: The start of the interval with respect to t0
|
|
104
103
|
interval_end: The end of the interval with respect to t0
|
|
105
|
-
|
|
106
|
-
|
|
104
|
+
time_resolution: The sample frequency of the timeseries
|
|
107
105
|
|
|
108
106
|
Returns:
|
|
109
107
|
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
110
108
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
111
109
|
"""
|
|
110
|
+
check_time_unique_increasing(datetimes)
|
|
111
|
+
|
|
112
112
|
total_duration = interval_end - interval_start
|
|
113
|
-
|
|
113
|
+
|
|
114
114
|
contiguous_time_periods = find_contiguous_time_periods(
|
|
115
115
|
datetimes=datetimes,
|
|
116
|
-
min_seq_length=int(total_duration /
|
|
117
|
-
max_gap_duration=
|
|
116
|
+
min_seq_length=int(total_duration / time_resolution) + 1,
|
|
117
|
+
max_gap_duration=time_resolution,
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
contiguous_t0_periods = trim_contiguous_time_periods(
|
|
@@ -123,7 +123,11 @@ def find_contiguous_t0_periods(
|
|
|
123
123
|
interval_end=interval_end,
|
|
124
124
|
)
|
|
125
125
|
|
|
126
|
-
|
|
126
|
+
if len(contiguous_t0_periods) == 0:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"No contiguous time periods found for {datetimes}. "
|
|
129
|
+
f"{interval_start=} {interval_end=} {time_resolution=}",
|
|
130
|
+
)
|
|
127
131
|
|
|
128
132
|
return contiguous_t0_periods
|
|
129
133
|
|
|
@@ -132,54 +136,59 @@ def find_contiguous_t0_periods_nwp(
|
|
|
132
136
|
init_times: pd.DatetimeIndex,
|
|
133
137
|
interval_start: pd.Timedelta,
|
|
134
138
|
max_staleness: pd.Timedelta,
|
|
135
|
-
max_dropout: pd.Timedelta =
|
|
136
|
-
first_forecast_step: pd.Timedelta =
|
|
137
|
-
|
|
139
|
+
max_dropout: pd.Timedelta = ZERO_TDELTA,
|
|
140
|
+
first_forecast_step: pd.Timedelta = ZERO_TDELTA,
|
|
138
141
|
) -> pd.DataFrame:
|
|
139
|
-
"""Get all time periods from the NWP init
|
|
142
|
+
"""Get all time periods from the NWP init-times which are valid as t0 datetimes.
|
|
140
143
|
|
|
141
144
|
Args:
|
|
142
145
|
init_times: The initialisation times of the available forecasts
|
|
143
|
-
interval_start: The start of the
|
|
144
|
-
max_staleness: Up to how long after an init
|
|
145
|
-
init
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
146
|
+
interval_start: The start of the time interval with respect to t0
|
|
147
|
+
max_staleness: Up to how long after an init-time are we willing to use the forecast.
|
|
148
|
+
Each init-time will only be used up to this t0 time regardless of the forecast valid
|
|
149
|
+
time.
|
|
150
|
+
max_dropout: What is the maximum amount of dropout that will be used.
|
|
151
|
+
This must be <= max_staleness.
|
|
152
|
+
first_forecast_step: The timedelta of the first step of the forecast.
|
|
153
|
+
By default we assume the first valid time of the forecast
|
|
154
|
+
is the same as its init-time.
|
|
150
155
|
|
|
151
156
|
Returns:
|
|
152
|
-
pd.DataFrame where each row represents a single time period.
|
|
157
|
+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
|
|
153
158
|
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
|
|
154
159
|
"""
|
|
155
160
|
# Sanity checks.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
+
if len(init_times) == 0:
|
|
162
|
+
raise ValueError("No init-times to use")
|
|
163
|
+
|
|
164
|
+
check_time_unique_increasing(init_times)
|
|
165
|
+
|
|
166
|
+
if max_staleness < pd.Timedelta(0):
|
|
167
|
+
raise ValueError("The max staleness must be positive")
|
|
168
|
+
if not (pd.Timedelta(0) <= max_dropout <= max_staleness):
|
|
169
|
+
raise ValueError("The max dropout must be between 0 and the max staleness")
|
|
161
170
|
|
|
162
|
-
|
|
171
|
+
history_drop_buffer = max(first_forecast_step - interval_start, max_dropout)
|
|
163
172
|
|
|
164
173
|
# Store contiguous periods
|
|
165
|
-
contiguous_periods = []
|
|
174
|
+
contiguous_periods: list[list[pd.Timestamp]] = []
|
|
166
175
|
|
|
167
|
-
# Begin the first period allowing for the time to the first_forecast_step, the length of the
|
|
176
|
+
# Begin the first period allowing for the time to the first_forecast_step, the length of the
|
|
168
177
|
# interval sampled from before t0, and the dropout
|
|
169
|
-
start_this_period = init_times[0] +
|
|
178
|
+
start_this_period = init_times[0] + history_drop_buffer
|
|
170
179
|
|
|
171
180
|
# The first forecast is valid up to the max staleness
|
|
172
181
|
end_this_period = init_times[0] + max_staleness
|
|
173
182
|
|
|
174
183
|
for dt_init in init_times[1:]:
|
|
175
|
-
# If the previous init
|
|
176
|
-
# considering dropout) then the contiguous period breaks
|
|
177
|
-
# Else if the previous init
|
|
184
|
+
# If the previous init-time becomes stale before the next init-time becomes valid (whilst
|
|
185
|
+
# also considering dropout) then the contiguous period breaks
|
|
186
|
+
# Else if the previous init-time becomes stale before the fist step of the next forecast
|
|
178
187
|
# then this also causes a break in the contiguous period
|
|
179
|
-
if
|
|
188
|
+
if end_this_period < dt_init + max(max_dropout, first_forecast_step):
|
|
180
189
|
contiguous_periods.append([start_this_period, end_this_period])
|
|
181
190
|
# The new period begins with the same conditions as the first period
|
|
182
|
-
start_this_period = dt_init +
|
|
191
|
+
start_this_period = dt_init + history_drop_buffer
|
|
183
192
|
end_this_period = dt_init + max_staleness
|
|
184
193
|
|
|
185
194
|
contiguous_periods.append([start_this_period, end_this_period])
|
|
@@ -190,11 +199,13 @@ def find_contiguous_t0_periods_nwp(
|
|
|
190
199
|
def intersection_of_multiple_dataframes_of_periods(
|
|
191
200
|
time_periods: list[pd.DataFrame],
|
|
192
201
|
) -> pd.DataFrame:
|
|
193
|
-
"""Find the intersection of
|
|
202
|
+
"""Find the intersection of list of time periods.
|
|
194
203
|
|
|
195
|
-
|
|
204
|
+
Consecutively updates intersection of time periods.
|
|
205
|
+
See the docstring of intersection_of_2_dataframes_of_periods() for further details.
|
|
196
206
|
"""
|
|
197
|
-
|
|
207
|
+
if len(time_periods) == 0:
|
|
208
|
+
raise ValueError("No time periods to intersect")
|
|
198
209
|
intersection = time_periods[0]
|
|
199
210
|
for time_period in time_periods[1:]:
|
|
200
211
|
intersection = intersection_of_2_dataframes_of_periods(intersection, time_period)
|
|
@@ -210,7 +221,8 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
210
221
|
A typical use-case is that each pd.DataFrame represents all the time periods where
|
|
211
222
|
a `DataSource` has contiguous, valid data.
|
|
212
223
|
|
|
213
|
-
|
|
224
|
+
Graphical representation of two pd.DataFrames of time periods and their intersection,
|
|
225
|
+
as follows:
|
|
214
226
|
|
|
215
227
|
----------------------> TIME ->---------------------
|
|
216
228
|
a: |-----| |----| |----------| |-----------|
|
|
@@ -218,9 +230,9 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
218
230
|
intersection: |--| |-| |--| |---|
|
|
219
231
|
|
|
220
232
|
Args:
|
|
221
|
-
a: pd.DataFrame where each row represents a time period.
|
|
233
|
+
a: pd.DataFrame where each row represents a time period. The pd.DataFrame has
|
|
222
234
|
two columns: start_dt and end_dt.
|
|
223
|
-
b: pd.DataFrame where each row represents a time period.
|
|
235
|
+
b: pd.DataFrame where each row represents a time period. The pd.DataFrame has
|
|
224
236
|
two columns: start_dt and end_dt.
|
|
225
237
|
|
|
226
238
|
Returns:
|
|
@@ -239,7 +251,7 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
|
|
|
239
251
|
# and `a` must always end after `b` starts:
|
|
240
252
|
|
|
241
253
|
# TODO: <= and >= because we should allow overlap time periods of length 1. e.g.
|
|
242
|
-
# a: |----| or |---|
|
|
254
|
+
# a: |----| or |---|
|
|
243
255
|
# b: |--| |---|
|
|
244
256
|
# These aren't allowed if we use < and >.
|
|
245
257
|
|
|
@@ -1,36 +1,45 @@
|
|
|
1
|
-
"""Geospatial functions
|
|
1
|
+
"""Geospatial coordinate transformation functions.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Provides utilities for working with different coordinate systems
|
|
4
|
+
commonly used in geospatial applications, particularly for UK-based data.
|
|
5
|
+
|
|
6
|
+
Supports conversions between:
|
|
7
|
+
- OSGB36 (Ordnance Survey Great Britain, easting/northing in meters)
|
|
8
|
+
- WGS84 (World Geodetic System, latitude/longitude in degrees)
|
|
9
|
+
- Geostationary satellite coordinate systems
|
|
10
|
+
"""
|
|
5
11
|
|
|
6
12
|
import numpy as np
|
|
7
13
|
import pyproj
|
|
14
|
+
import pyresample
|
|
8
15
|
import xarray as xr
|
|
9
16
|
|
|
10
|
-
#
|
|
11
|
-
#
|
|
12
|
-
#
|
|
13
|
-
# Transverse Mercator projection, using 'easting' and 'northing'
|
|
14
|
-
# coordinates which are in meters. See https://epsg.io/27700
|
|
17
|
+
# Coordinate Reference System (CRS) identifiers
|
|
18
|
+
# OSGB36: UK Ordnance Survey National Grid (easting/northing in meters)
|
|
19
|
+
# Refer to - https://epsg.io/27700
|
|
15
20
|
OSGB36 = 27700
|
|
16
21
|
|
|
17
|
-
# WGS84
|
|
18
|
-
# latitude and longitude.
|
|
22
|
+
# WGS84: World Geodetic System 1984 (latitude/longitude in degrees), used in GPS
|
|
19
23
|
WGS84 = 4326
|
|
20
24
|
|
|
21
|
-
|
|
25
|
+
# Pre-init Transformer
|
|
22
26
|
_osgb_to_lon_lat = pyproj.Transformer.from_crs(
|
|
23
|
-
crs_from=OSGB36,
|
|
27
|
+
crs_from=OSGB36,
|
|
28
|
+
crs_to=WGS84,
|
|
29
|
+
always_xy=True,
|
|
24
30
|
).transform
|
|
25
31
|
_lon_lat_to_osgb = pyproj.Transformer.from_crs(
|
|
26
|
-
crs_from=WGS84,
|
|
32
|
+
crs_from=WGS84,
|
|
33
|
+
crs_to=OSGB36,
|
|
34
|
+
always_xy=True,
|
|
27
35
|
).transform
|
|
28
36
|
|
|
29
37
|
|
|
30
38
|
def osgb_to_lon_lat(
|
|
31
|
-
x:
|
|
32
|
-
|
|
33
|
-
|
|
39
|
+
x: float | np.ndarray,
|
|
40
|
+
y: float | np.ndarray,
|
|
41
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
42
|
+
"""Change OSGB coordinates to lon-lat.
|
|
34
43
|
|
|
35
44
|
Args:
|
|
36
45
|
x: osgb east-west
|
|
@@ -41,9 +50,9 @@ def osgb_to_lon_lat(
|
|
|
41
50
|
|
|
42
51
|
|
|
43
52
|
def lon_lat_to_osgb(
|
|
44
|
-
x:
|
|
45
|
-
y:
|
|
46
|
-
) -> tuple[
|
|
53
|
+
x: float | np.ndarray,
|
|
54
|
+
y: float | np.ndarray,
|
|
55
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
47
56
|
"""Change lon-lat coordinates to OSGB.
|
|
48
57
|
|
|
49
58
|
Args:
|
|
@@ -56,11 +65,11 @@ def lon_lat_to_osgb(
|
|
|
56
65
|
|
|
57
66
|
|
|
58
67
|
def lon_lat_to_geostationary_area_coords(
|
|
59
|
-
longitude:
|
|
60
|
-
latitude:
|
|
68
|
+
longitude: float | np.ndarray,
|
|
69
|
+
latitude: float | np.ndarray,
|
|
61
70
|
xr_data: xr.DataArray,
|
|
62
|
-
) -> tuple[
|
|
63
|
-
"""Loads geostationary area and transformation from lat-lon to geostationary coords
|
|
71
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
72
|
+
"""Loads geostationary area and transformation from lat-lon to geostationary coords.
|
|
64
73
|
|
|
65
74
|
Args:
|
|
66
75
|
longitude: longitude
|
|
@@ -72,12 +81,13 @@ def lon_lat_to_geostationary_area_coords(
|
|
|
72
81
|
"""
|
|
73
82
|
return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
|
|
74
83
|
|
|
84
|
+
|
|
75
85
|
def osgb_to_geostationary_area_coords(
|
|
76
|
-
x:
|
|
77
|
-
y:
|
|
86
|
+
x: float | np.ndarray,
|
|
87
|
+
y: float | np.ndarray,
|
|
78
88
|
xr_data: xr.DataArray,
|
|
79
|
-
) -> tuple[
|
|
80
|
-
"""Loads geostationary area and transformation from OSGB to geostationary coords
|
|
89
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
90
|
+
"""Loads geostationary area and transformation from OSGB to geostationary coords.
|
|
81
91
|
|
|
82
92
|
Args:
|
|
83
93
|
x: osgb east-west
|
|
@@ -87,47 +97,45 @@ def osgb_to_geostationary_area_coords(
|
|
|
87
97
|
Returns:
|
|
88
98
|
Geostationary coords: x, y
|
|
89
99
|
"""
|
|
90
|
-
|
|
91
100
|
return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
|
|
92
101
|
|
|
93
102
|
|
|
94
|
-
|
|
95
103
|
def coordinates_to_geostationary_area_coords(
|
|
96
|
-
x:
|
|
97
|
-
y:
|
|
104
|
+
x: float | np.ndarray,
|
|
105
|
+
y: float | np.ndarray,
|
|
98
106
|
xr_data: xr.DataArray,
|
|
99
|
-
crs_from: int
|
|
100
|
-
) -> tuple[
|
|
101
|
-
"""Loads geostationary area and
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
x: osgb east-west, or latitude
|
|
105
|
-
y: osgb north-south, or longitude
|
|
106
|
-
xr_data: xarray object with geostationary area
|
|
107
|
-
crs_from: the cordiates system of x,y
|
|
107
|
+
crs_from: int,
|
|
108
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
109
|
+
"""Loads geostationary area and transforms to geostationary coords.
|
|
108
110
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
111
|
+
Args:
|
|
112
|
+
x: osgb east-west, or latitude
|
|
113
|
+
y: osgb north-south, or longitude
|
|
114
|
+
xr_data: xarray object with geostationary area
|
|
115
|
+
crs_from: the cordiates system of x,y
|
|
114
116
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
+
Returns:
|
|
118
|
+
Geostationary coords: x, y
|
|
119
|
+
"""
|
|
120
|
+
if crs_from not in [OSGB36, WGS84]:
|
|
121
|
+
raise ValueError(f"Unrecognized coordinate system: {crs_from}")
|
|
117
122
|
|
|
118
123
|
area_definition_yaml = xr_data.attrs["area"]
|
|
119
124
|
|
|
120
125
|
geostationary_area_definition = pyresample.area_config.load_area_from_string(
|
|
121
|
-
area_definition_yaml
|
|
126
|
+
area_definition_yaml,
|
|
122
127
|
)
|
|
123
128
|
geostationary_crs = geostationary_area_definition.crs
|
|
124
129
|
osgb_to_geostationary = pyproj.Transformer.from_crs(
|
|
125
|
-
crs_from=crs_from,
|
|
130
|
+
crs_from=crs_from,
|
|
131
|
+
crs_to=geostationary_crs,
|
|
132
|
+
always_xy=True,
|
|
126
133
|
).transform
|
|
127
134
|
return osgb_to_geostationary(xx=x, yy=y)
|
|
128
135
|
|
|
129
136
|
|
|
130
|
-
def _coord_priority(available_coords):
|
|
137
|
+
def _coord_priority(available_coords: list[str]) -> tuple[str, str, str]:
|
|
138
|
+
"""Determines the coordinate system of spatial coordinates present."""
|
|
131
139
|
if "longitude" in available_coords:
|
|
132
140
|
return "lon_lat", "longitude", "latitude"
|
|
133
141
|
elif "x_geostationary" in available_coords:
|
|
@@ -138,7 +146,7 @@ def _coord_priority(available_coords):
|
|
|
138
146
|
raise ValueError(f"Unrecognized coordinate system: {available_coords}")
|
|
139
147
|
|
|
140
148
|
|
|
141
|
-
def spatial_coord_type(ds: xr.DataArray):
|
|
149
|
+
def spatial_coord_type(ds: xr.DataArray) -> tuple[str, str, str]:
|
|
142
150
|
"""Searches the data array to determine the kind of spatial coordinates present.
|
|
143
151
|
|
|
144
152
|
This search has a preference for the dimension coordinates of the xarray object.
|
|
@@ -147,9 +155,10 @@ def spatial_coord_type(ds: xr.DataArray):
|
|
|
147
155
|
ds: Dataset with spatial coords
|
|
148
156
|
|
|
149
157
|
Returns:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
158
|
+
Three strings with:
|
|
159
|
+
1. The kind of the coordinate system
|
|
160
|
+
2. Name of the x-coordinate
|
|
161
|
+
3. Name of the y-coordinate
|
|
153
162
|
"""
|
|
154
163
|
if isinstance(ds, xr.DataArray):
|
|
155
164
|
# Search dimension coords of dataarray
|
|
@@ -1,62 +1,27 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Location model with coordinate system validation."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
3
|
from pydantic import BaseModel, Field, model_validator
|
|
7
4
|
|
|
5
|
+
allowed_coordinate_systems = ["osgb", "lon_lat", "geostationary", "idx"]
|
|
8
6
|
|
|
9
|
-
allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
|
|
10
7
|
|
|
11
8
|
class Location(BaseModel):
|
|
12
9
|
"""Represent a spatial location."""
|
|
13
10
|
|
|
14
|
-
coordinate_system:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
id: Optional[int] = Field(None)
|
|
18
|
-
|
|
19
|
-
@model_validator(mode='after')
|
|
20
|
-
def validate_coordinate_system(self):
|
|
21
|
-
"""Validate 'coordinate_system'"""
|
|
22
|
-
if self.coordinate_system not in allowed_coordinate_systems:
|
|
23
|
-
raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
|
|
24
|
-
return self
|
|
25
|
-
|
|
26
|
-
@model_validator(mode='after')
|
|
27
|
-
def validate_x(self):
|
|
28
|
-
"""Validate 'x'"""
|
|
29
|
-
min_x: float
|
|
30
|
-
max_x: float
|
|
31
|
-
|
|
32
|
-
co = self.coordinate_system
|
|
33
|
-
if co == "osgb":
|
|
34
|
-
min_x, max_x = -103976.3, 652897.98
|
|
35
|
-
if co == "lon_lat":
|
|
36
|
-
min_x, max_x = -180, 180
|
|
37
|
-
if co == "geostationary":
|
|
38
|
-
min_x, max_x = -5568748.275756836, 5567248.074173927
|
|
39
|
-
if co == "idx":
|
|
40
|
-
min_x, max_x = 0, np.inf
|
|
41
|
-
if self.x < min_x or self.x > max_x:
|
|
42
|
-
raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
|
|
43
|
-
return self
|
|
11
|
+
coordinate_system: str = Field(...,
|
|
12
|
+
description="Coordinate system for the location must be lon_lat, osgb, or geostationary",
|
|
13
|
+
)
|
|
44
14
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
min_y: float
|
|
49
|
-
max_y: float
|
|
15
|
+
x: float = Field(..., description="x coordinate - i.e. east-west position")
|
|
16
|
+
y: float = Field(..., description="y coordinate - i.e. north-south position")
|
|
17
|
+
id: int | None = Field(None, description="ID of the location - e.g. GSP ID")
|
|
50
18
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
if
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
min_y, max_y = 0, np.inf
|
|
60
|
-
if self.y < min_y or self.y > max_y:
|
|
61
|
-
raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
|
|
19
|
+
@model_validator(mode="after")
|
|
20
|
+
def validate_coordinate_system(self) -> "Location":
|
|
21
|
+
"""Validate 'coordinate_system'."""
|
|
22
|
+
if self.coordinate_system not in allowed_coordinate_systems:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"coordinate_system = {self.coordinate_system} "
|
|
25
|
+
f"is not in {allowed_coordinate_systems}",
|
|
26
|
+
)
|
|
62
27
|
return self
|