ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (76) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +73 -61
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +9 -10
  14. ocf_data_sampler/load/site.py +10 -6
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  19. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  20. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  21. ocf_data_sampler/numpy_sample/site.py +5 -8
  22. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  23. ocf_data_sampler/sample/base.py +15 -17
  24. ocf_data_sampler/sample/site.py +13 -20
  25. ocf_data_sampler/sample/uk_regional.py +29 -35
  26. ocf_data_sampler/select/dropout.py +16 -14
  27. ocf_data_sampler/select/fill_time_periods.py +15 -5
  28. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  29. ocf_data_sampler/select/geospatial.py +63 -54
  30. ocf_data_sampler/select/location.py +16 -51
  31. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  32. ocf_data_sampler/select/select_time_slice.py +71 -58
  33. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  34. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  35. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  36. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  37. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  38. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  39. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +62 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  48. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  49. tests/__init__.py +0 -0
  50. tests/config/test_config.py +0 -113
  51. tests/config/test_load.py +0 -7
  52. tests/config/test_save.py +0 -28
  53. tests/conftest.py +0 -319
  54. tests/load/test_load_gsp.py +0 -15
  55. tests/load/test_load_nwp.py +0 -21
  56. tests/load/test_load_satellite.py +0 -17
  57. tests/load/test_load_sites.py +0 -14
  58. tests/numpy_sample/test_collate.py +0 -21
  59. tests/numpy_sample/test_datetime_features.py +0 -37
  60. tests/numpy_sample/test_gsp.py +0 -38
  61. tests/numpy_sample/test_nwp.py +0 -13
  62. tests/numpy_sample/test_satellite.py +0 -40
  63. tests/numpy_sample/test_sun_position.py +0 -81
  64. tests/select/test_dropout.py +0 -69
  65. tests/select/test_fill_time_periods.py +0 -28
  66. tests/select/test_find_contiguous_time_periods.py +0 -202
  67. tests/select/test_location.py +0 -67
  68. tests/select/test_select_spatial_slice.py +0 -154
  69. tests/select/test_select_time_slice.py +0 -275
  70. tests/test_sample/test_base.py +0 -164
  71. tests/test_sample/test_site_sample.py +0 -165
  72. tests/test_sample/test_uk_regional_sample.py +0 -136
  73. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  74. tests/torch_datasets/test_pvnet_uk.py +0 -154
  75. tests/torch_datasets/test_site.py +0 -226
  76. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,57 +1,80 @@
1
- import xarray as xr
2
- import pandas as pd
1
+ """Select a time slice from a Dataset or DataArray."""
2
+
3
3
  import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+
4
7
 
5
8
  def select_time_slice(
6
- ds: xr.DataArray,
9
+ da: xr.DataArray,
7
10
  t0: pd.Timestamp,
8
11
  interval_start: pd.Timedelta,
9
12
  interval_end: pd.Timedelta,
10
- sample_period_duration: pd.Timedelta,
11
- ):
12
- """Select a time slice from a Dataset or DataArray."""
13
- t0_datetime_utc = pd.Timestamp(t0)
14
- start_dt = t0_datetime_utc + interval_start
15
- end_dt = t0_datetime_utc + interval_end
13
+ time_resolution: pd.Timedelta,
14
+ ) -> xr.DataArray:
15
+ """Select a time slice from a DataArray.
16
+
17
+ Args:
18
+ da: The DataArray to slice from
19
+ t0: The init-time
20
+ interval_start: The start of the interval with respect to t0
21
+ interval_end: The end of the interval with respect to t0
22
+ time_resolution: Distance between neighbouring timestamps
23
+ """
24
+ start_dt = t0 + interval_start
25
+ end_dt = t0 + interval_end
16
26
 
17
- start_dt = start_dt.ceil(sample_period_duration)
18
- end_dt = end_dt.ceil(sample_period_duration)
27
+ start_dt = start_dt.ceil(time_resolution)
28
+ end_dt = end_dt.ceil(time_resolution)
29
+
30
+ return da.sel(time_utc=slice(start_dt, end_dt))
19
31
 
20
- return ds.sel(time_utc=slice(start_dt, end_dt))
21
32
 
22
33
  def select_time_slice_nwp(
23
34
  da: xr.DataArray,
24
35
  t0: pd.Timestamp,
25
36
  interval_start: pd.Timedelta,
26
37
  interval_end: pd.Timedelta,
27
- sample_period_duration: pd.Timedelta,
38
+ time_resolution: pd.Timedelta,
28
39
  dropout_timedeltas: list[pd.Timedelta] | None = None,
29
40
  dropout_frac: float | None = 0,
30
- accum_channels: list[str] = [],
31
- channel_dim_name: str = "channel",
32
- ):
41
+ accum_channels: list[str] | None = None,
42
+ ) -> xr.DataArray:
43
+ """Select a time slice from an NWP DataArray.
44
+
45
+ Args:
46
+ da: The DataArray to slice from
47
+ t0: The init-time
48
+ interval_start: The start of the interval with respect to t0
49
+ interval_end: The end of the interval with respect to t0
50
+ time_resolution: Distance between neighbouring timestamps
51
+ dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
52
+ dropout_frac: Probability to apply dropout
53
+ accum_channels: Channels which are accumulated and need to be differenced
54
+ """
55
+ if accum_channels is None:
56
+ accum_channels = []
57
+
33
58
  if dropout_timedeltas is not None:
34
- assert all(
35
- [t < pd.Timedelta(0) for t in dropout_timedeltas]
36
- ), "dropout timedeltas must be negative"
37
- assert len(dropout_timedeltas) >= 1
38
- assert 0 <= dropout_frac <= 1
39
- consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
59
+ if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
60
+ raise ValueError("dropout timedeltas must be negative")
61
+ if len(dropout_timedeltas) < 1:
62
+ raise ValueError("dropout timedeltas must have at least one element")
63
+
64
+ if not (0 <= dropout_frac <= 1):
65
+ raise ValueError("dropout_frac must be between 0 and 1")
40
66
 
41
- # The accumatation and non-accumulation channels
42
- accum_channels = np.intersect1d(
43
- da[channel_dim_name].values, accum_channels
44
- )
45
- non_accum_channels = np.setdiff1d(
46
- da[channel_dim_name].values, accum_channels
47
- )
67
+ consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
48
68
 
49
- start_dt = (t0 + interval_start).ceil(sample_period_duration)
50
- end_dt = (t0 + interval_end).ceil(sample_period_duration)
69
+ # The accumatated and non-accumulated channels
70
+ accum_channels = np.intersect1d(da.channel.values, accum_channels)
71
+ non_accum_channels = np.setdiff1d(da.channel.values, accum_channels)
51
72
 
52
- target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
73
+ start_dt = (t0 + interval_start).ceil(time_resolution)
74
+ end_dt = (t0 + interval_end).ceil(time_resolution)
75
+ target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
53
76
 
54
- # Maybe apply NWP dropout
77
+ # Potentially apply NWP dropout
55
78
  if consider_dropout and (np.random.uniform() < dropout_frac):
56
79
  dt = np.random.choice(dropout_timedeltas)
57
80
  t0_available = t0 + dt
@@ -59,9 +82,7 @@ def select_time_slice_nwp(
59
82
  t0_available = t0
60
83
 
61
84
  # Forecasts made up to and including t0
62
- available_init_times = da.init_time_utc.sel(
63
- init_time_utc=slice(None, t0_available)
64
- )
85
+ available_init_times = da.init_time_utc.sel(init_time_utc=slice(None, t0_available))
65
86
 
66
87
  # Find the most recent available init times for all target times
67
88
  selected_init_times = available_init_times.sel(
@@ -74,10 +95,10 @@ def select_time_slice_nwp(
74
95
 
75
96
  # We want one timestep for each target_time_hourly (obviously!) If we simply do
76
97
  # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
77
- # init_times and steps, which is not what # we want! Instead, we use xarray's
78
- # vectorized-indexing mode by using a DataArray indexer. See the last example here:
98
+ # init_times and steps, which is not what we want! Instead, we use xarray's
99
+ # vectorised-indexing mode via using a DataArray indexer. See the last example here:
79
100
  # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
80
-
101
+
81
102
  coords = {"target_time_utc": target_times}
82
103
  init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
83
104
  step_indexer = xr.DataArray(steps, coords=coords)
@@ -90,38 +111,30 @@ def select_time_slice_nwp(
90
111
  unique_init_times = np.unique(selected_init_times)
91
112
  # - find the min and max steps we slice over. Max is extended due to diff
92
113
  min_step = min(steps)
93
- max_step = max(steps) + sample_period_duration
114
+ max_step = max(steps) + time_resolution
94
115
 
95
- da_min = da.sel(
96
- {
97
- "init_time_utc": unique_init_times,
98
- "step": slice(min_step, max_step),
99
- }
100
- )
116
+ da_min = da.sel(init_time_utc=unique_init_times, step=slice(min_step, max_step))
101
117
 
102
118
  # Slice out the data which does not need to be diffed
103
- da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
104
- da_sel_non_accum = da_non_accum.sel(
105
- step=step_indexer, init_time_utc=init_time_indexer
106
- )
119
+ da_non_accum = da_min.sel(channel=non_accum_channels)
120
+ da_sel_non_accum = da_non_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
107
121
 
108
122
  # Slice out the channels which need to be diffed
109
- da_accum = da_min.sel({channel_dim_name: accum_channels})
110
-
123
+ da_accum = da_min.sel(channel=accum_channels)
124
+
111
125
  # Take the diff and slice requested data
112
126
  da_accum = da_accum.diff(dim="step", label="lower")
113
127
  da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
114
128
 
115
129
  # Join diffed and non-diffed variables
116
- da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
117
-
130
+ da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim="channel")
131
+
118
132
  # Reorder the variable back to the original order
119
- da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
133
+ da_sel = da_sel.sel(channel=da.channel.values)
120
134
 
121
135
  # Rename the diffed channels
122
- da_sel[channel_dim_name] = [
123
- f"diff_{v}" if v in accum_channels else v
124
- for v in da_sel[channel_dim_name].values
136
+ da_sel["channel"] = [
137
+ f"diff_{v}" if v in accum_channels else v for v in da_sel.channel.values
125
138
  ]
126
139
 
127
140
  return da_sel
@@ -1,4 +1,5 @@
1
- """ Functions for selecting data around a given location """
1
+ """Functions for selecting data around a given location."""
2
+
2
3
  from ocf_data_sampler.config import Configuration
3
4
  from ocf_data_sampler.select.location import Location
4
5
  from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
@@ -9,24 +10,24 @@ def slice_datasets_by_space(
9
10
  location: Location,
10
11
  config: Configuration,
11
12
  ) -> dict:
12
- """Slice the dictionary of input data sources around a given location
13
+ """Slice the dictionary of input data sources around a given location.
13
14
 
14
15
  Args:
15
16
  datasets_dict: Dictionary of the input data sources
16
17
  location: The location to sample around
17
18
  config: Configuration object.
18
19
  """
19
-
20
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})
20
+ if not set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"}):
21
+ raise ValueError(
22
+ "'datasets_dict' should only contain keys 'nwp', 'sat', 'gsp', 'site'",
23
+ )
21
24
 
22
25
  sliced_datasets_dict = {}
23
26
 
24
27
  if "nwp" in datasets_dict:
25
-
26
28
  sliced_datasets_dict["nwp"] = {}
27
29
 
28
30
  for nwp_key, nwp_config in config.input_data.nwp.items():
29
-
30
31
  sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
31
32
  datasets_dict["nwp"][nwp_key],
32
33
  location,
@@ -1,25 +1,26 @@
1
- """ Slice datasets by time"""
1
+ """Slice datasets by time."""
2
+
2
3
  import pandas as pd
3
4
  import xarray as xr
4
5
 
5
6
  from ocf_data_sampler.config import Configuration
6
- from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
7
- from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
7
+ from ocf_data_sampler.select.dropout import apply_dropout_time, draw_dropout_time
8
+ from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
8
9
  from ocf_data_sampler.utils import minutes
9
10
 
11
+
10
12
  def slice_datasets_by_time(
11
13
  datasets_dict: dict,
12
14
  t0: pd.Timestamp,
13
15
  config: Configuration,
14
16
  ) -> dict:
15
- """Slice the dictionary of input data sources around a given t0 time
17
+ """Slice the dictionary of input data sources around a given t0 time.
16
18
 
17
19
  Args:
18
20
  datasets_dict: Dictionary of the input data sources
19
21
  t0: The init-time
20
22
  config: Configuration object.
21
23
  """
22
-
23
24
  sliced_datasets_dict = {}
24
25
 
25
26
  if "nwp" in datasets_dict:
@@ -31,7 +32,7 @@ def slice_datasets_by_time(
31
32
  sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
32
33
  da_nwp,
33
34
  t0,
34
- sample_period_duration=minutes(nwp_config.time_resolution_minutes),
35
+ time_resolution=minutes(nwp_config.time_resolution_minutes),
35
36
  interval_start=minutes(nwp_config.interval_start_minutes),
36
37
  interval_end=minutes(nwp_config.interval_end_minutes),
37
38
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
@@ -45,7 +46,7 @@ def slice_datasets_by_time(
45
46
  sliced_datasets_dict["sat"] = select_time_slice(
46
47
  datasets_dict["sat"],
47
48
  t0,
48
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
49
+ time_resolution=minutes(sat_config.time_resolution_minutes),
49
50
  interval_start=minutes(sat_config.interval_start_minutes),
50
51
  interval_end=minutes(sat_config.interval_end_minutes),
51
52
  )
@@ -65,11 +66,11 @@ def slice_datasets_by_time(
65
66
 
66
67
  if "gsp" in datasets_dict:
67
68
  gsp_config = config.input_data.gsp
68
-
69
+
69
70
  da_gsp_past = select_time_slice(
70
71
  datasets_dict["gsp"],
71
72
  t0,
72
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
73
+ time_resolution=minutes(gsp_config.time_resolution_minutes),
73
74
  interval_start=minutes(gsp_config.interval_start_minutes),
74
75
  interval_end=minutes(0),
75
76
  )
@@ -82,18 +83,18 @@ def slice_datasets_by_time(
82
83
  )
83
84
 
84
85
  da_gsp_past = apply_dropout_time(
85
- da_gsp_past,
86
- gsp_dropout_time
86
+ da_gsp_past,
87
+ gsp_dropout_time,
87
88
  )
88
-
89
+
89
90
  da_gsp_future = select_time_slice(
90
91
  datasets_dict["gsp"],
91
92
  t0,
92
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
93
+ time_resolution=minutes(gsp_config.time_resolution_minutes),
93
94
  interval_start=minutes(gsp_config.time_resolution_minutes),
94
95
  interval_end=minutes(gsp_config.interval_end_minutes),
95
96
  )
96
-
97
+
97
98
  sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
98
99
 
99
100
  if "site" in datasets_dict:
@@ -102,7 +103,7 @@ def slice_datasets_by_time(
102
103
  sliced_datasets_dict["site"] = select_time_slice(
103
104
  datasets_dict["site"],
104
105
  t0,
105
- sample_period_duration=minutes(site_config.time_resolution_minutes),
106
+ time_resolution=minutes(site_config.time_resolution_minutes),
106
107
  interval_start=minutes(site_config.interval_start_minutes),
107
108
  interval_end=minutes(site_config.interval_end_minutes),
108
109
  )
@@ -120,4 +121,4 @@ def slice_datasets_by_time(
120
121
  site_dropout_time,
121
122
  )
122
123
 
123
- return sliced_datasets_dict
124
+ return sliced_datasets_dict