ocf-data-sampler 0.5.15__py3-none-any.whl → 0.5.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -28,7 +28,7 @@ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
28
28
  NWPSampleKey.channel_names: da.channel.values,
29
29
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
30
30
  NWPSampleKey.step: (da.step.values / 3600).astype(int),
31
- NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
31
+ NWPSampleKey.target_time_utc: (da.init_time_utc.values + da.step.values).astype(float),
32
32
  }
33
33
 
34
34
  if t0_idx is not None:
@@ -0,0 +1,25 @@
1
+ """Takes the diff along the step axis for a given set of channels."""
2
+
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+
7
+ def diff_channels(da: xr.DataArray, accum_channels: list[str]) -> xr.DataArray:
8
+ """Perform in-place diff of the given channels of the DataArray in the steps dimension.
9
+
10
+ Args:
11
+ da: The DataArray to slice from
12
+ accum_channels: Channels which are accumulated and need to be differenced
13
+ """
14
+ if da.dims[:2] != ("step", "channel"):
15
+ raise ValueError("This function assumes the first two dimensions are step then channel")
16
+
17
+ all_channels = da.channel.values
18
+ accum_channel_inds = [i for i, c in enumerate(all_channels) if c in accum_channels]
19
+
20
+ # Make a copy of the values to avoid changing the underlying numpy array
21
+ vals = da.values.copy()
22
+ vals[:-1, accum_channel_inds] = np.diff(vals[:, accum_channel_inds], axis=0)
23
+ da.values = vals
24
+
25
+ return da.isel(step=slice(0, -1))
@@ -38,7 +38,6 @@ def select_time_slice_nwp(
38
38
  time_resolution: pd.Timedelta,
39
39
  dropout_timedeltas: list[pd.Timedelta] | None = None,
40
40
  dropout_frac: float | None = 0,
41
- accum_channels: list[str] | None = None,
42
41
  ) -> xr.DataArray:
43
42
  """Select a time slice from an NWP DataArray.
44
43
 
@@ -50,11 +49,8 @@ def select_time_slice_nwp(
50
49
  time_resolution: Distance between neighbouring timestamps
51
50
  dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
52
51
  dropout_frac: Probability to apply dropout
53
- accum_channels: Channels which are accumulated and need to be differenced
54
52
  """
55
- if accum_channels is None:
56
- accum_channels = []
57
-
53
+ # Input checking
58
54
  if dropout_timedeltas is None:
59
55
  dropout_timedeltas = []
60
56
 
@@ -69,75 +65,43 @@ def select_time_slice_nwp(
69
65
 
70
66
  consider_dropout = len(dropout_timedeltas) > 0 and dropout_frac > 0
71
67
 
72
- # The accumatated and non-accumulated channels
73
- accum_channels = np.intersect1d(da.channel.values, accum_channels)
74
- non_accum_channels = np.setdiff1d(da.channel.values, accum_channels)
75
-
76
68
  start_dt = (t0 + interval_start).ceil(time_resolution)
77
69
  end_dt = (t0 + interval_end).ceil(time_resolution)
78
70
  target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
79
71
 
80
72
  # Potentially apply NWP dropout
81
73
  if consider_dropout and (np.random.uniform() < dropout_frac):
82
- dt = np.random.choice(dropout_timedeltas)
83
- t0_available = t0 + dt
74
+ t0_available = t0 + np.random.choice(dropout_timedeltas)
84
75
  else:
85
76
  t0_available = t0
86
77
 
87
- # Forecasts made up to and including t0
88
- available_init_times = da.init_time_utc.sel(init_time_utc=slice(None, t0_available))
78
+ # Get the available and relevant init-times
79
+ t_min = target_times[0] - da.step.values[-1]
80
+ init_times = da.init_time_utc.values
81
+ available_init_times = init_times[(t_min<=init_times) & (init_times<=t0_available)]
89
82
 
90
- # Find the most recent available init times for all target times
91
- selected_init_times = available_init_times.sel(
92
- init_time_utc=target_times,
93
- method="ffill", # forward fill from init times to target times
94
- ).values
83
+ # Find the most recent available init-times for all target-times
84
+ selected_init_times = np.array(
85
+ [available_init_times[available_init_times<=t][-1] for t in target_times],
86
+ )
95
87
 
96
- # Find the required steps for all target times
88
+ # Find the required steps for all target-times
97
89
  steps = target_times - selected_init_times
98
90
 
99
- # We want one timestep for each target_time_hourly (obviously!) If we simply do
100
- # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
101
- # init_times and steps, which is not what we want! Instead, we use xarray's
102
- # vectorised-indexing mode via using a DataArray indexer. See the last example here:
103
- # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
104
-
105
- coords = {"target_time_utc": target_times}
106
- init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
107
- step_indexer = xr.DataArray(steps, coords=coords)
91
+ # If we are only selecting from one init-time we can construct the slice so its faster
92
+ if len(np.unique(selected_init_times))==1:
93
+ da_sel = da.sel(init_time_utc=selected_init_times[0], step=slice(steps[0], steps[-1]))
108
94
 
109
- if len(accum_channels) == 0:
110
- da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
95
+ # If we are selecting from multiple init times this more complex and slower
111
96
  else:
112
- # First minimise the size of the dataset we are diffing
113
- # - find the init times we are slicing from
114
- unique_init_times = np.unique(selected_init_times)
115
- # - find the min and max steps we slice over. Max is extended due to diff
116
- min_step = min(steps)
117
- max_step = max(steps) + time_resolution
118
-
119
- da_min = da.sel(init_time_utc=unique_init_times, step=slice(min_step, max_step))
120
-
121
- # Slice out the data which does not need to be diffed
122
- da_non_accum = da_min.sel(channel=non_accum_channels)
123
- da_sel_non_accum = da_non_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
124
-
125
- # Slice out the channels which need to be diffed
126
- da_accum = da_min.sel(channel=accum_channels)
127
-
128
- # Take the diff and slice requested data
129
- da_accum = da_accum.diff(dim="step", label="lower")
130
- da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
131
-
132
- # Join diffed and non-diffed variables
133
- da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim="channel")
134
-
135
- # Reorder the variable back to the original order
136
- da_sel = da_sel.sel(channel=da.channel.values)
137
-
138
- # Rename the diffed channels
139
- da_sel["channel"] = [
140
- f"diff_{v}" if v in accum_channels else v for v in da_sel.channel.values
141
- ]
97
+ # We want one timestep for each target_time_hourly (obviously!) If we simply do
98
+ # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
99
+ # init_times and steps, which is not what we want! Instead, we use xarray's
100
+ # vectorised-indexing mode via using a DataArray indexer. See the last example here:
101
+ # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
102
+ coords = {"step": steps}
103
+ init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
104
+ step_indexer = xr.DataArray(steps, coords=coords)
105
+ da_sel = da.sel(init_time_utc=init_time_indexer, step=step_indexer)
142
106
 
143
107
  return da_sel
@@ -22,6 +22,7 @@ from ocf_data_sampler.select import Location, fill_time_periods
22
22
  from ocf_data_sampler.torch_datasets.utils import (
23
23
  add_alterate_coordinate_projections,
24
24
  config_normalization_values_to_dicts,
25
+ diff_nwp_data,
25
26
  fill_nans_in_arrays,
26
27
  find_valid_time_periods,
27
28
  merge_dicts,
@@ -259,7 +260,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
259
260
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
260
261
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
261
262
  sample_dict = tensorstore_compute(sample_dict)
262
-
263
+ sample_dict = diff_nwp_data(sample_dict, self.config)
263
264
  return self.process_and_combine_datasets(sample_dict, t0, location)
264
265
 
265
266
  @override
@@ -318,6 +319,7 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
318
319
  # Slice by time then load to avoid loading the data multiple times from disk
319
320
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
320
321
  sample_dict = tensorstore_compute(sample_dict)
322
+ sample_dict = diff_nwp_data(sample_dict, self.config)
321
323
 
322
324
  gsp_samples = []
323
325
 
@@ -27,6 +27,7 @@ from ocf_data_sampler.select import (
27
27
  from ocf_data_sampler.torch_datasets.utils import (
28
28
  add_alterate_coordinate_projections,
29
29
  config_normalization_values_to_dicts,
30
+ diff_nwp_data,
30
31
  fill_nans_in_arrays,
31
32
  find_valid_time_periods,
32
33
  merge_dicts,
@@ -57,6 +58,7 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
57
58
 
58
59
  return locations
59
60
 
61
+
60
62
  def process_and_combine_datasets(
61
63
  dataset_dict: dict,
62
64
  config: Configuration,
@@ -80,8 +82,6 @@ def process_and_combine_datasets(
80
82
 
81
83
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
82
84
 
83
- # Standardise and convert to NumpyBatch
84
-
85
85
  channel_means = means_dict["nwp"][nwp_key]
86
86
  channel_stds = stds_dict["nwp"][nwp_key]
87
87
 
@@ -276,8 +276,8 @@ class SitesDataset(Dataset):
276
276
  """
277
277
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
278
278
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
279
-
280
279
  sample_dict = tensorstore_compute(sample_dict)
280
+ sample_dict = diff_nwp_data(sample_dict, self.config)
281
281
 
282
282
  return process_and_combine_datasets(
283
283
  sample_dict,
@@ -414,6 +414,7 @@ class SitesDatasetConcurrent(Dataset):
414
414
  # slice by time first as we want to keep all site id info
415
415
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
416
416
  sample_dict = tensorstore_compute(sample_dict)
417
+ sample_dict = diff_nwp_data(sample_dict, self.config)
417
418
 
418
419
  site_samples = []
419
420
 
@@ -3,4 +3,5 @@ from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
3
  from .valid_time_periods import find_valid_time_periods
4
4
  from .spatial_slice_for_dataset import slice_datasets_by_space
5
5
  from .time_slice_for_dataset import slice_datasets_by_time
6
- from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
6
+ from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
7
+ from .diff_nwp_data import diff_nwp_data
@@ -0,0 +1,20 @@
1
+ """Take the in-place diff of some channels of the NWP data."""
2
+
3
+ from ocf_data_sampler.config import Configuration
4
+ from ocf_data_sampler.select.diff_channels import diff_channels
5
+
6
+
7
+ def diff_nwp_data(dataset_dict: dict, config: Configuration) -> dict:
8
+ """Take the in-place diff of some channels of the NWP data.
9
+
10
+ Args:
11
+ dataset_dict: Dictionary of xarray datasets
12
+ config: Configuration object
13
+ """
14
+ if "nwp" in dataset_dict:
15
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
16
+ accum_channels = config.input_data.nwp[nwp_key].accum_channels
17
+ if len(accum_channels)>0:
18
+ # diff_channels() is an in-place operation and modifies the input
19
+ dataset_dict["nwp"][nwp_key] = diff_channels(da_nwp, accum_channels)
20
+ return dataset_dict
@@ -28,15 +28,23 @@ def slice_datasets_by_time(
28
28
  for nwp_key, da_nwp in datasets_dict["nwp"].items():
29
29
  nwp_config = config.input_data.nwp[nwp_key]
30
30
 
31
+ # Add a buffer if we need to diff some of the channels in time
32
+ if len(nwp_config.accum_channels)>0:
33
+ interval_end_mins = (
34
+ nwp_config.interval_end_minutes
35
+ + nwp_config.time_resolution_minutes
36
+ )
37
+ else:
38
+ interval_end_mins = nwp_config.interval_end_minutes
39
+
31
40
  sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
32
41
  da_nwp,
33
42
  t0,
34
43
  time_resolution=minutes(nwp_config.time_resolution_minutes),
35
44
  interval_start=minutes(nwp_config.interval_start_minutes),
36
- interval_end=minutes(nwp_config.interval_end_minutes),
45
+ interval_end=minutes(interval_end_mins),
37
46
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
38
47
  dropout_frac=nwp_config.dropout_fraction,
39
- accum_channels=nwp_config.accum_channels,
40
48
  )
41
49
 
42
50
  if "sat" in datasets_dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.15
3
+ Version: 0.5.16
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -27,36 +27,38 @@ ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lH
27
27
  ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
28
28
  ocf_data_sampler/numpy_sample/datetime_features.py,sha256=ObHM42VnZB7_daQ5a42GeftoDWYtVMT-wDP8kRtY_84,857
29
29
  ocf_data_sampler/numpy_sample/gsp.py,sha256=sOWX1ubeQSrK6_0vdy_RKVUvqzohOc5pBu7W4Co7iN8,983
30
- ocf_data_sampler/numpy_sample/nwp.py,sha256=lXqE2Il0xX5hzz76HHkiYmfDsXWWhmaA_6bSnmwbAXU,1078
30
+ ocf_data_sampler/numpy_sample/nwp.py,sha256=AabiasD6OZDdfkPtYWpehV9XpaRHOiEr5g1nSdZdDv8,1095
31
31
  ocf_data_sampler/numpy_sample/satellite.py,sha256=RaYzYIcB1AmDrKeiqSpn4QVfBH-QMe26F1P5t1az2Jg,1111
32
32
  ocf_data_sampler/numpy_sample/site.py,sha256=4S19bzCN5lswVUrmWRfwpVsBPUE7bi0OIdxsD9wgvhU,982
33
33
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=5tt-zNm6aRuZMsxZPaAxyg7HeikswfZCeHWXTHuO2K0,1555
34
34
  ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTAooNg_s,210
35
+ ocf_data_sampler/select/diff_channels.py,sha256=W66JcI2pSM-7DnB76_Ag6kUv3f7FqMS-vNkb2467WAk,938
35
36
  ocf_data_sampler/select/dropout.py,sha256=i5NDP6oQnZBkQRJW-aXVrPXawktVKQz5VMexe5Ww51g,2021
36
37
  ocf_data_sampler/select/fill_time_periods.py,sha256=TlGxp1xiAqnhdWfLy0pv3FuZc00dtimjWdLzr4JoTGA,865
37
38
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=etkr6LuB7zxkfzWJ6SgHiULdRuFzFlq5bOUNd257Qx4,11545
38
39
  ocf_data_sampler/select/geospatial.py,sha256=rvMy_e--3tm-KAy9pU6b9-UMBQqH2sXykr3N_4SHYy4,6528
39
40
  ocf_data_sampler/select/location.py,sha256=Qp0di-Pgq8WLjN9IBcTVTaRM3lckhr4ZVzaDRcgVXHw,2352
40
41
  ocf_data_sampler/select/select_spatial_slice.py,sha256=Ym_YJjZqeMPC5Bw_xMi7Re2-uCbUagm2KXhnAnstTHo,7200
41
- ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
42
+ ocf_data_sampler/select/select_time_slice.py,sha256=cpkdovJMvcjxSGfq9G0OJK5aDAeCXg7exWYrJnR4N2w,4116
42
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
43
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=wVx4QKHqak2FbxtryAxsVe6wpYM2n_YKgIKpiVs6gpE,12098
44
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=ivdSB_YpAqL8-Q1m_uKTGU2YlNQ1bZXdwialT_UpGuo,15590
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=wUsIZ0Fhq5bbE8v02C0UPcFWIhWI7kfSka9UrWP0_m4,12240
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=OXrYSRrWUdQbEjsEPPJjam10zJKU6S3r5kA07RbpzFU,15680
45
46
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
46
47
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
47
48
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
48
49
  ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
49
- ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=TNSYuSSmFgjsvvJxtoDrH645Z64CHsNUUQ0iayTccP4,416
50
+ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=4l1VcEmxHInU9G66zrimNMa8WcyKUASQST_iF9QfxUw,457
50
51
  ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py,sha256=w6Q4TyxNyl7PKAbhqiXvqOpnqIjwmOUcGREIvPNGYlQ,2666
51
52
  ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=SGt1H2nXcaj44ND14-gHzvA7dkLfgjTacCq7rOkRGwg,1991
53
+ ocf_data_sampler/torch_datasets/utils/diff_nwp_data.py,sha256=o7NpKWxKHhwMbol3xBAF087-tDgDUZeP0j8vG08E7Nc,816
52
54
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
53
55
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
54
- ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=Q_-kCTtUieyEDpSElY1xwJct7Vsw0LAn5MbYSg2O6vg,3621
56
+ ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=1r1J2KNSo1_imN9gpVf5AupJaZ7VSnSevS1o_wck440,3925
55
57
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0WrglnX5YUAFjXXlO2GwEBHWyqo8TDuiOA,4714
56
58
  ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul3l0EP73Ik002fStr_bhsZh9mQqEU,4735
57
59
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
58
60
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
59
- ocf_data_sampler-0.5.15.dist-info/METADATA,sha256=AcLJpUOG6smk3WDSZkj3K8cjhvSg9z0lPoEKM16B6q8,12817
60
- ocf_data_sampler-0.5.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- ocf_data_sampler-0.5.15.dist-info/top_level.txt,sha256=deUxqmsONNAGZDNbsntbXH7BRA1MqWaUeAJrCo6q_xA,25
62
- ocf_data_sampler-0.5.15.dist-info/RECORD,,
61
+ ocf_data_sampler-0.5.16.dist-info/METADATA,sha256=82UiAraNLrkhOMwZcLeK7Ckg3zgArx5BuzvfBOhy9m8,12817
62
+ ocf_data_sampler-0.5.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ ocf_data_sampler-0.5.16.dist-info/top_level.txt,sha256=deUxqmsONNAGZDNbsntbXH7BRA1MqWaUeAJrCo6q_xA,25
64
+ ocf_data_sampler-0.5.16.dist-info/RECORD,,