ocf-data-sampler 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ocf_data_sampler/config/model.py +34 -0
  2. ocf_data_sampler/load/load_dataset.py +55 -0
  3. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  4. ocf_data_sampler/load/site.py +30 -0
  5. ocf_data_sampler/numpy_batch/__init__.py +4 -3
  6. ocf_data_sampler/numpy_batch/gsp.py +12 -12
  7. ocf_data_sampler/numpy_batch/nwp.py +14 -14
  8. ocf_data_sampler/numpy_batch/satellite.py +8 -8
  9. ocf_data_sampler/numpy_batch/site.py +29 -0
  10. ocf_data_sampler/select/__init__.py +8 -1
  11. ocf_data_sampler/select/dropout.py +2 -1
  12. ocf_data_sampler/select/geospatial.py +43 -1
  13. ocf_data_sampler/select/select_spatial_slice.py +8 -2
  14. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  15. ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
  16. ocf_data_sampler/time_functions.py +11 -0
  17. ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
  18. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
  19. ocf_data_sampler/torch_datasets/site.py +196 -0
  20. ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
  21. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/METADATA +1 -1
  22. ocf_data_sampler-0.0.25.dist-info/RECORD +66 -0
  23. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/WHEEL +1 -1
  24. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/top_level.txt +1 -0
  25. scripts/refactor_site.py +50 -0
  26. tests/conftest.py +62 -0
  27. tests/load/test_load_sites.py +14 -0
  28. tests/numpy_batch/test_gsp.py +1 -2
  29. tests/numpy_batch/test_nwp.py +1 -3
  30. tests/numpy_batch/test_satellite.py +1 -3
  31. tests/numpy_batch/test_sun_position.py +7 -7
  32. tests/torch_datasets/test_pvnet_uk_regional.py +4 -6
  33. tests/torch_datasets/test_site.py +85 -0
  34. ocf_data_sampler-0.0.23.dist-info/RECORD +0 -54
  35. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/LICENSE +0 -0
@@ -102,6 +102,39 @@ class TimeResolutionMixin(Base):
102
102
  )
103
103
 
104
104
 
105
+ class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
106
+ """Site configuration model"""
107
+
108
+ file_path: str = Field(
109
+ ...,
110
+ description="The NetCDF files holding the power timeseries.",
111
+ )
112
+ metadata_file_path: str = Field(
113
+ ...,
114
+ description="The CSV files describing power system",
115
+ )
116
+
117
+ @field_validator("forecast_minutes")
118
+ def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
119
+ """Check forecast length requested will give stable number of timesteps"""
120
+ if v % info.data["time_resolution_minutes"] != 0:
121
+ message = "Forecast duration must be divisible by time resolution"
122
+ logger.error(message)
123
+ raise Exception(message)
124
+ return v
125
+
126
+ @field_validator("history_minutes")
127
+ def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
128
+ """Check history length requested will give stable number of timesteps"""
129
+ if v % info.data["time_resolution_minutes"] != 0:
130
+ message = "History duration must be divisible by time resolution"
131
+ logger.error(message)
132
+ raise Exception(message)
133
+ return v
134
+
135
+ # TODO validate the netcdf for sites
136
+ # TODO validate the csv for metadata
137
+
105
138
  class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
106
139
  """Satellite configuration model"""
107
140
 
@@ -240,6 +273,7 @@ class InputData(Base):
240
273
  satellite: Optional[Satellite] = None
241
274
  nwp: Optional[MultiNWP] = None
242
275
  gsp: Optional[GSP] = None
276
+ site: Optional[Site] = None
243
277
 
244
278
 
245
279
  class Configuration(Base):
@@ -0,0 +1,55 @@
1
+ """ Loads all data sources """
2
+ import xarray as xr
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.load.gsp import open_gsp
6
+ from ocf_data_sampler.load.nwp import open_nwp
7
+ from ocf_data_sampler.load.satellite import open_sat_data
8
+ from ocf_data_sampler.load.site import open_site
9
+
10
+
11
+ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
12
+ """Construct dictionary of all of the input data sources
13
+
14
+ Args:
15
+ config: Configuration file
16
+ """
17
+
18
+ in_config = config.input_data
19
+
20
+ datasets_dict = {}
21
+
22
+ # Load GSP data unless the path is None
23
+ if in_config.gsp and in_config.gsp.gsp_zarr_path:
24
+ da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
25
+
26
+ # Remove national GSP
27
+ datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
28
+
29
+ # Load NWP data if in config
30
+ if in_config.nwp:
31
+
32
+ datasets_dict["nwp"] = {}
33
+ for nwp_source, nwp_config in in_config.nwp.items():
34
+
35
+ da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
36
+
37
+ da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
38
+
39
+ datasets_dict["nwp"][nwp_source] = da_nwp
40
+
41
+ # Load satellite data if in config
42
+ if in_config.satellite:
43
+ sat_config = config.input_data.satellite
44
+
45
+ da_sat = open_sat_data(sat_config.satellite_zarr_path)
46
+
47
+ da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
48
+
49
+ datasets_dict["sat"] = da_sat
50
+
51
+ if in_config.site:
52
+ da_sites = open_site(in_config.site)
53
+ datasets_dict["site"] = da_sites
54
+
55
+ return datasets_dict
@@ -9,7 +9,6 @@ from ocf_data_sampler.load.utils import (
9
9
  )
10
10
 
11
11
 
12
-
13
12
  def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
14
13
  """
15
14
  Opens the ECMWF IFS NWP data
@@ -27,10 +26,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
27
26
  ds = ds.rename(
28
27
  {
29
28
  "init_time": "init_time_utc",
30
- "variable": "channel",
31
29
  }
32
30
  )
33
31
 
32
+ # LEGACY SUPPORT
33
+ # rename variable to channel if it exists
34
+ if "variable" in ds:
35
+ ds = ds.rename({"variable": "channel"})
36
+
34
37
  # Check the timestamps are unique and increasing
35
38
  check_time_unique_increasing(ds.init_time_utc)
36
39
 
@@ -0,0 +1,30 @@
1
+ import pandas as pd
2
+ import xarray as xr
3
+ import numpy as np
4
+
5
+ from ocf_data_sampler.config.model import Site
6
+
7
+
8
+ def open_site(sites_config: Site) -> xr.DataArray:
9
+
10
+ # Load site generation xr.Dataset
11
+ data_ds = xr.open_dataset(sites_config.file_path)
12
+
13
+ # Load site generation data
14
+ metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")
15
+
16
+ # Add coordinates
17
+ ds = data_ds.assign_coords(
18
+ latitude=(metadata_df.latitude.to_xarray()),
19
+ longitude=(metadata_df.longitude.to_xarray()),
20
+ capacity_kwp=data_ds.capacity_kwp,
21
+ )
22
+
23
+ # Sanity checks
24
+ assert np.isfinite(data_ds.capacity_kwp.values).all()
25
+ assert (data_ds.capacity_kwp.values > 0).all()
26
+ assert metadata_df.index.is_unique
27
+
28
+ return ds.generation_kw
29
+
30
+
@@ -1,7 +1,8 @@
1
1
  """Conversion from Xarray to NumpyBatch"""
2
2
 
3
- from .gsp import convert_gsp_to_numpy_batch
4
- from .nwp import convert_nwp_to_numpy_batch
5
- from .satellite import convert_satellite_to_numpy_batch
3
+ from .gsp import convert_gsp_to_numpy_batch, GSPBatchKey
4
+ from .nwp import convert_nwp_to_numpy_batch, NWPBatchKey
5
+ from .satellite import convert_satellite_to_numpy_batch, SatelliteBatchKey
6
6
  from .sun_position import make_sun_position_numpy_batch
7
+ from .site import convert_site_to_numpy_batch
7
8
 
@@ -6,15 +6,15 @@ import xarray as xr
6
6
  class GSPBatchKey:
7
7
 
8
8
  gsp = 'gsp'
9
- gsp_nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
10
- gsp_effective_capacity_mwp = 'gsp_effective_capacity_mwp'
11
- gsp_time_utc = 'gsp_time_utc'
12
- gsp_t0_idx = 'gsp_t0_idx'
13
- gsp_solar_azimuth = 'gsp_solar_azimuth'
14
- gsp_solar_elevation = 'gsp_solar_elevation'
9
+ nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
10
+ effective_capacity_mwp = 'gsp_effective_capacity_mwp'
11
+ time_utc = 'gsp_time_utc'
12
+ t0_idx = 'gsp_t0_idx'
13
+ solar_azimuth = 'gsp_solar_azimuth'
14
+ solar_elevation = 'gsp_solar_elevation'
15
15
  gsp_id = 'gsp_id'
16
- gsp_x_osgb = 'gsp_x_osgb'
17
- gsp_y_osgb = 'gsp_y_osgb'
16
+ x_osgb = 'gsp_x_osgb'
17
+ y_osgb = 'gsp_y_osgb'
18
18
 
19
19
 
20
20
  def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
@@ -22,12 +22,12 @@ def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> d
22
22
 
23
23
  example = {
24
24
  GSPBatchKey.gsp: da.values,
25
- GSPBatchKey.gsp_nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
26
- GSPBatchKey.gsp_effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values,
27
- GSPBatchKey.gsp_time_utc: da["time_utc"].values.astype(float),
25
+ GSPBatchKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
26
+ GSPBatchKey.effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values,
27
+ GSPBatchKey.time_utc: da["time_utc"].values.astype(float),
28
28
  }
29
29
 
30
30
  if t0_idx is not None:
31
- example[GSPBatchKey.gsp_t0_idx] = t0_idx
31
+ example[GSPBatchKey.t0_idx] = t0_idx
32
32
 
33
33
  return example
@@ -7,13 +7,13 @@ import xarray as xr
7
7
  class NWPBatchKey:
8
8
 
9
9
  nwp = 'nwp'
10
- nwp_channel_names = 'nwp_channel_names'
11
- nwp_init_time_utc = 'nwp_init_time_utc'
12
- nwp_step = 'nwp_step'
13
- nwp_target_time_utc = 'nwp_target_time_utc'
14
- nwp_t0_idx = 'nwp_t0_idx'
15
- nwp_y_osgb = 'nwp_y_osgb'
16
- nwp_x_osgb = 'nwp_x_osgb'
10
+ channel_names = 'nwp_channel_names'
11
+ init_time_utc = 'nwp_init_time_utc'
12
+ step = 'nwp_step'
13
+ target_time_utc = 'nwp_target_time_utc'
14
+ t0_idx = 'nwp_t0_idx'
15
+ y_osgb = 'nwp_y_osgb'
16
+ x_osgb = 'nwp_x_osgb'
17
17
 
18
18
 
19
19
  def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
@@ -21,23 +21,23 @@ def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> d
21
21
 
22
22
  example = {
23
23
  NWPBatchKey.nwp: da.values,
24
- NWPBatchKey.nwp_channel_names: da.channel.values,
25
- NWPBatchKey.nwp_init_time_utc: da.init_time_utc.values.astype(float),
26
- NWPBatchKey.nwp_step: (da.step.values / pd.Timedelta("1h")).astype(int),
24
+ NWPBatchKey.channel_names: da.channel.values,
25
+ NWPBatchKey.init_time_utc: da.init_time_utc.values.astype(float),
26
+ NWPBatchKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
27
27
  }
28
28
 
29
29
  if "target_time_utc" in da.coords:
30
- example[NWPBatchKey.nwp_target_time_utc] = da.target_time_utc.values.astype(float)
30
+ example[NWPBatchKey.target_time_utc] = da.target_time_utc.values.astype(float)
31
31
 
32
32
  # TODO: Do we need this at all? Especially since it is only present in UKV data
33
33
  for batch_key, dataset_key in (
34
- (NWPBatchKey.nwp_y_osgb, "y_osgb"),
35
- (NWPBatchKey.nwp_x_osgb, "x_osgb"),
34
+ (NWPBatchKey.y_osgb, "y_osgb"),
35
+ (NWPBatchKey.x_osgb, "x_osgb"),
36
36
  ):
37
37
  if dataset_key in da.coords:
38
38
  example[batch_key] = da[dataset_key].values
39
39
 
40
40
  if t0_idx is not None:
41
- example[NWPBatchKey.nwp_t0_idx] = t0_idx
41
+ example[NWPBatchKey.t0_idx] = t0_idx
42
42
 
43
43
  return example
@@ -5,26 +5,26 @@ import xarray as xr
5
5
  class SatelliteBatchKey:
6
6
 
7
7
  satellite_actual = 'satellite_actual'
8
- satellite_time_utc = 'satellite_time_utc'
9
- satellite_x_geostationary = 'satellite_x_geostationary'
10
- satellite_y_geostationary = 'satellite_y_geostationary'
11
- satellite_t0_idx = 'satellite_t0_idx'
8
+ time_utc = 'satellite_time_utc'
9
+ x_geostationary = 'satellite_x_geostationary'
10
+ y_geostationary = 'satellite_y_geostationary'
11
+ t0_idx = 'satellite_t0_idx'
12
12
 
13
13
 
14
14
  def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
15
  """Convert from Xarray to NumpyBatch"""
16
16
  example = {
17
17
  SatelliteBatchKey.satellite_actual: da.values,
18
- SatelliteBatchKey.satellite_time_utc: da.time_utc.values.astype(float),
18
+ SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
19
19
  }
20
20
 
21
21
  for batch_key, dataset_key in (
22
- (SatelliteBatchKey.satellite_x_geostationary, "x_geostationary"),
23
- (SatelliteBatchKey.satellite_y_geostationary, "y_geostationary"),
22
+ (SatelliteBatchKey.x_geostationary, "x_geostationary"),
23
+ (SatelliteBatchKey.y_geostationary, "y_geostationary"),
24
24
  ):
25
25
  example[batch_key] = da[dataset_key].values
26
26
 
27
27
  if t0_idx is not None:
28
- example[SatelliteBatchKey.satellite_t0_idx] = t0_idx
28
+ example[SatelliteBatchKey.t0_idx] = t0_idx
29
29
 
30
30
  return example
@@ -0,0 +1,29 @@
1
+ """Convert site to Numpy Batch"""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ class SiteBatchKey:
7
+
8
+ generation = "site"
9
+ site_capacity_kwp = "site_capacity_kwp"
10
+ site_time_utc = "site_time_utc"
11
+ site_t0_idx = "site_t0_idx"
12
+ site_solar_azimuth = "site_solar_azimuth"
13
+ site_solar_elevation = "site_solar_elevation"
14
+ site_id = "site_id"
15
+
16
+
17
+ def convert_site_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
18
+ """Convert from Xarray to NumpyBatch"""
19
+
20
+ example = {
21
+ SiteBatchKey.generation: da.values,
22
+ SiteBatchKey.site_capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
23
+ SiteBatchKey.site_time_utc: da["time_utc"].values.astype(float),
24
+ }
25
+
26
+ if t0_idx is not None:
27
+ example[SiteBatchKey.site_t0_idx] = t0_idx
28
+
29
+ return example
@@ -1 +1,8 @@
1
-
1
+ from .fill_time_periods import fill_time_periods
2
+ from .find_contiguous_time_periods import (
3
+ find_contiguous_t0_periods,
4
+ intersection_of_multiple_dataframes_of_periods,
5
+ )
6
+ from .location import Location
7
+ from .spatial_slice_for_dataset import slice_datasets_by_space
8
+ from .time_slice_for_dataset import slice_datasets_by_time
@@ -1,3 +1,4 @@
1
+ """ Functions for simulating dropout in time series data """
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
  import xarray as xr
@@ -5,7 +6,7 @@ import xarray as xr
5
6
 
6
7
  def draw_dropout_time(
7
8
  t0: pd.Timestamp,
8
- dropout_timedeltas: list[pd.Timedelta] | None,
9
+ dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None,
9
10
  dropout_frac: float = 0,
10
11
  ):
11
12
 
@@ -55,6 +55,23 @@ def lon_lat_to_osgb(
55
55
  return _lon_lat_to_osgb(xx=x, yy=y)
56
56
 
57
57
 
58
+ def lon_lat_to_geostationary_area_coords(
59
+ longitude: Union[Number, np.ndarray],
60
+ latitude: Union[Number, np.ndarray],
61
+ xr_data: xr.DataArray,
62
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
+ """Loads geostationary area and transformation from lat-lon to geostationary coords
64
+
65
+ Args:
66
+ longitude: longitude
67
+ latitude: latitude
68
+ xr_data: xarray object with geostationary area
69
+
70
+ Returns:
71
+ Geostationary coords: x, y
72
+ """
73
+ return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)
74
+
58
75
  def osgb_to_geostationary_area_coords(
59
76
  x: Union[Number, np.ndarray],
60
77
  y: Union[Number, np.ndarray],
@@ -70,6 +87,31 @@ def osgb_to_geostationary_area_coords(
70
87
  Returns:
71
88
  Geostationary coords: x, y
72
89
  """
90
+
91
+ return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)
92
+
93
+
94
+
95
+ def coordinates_to_geostationary_area_coords(
96
+ x: Union[Number, np.ndarray],
97
+ y: Union[Number, np.ndarray],
98
+ xr_data: xr.DataArray,
99
+ crs_from: int
100
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
101
+ """Loads geostationary area and transformation from respective coordiates to geostationary coords
102
+
103
+ Args:
104
+ x: osgb east-west, or latitude
105
+ y: osgb north-south, or longitude
106
+ xr_data: xarray object with geostationary area
107
+ crs_from: the cordiates system of x,y
108
+
109
+ Returns:
110
+ Geostationary coords: x, y
111
+ """
112
+
113
+ assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}"
114
+
73
115
  # Only load these if using geostationary projection
74
116
  import pyresample
75
117
 
@@ -80,7 +122,7 @@ def osgb_to_geostationary_area_coords(
80
122
  )
81
123
  geostationary_crs = geostationary_area_definition.crs
82
124
  osgb_to_geostationary = pyproj.Transformer.from_crs(
83
- crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
125
+ crs_from=crs_from, crs_to=geostationary_crs, always_xy=True
84
126
  ).transform
85
127
  return osgb_to_geostationary(xx=x, yy=y)
86
128
 
@@ -8,6 +8,7 @@ import xarray as xr
8
8
  from ocf_data_sampler.select.location import Location
9
9
  from ocf_data_sampler.select.geospatial import (
10
10
  lon_lat_to_osgb,
11
+ lon_lat_to_geostationary_area_coords,
11
12
  osgb_to_geostationary_area_coords,
12
13
  osgb_to_lon_lat,
13
14
  spatial_coord_type,
@@ -101,7 +102,7 @@ def _get_idx_of_pixel_closest_to_poi(
101
102
 
102
103
  def _get_idx_of_pixel_closest_to_poi_geostationary(
103
104
  da: xr.DataArray,
104
- center_osgb: Location,
105
+ center: Location,
105
106
  ) -> Location:
106
107
  """
107
108
  Return x and y index location of pixel at center of region of interest.
@@ -116,7 +117,12 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
116
117
 
117
118
  _, x_dim, y_dim = spatial_coord_type(da)
118
119
 
119
- x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da)
120
+ if center.coordinate_system == 'osgb':
121
+ x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
122
+ elif center.coordinate_system == 'lon_lat':
123
+ x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
124
+ else:
125
+ x,y = center.x, center.y
120
126
  center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
121
127
 
122
128
  # Check that the requested point lies within the data
@@ -0,0 +1,53 @@
1
+ """ Functions for selecting data around a given location """
2
+ from ocf_data_sampler.config import Configuration
3
+ from ocf_data_sampler.select.location import Location
4
+ from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
5
+
6
+
7
+ def slice_datasets_by_space(
8
+ datasets_dict: dict,
9
+ location: Location,
10
+ config: Configuration,
11
+ ) -> dict:
12
+ """Slice the dictionary of input data sources around a given location
13
+
14
+ Args:
15
+ datasets_dict: Dictionary of the input data sources
16
+ location: The location to sample around
17
+ config: Configuration object.
18
+ """
19
+
20
+ assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})
21
+
22
+ sliced_datasets_dict = {}
23
+
24
+ if "nwp" in datasets_dict:
25
+
26
+ sliced_datasets_dict["nwp"] = {}
27
+
28
+ for nwp_key, nwp_config in config.input_data.nwp.items():
29
+
30
+ sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
31
+ datasets_dict["nwp"][nwp_key],
32
+ location,
33
+ height_pixels=nwp_config.nwp_image_size_pixels_height,
34
+ width_pixels=nwp_config.nwp_image_size_pixels_width,
35
+ )
36
+
37
+ if "sat" in datasets_dict:
38
+ sat_config = config.input_data.satellite
39
+
40
+ sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
41
+ datasets_dict["sat"],
42
+ location,
43
+ height_pixels=sat_config.satellite_image_size_pixels_height,
44
+ width_pixels=sat_config.satellite_image_size_pixels_width,
45
+ )
46
+
47
+ if "gsp" in datasets_dict:
48
+ sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
49
+
50
+ if "site" in datasets_dict:
51
+ sliced_datasets_dict["site"] = datasets_dict["site"].sel(site_id=location.id)
52
+
53
+ return sliced_datasets_dict
@@ -0,0 +1,124 @@
1
+ """ Slice datasets by time"""
2
+ import pandas as pd
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
6
+ from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
7
+ from ocf_data_sampler.time_functions import minutes
8
+
9
+
10
+ def slice_datasets_by_time(
11
+ datasets_dict: dict,
12
+ t0: pd.Timestamp,
13
+ config: Configuration,
14
+ ) -> dict:
15
+ """Slice the dictionary of input data sources around a given t0 time
16
+
17
+ Args:
18
+ datasets_dict: Dictionary of the input data sources
19
+ t0: The init-time
20
+ config: Configuration object.
21
+ """
22
+
23
+ sliced_datasets_dict = {}
24
+
25
+ if "nwp" in datasets_dict:
26
+
27
+ sliced_datasets_dict["nwp"] = {}
28
+
29
+ for nwp_key, da_nwp in datasets_dict["nwp"].items():
30
+
31
+ nwp_config = config.input_data.nwp[nwp_key]
32
+
33
+ sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
34
+ da_nwp,
35
+ t0,
36
+ sample_period_duration=minutes(nwp_config.time_resolution_minutes),
37
+ history_duration=minutes(nwp_config.history_minutes),
38
+ forecast_duration=minutes(nwp_config.forecast_minutes),
39
+ dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
40
+ dropout_frac=nwp_config.dropout_fraction,
41
+ accum_channels=nwp_config.nwp_accum_channels,
42
+ )
43
+
44
+ if "sat" in datasets_dict:
45
+
46
+ sat_config = config.input_data.satellite
47
+
48
+ sliced_datasets_dict["sat"] = select_time_slice(
49
+ datasets_dict["sat"],
50
+ t0,
51
+ sample_period_duration=minutes(sat_config.time_resolution_minutes),
52
+ interval_start=minutes(-sat_config.history_minutes),
53
+ interval_end=minutes(-sat_config.live_delay_minutes),
54
+ max_steps_gap=2,
55
+ )
56
+
57
+ # Randomly sample dropout
58
+ sat_dropout_time = draw_dropout_time(
59
+ t0,
60
+ dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
61
+ dropout_frac=sat_config.dropout_fraction,
62
+ )
63
+
64
+ # Apply the dropout
65
+ sliced_datasets_dict["sat"] = apply_dropout_time(
66
+ sliced_datasets_dict["sat"],
67
+ sat_dropout_time,
68
+ )
69
+
70
+ if "gsp" in datasets_dict:
71
+ gsp_config = config.input_data.gsp
72
+
73
+ sliced_datasets_dict["gsp_future"] = select_time_slice(
74
+ datasets_dict["gsp"],
75
+ t0,
76
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
77
+ interval_start=minutes(30),
78
+ interval_end=minutes(gsp_config.forecast_minutes),
79
+ )
80
+
81
+ sliced_datasets_dict["gsp"] = select_time_slice(
82
+ datasets_dict["gsp"],
83
+ t0,
84
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
85
+ interval_start=-minutes(gsp_config.history_minutes),
86
+ interval_end=minutes(0),
87
+ )
88
+
89
+ # Dropout on the GSP, but not the future GSP
90
+ gsp_dropout_time = draw_dropout_time(
91
+ t0,
92
+ dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
93
+ dropout_frac=gsp_config.dropout_fraction,
94
+ )
95
+
96
+ sliced_datasets_dict["gsp"] = apply_dropout_time(
97
+ sliced_datasets_dict["gsp"], gsp_dropout_time
98
+ )
99
+
100
+ if "site" in datasets_dict:
101
+ site_config = config.input_data.site
102
+
103
+ sliced_datasets_dict["site"] = select_time_slice(
104
+ datasets_dict["site"],
105
+ t0,
106
+ sample_period_duration=minutes(site_config.time_resolution_minutes),
107
+ interval_start=-minutes(site_config.history_minutes),
108
+ interval_end=minutes(site_config.forecast_minutes),
109
+ )
110
+
111
+ # Randomly sample dropout
112
+ site_dropout_time = draw_dropout_time(
113
+ t0,
114
+ dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
115
+ dropout_frac=site_config.dropout_fraction,
116
+ )
117
+
118
+ # Apply the dropout
119
+ sliced_datasets_dict["site"] = apply_dropout_time(
120
+ sliced_datasets_dict["site"],
121
+ site_dropout_time,
122
+ )
123
+
124
+ return sliced_datasets_dict
@@ -0,0 +1,11 @@
1
+ import pandas as pd
2
+
3
+
4
+ def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
5
+ """Timedelta minutes
6
+
7
+ Args:
8
+ minutes: the number of minutes, single value or list
9
+ """
10
+ minutes_delta = pd.to_timedelta(minutes, unit="m")
11
+ return minutes_delta