ocf-data-sampler 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -8,23 +8,23 @@ from ocf_data_sampler.config.model import Site
8
8
  def open_site(sites_config: Site) -> xr.DataArray:
9
9
 
10
10
  # Load site generation xr.Dataset
11
- data_ds = xr.open_dataset(sites_config.file_path)
11
+ site_generation_ds = xr.open_dataset(sites_config.file_path)
12
12
 
13
13
  # Load site generation data
14
14
  metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")
15
15
 
16
- # Add coordinates
17
- ds = data_ds.assign_coords(
18
- latitude=(metadata_df.latitude.to_xarray()),
19
- longitude=(metadata_df.longitude.to_xarray()),
20
- capacity_kwp=data_ds.capacity_kwp,
16
+ # Ensure metadata aligns with the site_id dimension in data_ds
17
+ metadata_df = metadata_df.reindex(site_generation_ds.site_id.values)
18
+
19
+ # Assign coordinates to the Dataset using the aligned metadata
20
+ site_generation_ds = site_generation_ds.assign_coords(
21
+ latitude=("site_id", metadata_df["latitude"].values),
22
+ longitude=("site_id", metadata_df["longitude"].values),
23
+ capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
21
24
  )
22
25
 
23
26
  # Sanity checks
24
- assert np.isfinite(data_ds.capacity_kwp.values).all()
25
- assert (data_ds.capacity_kwp.values > 0).all()
27
+ assert np.isfinite(site_generation_ds.capacity_kwp.values).all()
28
+ assert (site_generation_ds.capacity_kwp.values > 0).all()
26
29
  assert metadata_df.index.is_unique
27
-
28
- return ds.generation_kw
29
-
30
-
30
+ return site_generation_ds.generation_kw
@@ -1 +1,2 @@
1
-
1
+ from .pvnet_uk_regional import PVNetUKRegionalDataset
2
+ from .site import SitesDataset
@@ -1,6 +1,7 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
  import xarray as xr
4
+ from typing import Tuple
4
5
 
5
6
  from ocf_data_sampler.config import Configuration
6
7
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
@@ -9,7 +10,6 @@ from ocf_data_sampler.numpy_batch import (
9
10
  convert_satellite_to_numpy_batch,
10
11
  convert_gsp_to_numpy_batch,
11
12
  make_sun_position_numpy_batch,
12
- convert_site_to_numpy_batch,
13
13
  )
14
14
  from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
15
15
  from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
@@ -73,18 +73,6 @@ def process_and_combine_datasets(
73
73
  }
74
74
  )
75
75
 
76
-
77
- if "site" in dataset_dict:
78
- site_config = config.input_data.site
79
- da_sites = dataset_dict["site"]
80
- da_sites = da_sites / da_sites.capacity_kwp
81
-
82
- numpy_modalities.append(
83
- convert_site_to_numpy_batch(
84
- da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
85
- )
86
- )
87
-
88
76
  if target_key == 'gsp':
89
77
  # Make sun coords NumpyBatch
90
78
  datetimes = pd.date_range(
@@ -95,16 +83,6 @@ def process_and_combine_datasets(
95
83
 
96
84
  lon, lat = osgb_to_lon_lat(location.x, location.y)
97
85
 
98
- elif target_key == 'site':
99
- # Make sun coords NumpyBatch
100
- datetimes = pd.date_range(
101
- t0+minutes(site_config.interval_start_minutes),
102
- t0+minutes(site_config.interval_end_minutes),
103
- freq=minutes(site_config.time_resolution_minutes),
104
- )
105
-
106
- lon, lat = location.x, location.y
107
-
108
86
  numpy_modalities.append(
109
87
  make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
110
88
  )
@@ -115,6 +93,47 @@ def process_and_combine_datasets(
115
93
 
116
94
  return combined_sample
117
95
 
96
+ def process_and_combine_site_sample_dict(
97
+ dataset_dict: dict,
98
+ config: Configuration,
99
+ ) -> xr.Dataset:
100
+ """
101
+ Normalize and combine data into a single xr Dataset
102
+
103
+ Args:
104
+ dataset_dict: dict containing sliced xr DataArrays
105
+ config: Configuration for the model
106
+
107
+ Returns:
108
+ xr.Dataset: A merged Dataset with nans filled in.
109
+
110
+ """
111
+
112
+ data_arrays = []
113
+
114
+ if "nwp" in dataset_dict:
115
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
116
+ # Standardise
117
+ provider = config.input_data.nwp[nwp_key].provider
118
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
119
+ data_arrays.append((f"nwp-{provider}", da_nwp))
120
+
121
+ if "sat" in dataset_dict:
122
+ # TODO add some satellite normalisation
123
+ da_sat = dataset_dict["sat"]
124
+ data_arrays.append(("satellite", da_sat))
125
+
126
+ if "site" in dataset_dict:
127
+ # site_config = config.input_data.site
128
+ da_sites = dataset_dict["site"]
129
+ da_sites = da_sites / da_sites.capacity_kwp
130
+ data_arrays.append(("sites", da_sites))
131
+
132
+ combined_sample_dataset = merge_arrays(data_arrays)
133
+
134
+ # Fill any nan values
135
+ return combined_sample_dataset.fillna(0.0)
136
+
118
137
 
119
138
  def merge_dicts(list_of_dicts: list[dict]) -> dict:
120
139
  """Merge a list of dictionaries into a single dictionary"""
@@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
124
143
  combined_dict.update(d)
125
144
  return combined_dict
126
145
 
146
+ def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
147
+ """
148
+ Combine a list of DataArrays into a single Dataset with unique naming conventions.
149
+
150
+ Args:
151
+ list_of_arrays: List of tuples where each tuple contains:
152
+ - A string (key name).
153
+ - An xarray.DataArray.
154
+
155
+ Returns:
156
+ xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
157
+ """
158
+ datasets = []
159
+
160
+ for key, data_array in normalised_data_arrays:
161
+ # Ensure all attributes are strings for consistency
162
+ data_array = data_array.assign_attrs(
163
+ {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
164
+ )
165
+
166
+ # Convert DataArray to Dataset with the variable name as the key
167
+ dataset = data_array.to_dataset(name=key)
168
+
169
+ # Prepend key name to all dimension and coordinate names for uniqueness
170
+ dataset = dataset.rename(
171
+ {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
172
+ )
173
+ dataset = dataset.rename(
174
+ {coord: f"{key}__{coord}" for coord in dataset.coords}
175
+ )
176
+
177
+ # Handle concatenation dimension if applicable
178
+ concat_dim = (
179
+ f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
180
+ else f"{key}__time_utc"
181
+ )
182
+
183
+ if f"{key}__init_time_utc" in dataset.coords:
184
+ init_coord = f"{key}__init_time_utc"
185
+ if dataset[init_coord].ndim == 0: # Check if scalar
186
+ expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
187
+ dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})
188
+
189
+ datasets.append(dataset)
190
+
191
+ # Ensure all datasets are valid xarray.Dataset objects
192
+ for ds in datasets:
193
+ assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"
194
+
195
+ # Merge all prepared datasets
196
+ combined_dataset = xr.merge(datasets)
197
+
198
+ return combined_dataset
127
199
 
128
200
  def fill_nans_in_arrays(batch: dict) -> dict:
129
201
  """Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
@@ -15,7 +15,7 @@ from ocf_data_sampler.select import (
15
15
  slice_datasets_by_time, slice_datasets_by_space
16
16
  )
17
17
  from ocf_data_sampler.utils import minutes
18
- from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
18
+ from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict
19
19
  from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
20
20
 
21
21
  xr.set_options(keep_attrs=True)
@@ -152,10 +152,9 @@ class SitesDataset(Dataset):
152
152
  """
153
153
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
154
154
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
155
- sample_dict = compute(sample_dict)
156
-
157
- sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')
158
155
 
156
+ sample = process_and_combine_site_sample_dict(sample_dict, self.config)
157
+ sample = sample.compute()
159
158
  return sample
160
159
 
161
160
  def get_location_from_site_id(self, site_id):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.31
3
+ Version: 0.0.33
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -10,7 +10,7 @@ ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlA
10
10
  ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
12
12
  ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
13
- ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
13
+ ocf_data_sampler/load/site.py,sha256=P83uz01WBDzoZajdOH0m8FQt4-buKDlUk19N548KqhA,1086
14
14
  ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
15
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
16
  ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
@@ -34,10 +34,10 @@ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejD
34
34
  ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_VmDn42s5leMzHt0k,6122
35
35
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
36
36
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
37
- ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
38
- ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=4k6f6PlMqrg3luMwGw3764iOyfuUNUePKyoikYGaRMI,4953
37
+ ocf_data_sampler/torch_datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
38
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=WwwuzxXoq8S70R-tWABXUMO854TG8GWYnNhb1IU8MRY,7526
39
39
  ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=QRFqbdfNchVWj4y70n-rJdFvFGvQj-WpZLdFqWjnOTw,5543
40
- ocf_data_sampler/torch_datasets/site.py,sha256=lo2ULurfWNu9vzBC6H4pdKMMpUMIT8_FWC1l_1mgIOM,6596
40
+ ocf_data_sampler/torch_datasets/site.py,sha256=NYuhgm9ti9SRt1dcb_WrFYYo14NgVdOsaoPbc5FsnaA,6560
41
41
  ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
42
42
  scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
43
43
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,11 +56,11 @@ tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJ
56
56
  tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
57
57
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
58
58
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
59
- tests/select/test_select_time_slice.py,sha256=QOhoR3qsr7RBGze4yohcViZ-ad1zYQzIKzxlnf0ymnU,9603
60
- tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
61
- tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
62
- ocf_data_sampler-0.0.31.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
- ocf_data_sampler-0.0.31.dist-info/METADATA,sha256=rFtc-y0PkztBWSGazWfr7WsPRo7SdnccosWrlaLJmk8,9559
64
- ocf_data_sampler-0.0.31.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
65
- ocf_data_sampler-0.0.31.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
- ocf_data_sampler-0.0.31.dist-info/RECORD,,
59
+ tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2lr1BgJZkoo,9603
60
+ tests/torch_datasets/test_pvnet_uk_regional.py,sha256=ZNyrisyhM1vw4q2qcHDuvX2uRi-v3U8Y8lOYx7cd8yM,2635
61
+ tests/torch_datasets/test_site.py,sha256=YuVjWTI14_kmEOx23XE5J_RZ8UalCKD2xRv6mqYizB8,2872
62
+ ocf_data_sampler-0.0.33.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
+ ocf_data_sampler-0.0.33.dist-info/METADATA,sha256=HWhJVLbJkMM9t5i_GTDbgfit1essTQh_02sOOOgolAY,9559
64
+ ocf_data_sampler-0.0.33.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
65
+ ocf_data_sampler-0.0.33.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
+ ocf_data_sampler-0.0.33.dist-info/RECORD,,
@@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
197
197
  t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
198
198
  interval_start = pd.Timedelta(-6, "h")
199
199
  interval_end = pd.Timedelta(3, "h")
200
- freq = pd.Timedelta("1H")
200
+ freq = pd.Timedelta("1h")
201
201
  dropout_timedelta = pd.Timedelta("-2h")
202
202
 
203
203
  t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  import tempfile
3
3
 
4
- from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
4
+ from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
5
5
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
6
  from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
7
7
 
@@ -1,11 +1,12 @@
1
1
  import pandas as pd
2
2
  import pytest
3
3
 
4
- from ocf_data_sampler.torch_datasets.site import SitesDataset
4
+ from ocf_data_sampler.torch_datasets import SitesDataset
5
5
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
6
  from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
7
7
  from ocf_data_sampler.numpy_batch.site import SiteBatchKey
8
8
  from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
9
+ from xarray import Dataset
9
10
 
10
11
 
11
12
  @pytest.fixture()
@@ -34,31 +35,26 @@ def test_site(site_config_filename):
34
35
  # Generate a sample
35
36
  sample = dataset[0]
36
37
 
37
- assert isinstance(sample, dict)
38
+ assert isinstance(sample, Dataset)
38
39
 
39
- for key in [
40
- NWPBatchKey.nwp,
41
- SatelliteBatchKey.satellite_actual,
42
- SiteBatchKey.generation,
43
- SiteBatchKey.site_solar_azimuth,
44
- SiteBatchKey.site_solar_elevation,
45
- ]:
46
- assert key in sample
40
+ # Expected dimensions and data variables
41
+ expected_dims = {'satellite__x_geostationary', 'sites__time_utc', 'nwp-ukv__target_time_utc',
42
+ 'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
43
+ 'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'}
44
+ expected_data_vars = {"nwp-ukv", "satellite", "sites"}
47
45
 
48
- for nwp_source in ["ukv"]:
49
- assert nwp_source in sample[NWPBatchKey.nwp]
46
+ # Check dimensions
47
+ assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
48
+ # Check data variables
49
+ assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
50
50
 
51
51
  # check the shape of the data is correct
52
52
  # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
53
- assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
53
+ assert sample["satellite"].values.shape == (7, 1, 2, 2)
54
54
  # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
55
- assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
56
- # 3 hours of 30 minute data (inclusive)
57
- assert sample[SiteBatchKey.generation].shape == (4,)
58
- # Solar angles have same shape as GSP data
59
- assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
60
- assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)
61
-
55
+ assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
56
+ # 1.5 hours of 30 minute data (inclusive)
57
+ assert sample["sites"].values.shape == (4,)
62
58
 
63
59
  def test_site_time_filter_start(site_config_filename):
64
60