ocf-data-sampler 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (28) hide show
  1. ocf_data_sampler/config/model.py +84 -87
  2. ocf_data_sampler/load/load_dataset.py +55 -0
  3. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  4. ocf_data_sampler/load/site.py +30 -0
  5. ocf_data_sampler/numpy_batch/__init__.py +1 -0
  6. ocf_data_sampler/numpy_batch/site.py +29 -0
  7. ocf_data_sampler/select/__init__.py +8 -1
  8. ocf_data_sampler/select/dropout.py +2 -1
  9. ocf_data_sampler/select/geospatial.py +43 -1
  10. ocf_data_sampler/select/select_spatial_slice.py +8 -2
  11. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  12. ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
  13. ocf_data_sampler/time_functions.py +11 -0
  14. ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
  15. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
  16. ocf_data_sampler/torch_datasets/site.py +196 -0
  17. ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
  18. {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/METADATA +1 -1
  19. {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/RECORD +28 -16
  20. {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/WHEEL +1 -1
  21. {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/top_level.txt +1 -0
  22. scripts/refactor_site.py +50 -0
  23. tests/config/test_config.py +9 -6
  24. tests/conftest.py +62 -0
  25. tests/load/test_load_sites.py +14 -0
  26. tests/torch_datasets/test_pvnet_uk_regional.py +4 -4
  27. tests/torch_datasets/test_site.py +85 -0
  28. {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/LICENSE +0 -0
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \
6
+ find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods
7
+ from ocf_data_sampler.time_functions import minutes
8
+
9
+
10
+ def find_valid_time_periods(
11
+ datasets_dict: dict,
12
+ config: Configuration,
13
+ ):
14
+ """Find the t0 times where all of the requested input data is available
15
+
16
+ Args:
17
+ datasets_dict: A dictionary of input datasets
18
+ config: Configuration file
19
+ """
20
+
21
+ assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
22
+
23
+ contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
24
+
25
+ if "nwp" in datasets_dict:
26
+ for nwp_key, nwp_config in config.input_data.nwp.items():
27
+
28
+ da = datasets_dict["nwp"][nwp_key]
29
+
30
+ if nwp_config.dropout_timedeltas_minutes is None:
31
+ max_dropout = minutes(0)
32
+ else:
33
+ max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
34
+
35
+ if nwp_config.max_staleness_minutes is None:
36
+ max_staleness = None
37
+ else:
38
+ max_staleness = minutes(nwp_config.max_staleness_minutes)
39
+
40
+ # The last step of the forecast is lost if we have to diff channels
41
+ if len(nwp_config.accum_channels) > 0:
42
+ end_buffer = minutes(nwp_config.time_resolution_minutes)
43
+ else:
44
+ end_buffer = minutes(0)
45
+
46
+ # This is the max staleness we can use considering the max step of the input data
47
+ max_possible_staleness = (
48
+ pd.Timedelta(da["step"].max().item())
49
+ - minutes(nwp_config.forecast_minutes)
50
+ - end_buffer
51
+ )
52
+
53
+ # Default to use max possible staleness unless specified in config
54
+ if max_staleness is None:
55
+ max_staleness = max_possible_staleness
56
+ else:
57
+ # Make sure the max acceptable staleness isn't longer than the max possible
58
+ assert max_staleness <= max_possible_staleness
59
+
60
+ time_periods = find_contiguous_t0_periods_nwp(
61
+ datetimes=pd.DatetimeIndex(da["init_time_utc"]),
62
+ history_duration=minutes(nwp_config.history_minutes),
63
+ max_staleness=max_staleness,
64
+ max_dropout=max_dropout,
65
+ )
66
+
67
+ contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
68
+
69
+ if "sat" in datasets_dict:
70
+ sat_config = config.input_data.satellite
71
+
72
+ time_periods = find_contiguous_t0_periods(
73
+ pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
74
+ sample_period_duration=minutes(sat_config.time_resolution_minutes),
75
+ history_duration=minutes(sat_config.history_minutes),
76
+ forecast_duration=minutes(sat_config.forecast_minutes),
77
+ )
78
+
79
+ contiguous_time_periods['sat'] = time_periods
80
+
81
+ if "gsp" in datasets_dict:
82
+ gsp_config = config.input_data.gsp
83
+
84
+ time_periods = find_contiguous_t0_periods(
85
+ pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
86
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
87
+ history_duration=minutes(gsp_config.history_minutes),
88
+ forecast_duration=minutes(gsp_config.forecast_minutes),
89
+ )
90
+
91
+ contiguous_time_periods['gsp'] = time_periods
92
+
93
+ # just get the values (not the keys)
94
+ contiguous_time_periods_values = list(contiguous_time_periods.values())
95
+
96
+ # Find joint overlapping contiguous time periods
97
+ if len(contiguous_time_periods_values) > 1:
98
+ valid_time_periods = intersection_of_multiple_dataframes_of_periods(
99
+ contiguous_time_periods_values
100
+ )
101
+ else:
102
+ valid_time_periods = contiguous_time_periods_values[0]
103
+
104
+ # check there are some valid time periods
105
+ if len(valid_time_periods) == 0:
106
+ raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
107
+
108
+ return valid_time_periods
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.24
3
+ Version: 0.0.26
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,41 +1,52 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
2
  ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
3
+ ocf_data_sampler/time_functions.py,sha256=R6ZlVEe6h4UlJeUW7paZYAMWveOv9MTjMsoISCwnsiE,284
3
4
  ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
4
5
  ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
5
- ocf_data_sampler/config/model.py,sha256=bvU3BEMtcUh-N17fMVLTYtN-J2GcTM9Qq-CI5AfbE4Q,8128
6
+ ocf_data_sampler/config/model.py,sha256=YnGOzt6T835h6bozWqrlMnUIHPo26U8o-DTKAKvv_24,7121
6
7
  ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
7
8
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
9
  ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
9
10
  ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
+ ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
10
12
  ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
13
+ ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
11
14
  ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
12
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
13
16
  ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
14
17
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=vW-p3vCyQ-CofKo555-gE7VDi5hlpjtjTLfHqWF0HEE,1175
18
+ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
16
19
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
17
20
  ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
18
- ocf_data_sampler/numpy_batch/__init__.py,sha256=M9b7L7kSgoZt2FKENDe-uFI_Qzs-fDecw-5qyrhnTm4,290
21
+ ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
19
22
  ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
20
23
  ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
21
24
  ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
25
+ ocf_data_sampler/numpy_batch/site.py,sha256=lJYMEot50UgSBnSOgADQMjUhky1YyWKYqwNsisyYV6w,789
22
26
  ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
23
- ocf_data_sampler/select/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
- ocf_data_sampler/select/dropout.py,sha256=zDpVLMjGb70RRyYKN-WI2Kp3x9SznstT4cMcZ4dsvJg,1066
27
+ ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
28
+ ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
25
29
  ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
26
30
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
27
- ocf_data_sampler/select/geospatial.py,sha256=oHJoKEKubn3v3yKCVeuiPxuGroVA4RyrpNi6ARq5woE,3558
31
+ ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
28
32
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
29
- ocf_data_sampler/select/select_spatial_slice.py,sha256=hWIJe4_VzuQ2iiiQh7V17AXwTILT5kIkUvzG458J_Gw,11220
33
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
30
34
  ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
35
+ ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
36
+ ocf_data_sampler/select/time_slice_for_dataset.py,sha256=5gcTGgQ1D524OhullNRWq3hxCwl2SoliGR210G-62JA,4216
31
37
  ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
32
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=Mlz_uyt8c8-uN0uaEiJV2DgF5WAqtWlsINFgA925CZI,19025
38
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=GA-tGZLEMNAqX5Zun_7tPcTWVxlVtwejC9zfXPECwSk,4989
39
+ ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=TpHALGU7hpo3iLbvD0nkoY6zu94Vq99W1V1qSGEcIW8,5552
40
+ ocf_data_sampler/torch_datasets/site.py,sha256=1k0fWXYwAAIWG5DX_j3tgNfY8gglfPGLNzNlZd8EnJs,6631
41
+ ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=vP25e7DpWAu4dACTFMJZm0bi304iUFdi1XySAmxi_c0,4159
42
+ scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
33
43
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
- tests/conftest.py,sha256=O77dmow8mGpGPbZ6Pz7ma7cLaiV1k8mxW1eYg37Avrw,5585
35
- tests/config/test_config.py,sha256=G_PD_pXib0zdRBPUIn0jjwJ9VyoKaO_TanLN1Mh5Ca4,5055
44
+ tests/conftest.py,sha256=ZRktySCynj3NBbFRR4EFNLRLFMErkQsC-qQlmQzhbRg,7360
45
+ tests/config/test_config.py,sha256=C8NppoEVCMKxTTUf3o_z1Jb_I2DDH75XKpQ9x45U3Hw,5090
36
46
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
37
47
  tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
38
48
  tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
49
+ tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
39
50
  tests/numpy_batch/test_gsp.py,sha256=VANXV32K8aLX4dCdhCUnDorJmyNN-Bjc7Wc1N-RzWEk,548
40
51
  tests/numpy_batch/test_nwp.py,sha256=Fnj7cR-VR2Z0kMu8SrgnIayjxWnPWrYFjWSjMmnrh4Y,1445
41
52
  tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0dbsx1k,1179
@@ -46,9 +57,10 @@ tests/select/test_find_contiguous_time_periods.py,sha256=G6tJRJd0DMfH9EdfzlKWsmf
46
57
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
47
58
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
48
59
  tests/select/test_select_time_slice.py,sha256=XC1J3DBBDnt81jcba5u-Hnd0yKv8GIQErLm-OECV6rs,10147
49
- tests/torch_datasets/test_pvnet_uk_regional.py,sha256=u3taw6p3oozM0_7cEEhCYbImAQPRldRhpruqSyV08Vg,2675
50
- ocf_data_sampler-0.0.24.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
51
- ocf_data_sampler-0.0.24.dist-info/METADATA,sha256=wwIltvHOvOd-L2KaOF3jsLOMx-QuY6yP6sNR0QddCdk,5269
52
- ocf_data_sampler-0.0.24.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
53
- ocf_data_sampler-0.0.24.dist-info/top_level.txt,sha256=KaQn5qzkJGJP6hKWqsVAc9t0cMLjVvSTk8-kTrW79SA,23
54
- ocf_data_sampler-0.0.24.dist-info/RECORD,,
60
+ tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
61
+ tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
62
+ ocf_data_sampler-0.0.26.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
+ ocf_data_sampler-0.0.26.dist-info/METADATA,sha256=VRnSRX4dgDbz4k9bwSM66uqaHI4P97xC97_NsEIt5qU,5269
64
+ ocf_data_sampler-0.0.26.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ ocf_data_sampler-0.0.26.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
+ ocf_data_sampler-0.0.26.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,2 +1,3 @@
1
1
  ocf_data_sampler
2
+ scripts
2
3
  tests
@@ -0,0 +1,50 @@
1
+ """ Helper functions for refactoring legacy site data """
2
+
3
+
4
+ def legacy_format(data_ds, metadata_df):
5
+ """This formats old legacy data to the new format.
6
+
7
+ 1. This renames the columns in the metadata
8
+ 2. Re-formats the site data from data variables named by the site_id to
9
+ a data array with a site_id dimension. Also adds capacity_kwp to the dataset as a time series for each site_id
10
+ """
11
+
12
+ if "system_id" in metadata_df.columns:
13
+ metadata_df["site_id"] = metadata_df["system_id"]
14
+
15
+ if "capacity_megawatts" in metadata_df.columns:
16
+ metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000
17
+
18
+ # only site data has the site_id as data variables.
19
+ # We want to join them all together and create another coordinate called site_id
20
+ if "0" in data_ds:
21
+ gen_df = data_ds.to_dataframe()
22
+ gen_da = xr.DataArray(
23
+ data=gen_df.values,
24
+ coords=(
25
+ ("time_utc", gen_df.index.values),
26
+ ("site_id", metadata_df["site_id"]),
27
+ ),
28
+ name="generation_kw",
29
+ )
30
+
31
+ capacity_df = gen_df
32
+ for col in capacity_df.columns:
33
+ capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][
34
+ "capacity_kwp"
35
+ ].iloc[0]
36
+ capacity_da = xr.DataArray(
37
+ data=capacity_df.values,
38
+ coords=(
39
+ ("time_utc", gen_df.index.values),
40
+ ("site_id", metadata_df["site_id"]),
41
+ ),
42
+ name="capacity_kwp",
43
+ )
44
+ data_ds = xr.Dataset(
45
+ {
46
+ "generation_kw": gen_da,
47
+ "capacity_kwp": capacity_da,
48
+ }
49
+ )
50
+ return data_ds
@@ -10,13 +10,13 @@ from ocf_data_sampler.config import (
10
10
  )
11
11
 
12
12
 
13
- def test_default():
13
+ def test_default_configuration():
14
14
  """Test default pydantic class"""
15
15
 
16
16
  _ = Configuration()
17
17
 
18
18
 
19
- def test_yaml_load_test_config(test_config_filename):
19
+ def test_load_yaml_configuration(test_config_filename):
20
20
  """
21
21
  Test that yaml loading works for 'test_config.yaml'
22
22
  and fails for an empty .yaml file
@@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
56
56
  assert test_config == tmp_config
57
57
 
58
58
 
59
- def test_extra_field():
59
+ def test_extra_field_error():
60
60
  """
61
61
  Check an extra parameters in config causes error
62
62
  """
@@ -99,10 +99,11 @@ def test_incorrect_nwp_provider(test_config_filename):
99
99
 
100
100
  configuration = load_yaml_configuration(test_config_filename)
101
101
 
102
- configuration.input_data.nwp['ukv'].nwp_provider = "unexpected_provider"
102
+ configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
103
103
  with pytest.raises(Exception, match="NWP provider"):
104
104
  _ = Configuration(**configuration.model_dump())
105
105
 
106
+
106
107
  def test_incorrect_dropout(test_config_filename):
107
108
  """
108
109
  Check a dropout timedelta over 0 causes error and 0 doesn't
@@ -119,6 +120,7 @@ def test_incorrect_dropout(test_config_filename):
119
120
  configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
120
121
  _ = Configuration(**configuration.model_dump())
121
122
 
123
+
122
124
  def test_incorrect_dropout_fraction(test_config_filename):
123
125
  """
124
126
  Check dropout fraction outside of range causes error
@@ -127,11 +129,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
127
129
  configuration = load_yaml_configuration(test_config_filename)
128
130
 
129
131
  configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
130
- with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
132
+
133
+ with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
131
134
  _ = Configuration(**configuration.model_dump())
132
135
 
133
136
  configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
134
- with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
137
+ with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
135
138
  _ = Configuration(**configuration.model_dump())
136
139
 
137
140
 
tests/conftest.py CHANGED
@@ -6,6 +6,8 @@ import pytest
6
6
  import xarray as xr
7
7
  import tempfile
8
8
 
9
+ from ocf_data_sampler.config.model import Site
10
+
9
11
  _top_test_directory = os.path.dirname(os.path.realpath(__file__))
10
12
 
11
13
  @pytest.fixture()
@@ -197,6 +199,66 @@ def ds_uk_gsp():
197
199
  })
198
200
 
199
201
 
202
+ @pytest.fixture(scope="session")
203
+ def data_sites() -> Site:
204
+ """
205
+ Make fake data for sites
206
+ Returns: filename for netcdf file, and csv metadata
207
+ """
208
+ times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
209
+ site_ids = list(range(0,10))
210
+ capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4])
211
+ # these are quite specific for the fake satellite data
212
+ longitude = np.arange(-4, -3, 0.1)
213
+ latitude = np.arange(51, 52, 0.1)
214
+
215
+ generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32)
216
+
217
+ # repeat capacity in new dims len(times) times
218
+ capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10)
219
+
220
+ coords = (
221
+ ("time_utc", times),
222
+ ("site_id", site_ids),
223
+ )
224
+
225
+ da_cap = xr.DataArray(
226
+ capacity_kwp,
227
+ coords=coords,
228
+ )
229
+
230
+ da_gen = xr.DataArray(
231
+ generation,
232
+ coords=coords,
233
+ )
234
+
235
+ # metadata
236
+ meta_df = pd.DataFrame(columns=[], data = [])
237
+ meta_df['site_id'] = site_ids
238
+ meta_df['capacity_kwp'] = capacity_kwp_1d
239
+ meta_df['longitude'] = longitude
240
+ meta_df['latitude'] = latitude
241
+
242
+ generation = xr.Dataset({
243
+ "capacity_kwp": da_cap,
244
+ "generation_kw": da_gen,
245
+ })
246
+
247
+ with tempfile.TemporaryDirectory() as tmpdir:
248
+ filename = tmpdir + "/sites.netcdf"
249
+ filename_csv = tmpdir + "/sites_metadata.csv"
250
+ generation.to_netcdf(filename)
251
+ meta_df.to_csv(filename_csv)
252
+
253
+ site = Site(file_path=filename,
254
+ metadata_file_path=filename_csv,
255
+ time_resolution_minutes=30,
256
+ forecast_minutes=60,
257
+ history_minutes=30)
258
+
259
+ yield site
260
+
261
+
200
262
  @pytest.fixture(scope="session")
201
263
  def uk_gsp_zarr_path(ds_uk_gsp):
202
264
 
@@ -0,0 +1,14 @@
1
+ from ocf_data_sampler.load.site import open_site
2
+ import xarray as xr
3
+
4
+
5
+ def test_open_site(data_sites):
6
+ da = open_site(data_sites)
7
+
8
+ assert isinstance(da, xr.DataArray)
9
+ assert da.dims == ("time_utc", "site_id")
10
+
11
+ assert "capacity_kwp" in da.coords
12
+ assert "latitude" in da.coords
13
+ assert "longitude" in da.coords
14
+ assert da.shape == (49, 10)
@@ -11,9 +11,9 @@ def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_z
11
11
 
12
12
  # adjust config to point to the zarr file
13
13
  config = load_yaml_configuration(config_filename)
14
- config.input_data.nwp['ukv'].nwp_zarr_path = nwp_ukv_zarr_path
15
- config.input_data.satellite.satellite_zarr_path = sat_zarr_path
16
- config.input_data.gsp.gsp_zarr_path = uk_gsp_zarr_path
14
+ config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
15
+ config.input_data.satellite.zarr_path = sat_zarr_path
16
+ config.input_data.gsp.zarr_path = uk_gsp_zarr_path
17
17
 
18
18
  filename = f"{tmp_path}/configuration.yaml"
19
19
  save_yaml_configuration(config, filename)
@@ -60,7 +60,7 @@ def test_pvnet_no_gsp(pvnet_config_filename):
60
60
  # load config
61
61
  config = load_yaml_configuration(pvnet_config_filename)
62
62
  # remove gsp
63
- config.input_data.gsp.gsp_zarr_path = ''
63
+ config.input_data.gsp.zarr_path = ''
64
64
 
65
65
  # save temp config file
66
66
  with tempfile.NamedTemporaryFile() as temp_config_file:
@@ -0,0 +1,85 @@
1
+ import pandas as pd
2
+ import pytest
3
+
4
+ from ocf_data_sampler.torch_datasets.site import SitesDataset
5
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
+ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
7
+ from ocf_data_sampler.numpy_batch.site import SiteBatchKey
8
+ from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
9
+
10
+
11
+ @pytest.fixture()
12
+ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
13
+
14
+ # adjust config to point to the zarr file
15
+ config = load_yaml_configuration(config_filename)
16
+ config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
17
+ config.input_data.satellite.zarr_path = sat_zarr_path
18
+ config.input_data.site = data_sites
19
+ config.input_data.gsp = None
20
+
21
+ filename = f"{tmp_path}/configuration.yaml"
22
+ save_yaml_configuration(config, filename)
23
+ return filename
24
+
25
+
26
+ def test_site(site_config_filename):
27
+
28
+ # Create dataset object
29
+ dataset = SitesDataset(site_config_filename)
30
+
31
+ assert len(dataset) == 10 * 41
32
+ # TODO check 41
33
+
34
+ # Generate a sample
35
+ sample = dataset[0]
36
+
37
+ assert isinstance(sample, dict)
38
+
39
+ for key in [
40
+ NWPBatchKey.nwp,
41
+ SatelliteBatchKey.satellite_actual,
42
+ SiteBatchKey.generation,
43
+ SiteBatchKey.site_solar_azimuth,
44
+ SiteBatchKey.site_solar_elevation,
45
+ ]:
46
+ assert key in sample
47
+
48
+ for nwp_source in ["ukv"]:
49
+ assert nwp_source in sample[NWPBatchKey.nwp]
50
+
51
+ # check the shape of the data is correct
52
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
53
+ assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
54
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
55
+ assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
56
+ # 3 hours of 30 minute data (inclusive)
57
+ assert sample[SiteBatchKey.generation].shape == (4,)
58
+ # Solar angles have same shape as GSP data
59
+ assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
60
+ assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)
61
+
62
+
63
+ def test_site_time_filter_start(site_config_filename):
64
+
65
+ # Create dataset object
66
+ dataset = SitesDataset(site_config_filename, start_time="2024-01-01")
67
+
68
+ assert len(dataset) == 0
69
+
70
+
71
+ def test_site_time_filter_end(site_config_filename):
72
+
73
+ # Create dataset object
74
+ dataset = SitesDataset(site_config_filename, end_time="2000-01-01")
75
+
76
+ assert len(dataset) == 0
77
+
78
+
79
+ def test_site_get_sample(site_config_filename):
80
+
81
+ # Create dataset object
82
+ dataset = SitesDataset(site_config_filename)
83
+
84
+ assert len(dataset) == 410
85
+ sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)