ocf-data-sampler 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \
6
+ find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods
7
+ from ocf_data_sampler.time_functions import minutes
8
+
9
+
10
+ def find_valid_time_periods(
11
+ datasets_dict: dict,
12
+ config: Configuration,
13
+ ):
14
+ """Find the t0 times where all of the requested input data is available
15
+
16
+ Args:
17
+ datasets_dict: A dictionary of input datasets
18
+ config: Configuration file
19
+ """
20
+
21
+ assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
22
+
23
+ contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
24
+
25
+ if "nwp" in datasets_dict:
26
+ for nwp_key, nwp_config in config.input_data.nwp.items():
27
+
28
+ da = datasets_dict["nwp"][nwp_key]
29
+
30
+ if nwp_config.dropout_timedeltas_minutes is None:
31
+ max_dropout = minutes(0)
32
+ else:
33
+ max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
34
+
35
+ if nwp_config.max_staleness_minutes is None:
36
+ max_staleness = None
37
+ else:
38
+ max_staleness = minutes(nwp_config.max_staleness_minutes)
39
+
40
+ # The last step of the forecast is lost if we have to diff channels
41
+ if len(nwp_config.nwp_accum_channels) > 0:
42
+ end_buffer = minutes(nwp_config.time_resolution_minutes)
43
+ else:
44
+ end_buffer = minutes(0)
45
+
46
+ # This is the max staleness we can use considering the max step of the input data
47
+ max_possible_staleness = (
48
+ pd.Timedelta(da["step"].max().item())
49
+ - minutes(nwp_config.forecast_minutes)
50
+ - end_buffer
51
+ )
52
+
53
+ # Default to use max possible staleness unless specified in config
54
+ if max_staleness is None:
55
+ max_staleness = max_possible_staleness
56
+ else:
57
+ # Make sure the max acceptable staleness isn't longer than the max possible
58
+ assert max_staleness <= max_possible_staleness
59
+
60
+ time_periods = find_contiguous_t0_periods_nwp(
61
+ datetimes=pd.DatetimeIndex(da["init_time_utc"]),
62
+ history_duration=minutes(nwp_config.history_minutes),
63
+ max_staleness=max_staleness,
64
+ max_dropout=max_dropout,
65
+ )
66
+
67
+ contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
68
+
69
+ if "sat" in datasets_dict:
70
+ sat_config = config.input_data.satellite
71
+
72
+ time_periods = find_contiguous_t0_periods(
73
+ pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
74
+ sample_period_duration=minutes(sat_config.time_resolution_minutes),
75
+ history_duration=minutes(sat_config.history_minutes),
76
+ forecast_duration=minutes(sat_config.forecast_minutes),
77
+ )
78
+
79
+ contiguous_time_periods['sat'] = time_periods
80
+
81
+ if "gsp" in datasets_dict:
82
+ gsp_config = config.input_data.gsp
83
+
84
+ time_periods = find_contiguous_t0_periods(
85
+ pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
86
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
87
+ history_duration=minutes(gsp_config.history_minutes),
88
+ forecast_duration=minutes(gsp_config.forecast_minutes),
89
+ )
90
+
91
+ contiguous_time_periods['gsp'] = time_periods
92
+
93
+ # just get the values (not the keys)
94
+ contiguous_time_periods_values = list(contiguous_time_periods.values())
95
+
96
+ # Find joint overlapping contiguous time periods
97
+ if len(contiguous_time_periods_values) > 1:
98
+ valid_time_periods = intersection_of_multiple_dataframes_of_periods(
99
+ contiguous_time_periods_values
100
+ )
101
+ else:
102
+ valid_time_periods = contiguous_time_periods_values[0]
103
+
104
+ # check there are some valid time periods
105
+ if len(valid_time_periods) == 0:
106
+ raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
107
+
108
+ return valid_time_periods
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.24
3
+ Version: 0.0.25
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,41 +1,52 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
2
  ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
3
+ ocf_data_sampler/time_functions.py,sha256=R6ZlVEe6h4UlJeUW7paZYAMWveOv9MTjMsoISCwnsiE,284
3
4
  ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
4
5
  ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
5
- ocf_data_sampler/config/model.py,sha256=bvU3BEMtcUh-N17fMVLTYtN-J2GcTM9Qq-CI5AfbE4Q,8128
6
+ ocf_data_sampler/config/model.py,sha256=5GO8SF_4iOZhCAyIJyENSl0dnDRIWrURgqwslrVWke8,9462
6
7
  ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
7
8
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
9
  ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
9
10
  ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
+ ocf_data_sampler/load/load_dataset.py,sha256=R4RAIVLVx6CHA6Qs61kD9sx834I_GMGAn6G7ZgwFMUA,1627
10
12
  ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
13
+ ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
11
14
  ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
12
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
13
16
  ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
14
17
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=vW-p3vCyQ-CofKo555-gE7VDi5hlpjtjTLfHqWF0HEE,1175
18
+ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
16
19
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
17
20
  ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
18
- ocf_data_sampler/numpy_batch/__init__.py,sha256=M9b7L7kSgoZt2FKENDe-uFI_Qzs-fDecw-5qyrhnTm4,290
21
+ ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
19
22
  ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
20
23
  ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
21
24
  ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
25
+ ocf_data_sampler/numpy_batch/site.py,sha256=lJYMEot50UgSBnSOgADQMjUhky1YyWKYqwNsisyYV6w,789
22
26
  ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
23
- ocf_data_sampler/select/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
- ocf_data_sampler/select/dropout.py,sha256=zDpVLMjGb70RRyYKN-WI2Kp3x9SznstT4cMcZ4dsvJg,1066
27
+ ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
28
+ ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
25
29
  ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
26
30
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
27
- ocf_data_sampler/select/geospatial.py,sha256=oHJoKEKubn3v3yKCVeuiPxuGroVA4RyrpNi6ARq5woE,3558
31
+ ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
28
32
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
29
- ocf_data_sampler/select/select_spatial_slice.py,sha256=hWIJe4_VzuQ2iiiQh7V17AXwTILT5kIkUvzG458J_Gw,11220
33
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
30
34
  ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
35
+ ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=Nrc3j8DR5MM4BPPp9IQwaIMpoyOkc6AADMnfOjg-170,1791
36
+ ocf_data_sampler/select/time_slice_for_dataset.py,sha256=A9fxvurbM0JSRkrjyg5Lr70_Mj6t5OO7HFqHUZel9q4,4220
31
37
  ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
32
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=Mlz_uyt8c8-uN0uaEiJV2DgF5WAqtWlsINFgA925CZI,19025
38
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=Lovc2UM3-HgUy2BoQEIr0gQTz3USW6ACRWo-iTgxjHs,4993
39
+ ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=TpHALGU7hpo3iLbvD0nkoY6zu94Vq99W1V1qSGEcIW8,5552
40
+ ocf_data_sampler/torch_datasets/site.py,sha256=1k0fWXYwAAIWG5DX_j3tgNfY8gglfPGLNzNlZd8EnJs,6631
41
+ ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=dNJkBH5wdsFUjoFSmthU3yTqar6OPE77WsRQUebm-PY,4163
42
+ scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
33
43
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
- tests/conftest.py,sha256=O77dmow8mGpGPbZ6Pz7ma7cLaiV1k8mxW1eYg37Avrw,5585
44
+ tests/conftest.py,sha256=ZRktySCynj3NBbFRR4EFNLRLFMErkQsC-qQlmQzhbRg,7360
35
45
  tests/config/test_config.py,sha256=G_PD_pXib0zdRBPUIn0jjwJ9VyoKaO_TanLN1Mh5Ca4,5055
36
46
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
37
47
  tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
38
48
  tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
49
+ tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
39
50
  tests/numpy_batch/test_gsp.py,sha256=VANXV32K8aLX4dCdhCUnDorJmyNN-Bjc7Wc1N-RzWEk,548
40
51
  tests/numpy_batch/test_nwp.py,sha256=Fnj7cR-VR2Z0kMu8SrgnIayjxWnPWrYFjWSjMmnrh4Y,1445
41
52
  tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0dbsx1k,1179
@@ -47,8 +58,9 @@ tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts
47
58
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
48
59
  tests/select/test_select_time_slice.py,sha256=XC1J3DBBDnt81jcba5u-Hnd0yKv8GIQErLm-OECV6rs,10147
49
60
  tests/torch_datasets/test_pvnet_uk_regional.py,sha256=u3taw6p3oozM0_7cEEhCYbImAQPRldRhpruqSyV08Vg,2675
50
- ocf_data_sampler-0.0.24.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
51
- ocf_data_sampler-0.0.24.dist-info/METADATA,sha256=wwIltvHOvOd-L2KaOF3jsLOMx-QuY6yP6sNR0QddCdk,5269
52
- ocf_data_sampler-0.0.24.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
53
- ocf_data_sampler-0.0.24.dist-info/top_level.txt,sha256=KaQn5qzkJGJP6hKWqsVAc9t0cMLjVvSTk8-kTrW79SA,23
54
- ocf_data_sampler-0.0.24.dist-info/RECORD,,
61
+ tests/torch_datasets/test_site.py,sha256=5hdUP64neCDWEo2NMSd-MhbpuQjQvD6NOvhZ1DlMmo8,2733
62
+ ocf_data_sampler-0.0.25.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
+ ocf_data_sampler-0.0.25.dist-info/METADATA,sha256=p3SKEM4gRy0Z4LTcRWlgTrpjQ-QV89ar69tM9EwhudU,5269
64
+ ocf_data_sampler-0.0.25.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ ocf_data_sampler-0.0.25.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
+ ocf_data_sampler-0.0.25.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,2 +1,3 @@
1
1
  ocf_data_sampler
2
+ scripts
2
3
  tests
@@ -0,0 +1,50 @@
1
+ """ Helper functions for refactoring legacy site data """
2
+
3
+
4
+ def legacy_format(data_ds, metadata_df):
5
+ """This formats old legacy data to the new format.
6
+
7
+ 1. This renames the columns in the metadata
8
+ 2. Re-formats the site data from data variables named by the site_id to
9
+ a data array with a site_id dimension. Also adds capacity_kwp to the dataset as a time series for each site_id
10
+ """
11
+
12
+ if "system_id" in metadata_df.columns:
13
+ metadata_df["site_id"] = metadata_df["system_id"]
14
+
15
+ if "capacity_megawatts" in metadata_df.columns:
16
+ metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000
17
+
18
+ # only site data has the site_id as data variables.
19
+ # We want to join them all together and create another coordinate called site_id
20
+ if "0" in data_ds:
21
+ gen_df = data_ds.to_dataframe()
22
+ gen_da = xr.DataArray(
23
+ data=gen_df.values,
24
+ coords=(
25
+ ("time_utc", gen_df.index.values),
26
+ ("site_id", metadata_df["site_id"]),
27
+ ),
28
+ name="generation_kw",
29
+ )
30
+
31
+ capacity_df = gen_df
32
+ for col in capacity_df.columns:
33
+ capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][
34
+ "capacity_kwp"
35
+ ].iloc[0]
36
+ capacity_da = xr.DataArray(
37
+ data=capacity_df.values,
38
+ coords=(
39
+ ("time_utc", gen_df.index.values),
40
+ ("site_id", metadata_df["site_id"]),
41
+ ),
42
+ name="capacity_kwp",
43
+ )
44
+ data_ds = xr.Dataset(
45
+ {
46
+ "generation_kw": gen_da,
47
+ "capacity_kwp": capacity_da,
48
+ }
49
+ )
50
+ return data_ds
tests/conftest.py CHANGED
@@ -6,6 +6,8 @@ import pytest
6
6
  import xarray as xr
7
7
  import tempfile
8
8
 
9
+ from ocf_data_sampler.config.model import Site
10
+
9
11
  _top_test_directory = os.path.dirname(os.path.realpath(__file__))
10
12
 
11
13
  @pytest.fixture()
@@ -197,6 +199,66 @@ def ds_uk_gsp():
197
199
  })
198
200
 
199
201
 
202
+ @pytest.fixture(scope="session")
203
+ def data_sites() -> Site:
204
+ """
205
+ Make fake data for sites
206
+ Returns: filename for netcdf file, and csv metadata
207
+ """
208
+ times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
209
+ site_ids = list(range(0,10))
210
+ capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4])
211
+ # these are quite specific for the fake satellite data
212
+ longitude = np.arange(-4, -3, 0.1)
213
+ latitude = np.arange(51, 52, 0.1)
214
+
215
+ generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32)
216
+
217
+ # repeat capacity in new dims len(times) times
218
+ capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10)
219
+
220
+ coords = (
221
+ ("time_utc", times),
222
+ ("site_id", site_ids),
223
+ )
224
+
225
+ da_cap = xr.DataArray(
226
+ capacity_kwp,
227
+ coords=coords,
228
+ )
229
+
230
+ da_gen = xr.DataArray(
231
+ generation,
232
+ coords=coords,
233
+ )
234
+
235
+ # metadata
236
+ meta_df = pd.DataFrame(columns=[], data = [])
237
+ meta_df['site_id'] = site_ids
238
+ meta_df['capacity_kwp'] = capacity_kwp_1d
239
+ meta_df['longitude'] = longitude
240
+ meta_df['latitude'] = latitude
241
+
242
+ generation = xr.Dataset({
243
+ "capacity_kwp": da_cap,
244
+ "generation_kw": da_gen,
245
+ })
246
+
247
+ with tempfile.TemporaryDirectory() as tmpdir:
248
+ filename = tmpdir + "/sites.netcdf"
249
+ filename_csv = tmpdir + "/sites_metadata.csv"
250
+ generation.to_netcdf(filename)
251
+ meta_df.to_csv(filename_csv)
252
+
253
+ site = Site(file_path=filename,
254
+ metadata_file_path=filename_csv,
255
+ time_resolution_minutes=30,
256
+ forecast_minutes=60,
257
+ history_minutes=30)
258
+
259
+ yield site
260
+
261
+
200
262
  @pytest.fixture(scope="session")
201
263
  def uk_gsp_zarr_path(ds_uk_gsp):
202
264
 
@@ -0,0 +1,14 @@
1
+ from ocf_data_sampler.load.site import open_site
2
+ import xarray as xr
3
+
4
+
5
+ def test_open_site(data_sites):
6
+ da = open_site(data_sites)
7
+
8
+ assert isinstance(da, xr.DataArray)
9
+ assert da.dims == ("time_utc", "site_id")
10
+
11
+ assert "capacity_kwp" in da.coords
12
+ assert "latitude" in da.coords
13
+ assert "longitude" in da.coords
14
+ assert da.shape == (49, 10)
@@ -0,0 +1,85 @@
1
+ import pandas as pd
2
+ import pytest
3
+
4
+ from ocf_data_sampler.torch_datasets.site import SitesDataset
5
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
+ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
7
+ from ocf_data_sampler.numpy_batch.site import SiteBatchKey
8
+ from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
9
+
10
+
11
+ @pytest.fixture()
12
+ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
13
+
14
+ # adjust config to point to the zarr file
15
+ config = load_yaml_configuration(config_filename)
16
+ config.input_data.nwp["ukv"].nwp_zarr_path = nwp_ukv_zarr_path
17
+ config.input_data.satellite.satellite_zarr_path = sat_zarr_path
18
+ config.input_data.site = data_sites
19
+ config.input_data.gsp = None
20
+
21
+ filename = f"{tmp_path}/configuration.yaml"
22
+ save_yaml_configuration(config, filename)
23
+ return filename
24
+
25
+
26
+ def test_site(site_config_filename):
27
+
28
+ # Create dataset object
29
+ dataset = SitesDataset(site_config_filename)
30
+
31
+ assert len(dataset) == 10 * 41
32
+ # TODO check 41
33
+
34
+ # Generate a sample
35
+ sample = dataset[0]
36
+
37
+ assert isinstance(sample, dict)
38
+
39
+ for key in [
40
+ NWPBatchKey.nwp,
41
+ SatelliteBatchKey.satellite_actual,
42
+ SiteBatchKey.generation,
43
+ SiteBatchKey.site_solar_azimuth,
44
+ SiteBatchKey.site_solar_elevation,
45
+ ]:
46
+ assert key in sample
47
+
48
+ for nwp_source in ["ukv"]:
49
+ assert nwp_source in sample[NWPBatchKey.nwp]
50
+
51
+ # check the shape of the data is correct
52
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
53
+ assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
54
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
55
+ assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
56
+ # 3 hours of 30 minute data (inclusive)
57
+ assert sample[SiteBatchKey.generation].shape == (4,)
58
+ # Solar angles have same shape as GSP data
59
+ assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
60
+ assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)
61
+
62
+
63
+ def test_site_time_filter_start(site_config_filename):
64
+
65
+ # Create dataset object
66
+ dataset = SitesDataset(site_config_filename, start_time="2024-01-01")
67
+
68
+ assert len(dataset) == 0
69
+
70
+
71
+ def test_site_time_filter_end(site_config_filename):
72
+
73
+ # Create dataset object
74
+ dataset = SitesDataset(site_config_filename, end_time="2000-01-01")
75
+
76
+ assert len(dataset) == 0
77
+
78
+
79
+ def test_site_get_sample(site_config_filename):
80
+
81
+ # Create dataset object
82
+ dataset = SitesDataset(site_config_filename)
83
+
84
+ assert len(dataset) == 410
85
+ sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)