ocf-data-sampler 0.5.14__tar.gz → 0.5.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (73) hide show
  1. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/PKG-INFO +1 -1
  2. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/config/model.py +8 -17
  3. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/nwp.py +1 -1
  4. ocf_data_sampler-0.5.16/ocf_data_sampler/select/diff_channels.py +25 -0
  5. ocf_data_sampler-0.5.16/ocf_data_sampler/select/dropout.py +59 -0
  6. ocf_data_sampler-0.5.16/ocf_data_sampler/select/select_time_slice.py +107 -0
  7. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +3 -1
  8. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/datasets/site.py +4 -3
  9. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/__init__.py +2 -1
  10. ocf_data_sampler-0.5.16/ocf_data_sampler/torch_datasets/utils/diff_nwp_data.py +20 -0
  11. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +22 -30
  12. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler.egg-info/PKG-INFO +1 -1
  13. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler.egg-info/SOURCES.txt +2 -0
  14. ocf_data_sampler-0.5.14/ocf_data_sampler/select/dropout.py +0 -61
  15. ocf_data_sampler-0.5.14/ocf_data_sampler/select/select_time_slice.py +0 -143
  16. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/LICENSE +0 -0
  17. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/README.md +0 -0
  18. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/__init__.py +0 -0
  19. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/config/__init__.py +0 -0
  20. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/config/load.py +0 -0
  21. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/config/save.py +0 -0
  22. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
  23. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
  24. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/__init__.py +0 -0
  25. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/gsp.py +0 -0
  26. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/load_dataset.py +0 -0
  27. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  28. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  29. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  30. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
  31. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  32. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
  33. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
  34. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  35. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  36. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/open_xarray_tensorstore.py +0 -0
  37. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/satellite.py +0 -0
  38. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/site.py +0 -0
  39. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/load/utils.py +0 -0
  40. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  41. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  42. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
  43. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  44. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  45. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  46. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/site.py +0 -0
  47. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  48. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/__init__.py +0 -0
  49. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  50. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  51. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/geospatial.py +0 -0
  52. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/location.py +0 -0
  53. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  54. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  55. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -0
  56. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/sample/base.py +0 -0
  57. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/sample/site.py +0 -0
  58. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -0
  59. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py +0 -0
  60. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +0 -0
  61. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  62. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
  63. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  64. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
  65. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler/utils.py +0 -0
  66. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  67. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler.egg-info/requires.txt +0 -0
  68. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  69. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/pyproject.toml +0 -0
  70. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/scripts/download_gsp_location_data.py +0 -0
  71. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/scripts/refactor_site.py +0 -0
  72. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/setup.cfg +0 -0
  73. {ocf_data_sampler-0.5.14 → ocf_data_sampler-0.5.16}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.14
3
+ Version: 0.5.16
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -90,7 +90,7 @@ class DropoutMixin(Base):
90
90
  "negative or zero.",
91
91
  )
92
92
 
93
- dropout_fraction: float|list[float] = Field(
93
+ dropout_fraction: float | list[float] = Field(
94
94
  default=0,
95
95
  description="Either a float(Chance of dropout being applied to each sample) or a list of "
96
96
  "floats (probability that dropout of the corresponding timedelta is applied)",
@@ -106,31 +106,22 @@ class DropoutMixin(Base):
106
106
 
107
107
 
108
108
  @field_validator("dropout_fraction")
109
- def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
109
+ def dropout_fractions(cls, dropout_frac: float | list[float]) -> float | list[float]:
110
110
  """Validate 'dropout_frac'."""
111
- from math import isclose
112
- if isinstance(dropout_frac, float):
113
- if not (dropout_frac <= 1):
114
- raise ValueError("Input should be less than or equal to 1")
115
- elif not (dropout_frac >= 0):
116
- raise ValueError("Input should be greater than or equal to 0")
111
+ if isinstance(dropout_frac, float | int):
112
+ if not (0<= dropout_frac <= 1):
113
+ raise ValueError("Dropout fractions must be in range [0, 1]")
117
114
 
118
115
  elif isinstance(dropout_frac, list):
119
116
  if not dropout_frac:
120
117
  raise ValueError("List cannot be empty")
121
118
 
122
- if not all(isinstance(i, float) for i in dropout_frac):
123
- raise ValueError("All elements in the list must be floats")
124
-
125
119
  if not all(0 <= i <= 1 for i in dropout_frac):
126
- raise ValueError("Each float in the list must be between 0 and 1")
127
-
128
- if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
129
- raise ValueError("Sum of all floats in the list must be 1.0")
120
+ raise ValueError("All dropout fractions must be in range [0, 1]")
130
121
 
122
+ if not (0 <= sum(dropout_frac) <= 1):
123
+ raise ValueError("The sum of dropout fractions must be in range [0, 1]")
131
124
 
132
- else:
133
- raise TypeError("Must be either a float or a list of floats")
134
125
  return dropout_frac
135
126
 
136
127
 
@@ -28,7 +28,7 @@ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
28
28
  NWPSampleKey.channel_names: da.channel.values,
29
29
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
30
30
  NWPSampleKey.step: (da.step.values / 3600).astype(int),
31
- NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
31
+ NWPSampleKey.target_time_utc: (da.init_time_utc.values + da.step.values).astype(float),
32
32
  }
33
33
 
34
34
  if t0_idx is not None:
@@ -0,0 +1,25 @@
1
+ """Takes the diff along the step axis for a given set of channels."""
2
+
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+
7
+ def diff_channels(da: xr.DataArray, accum_channels: list[str]) -> xr.DataArray:
8
+ """Perform in-place diff of the given channels of the DataArray in the steps dimension.
9
+
10
+ Args:
11
+ da: The DataArray to slice from
12
+ accum_channels: Channels which are accumulated and need to be differenced
13
+ """
14
+ if da.dims[:2] != ("step", "channel"):
15
+ raise ValueError("This function assumes the first two dimensions are step then channel")
16
+
17
+ all_channels = da.channel.values
18
+ accum_channel_inds = [i for i, c in enumerate(all_channels) if c in accum_channels]
19
+
20
+ # Make a copy of the values to avoid changing the underlying numpy array
21
+ vals = da.values.copy()
22
+ vals[:-1, accum_channel_inds] = np.diff(vals[:, accum_channel_inds], axis=0)
23
+ da.values = vals
24
+
25
+ return da.isel(step=slice(0, -1))
@@ -0,0 +1,59 @@
1
+ """Functions for simulating dropout in time series data.
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import xarray as xr
10
+
11
+
12
+ def apply_history_dropout(
13
+ t0: pd.Timestamp,
14
+ dropout_timedeltas: list[pd.Timedelta],
15
+ dropout_frac: float | list[float],
16
+ da: xr.DataArray,
17
+ ) -> xr.DataArray:
18
+ """Apply randomly sampled dropout to the historical part of some sequence data.
19
+
20
+ Dropped out data is replaced with NaNs
21
+
22
+ Args:
23
+ t0: The forecast init-time.
24
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
25
+ dropout_frac: The probabilit(ies) that each dropout timedelta will be applied. This should
26
+ be between 0 and 1 inclusive.
27
+ da: Xarray DataArray with 'time_utc' coordinate
28
+ """
29
+ if len(dropout_timedeltas)==0:
30
+ return da
31
+
32
+ if isinstance(dropout_frac, float | int):
33
+
34
+ if not (0<=dropout_frac<=1):
35
+ raise ValueError("`dropout_frac` must be in range [0, 1]")
36
+
37
+ # Create list with equal chance for all dropout timedeltas
38
+ n = len(dropout_timedeltas)
39
+ dropout_frac = [dropout_frac/n for _ in range(n)]
40
+ else:
41
+ if not 0<=sum(dropout_frac)<=1:
42
+ raise ValueError("The sum of `dropout_frac` must be in range [0, 1]")
43
+ if len(dropout_timedeltas)!=len(dropout_frac):
44
+ raise ValueError("`dropout_timedeltas` and `dropout_frac` must have the same length")
45
+
46
+ dropout_frac = [*dropout_frac] # Make copy of the list so we can append to it
47
+
48
+ dropout_timedeltas = [*dropout_timedeltas] # Make copy of the list so we can append to it
49
+
50
+ # Add chance of no dropout
51
+ dropout_frac.append(1-sum(dropout_frac))
52
+ dropout_timedeltas.append(None)
53
+
54
+ timedelta_choice = np.random.choice(dropout_timedeltas, p=dropout_frac)
55
+
56
+ if timedelta_choice is None:
57
+ return da
58
+ else:
59
+ return da.where((da.time_utc <= timedelta_choice + t0) | (da.time_utc> t0))
@@ -0,0 +1,107 @@
1
+ """Select a time slice from a Dataset or DataArray."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+
7
+
8
+ def select_time_slice(
9
+ da: xr.DataArray,
10
+ t0: pd.Timestamp,
11
+ interval_start: pd.Timedelta,
12
+ interval_end: pd.Timedelta,
13
+ time_resolution: pd.Timedelta,
14
+ ) -> xr.DataArray:
15
+ """Select a time slice from a DataArray.
16
+
17
+ Args:
18
+ da: The DataArray to slice from
19
+ t0: The init-time
20
+ interval_start: The start of the interval with respect to t0
21
+ interval_end: The end of the interval with respect to t0
22
+ time_resolution: Distance between neighbouring timestamps
23
+ """
24
+ start_dt = t0 + interval_start
25
+ end_dt = t0 + interval_end
26
+
27
+ start_dt = start_dt.ceil(time_resolution)
28
+ end_dt = end_dt.ceil(time_resolution)
29
+
30
+ return da.sel(time_utc=slice(start_dt, end_dt))
31
+
32
+
33
+ def select_time_slice_nwp(
34
+ da: xr.DataArray,
35
+ t0: pd.Timestamp,
36
+ interval_start: pd.Timedelta,
37
+ interval_end: pd.Timedelta,
38
+ time_resolution: pd.Timedelta,
39
+ dropout_timedeltas: list[pd.Timedelta] | None = None,
40
+ dropout_frac: float | None = 0,
41
+ ) -> xr.DataArray:
42
+ """Select a time slice from an NWP DataArray.
43
+
44
+ Args:
45
+ da: The DataArray to slice from
46
+ t0: The init-time
47
+ interval_start: The start of the interval with respect to t0
48
+ interval_end: The end of the interval with respect to t0
49
+ time_resolution: Distance between neighbouring timestamps
50
+ dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
51
+ dropout_frac: Probability to apply dropout
52
+ """
53
+ # Input checking
54
+ if dropout_timedeltas is None:
55
+ dropout_timedeltas = []
56
+
57
+ if len(dropout_timedeltas)>0:
58
+ if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
59
+ raise ValueError("dropout timedeltas must be negative")
60
+ if len(dropout_timedeltas) < 1:
61
+ raise ValueError("dropout timedeltas must have at least one element")
62
+
63
+ if not (0 <= dropout_frac <= 1):
64
+ raise ValueError("dropout_frac must be between 0 and 1")
65
+
66
+ consider_dropout = len(dropout_timedeltas) > 0 and dropout_frac > 0
67
+
68
+ start_dt = (t0 + interval_start).ceil(time_resolution)
69
+ end_dt = (t0 + interval_end).ceil(time_resolution)
70
+ target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
71
+
72
+ # Potentially apply NWP dropout
73
+ if consider_dropout and (np.random.uniform() < dropout_frac):
74
+ t0_available = t0 + np.random.choice(dropout_timedeltas)
75
+ else:
76
+ t0_available = t0
77
+
78
+ # Get the available and relevant init-times
79
+ t_min = target_times[0] - da.step.values[-1]
80
+ init_times = da.init_time_utc.values
81
+ available_init_times = init_times[(t_min<=init_times) & (init_times<=t0_available)]
82
+
83
+ # Find the most recent available init-times for all target-times
84
+ selected_init_times = np.array(
85
+ [available_init_times[available_init_times<=t][-1] for t in target_times],
86
+ )
87
+
88
+ # Find the required steps for all target-times
89
+ steps = target_times - selected_init_times
90
+
91
+ # If we are only selecting from one init-time we can construct the slice so its faster
92
+ if len(np.unique(selected_init_times))==1:
93
+ da_sel = da.sel(init_time_utc=selected_init_times[0], step=slice(steps[0], steps[-1]))
94
+
95
+ # If we are selecting from multiple init times this more complex and slower
96
+ else:
97
+ # We want one timestep for each target_time_hourly (obviously!) If we simply do
98
+ # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
99
+ # init_times and steps, which is not what we want! Instead, we use xarray's
100
+ # vectorised-indexing mode via using a DataArray indexer. See the last example here:
101
+ # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
102
+ coords = {"step": steps}
103
+ init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
104
+ step_indexer = xr.DataArray(steps, coords=coords)
105
+ da_sel = da.sel(init_time_utc=init_time_indexer, step=step_indexer)
106
+
107
+ return da_sel
@@ -22,6 +22,7 @@ from ocf_data_sampler.select import Location, fill_time_periods
22
22
  from ocf_data_sampler.torch_datasets.utils import (
23
23
  add_alterate_coordinate_projections,
24
24
  config_normalization_values_to_dicts,
25
+ diff_nwp_data,
25
26
  fill_nans_in_arrays,
26
27
  find_valid_time_periods,
27
28
  merge_dicts,
@@ -259,7 +260,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
259
260
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
260
261
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
261
262
  sample_dict = tensorstore_compute(sample_dict)
262
-
263
+ sample_dict = diff_nwp_data(sample_dict, self.config)
263
264
  return self.process_and_combine_datasets(sample_dict, t0, location)
264
265
 
265
266
  @override
@@ -318,6 +319,7 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
318
319
  # Slice by time then load to avoid loading the data multiple times from disk
319
320
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
320
321
  sample_dict = tensorstore_compute(sample_dict)
322
+ sample_dict = diff_nwp_data(sample_dict, self.config)
321
323
 
322
324
  gsp_samples = []
323
325
 
@@ -27,6 +27,7 @@ from ocf_data_sampler.select import (
27
27
  from ocf_data_sampler.torch_datasets.utils import (
28
28
  add_alterate_coordinate_projections,
29
29
  config_normalization_values_to_dicts,
30
+ diff_nwp_data,
30
31
  fill_nans_in_arrays,
31
32
  find_valid_time_periods,
32
33
  merge_dicts,
@@ -57,6 +58,7 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
57
58
 
58
59
  return locations
59
60
 
61
+
60
62
  def process_and_combine_datasets(
61
63
  dataset_dict: dict,
62
64
  config: Configuration,
@@ -80,8 +82,6 @@ def process_and_combine_datasets(
80
82
 
81
83
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
82
84
 
83
- # Standardise and convert to NumpyBatch
84
-
85
85
  channel_means = means_dict["nwp"][nwp_key]
86
86
  channel_stds = stds_dict["nwp"][nwp_key]
87
87
 
@@ -276,8 +276,8 @@ class SitesDataset(Dataset):
276
276
  """
277
277
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
278
278
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
279
-
280
279
  sample_dict = tensorstore_compute(sample_dict)
280
+ sample_dict = diff_nwp_data(sample_dict, self.config)
281
281
 
282
282
  return process_and_combine_datasets(
283
283
  sample_dict,
@@ -414,6 +414,7 @@ class SitesDatasetConcurrent(Dataset):
414
414
  # slice by time first as we want to keep all site id info
415
415
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
416
416
  sample_dict = tensorstore_compute(sample_dict)
417
+ sample_dict = diff_nwp_data(sample_dict, self.config)
417
418
 
418
419
  site_samples = []
419
420
 
@@ -3,4 +3,5 @@ from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
3
  from .valid_time_periods import find_valid_time_periods
4
4
  from .spatial_slice_for_dataset import slice_datasets_by_space
5
5
  from .time_slice_for_dataset import slice_datasets_by_time
6
- from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
6
+ from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
7
+ from .diff_nwp_data import diff_nwp_data
@@ -0,0 +1,20 @@
1
+ """Take the in-place diff of some channels of the NWP data."""
2
+
3
+ from ocf_data_sampler.config import Configuration
4
+ from ocf_data_sampler.select.diff_channels import diff_channels
5
+
6
+
7
+ def diff_nwp_data(dataset_dict: dict, config: Configuration) -> dict:
8
+ """Take the in-place diff of some channels of the NWP data.
9
+
10
+ Args:
11
+ dataset_dict: Dictionary of xarray datasets
12
+ config: Configuration object
13
+ """
14
+ if "nwp" in dataset_dict:
15
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
16
+ accum_channels = config.input_data.nwp[nwp_key].accum_channels
17
+ if len(accum_channels)>0:
18
+ # diff_channels() is an in-place operation and modifies the input
19
+ dataset_dict["nwp"][nwp_key] = diff_channels(da_nwp, accum_channels)
20
+ return dataset_dict
@@ -1,10 +1,9 @@
1
1
  """Slice datasets by time."""
2
2
 
3
3
  import pandas as pd
4
- import xarray as xr
5
4
 
6
5
  from ocf_data_sampler.config import Configuration
7
- from ocf_data_sampler.select.dropout import apply_sampled_dropout_time
6
+ from ocf_data_sampler.select.dropout import apply_history_dropout
8
7
  from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
9
8
  from ocf_data_sampler.utils import minutes
10
9
 
@@ -29,15 +28,23 @@ def slice_datasets_by_time(
29
28
  for nwp_key, da_nwp in datasets_dict["nwp"].items():
30
29
  nwp_config = config.input_data.nwp[nwp_key]
31
30
 
31
+ # Add a buffer if we need to diff some of the channels in time
32
+ if len(nwp_config.accum_channels)>0:
33
+ interval_end_mins = (
34
+ nwp_config.interval_end_minutes
35
+ + nwp_config.time_resolution_minutes
36
+ )
37
+ else:
38
+ interval_end_mins = nwp_config.interval_end_minutes
39
+
32
40
  sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
33
41
  da_nwp,
34
42
  t0,
35
43
  time_resolution=minutes(nwp_config.time_resolution_minutes),
36
44
  interval_start=minutes(nwp_config.interval_start_minutes),
37
- interval_end=minutes(nwp_config.interval_end_minutes),
45
+ interval_end=minutes(interval_end_mins),
38
46
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
39
47
  dropout_frac=nwp_config.dropout_fraction,
40
- accum_channels=nwp_config.accum_channels,
41
48
  )
42
49
 
43
50
  if "sat" in datasets_dict:
@@ -52,7 +59,7 @@ def slice_datasets_by_time(
52
59
  )
53
60
 
54
61
  # Apply the randomly sampled dropout
55
- sliced_datasets_dict["sat"] = apply_sampled_dropout_time(
62
+ sliced_datasets_dict["sat"] = apply_history_dropout(
56
63
  t0,
57
64
  dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
58
65
  dropout_frac=sat_config.dropout_fraction,
@@ -62,59 +69,44 @@ def slice_datasets_by_time(
62
69
  if "gsp" in datasets_dict:
63
70
  gsp_config = config.input_data.gsp
64
71
 
65
- da_gsp_past = select_time_slice(
72
+ da_gsp = select_time_slice(
66
73
  datasets_dict["gsp"],
67
74
  t0,
68
75
  time_resolution=minutes(gsp_config.time_resolution_minutes),
69
76
  interval_start=minutes(gsp_config.interval_start_minutes),
70
- interval_end=minutes(0),
77
+ interval_end=minutes(gsp_config.interval_end_minutes),
71
78
  )
72
79
 
73
80
  # Dropout on the past GSP, but not the future GSP
74
- da_gsp_past = apply_sampled_dropout_time(
81
+ da_gsp = apply_history_dropout(
75
82
  t0,
76
83
  dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
77
84
  dropout_frac=gsp_config.dropout_fraction,
78
- da=da_gsp_past,
79
- )
80
-
81
- da_gsp_future = select_time_slice(
82
- datasets_dict["gsp"],
83
- t0,
84
- time_resolution=minutes(gsp_config.time_resolution_minutes),
85
- interval_start=minutes(gsp_config.time_resolution_minutes),
86
- interval_end=minutes(gsp_config.interval_end_minutes),
85
+ da=da_gsp,
87
86
  )
88
87
 
89
- sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
88
+ sliced_datasets_dict["gsp"] = da_gsp
90
89
 
91
90
  if "site" in datasets_dict:
92
91
  site_config = config.input_data.site
93
92
 
94
- da_site_past = select_time_slice(
93
+ da_site = select_time_slice(
95
94
  datasets_dict["site"],
96
95
  t0,
97
96
  time_resolution=minutes(site_config.time_resolution_minutes),
98
97
  interval_start=minutes(site_config.interval_start_minutes),
99
- interval_end=minutes(0),
98
+ interval_end=minutes(site_config.interval_end_minutes),
100
99
  )
101
100
 
102
101
  # Apply the randomly sampled dropout on the past site not the future
103
- da_site_past = apply_sampled_dropout_time(
102
+ da_site = apply_history_dropout(
104
103
  t0,
105
104
  dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
106
105
  dropout_frac=site_config.dropout_fraction,
107
- da=da_site_past,
106
+ da=da_site,
108
107
  )
109
108
 
110
- da_site_future = select_time_slice(
111
- datasets_dict["site"],
112
- t0,
113
- time_resolution=minutes(site_config.time_resolution_minutes),
114
- interval_start=minutes(site_config.time_resolution_minutes),
115
- interval_end=minutes(site_config.interval_end_minutes),
116
- )
109
+ sliced_datasets_dict["site"] = da_site
117
110
 
118
- sliced_datasets_dict["site"] = xr.concat([da_site_past, da_site_future], dim="time_utc")
119
111
 
120
112
  return sliced_datasets_dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.14
3
+ Version: 0.5.16
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -40,6 +40,7 @@ ocf_data_sampler/numpy_sample/satellite.py
40
40
  ocf_data_sampler/numpy_sample/site.py
41
41
  ocf_data_sampler/numpy_sample/sun_position.py
42
42
  ocf_data_sampler/select/__init__.py
43
+ ocf_data_sampler/select/diff_channels.py
43
44
  ocf_data_sampler/select/dropout.py
44
45
  ocf_data_sampler/select/fill_time_periods.py
45
46
  ocf_data_sampler/select/find_contiguous_time_periods.py
@@ -57,6 +58,7 @@ ocf_data_sampler/torch_datasets/sample/uk_regional.py
57
58
  ocf_data_sampler/torch_datasets/utils/__init__.py
58
59
  ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py
59
60
  ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py
61
+ ocf_data_sampler/torch_datasets/utils/diff_nwp_data.py
60
62
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
61
63
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py
62
64
  ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py
@@ -1,61 +0,0 @@
1
- """Functions for simulating dropout in time series data.
2
-
3
- This is used for the following types of data: GSP, Satellite and Site
4
- This is not used for NWP
5
- """
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import xarray as xr
10
-
11
-
12
- def apply_sampled_dropout_time(
13
- t0: pd.Timestamp,
14
- dropout_timedeltas: list[pd.Timedelta],
15
- dropout_frac: float|list[float],
16
- da: xr.DataArray,
17
- ) -> xr.DataArray:
18
- """Randomly pick a dropout time from a list of timedeltas and apply dropout time to the data.
19
-
20
- Args:
21
- t0: The forecast init-time
22
- dropout_timedeltas: List of timedeltas relative to t0 to pick from
23
- dropout_frac: Either a probability that dropout will be applied.
24
- This should be between 0 and 1 inclusive.
25
- Or a list of probabilities for each of the corresponding timedeltas
26
- da: Xarray DataArray with 'time_utc' coordinate
27
- """
28
- if isinstance(dropout_frac, list):
29
- # checking if len match
30
- if len(dropout_frac) != len(dropout_timedeltas):
31
- raise ValueError("Lengths of dropout_frac and dropout_timedeltas should match")
32
-
33
-
34
-
35
-
36
- dropout_time = t0 + np.random.choice(dropout_timedeltas,p=dropout_frac)
37
-
38
- return da.where(da.time_utc <= dropout_time)
39
-
40
-
41
-
42
- # old logic
43
- else:
44
- # sample dropout time
45
- if dropout_frac > 0 and len(dropout_timedeltas) == 0:
46
- raise ValueError("To apply dropout, dropout_timedeltas must be provided")
47
-
48
-
49
- if not (0 <= dropout_frac <= 1):
50
- raise ValueError("dropout_frac must be between 0 and 1 inclusive")
51
-
52
- if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
53
- dropout_time = None
54
- else:
55
- dropout_time = t0 + np.random.choice(dropout_timedeltas)
56
-
57
- # apply dropout time
58
- if dropout_time is None:
59
- return da
60
- # This replaces the times after the dropout with NaNs
61
- return da.where(da.time_utc <= dropout_time)
@@ -1,143 +0,0 @@
1
- """Select a time slice from a Dataset or DataArray."""
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import xarray as xr
6
-
7
-
8
- def select_time_slice(
9
- da: xr.DataArray,
10
- t0: pd.Timestamp,
11
- interval_start: pd.Timedelta,
12
- interval_end: pd.Timedelta,
13
- time_resolution: pd.Timedelta,
14
- ) -> xr.DataArray:
15
- """Select a time slice from a DataArray.
16
-
17
- Args:
18
- da: The DataArray to slice from
19
- t0: The init-time
20
- interval_start: The start of the interval with respect to t0
21
- interval_end: The end of the interval with respect to t0
22
- time_resolution: Distance between neighbouring timestamps
23
- """
24
- start_dt = t0 + interval_start
25
- end_dt = t0 + interval_end
26
-
27
- start_dt = start_dt.ceil(time_resolution)
28
- end_dt = end_dt.ceil(time_resolution)
29
-
30
- return da.sel(time_utc=slice(start_dt, end_dt))
31
-
32
-
33
- def select_time_slice_nwp(
34
- da: xr.DataArray,
35
- t0: pd.Timestamp,
36
- interval_start: pd.Timedelta,
37
- interval_end: pd.Timedelta,
38
- time_resolution: pd.Timedelta,
39
- dropout_timedeltas: list[pd.Timedelta] | None = None,
40
- dropout_frac: float | None = 0,
41
- accum_channels: list[str] | None = None,
42
- ) -> xr.DataArray:
43
- """Select a time slice from an NWP DataArray.
44
-
45
- Args:
46
- da: The DataArray to slice from
47
- t0: The init-time
48
- interval_start: The start of the interval with respect to t0
49
- interval_end: The end of the interval with respect to t0
50
- time_resolution: Distance between neighbouring timestamps
51
- dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
52
- dropout_frac: Probability to apply dropout
53
- accum_channels: Channels which are accumulated and need to be differenced
54
- """
55
- if accum_channels is None:
56
- accum_channels = []
57
-
58
- if dropout_timedeltas is None:
59
- dropout_timedeltas = []
60
-
61
- if len(dropout_timedeltas)>0:
62
- if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
63
- raise ValueError("dropout timedeltas must be negative")
64
- if len(dropout_timedeltas) < 1:
65
- raise ValueError("dropout timedeltas must have at least one element")
66
-
67
- if not (0 <= dropout_frac <= 1):
68
- raise ValueError("dropout_frac must be between 0 and 1")
69
-
70
- consider_dropout = len(dropout_timedeltas) > 0 and dropout_frac > 0
71
-
72
- # The accumatated and non-accumulated channels
73
- accum_channels = np.intersect1d(da.channel.values, accum_channels)
74
- non_accum_channels = np.setdiff1d(da.channel.values, accum_channels)
75
-
76
- start_dt = (t0 + interval_start).ceil(time_resolution)
77
- end_dt = (t0 + interval_end).ceil(time_resolution)
78
- target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
79
-
80
- # Potentially apply NWP dropout
81
- if consider_dropout and (np.random.uniform() < dropout_frac):
82
- dt = np.random.choice(dropout_timedeltas)
83
- t0_available = t0 + dt
84
- else:
85
- t0_available = t0
86
-
87
- # Forecasts made up to and including t0
88
- available_init_times = da.init_time_utc.sel(init_time_utc=slice(None, t0_available))
89
-
90
- # Find the most recent available init times for all target times
91
- selected_init_times = available_init_times.sel(
92
- init_time_utc=target_times,
93
- method="ffill", # forward fill from init times to target times
94
- ).values
95
-
96
- # Find the required steps for all target times
97
- steps = target_times - selected_init_times
98
-
99
- # We want one timestep for each target_time_hourly (obviously!) If we simply do
100
- # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
101
- # init_times and steps, which is not what we want! Instead, we use xarray's
102
- # vectorised-indexing mode via using a DataArray indexer. See the last example here:
103
- # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
104
-
105
- coords = {"target_time_utc": target_times}
106
- init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
107
- step_indexer = xr.DataArray(steps, coords=coords)
108
-
109
- if len(accum_channels) == 0:
110
- da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
111
- else:
112
- # First minimise the size of the dataset we are diffing
113
- # - find the init times we are slicing from
114
- unique_init_times = np.unique(selected_init_times)
115
- # - find the min and max steps we slice over. Max is extended due to diff
116
- min_step = min(steps)
117
- max_step = max(steps) + time_resolution
118
-
119
- da_min = da.sel(init_time_utc=unique_init_times, step=slice(min_step, max_step))
120
-
121
- # Slice out the data which does not need to be diffed
122
- da_non_accum = da_min.sel(channel=non_accum_channels)
123
- da_sel_non_accum = da_non_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
124
-
125
- # Slice out the channels which need to be diffed
126
- da_accum = da_min.sel(channel=accum_channels)
127
-
128
- # Take the diff and slice requested data
129
- da_accum = da_accum.diff(dim="step", label="lower")
130
- da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
131
-
132
- # Join diffed and non-diffed variables
133
- da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim="channel")
134
-
135
- # Reorder the variable back to the original order
136
- da_sel = da_sel.sel(channel=da.channel.values)
137
-
138
- # Rename the diffed channels
139
- da_sel["channel"] = [
140
- f"diff_{v}" if v in accum_channels else v for v in da_sel.channel.values
141
- ]
142
-
143
- return da_sel