ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +77 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.42.dist-info/METADATA +153 -0
  32. ocf_data_sampler-0.0.42.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/LICENSE +0 -0
@@ -0,0 +1,77 @@
1
+ from ocf_data_sampler.numpy_sample import NWPSampleKey
2
+
3
+ import numpy as np
4
+ import logging
5
+ from typing import Union
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+
11
+ def stack_np_examples_into_sample(dict_list):
12
+ # """
13
+ # Stacks Numpy examples into a sample
14
+
15
+ # See also: `unstack_np_sample_into_examples()` for opposite
16
+
17
+ # Args:
18
+ # dict_list: A list of dict-like Numpy examples to stack
19
+
20
+ # Returns:
21
+ # The stacked NumpySample object
22
+ # """
23
+
24
+ if not dict_list:
25
+ raise ValueError("Input is empty")
26
+
27
+ # Extract keys from first dict - structure
28
+ sample = {}
29
+ sample_keys = list(dict_list[0].keys())
30
+
31
+ # Process - handle NWP separately due to nested structure
32
+ for sample_key in sample_keys:
33
+ if sample_key == "nwp":
34
+ sample["nwp"] = process_nwp_data(dict_list)
35
+ else:
36
+ # Stack arrays for the given key across all dicts
37
+ sample[sample_key] = stack_data_list([d[sample_key] for d in dict_list], sample_key)
38
+ return sample
39
+
40
+
41
+ def process_nwp_data(dict_list):
42
+ """Stacks data for NWP, handling nested structure"""
43
+
44
+ nwp_sample = {}
45
+ nwp_sources = dict_list[0]["nwp"].keys()
46
+
47
+ # Stack data for each NWP source independently
48
+ for nwp_source in nwp_sources:
49
+ nested_keys = dict_list[0]["nwp"][nwp_source].keys()
50
+ nwp_sample[nwp_source] = {
51
+ key: stack_data_list([d["nwp"][nwp_source][key] for d in dict_list], key)
52
+ for key in nested_keys
53
+ }
54
+ return nwp_sample
55
+
56
+ def _key_is_constant(sample_key):
57
+ return sample_key.endswith("t0_idx") or sample_key == NWPSampleKey.channel_names
58
+
59
+
60
+ def stack_data_list(data_list: list,sample_key: Union[str, NWPSampleKey],):
61
+ """How to combine data entries for each key
62
+
63
+ Args:
64
+ data_list: List of data entries to combine
65
+ sample_key: Key identifying the data type
66
+ """
67
+ if _key_is_constant(sample_key):
68
+ # These are always the same for all examples.
69
+ return data_list[0]
70
+ try:
71
+ return np.stack(data_list)
72
+ except Exception as e:
73
+ logger.debug(f"Could not stack the following shapes together, ({sample_key})")
74
+ shapes = [example.shape for example in data_list]
75
+ logger.debug(shapes)
76
+ logger.error(e)
77
+ raise e
@@ -0,0 +1,34 @@
1
+ """Convert GSP to Numpy Sample"""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ class GSPSampleKey:
7
+
8
+ gsp = 'gsp'
9
+ nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
10
+ effective_capacity_mwp = 'gsp_effective_capacity_mwp'
11
+ time_utc = 'gsp_time_utc'
12
+ t0_idx = 'gsp_t0_idx'
13
+ solar_azimuth = 'gsp_solar_azimuth'
14
+ solar_elevation = 'gsp_solar_elevation'
15
+ gsp_id = 'gsp_id'
16
+ x_osgb = 'gsp_x_osgb'
17
+ y_osgb = 'gsp_y_osgb'
18
+
19
+
20
+ def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
+ """Convert from Xarray to NumpySample"""
22
+
23
+ # Extract values from the DataArray
24
+ example = {
25
+ GSPSampleKey.gsp: da.values,
26
+ GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
27
+ GSPSampleKey.effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values,
28
+ GSPSampleKey.time_utc: da["time_utc"].values.astype(float),
29
+ }
30
+
31
+ if t0_idx is not None:
32
+ example[GSPSampleKey.t0_idx] = t0_idx
33
+
34
+ return example
@@ -0,0 +1,42 @@
1
+ """Convert NWP to NumpySample"""
2
+
3
+ import pandas as pd
4
+ import xarray as xr
5
+
6
+
7
+ class NWPSampleKey:
8
+
9
+ nwp = 'nwp'
10
+ channel_names = 'nwp_channel_names'
11
+ init_time_utc = 'nwp_init_time_utc'
12
+ step = 'nwp_step'
13
+ target_time_utc = 'nwp_target_time_utc'
14
+ t0_idx = 'nwp_t0_idx'
15
+ y_osgb = 'nwp_y_osgb'
16
+ x_osgb = 'nwp_x_osgb'
17
+
18
+
19
+
20
+ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
+ """Convert from Xarray to NWP NumpySample"""
22
+
23
+ # Create example and add t if available
24
+ example = {
25
+ NWPSampleKey.nwp: da.values,
26
+ NWPSampleKey.channel_names: da.channel.values,
27
+ NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
28
+ NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
29
+ }
30
+
31
+ if "target_time_utc" in da.coords:
32
+ example[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
33
+
34
+ # TODO: Do we need this at all? Especially since it is only present in UKV data
35
+ for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
36
+ if dataset_key in da.coords:
37
+ example[sample_key] = da[dataset_key].values
38
+
39
+ if t0_idx is not None:
40
+ example[NWPSampleKey.t0_idx] = t0_idx
41
+
42
+ return example
@@ -0,0 +1,30 @@
1
+ """Convert Satellite to NumpySample"""
2
+ import xarray as xr
3
+
4
+
5
+ class SatelliteSampleKey:
6
+
7
+ satellite_actual = 'satellite_actual'
8
+ time_utc = 'satellite_time_utc'
9
+ x_geostationary = 'satellite_x_geostationary'
10
+ y_geostationary = 'satellite_y_geostationary'
11
+ t0_idx = 'satellite_t0_idx'
12
+
13
+
14
+ def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
+ """Convert from Xarray to NumpySample"""
16
+ example = {
17
+ SatelliteSampleKey.satellite_actual: da.values,
18
+ SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
19
+ }
20
+
21
+ for sample_key, dataset_key in (
22
+ (SatelliteSampleKey.x_geostationary, "x_geostationary"),
23
+ (SatelliteSampleKey.y_geostationary, "y_geostationary"),
24
+ ):
25
+ example[sample_key] = da[dataset_key].values
26
+
27
+ if t0_idx is not None:
28
+ example[SatelliteSampleKey.t0_idx] = t0_idx
29
+
30
+ return example
@@ -0,0 +1,30 @@
1
+ """Convert site to Numpy Sample"""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ class SiteSampleKey:
7
+
8
+ generation = "site"
9
+ capacity_kwp = "site_capacity_kwp"
10
+ time_utc = "site_time_utc"
11
+ t0_idx = "site_t0_idx"
12
+ solar_azimuth = "site_solar_azimuth"
13
+ solar_elevation = "site_solar_elevation"
14
+ id = "site_id"
15
+
16
+
17
+ def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
18
+ """Convert from Xarray to NumpySample"""
19
+
20
+ # Extract values from the DataArray
21
+ example = {
22
+ SiteSampleKey.generation: da.values,
23
+ SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
24
+ SiteSampleKey.time_utc: da["time_utc"].values.astype(float),
25
+ }
26
+
27
+ if t0_idx is not None:
28
+ example[SiteSampleKey.t0_idx] = t0_idx
29
+
30
+ return example
@@ -2,7 +2,6 @@
2
2
  import pvlib
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from ocf_datapipes.batch import BatchKey, NumpyBatch
6
5
 
7
6
 
8
7
  def calculate_azimuth_and_elevation(
@@ -33,13 +32,13 @@ def calculate_azimuth_and_elevation(
33
32
  return azimuth, elevation
34
33
 
35
34
 
36
- def make_sun_position_numpy_batch(
35
+ def make_sun_position_numpy_sample(
37
36
  datetimes: pd.DatetimeIndex,
38
37
  lon: float,
39
38
  lat: float,
40
- key_preffix: str = "gsp"
41
- ) -> NumpyBatch:
42
- """Creates NumpyBatch with standardized solar coordinates
39
+ key_prefix: str = "gsp"
40
+ ) -> dict:
41
+ """Creates NumpySample with standardized solar coordinates
43
42
 
44
43
  Args:
45
44
  datetimes: The datetimes to calculate solar angles for
@@ -57,10 +56,10 @@ def make_sun_position_numpy_batch(
57
56
  # Elevation is in range [-90, 90] degrees
58
57
  elevation = elevation / 180 + 0.5
59
58
 
60
- # Make NumpyBatch
61
- sun_numpy_batch: NumpyBatch = {
62
- BatchKey[key_preffix + "_solar_azimuth"]: azimuth,
63
- BatchKey[key_preffix + "_solar_elevation"]: elevation,
59
+ # Make NumpySample
60
+ sun_numpy_sample = {
61
+ key_prefix + "_solar_azimuth": azimuth,
62
+ key_prefix + "_solar_elevation": elevation,
64
63
  }
65
64
 
66
- return sun_numpy_batch
65
+ return sun_numpy_sample
@@ -1 +1,8 @@
1
-
1
+ from .fill_time_periods import fill_time_periods
2
+ from .find_contiguous_time_periods import (
3
+ find_contiguous_t0_periods,
4
+ intersection_of_multiple_dataframes_of_periods,
5
+ )
6
+ from .location import Location
7
+ from .spatial_slice_for_dataset import slice_datasets_by_space
8
+ from .time_slice_for_dataset import slice_datasets_by_time
@@ -1,3 +1,4 @@
1
+ """ Functions for simulating dropout in time series data """
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
  import xarray as xr
@@ -5,14 +6,14 @@ import xarray as xr
5
6
 
6
7
  def draw_dropout_time(
7
8
  t0: pd.Timestamp,
8
- dropout_timedeltas: list[pd.Timedelta] | None,
9
+ dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None,
9
10
  dropout_frac: float = 0,
10
11
  ):
11
12
 
12
13
  if dropout_timedeltas is not None:
13
14
  assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas"
14
15
  assert all(
15
- [t < pd.Timedelta("0min") for t in dropout_timedeltas]
16
+ [t <= pd.Timedelta("0min") for t in dropout_timedeltas]
16
17
  ), "dropout timedeltas must be negative"
17
18
  assert 0 <= dropout_frac <= 1
18
19
 
@@ -35,4 +36,4 @@ def apply_dropout_time(
35
36
  return ds
36
37
  else:
37
38
  # This replaces the times after the dropout with NaNs
38
- return ds.where(ds.time_utc <= dropout_time)
39
+ return ds.where(ds.time_utc <= dropout_time)
@@ -63,16 +63,16 @@ def find_contiguous_time_periods(
63
63
 
64
64
  def trim_contiguous_time_periods(
65
65
  contiguous_time_periods: pd.DataFrame,
66
- history_duration: pd.Timedelta,
67
- forecast_duration: pd.Timedelta,
66
+ interval_start: pd.Timedelta,
67
+ interval_end: pd.Timedelta,
68
68
  ) -> pd.DataFrame:
69
69
  """Trim the contiguous time periods to allow for history and forecast durations.
70
70
 
71
71
  Args:
72
72
  contiguous_time_periods: DataFrame where each row represents a single time period. The
73
73
  DataFrame must have `start_dt` and `end_dt` columns.
74
- history_duration: Length of the historical slice used for a sample
75
- forecast_duration: Length of the forecast slice used for a sample
74
+ interval_start: The start of the interval with respect to t0
75
+ interval_end: The end of the interval with respect to t0
76
76
 
77
77
 
78
78
  Returns:
@@ -80,8 +80,8 @@ def trim_contiguous_time_periods(
80
80
  """
81
81
  contiguous_time_periods = contiguous_time_periods.copy()
82
82
 
83
- contiguous_time_periods["start_dt"] += history_duration
84
- contiguous_time_periods["end_dt"] -= forecast_duration
83
+ contiguous_time_periods["start_dt"] -= interval_start
84
+ contiguous_time_periods["end_dt"] -= interval_end
85
85
 
86
86
  valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
87
87
  contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
@@ -92,16 +92,16 @@ def trim_contiguous_time_periods(
92
92
 
93
93
  def find_contiguous_t0_periods(
94
94
  datetimes: pd.DatetimeIndex,
95
- history_duration: pd.Timedelta,
96
- forecast_duration: pd.Timedelta,
95
+ interval_start: pd.Timedelta,
96
+ interval_end: pd.Timedelta,
97
97
  sample_period_duration: pd.Timedelta,
98
98
  ) -> pd.DataFrame:
99
99
  """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
100
100
 
101
101
  Args:
102
102
  datetimes: pd.DatetimeIndex. Must be sorted.
103
- history_duration: Length of the historical slice used for each sample
104
- forecast_duration: Length of the forecast slice used for each sample
103
+ interval_start: The start of the interval with respect to t0
104
+ interval_end: The end of the interval with respect to t0
105
105
  sample_period_duration: The sample frequency of the timeseries
106
106
 
107
107
 
@@ -109,7 +109,7 @@ def find_contiguous_t0_periods(
109
109
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
110
110
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
111
111
  """
112
- total_duration = history_duration + forecast_duration
112
+ total_duration = interval_end - interval_start
113
113
 
114
114
  contiguous_time_periods = find_contiguous_time_periods(
115
115
  datetimes=datetimes,
@@ -119,8 +119,8 @@ def find_contiguous_t0_periods(
119
119
 
120
120
  contiguous_t0_periods = trim_contiguous_time_periods(
121
121
  contiguous_time_periods=contiguous_time_periods,
122
- history_duration=history_duration,
123
- forecast_duration=forecast_duration,
122
+ interval_start=interval_start,
123
+ interval_end=interval_end,
124
124
  )
125
125
 
126
126
  assert len(contiguous_t0_periods) > 0
@@ -128,92 +128,57 @@ def find_contiguous_t0_periods(
128
128
  return contiguous_t0_periods
129
129
 
130
130
 
131
- def _find_contiguous_t0_periods_nwp(
132
- ds,
133
- history_duration: pd.Timedelta,
134
- forecast_duration: pd.Timedelta,
135
- max_staleness: pd.Timedelta | None = None,
136
- max_dropout: pd.Timedelta = pd.Timedelta(0),
137
- time_dim: str = "init_time_utc",
138
- end_buffer: pd.Timedelta = pd.Timedelta(0),
139
- ):
140
-
141
- assert "step" in ds.coords
142
- # It is possible to use up to this amount of max staleness for the dataset and slice
143
- # required
144
- possible_max_staleness = (
145
- pd.Timedelta(ds["step"].max().item())
146
- - forecast_duration
147
- - end_buffer
148
- )
149
-
150
- # If max_staleness is set to None we set it based on the max step ahead of the input
151
- # forecast data
152
- if max_staleness is None:
153
- max_staleness = possible_max_staleness
154
- else:
155
- # Make sure the max acceptable staleness isn't longer than the max possible
156
- assert max_staleness <= possible_max_staleness
157
- max_staleness = max_staleness
158
-
159
- contiguous_time_periods = find_contiguous_t0_periods_nwp(
160
- datetimes=pd.DatetimeIndex(ds[time_dim]),
161
- history_duration=history_duration,
162
- max_staleness=max_staleness,
163
- max_dropout=max_dropout,
164
- )
165
- return contiguous_time_periods
166
-
167
-
168
-
169
131
  def find_contiguous_t0_periods_nwp(
170
- datetimes: pd.DatetimeIndex,
171
- history_duration: pd.Timedelta,
132
+ init_times: pd.DatetimeIndex,
133
+ interval_start: pd.Timedelta,
172
134
  max_staleness: pd.Timedelta,
173
135
  max_dropout: pd.Timedelta = pd.Timedelta(0),
136
+ first_forecast_step: pd.Timedelta = pd.Timedelta(0),
137
+
174
138
  ) -> pd.DataFrame:
175
139
  """Get all time periods from the NWP init times which are valid as t0 datetimes.
176
140
 
177
141
  Args:
178
- datetimes: Sorted pd.DatetimeIndex
179
- history_duration: Length of the historical slice used for a sample
180
- max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
181
- forecast. Each init time will only be used up to this t0 time regardless of the forecast
182
- valid time.
142
+ init_times: The initialisation times of the available forecasts
143
+ interval_start: The start of the desired data interval with respect to t0
144
+ max_staleness: Up to how long after an init time are we willing to use the forecast. Each
145
+ init time will only be used up to this t0 time regardless of the forecast valid time.
183
146
  max_dropout: What is the maximum amount of dropout that will be used. This must be <=
184
147
  max_staleness.
148
+ first_forecast_step: The timedelta of the first step of the forecast. By default we assume
149
+ the first valid time of the forecast is the same as its init time.
185
150
 
186
151
  Returns:
187
152
  pd.DataFrame where each row represents a single time period. The pd.DataFrame
188
153
  has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
189
154
  """
190
155
  # Sanity checks.
191
- assert len(datetimes) > 0
192
- assert datetimes.is_monotonic_increasing
193
- assert datetimes.is_unique
194
- assert history_duration >= pd.Timedelta(0)
156
+ assert len(init_times) > 0
157
+ assert init_times.is_monotonic_increasing
158
+ assert init_times.is_unique
195
159
  assert max_staleness >= pd.Timedelta(0)
196
- assert max_dropout <= max_staleness
160
+ assert pd.Timedelta(0) <= max_dropout <= max_staleness
197
161
 
198
- hist_drop_buffer = max(history_duration, max_dropout)
162
+ hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)
199
163
 
200
164
  # Store contiguous periods
201
165
  contiguous_periods = []
202
166
 
203
- # Start first period allowing for history slice and max dropout
204
- start_this_period = datetimes[0] + hist_drop_buffer
167
+ # Begin the first period allowing for the time to the first_forecast_step, the length of the
168
+ # interval sampled from before t0, and the dropout
169
+ start_this_period = init_times[0] + hist_drop_buffer
205
170
 
206
171
  # The first forecast is valid up to the max staleness
207
- end_this_period = datetimes[0] + max_staleness
208
-
209
- for dt_init in datetimes[1:]:
210
- # If the previous init time becomes stale before the next init becomes valid whilst also
211
- # considering dropout - then the contiguous period breaks, and new starts with considering
212
- # dropout and history duration
213
- if end_this_period < dt_init + max_dropout:
172
+ end_this_period = init_times[0] + max_staleness
173
+
174
+ for dt_init in init_times[1:]:
175
+ # If the previous init time becomes stale before the next init becomes valid (whilst also
176
+ # considering dropout) then the contiguous period breaks
177
+ # Else if the previous init time becomes stale before the fist step of the next forecast
178
+ # then this also causes a break in the contiguous period
179
+ if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
214
180
  contiguous_periods.append([start_this_period, end_this_period])
215
-
216
- # And start a new period
181
+ # The new period begins with the same conditions as the first period
217
182
  start_this_period = dt_init + hist_drop_buffer
218
183
  end_this_period = dt_init + max_staleness
219
184
 
@@ -0,0 +1,160 @@
1
+ """Geospatial functions"""
2
+
3
+ from numbers import Number
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import pyproj
8
+ import xarray as xr
9
+
10
+ # OSGB is also called "OSGB 1936 / British National Grid -- United
11
+ # Kingdom Ordnance Survey". OSGB is used in many UK electricity
12
+ # system maps, and is used by the UK Met Office UKV model. OSGB is a
13
+ # Transverse Mercator projection, using 'easting' and 'northing'
14
+ # coordinates which are in meters. See https://epsg.io/27700
15
+ OSGB36 = 27700
16
+
17
+ # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
18
+ # latitude and longitude.
19
+ WGS84 = 4326
20
+
21
+
22
+ _osgb_to_lon_lat = pyproj.Transformer.from_crs(
23
+ crs_from=OSGB36, crs_to=WGS84, always_xy=True
24
+ ).transform
25
+ _lon_lat_to_osgb = pyproj.Transformer.from_crs(
26
+ crs_from=WGS84, crs_to=OSGB36, always_xy=True
27
+ ).transform
28
+
29
+
30
+ def osgb_to_lon_lat(
31
+ x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
32
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
33
+ """Change OSGB coordinates to lon, lat.
34
+
35
+ Args:
36
+ x: osgb east-west
37
+ y: osgb north-south
38
+ Return: 2-tuple of longitude (east-west), latitude (north-south)
39
+ """
40
+ return _osgb_to_lon_lat(xx=x, yy=y)
41
+
42
+
43
+ def lon_lat_to_osgb(
44
+ x: Union[Number, np.ndarray],
45
+ y: Union[Number, np.ndarray],
46
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
47
+ """Change lon-lat coordinates to OSGB.
48
+
49
+ Args:
50
+ x: longitude east-west
51
+ y: latitude north-south
52
+
53
+ Return: 2-tuple of OSGB x, y
54
+ """
55
+ return _lon_lat_to_osgb(xx=x, yy=y)
56
+
57
+
58
+ def lon_lat_to_geostationary_area_coords(
59
+ longitude: Union[Number, np.ndarray],
60
+ latitude: Union[Number, np.ndarray],
61
+ xr_data: xr.DataArray,
62
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
+ """Loads geostationary area and transformation from lat-lon to geostationary coords
64
+
65
+ Args:
66
+ longitude: longitude
67
+ latitude: latitude
68
+ xr_data: xarray object with geostationary area
69
+
70
+ Returns:
71
+ Geostationary coords: x, y
72
+ """
73
+ return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
74
+
75
+ def osgb_to_geostationary_area_coords(
76
+ x: Union[Number, np.ndarray],
77
+ y: Union[Number, np.ndarray],
78
+ xr_data: xr.DataArray,
79
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
80
+ """Loads geostationary area and transformation from OSGB to geostationary coords
81
+
82
+ Args:
83
+ x: osgb east-west
84
+ y: osgb north-south
85
+ xr_data: xarray object with geostationary area
86
+
87
+ Returns:
88
+ Geostationary coords: x, y
89
+ """
90
+
91
+ return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
92
+
93
+
94
+
95
+ def coordinates_to_geostationary_area_coords(
96
+ x: Union[Number, np.ndarray],
97
+ y: Union[Number, np.ndarray],
98
+ xr_data: xr.DataArray,
99
+ crs_from: int
100
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
101
+ """Loads geostationary area and transformation from respective coordiates to geostationary coords
102
+
103
+ Args:
104
+ x: osgb east-west, or latitude
105
+ y: osgb north-south, or longitude
106
+ xr_data: xarray object with geostationary area
107
+ crs_from: the cordiates system of x,y
108
+
109
+ Returns:
110
+ Geostationary coords: x, y
111
+ """
112
+
113
+ assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}"
114
+
115
+ # Only load these if using geostationary projection
116
+ import pyresample
117
+
118
+ area_definition_yaml = xr_data.attrs["area"]
119
+
120
+ geostationary_area_definition = pyresample.area_config.load_area_from_string(
121
+ area_definition_yaml
122
+ )
123
+ geostationary_crs = geostationary_area_definition.crs
124
+ osgb_to_geostationary = pyproj.Transformer.from_crs(
125
+ crs_from=crs_from, crs_to=geostationary_crs, always_xy=True
126
+ ).transform
127
+ return osgb_to_geostationary(xx=x, yy=y)
128
+
129
+
130
+ def _coord_priority(available_coords):
131
+ if "longitude" in available_coords:
132
+ return "lon_lat", "longitude", "latitude"
133
+ elif "x_geostationary" in available_coords:
134
+ return "geostationary", "x_geostationary", "y_geostationary"
135
+ elif "x_osgb" in available_coords:
136
+ return "osgb", "x_osgb", "y_osgb"
137
+ else:
138
+ raise ValueError(f"Unrecognized coordinate system: {available_coords}")
139
+
140
+
141
+ def spatial_coord_type(ds: xr.DataArray):
142
+ """Searches the data array to determine the kind of spatial coordinates present.
143
+
144
+ This search has a preference for the dimension coordinates of the xarray object.
145
+
146
+ Args:
147
+ ds: Dataset with spatial coords
148
+
149
+ Returns:
150
+ str: The kind of the coordinate system
151
+ x_coord: Name of the x-coordinate
152
+ y_coord: Name of the y-coordinate
153
+ """
154
+ if isinstance(ds, xr.DataArray):
155
+ # Search dimension coords of dataarray
156
+ coords = _coord_priority(ds.xindexes)
157
+ else:
158
+ raise ValueError(f"Unrecognized input type: {type(ds)}")
159
+
160
+ return coords