ocf-data-sampler 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -28,6 +28,7 @@ class NWPStatDict(dict):
28
28
  f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
29
29
  )
30
30
 
31
+
31
32
  # ------ UKV
32
33
  # Means and std computed WITH version_7 and higher, MetOffice values
33
34
  UKV_STD = {
@@ -49,6 +50,7 @@ UKV_STD = {
49
50
  "prmsl": 1252.71790539,
50
51
  "prate": 0.00021497,
51
52
  }
53
+
52
54
  UKV_MEAN = {
53
55
  "cdcb": 1412.26599062,
54
56
  "lcc": 50.08362643,
@@ -97,6 +99,7 @@ ECMWF_STD = {
97
99
  "diff_duvrs": 81605.25,
98
100
  "diff_sr": 818950.6875,
99
101
  }
102
+
100
103
  ECMWF_MEAN = {
101
104
  "dlwrf": 27187026.0,
102
105
  "dswrf": 11458988.0,
@@ -133,3 +136,38 @@ NWP_MEANS = NWPStatDict(
133
136
  ecmwf=ECMWF_MEAN,
134
137
  )
135
138
 
139
+ # ------ Satellite
140
+ # RSS Mean and std values from randomised 20% of 2020 imagery
141
+
142
+ RSS_STD = {
143
+ "HRV": 0.11405209,
144
+ "IR_016": 0.21462157,
145
+ "IR_039": 0.04618041,
146
+ "IR_087": 0.06687243,
147
+ "IR_097": 0.0468558,
148
+ "IR_108": 0.17482725,
149
+ "IR_120": 0.06115861,
150
+ "IR_134": 0.04492306,
151
+ "VIS006": 0.12184761,
152
+ "VIS008": 0.13090034,
153
+ "WV_062": 0.16111417,
154
+ "WV_073": 0.12924142,
155
+ }
156
+
157
+ RSS_MEAN = {
158
+ "HRV": 0.09298719,
159
+ "IR_016": 0.17594202,
160
+ "IR_039": 0.86167645,
161
+ "IR_087": 0.7719318,
162
+ "IR_097": 0.8014212,
163
+ "IR_108": 0.71254843,
164
+ "IR_120": 0.89058584,
165
+ "IR_134": 0.944365,
166
+ "VIS006": 0.09633306,
167
+ "VIS008": 0.11426069,
168
+ "WV_062": 0.7359355,
169
+ "WV_073": 0.62479186,
170
+ }
171
+
172
+ RSS_STD = _to_data_array(RSS_STD)
173
+ RSS_MEAN = _to_data_array(RSS_MEAN)
@@ -1,5 +1,4 @@
1
1
  """Convert NWP to NumpyBatch"""
2
-
3
2
  import pandas as pd
4
3
  import xarray as xr
5
4
 
@@ -13,6 +13,7 @@ class SatelliteBatchKey:
13
13
 
14
14
  def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
15
  """Convert from Xarray to NumpyBatch"""
16
+
16
17
  example = {
17
18
  SatelliteBatchKey.satellite_actual: da.values,
18
19
  SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
@@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None
27
28
  if t0_idx is not None:
28
29
  example[SatelliteBatchKey.t0_idx] = t0_idx
29
30
 
30
- return example
31
+ return example
@@ -4,7 +4,7 @@ import xarray as xr
4
4
  from typing import Tuple
5
5
 
6
6
  from ocf_data_sampler.config import Configuration
7
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
7
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
8
8
  from ocf_data_sampler.numpy_batch import (
9
9
  convert_nwp_to_numpy_batch,
10
10
  convert_satellite_to_numpy_batch,
@@ -25,8 +25,8 @@ def process_and_combine_datasets(
25
25
  location: Location,
26
26
  target_key: str = 'gsp'
27
27
  ) -> dict:
28
- """Normalize and convert data to numpy arrays"""
29
28
 
29
+ """Normalise and convert data to numpy arrays"""
30
30
  numpy_modalities = []
31
31
 
32
32
  if "nwp" in dataset_dict:
@@ -37,19 +37,23 @@ def process_and_combine_datasets(
37
37
  # Standardise
38
38
  provider = config.input_data.nwp[nwp_key].provider
39
39
  da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
40
+
40
41
  # Convert to NumpyBatch
41
42
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
42
43
 
43
44
  # Combine the NWPs into NumpyBatch
44
45
  numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
45
46
 
47
+
46
48
  if "sat" in dataset_dict:
47
- # Satellite is already in the range [0-1] so no need to standardise
49
+ # Standardise
48
50
  da_sat = dataset_dict["sat"]
51
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
49
52
 
50
53
  # Convert to NumpyBatch
51
54
  numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
52
55
 
56
+
53
57
  gsp_config = config.input_data.gsp
54
58
 
55
59
  if "gsp" in dataset_dict:
@@ -93,6 +97,7 @@ def process_and_combine_datasets(
93
97
 
94
98
  return combined_sample
95
99
 
100
+
96
101
  def process_and_combine_site_sample_dict(
97
102
  dataset_dict: dict,
98
103
  config: Configuration,
@@ -119,8 +124,9 @@ def process_and_combine_site_sample_dict(
119
124
  data_arrays.append((f"nwp-{provider}", da_nwp))
120
125
 
121
126
  if "sat" in dataset_dict:
122
- # TODO add some satellite normalisation
127
+ # Standardise
123
128
  da_sat = dataset_dict["sat"]
129
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
124
130
  data_arrays.append(("satellite", da_sat))
125
131
 
126
132
  if "site" in dataset_dict:
@@ -143,6 +149,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
143
149
  combined_dict.update(d)
144
150
  return combined_dict
145
151
 
152
+
146
153
  def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
147
154
  """
148
155
  Combine a list of DataArrays into a single Dataset with unique naming conventions.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.37
3
+ Version: 0.0.38
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,5 +1,5 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
- ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
2
+ ocf_data_sampler/constants.py,sha256=G2VfkE_-veq_0hNBQQOQCtCsfC37O5-QG9mJWEmln5s,4153
3
3
  ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
4
  ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
5
5
  ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
@@ -21,8 +21,8 @@ ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3tx
21
21
  ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
22
22
  ocf_data_sampler/numpy_batch/collate.py,sha256=KyWdDi8AXD5YiokXXiqr2_X1SC1me1GrhnQMelg0Qx8,2202
23
23
  ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
24
- ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
25
- ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
24
+ ocf_data_sampler/numpy_batch/nwp.py,sha256=bEvBB9xGf7B8okPBZ-eZLK4PBWA0nvmmEFiN49dgqPU,1254
25
+ ocf_data_sampler/numpy_batch/satellite.py,sha256=VKo8eiSIcYhAdHHBUH697HMz7rBv6S9XZ6_XCZ-qG4Y,905
26
26
  ocf_data_sampler/numpy_batch/site.py,sha256=CWI0efUl8SrnGm0VNGdGwAqrmlT1XaVbJIUE2hSOz9E,744
27
27
  ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
28
28
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -36,7 +36,7 @@ ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_V
36
36
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
37
37
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
38
38
  ocf_data_sampler/torch_datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
39
- ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=WwwuzxXoq8S70R-tWABXUMO854TG8GWYnNhb1IU8MRY,7526
39
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=ImfU4I75x7A57KCShWj6dr62tNtJqJ0ImKRiT0hijIQ,7564
40
40
  ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=QRFqbdfNchVWj4y70n-rJdFvFGvQj-WpZLdFqWjnOTw,5543
41
41
  ocf_data_sampler/torch_datasets/site.py,sha256=NYuhgm9ti9SRt1dcb_WrFYYo14NgVdOsaoPbc5FsnaA,6560
42
42
  ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
@@ -59,10 +59,11 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
59
59
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
60
60
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
61
61
  tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2lr1BgJZkoo,9603
62
+ tests/torch_datasets/test_process_and_combine.py,sha256=SWmrI59JVfMnHK78N5yhKzQR8b5kJ8TeMZke9Mlnc-o,5717
62
63
  tests/torch_datasets/test_pvnet_uk_regional.py,sha256=eqy0nQOWoHnqltlJlGmRlgIiIzPEwOC6o5A6GARryKA,2118
63
64
  tests/torch_datasets/test_site.py,sha256=YuVjWTI14_kmEOx23XE5J_RZ8UalCKD2xRv6mqYizB8,2872
64
- ocf_data_sampler-0.0.37.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
65
- ocf_data_sampler-0.0.37.dist-info/METADATA,sha256=tKixIA37U0AA76QsYmCIfLzpzE2aSGRmquSx69jX4aY,10290
66
- ocf_data_sampler-0.0.37.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
67
- ocf_data_sampler-0.0.37.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
68
- ocf_data_sampler-0.0.37.dist-info/RECORD,,
65
+ ocf_data_sampler-0.0.38.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
66
+ ocf_data_sampler-0.0.38.dist-info/METADATA,sha256=YbU2ymHq94ZLsyjlD1ZdKoYpVVDzUUmyWN7xRDBvQDM,10290
67
+ ocf_data_sampler-0.0.38.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
68
+ ocf_data_sampler-0.0.38.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
69
+ ocf_data_sampler-0.0.38.dist-info/RECORD,,
@@ -0,0 +1,165 @@
1
+ import pytest
2
+ import tempfile
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import xarray as xr
7
+ import dask.array as da
8
+
9
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
10
+ from ocf_data_sampler.config import Configuration
11
+ from ocf_data_sampler.select.location import Location
12
+ from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
13
+ from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
14
+
15
+ from ocf_data_sampler.torch_datasets.process_and_combine import (
16
+ process_and_combine_datasets,
17
+ process_and_combine_site_sample_dict,
18
+ merge_dicts,
19
+ fill_nans_in_arrays,
20
+ compute,
21
+ )
22
+
23
+
24
+ def test_process_and_combine_datasets(pvnet_config_filename):
25
+
26
+ # Load in config for function and define location
27
+ config = load_yaml_configuration(pvnet_config_filename)
28
+ t0 = pd.Timestamp("2024-01-01 00:00")
29
+ location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
30
+
31
+ nwp_data = xr.DataArray(
32
+ np.random.rand(4, 2, 2, 2),
33
+ dims=["time_utc", "channel", "y", "x"],
34
+ coords={
35
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
36
+ "channel": ["t2m", "dswrf"],
37
+ "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
38
+ "init_time_utc": pd.Timestamp("2024-01-01 00:00")
39
+ }
40
+ )
41
+
42
+ sat_data = xr.DataArray(
43
+ np.random.rand(7, 1, 2, 2),
44
+ dims=["time_utc", "channel", "y", "x"],
45
+ coords={
46
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
47
+ "channel": ["HRV"],
48
+ "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
49
+ "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
50
+ }
51
+ )
52
+
53
+ # Combine as dict
54
+ dataset_dict = {
55
+ "nwp": {"ukv": nwp_data},
56
+ "sat": sat_data
57
+ }
58
+
59
+ # Call relevant function
60
+ result = process_and_combine_datasets(dataset_dict, config, t0, location)
61
+
62
+ # Assert result is dict - check and validate
63
+ assert isinstance(result, dict)
64
+ assert NWPBatchKey.nwp in result
65
+ assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
66
+ assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
67
+
68
+
69
+ def test_merge_dicts():
70
+ """Test merge_dicts function"""
71
+ dict1 = {"a": 1, "b": 2}
72
+ dict2 = {"c": 3, "d": 4}
73
+ dict3 = {"e": 5}
74
+
75
+ result = merge_dicts([dict1, dict2, dict3])
76
+ assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
77
+
78
+ # Test key overwriting
79
+ dict4 = {"a": 10, "f": 6}
80
+ result = merge_dicts([dict1, dict4])
81
+ assert result["a"] == 10
82
+
83
+
84
+ def test_fill_nans_in_arrays():
85
+ """Test the fill_nans_in_arrays function"""
86
+ array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
87
+ nested_dict = {
88
+ "array1": array_with_nans,
89
+ "nested": {
90
+ "array2": np.array([np.nan, 2.0, np.nan, 4.0])
91
+ },
92
+ "string_key": "not_an_array"
93
+ }
94
+
95
+ result = fill_nans_in_arrays(nested_dict)
96
+
97
+ assert not np.isnan(result["array1"]).any()
98
+ assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
99
+ assert not np.isnan(result["nested"]["array2"]).any()
100
+ assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
101
+ assert result["string_key"] == "not_an_array"
102
+
103
+
104
+ def test_compute():
105
+ """Test compute function with dask array"""
106
+ da_dask = xr.DataArray(da.random.random((5, 5)))
107
+
108
+ # Create a nested dictionary with dask array
109
+ nested_dict = {
110
+ "array1": da_dask,
111
+ "nested": {
112
+ "array2": da_dask
113
+ }
114
+ }
115
+
116
+ # Ensure initial data is lazy - i.e. not yet computed
117
+ assert not isinstance(nested_dict["array1"].data, np.ndarray)
118
+ assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
119
+
120
+ # Call the compute function
121
+ result = compute(nested_dict)
122
+
123
+ # Assert that the result is an xarray DataArray and no longer lazy
124
+ assert isinstance(result["array1"], xr.DataArray)
125
+ assert isinstance(result["nested"]["array2"], xr.DataArray)
126
+ assert isinstance(result["array1"].data, np.ndarray)
127
+ assert isinstance(result["nested"]["array2"].data, np.ndarray)
128
+
129
+ # Ensure there no NaN values in computed data
130
+ assert not np.isnan(result["array1"].data).any()
131
+ assert not np.isnan(result["nested"]["array2"].data).any()
132
+
133
+
134
+ def test_process_and_combine_site_sample_dict(pvnet_config_filename):
135
+ # Load config
136
+ config = load_yaml_configuration(pvnet_config_filename)
137
+
138
+ # Specify minimal structure for testing
139
+ raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
140
+ site_dict = {
141
+ "nwp": {
142
+ "ukv": xr.DataArray(
143
+ raw_nwp_values,
144
+ dims=["time_utc", "channel", "y", "x"],
145
+ coords={
146
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
147
+ "channel": ["dswrf"], # Single channel
148
+ },
149
+ )
150
+ }
151
+ }
152
+ print(f"Input site_dict: {site_dict}")
153
+
154
+ # Call function
155
+ result = process_and_combine_site_sample_dict(site_dict, config)
156
+
157
+ # Assert to validate output structure
158
+ assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
159
+ assert len(result.data_vars) > 0, "Dataset should contain data variables"
160
+
161
+ # Validate variable via assertion and shape of such
162
+ expected_variable = "nwp-ukv"
163
+ assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
164
+ nwp_result = result[expected_variable]
165
+ assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}"