ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (78) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +146 -64
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/load/gsp.py +6 -5
  5. ocf_data_sampler/load/load_dataset.py +5 -6
  6. ocf_data_sampler/load/nwp/nwp.py +17 -5
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  8. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  9. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  10. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  11. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  12. ocf_data_sampler/load/satellite.py +9 -10
  13. ocf_data_sampler/load/site.py +10 -6
  14. ocf_data_sampler/load/utils.py +21 -16
  15. ocf_data_sampler/numpy_sample/collate.py +10 -9
  16. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  17. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  18. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  19. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  20. ocf_data_sampler/numpy_sample/site.py +5 -8
  21. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  22. ocf_data_sampler/sample/base.py +15 -17
  23. ocf_data_sampler/sample/site.py +13 -20
  24. ocf_data_sampler/sample/uk_regional.py +29 -35
  25. ocf_data_sampler/select/dropout.py +16 -14
  26. ocf_data_sampler/select/fill_time_periods.py +15 -5
  27. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  28. ocf_data_sampler/select/geospatial.py +63 -54
  29. ocf_data_sampler/select/location.py +16 -51
  30. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  31. ocf_data_sampler/select/select_time_slice.py +71 -58
  32. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  33. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  34. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
  35. ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
  36. ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  37. ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +63 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler/constants.py +0 -222
  48. ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
  49. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  50. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  51. tests/__init__.py +0 -0
  52. tests/config/test_config.py +0 -113
  53. tests/config/test_load.py +0 -7
  54. tests/config/test_save.py +0 -28
  55. tests/conftest.py +0 -319
  56. tests/load/test_load_gsp.py +0 -15
  57. tests/load/test_load_nwp.py +0 -21
  58. tests/load/test_load_satellite.py +0 -17
  59. tests/load/test_load_sites.py +0 -14
  60. tests/numpy_sample/test_collate.py +0 -21
  61. tests/numpy_sample/test_datetime_features.py +0 -37
  62. tests/numpy_sample/test_gsp.py +0 -38
  63. tests/numpy_sample/test_nwp.py +0 -13
  64. tests/numpy_sample/test_satellite.py +0 -40
  65. tests/numpy_sample/test_sun_position.py +0 -81
  66. tests/select/test_dropout.py +0 -69
  67. tests/select/test_fill_time_periods.py +0 -28
  68. tests/select/test_find_contiguous_time_periods.py +0 -202
  69. tests/select/test_location.py +0 -67
  70. tests/select/test_select_spatial_slice.py +0 -154
  71. tests/select/test_select_time_slice.py +0 -275
  72. tests/test_sample/test_base.py +0 -164
  73. tests/test_sample/test_site_sample.py +0 -165
  74. tests/test_sample/test_uk_regional_sample.py +0 -136
  75. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  76. tests/torch_datasets/test_pvnet_uk.py +0 -154
  77. tests/torch_datasets/test_site.py +0 -226
  78. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,82 +0,0 @@
1
- import xarray as xr
2
-
3
- from ocf_data_sampler.config import Configuration
4
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
5
-
6
-
7
- def validate_channels(
8
- data_channels: list,
9
- means_channels: list,
10
- stds_channels: list,
11
- source_name: str | None = None
12
- ) -> None:
13
- """
14
- Validates that all channels in data have corresponding normalisation constants.
15
-
16
- Args:
17
- data_channels: Set of channels from the data
18
- means_channels: Set of channels from means constants
19
- stds_channels: Set of channels from stds constants
20
- source_name: Name of data source (e.g., 'ecmwf', 'satellite') for error messages
21
-
22
- Raises:
23
- ValueError: If there's a mismatch between data channels and normalisation constants
24
- """
25
-
26
- data_set = set(data_channels)
27
- means_set = set(means_channels)
28
- stds_set = set(stds_channels)
29
-
30
- # Find missing channels in means
31
- missing_in_means = data_set - means_set
32
- if missing_in_means:
33
- raise ValueError(
34
- f"The following channels for {source_name} are missing in normalisation means: "
35
- f"{missing_in_means}"
36
- )
37
-
38
- # Find missing channels in stds
39
- missing_in_stds = data_set - stds_set
40
- if missing_in_stds:
41
- raise ValueError(
42
- f"The following channels for {source_name} are missing in normalisation stds: "
43
- f"{missing_in_stds}"
44
- )
45
-
46
-
47
- def validate_nwp_channels(config: Configuration) -> None:
48
- """Validate that NWP channels in config have corresponding normalisation constants.
49
-
50
- Args:
51
- config: Configuration object containing NWP channel information
52
-
53
- Raises:
54
- ValueError: If there's a mismatch between configured NWP channels and normalisation constants
55
- """
56
- if hasattr(config.input_data, "nwp"):
57
- for nwp_key, nwp_config in config.input_data.nwp.items():
58
- provider = nwp_config.provider
59
- validate_channels(
60
- data_channels=nwp_config.channels,
61
- means_channels=NWP_MEANS[provider].channel.values,
62
- stds_channels=NWP_STDS[provider].channel.values,
63
- source_name=provider
64
- )
65
-
66
-
67
- def validate_satellite_channels(config: Configuration) -> None:
68
- """Validate that satellite channels in config have corresponding normalisation constants.
69
-
70
- Args:
71
- config: Configuration object containing satellite channel information
72
-
73
- Raises:
74
- ValueError: If there's a mismatch between configured satellite channels and normalisation constants
75
- """
76
- if hasattr(config.input_data, "satellite"):
77
- validate_channels(
78
- data_channels=config.input_data.satellite.channels,
79
- means_channels=RSS_MEAN.channel.values,
80
- stds_channels=RSS_STD.channel.values,
81
- source_name="satellite"
82
- )
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Open Climate Fix
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
@@ -1,82 +0,0 @@
1
- ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
- ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs,5077
3
- ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
- ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
5
- ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
6
- ocf_data_sampler/config/model.py,sha256=8PO-23uVy_JjWOJKgaZWdNMehQsAI-Jn8t0lcmBycwg,6992
7
- ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
8
- ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
9
- ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
10
- ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
11
- ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
12
- ocf_data_sampler/load/satellite.py,sha256=SEQZ9oPe-asEeZeEMDkB1xWK5hErhWMagxohFcBl6KI,2294
13
- ocf_data_sampler/load/site.py,sha256=hMdoF6sn2PcSBfF2soj7nuQoK9SItaxDXco5nk2n-44,1232
14
- ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
15
- ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
- ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
17
- ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=8rYZKdV62AdczVNSOJ2G0BM4-fRFRV0_y5zkHgNYkQs,1004
19
- ocf_data_sampler/load/nwp/providers/ukv.py,sha256=dM_kvUI0xk9xEdslXqZGjOPP96PEw3qAci5mPUgUvxA,1014
20
- ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s5ebRx3_6DO-33FQ,780
21
- ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
22
- ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
23
- ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
24
- ocf_data_sampler/numpy_sample/gsp.py,sha256=uBquCFCoWuhJKY8sXpgsTCUDWUuLuv1XeixtFnFw6KU,1115
25
- ocf_data_sampler/numpy_sample/nwp.py,sha256=Tiba-es23XeyMoEPgZUpLT6EnJCGU9A_1MdY6qkE7bM,1015
26
- ocf_data_sampler/numpy_sample/satellite.py,sha256=RdXMdGGXysUx-AdL9T33yFOlxprtIdPNBKKX99-mhpY,991
27
- ocf_data_sampler/numpy_sample/site.py,sha256=TvoEU85fmjYW8pD9UZOyUUACjimdQYxEzulQXunRO6Q,1425
28
- ocf_data_sampler/numpy_sample/sun_position.py,sha256=ithM--eztAhiIQ1g52tlxgj-tMKbsJzx8mk6CgV2tzk,1613
29
- ocf_data_sampler/sample/__init__.py,sha256=zdS73NTnxFX_j8uh9tT-IXiURB6635wbneM1koWYV1o,169
30
- ocf_data_sampler/sample/base.py,sha256=IH3HbfqEUwjHmq-h2eJYLd8Jk-0ZcOylnehMyCPMV38,2223
31
- ocf_data_sampler/sample/site.py,sha256=ONf2Yz5zi8Ombd_znA4T7NXbO01F76kQsBZv6rfnC74,1343
32
- ocf_data_sampler/sample/uk_regional.py,sha256=KhJ5Ik1pZRp7PgIJjGIrE4i7SQnIdVjUbBHnfn-7ghg,2649
33
- ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
34
- ocf_data_sampler/select/dropout.py,sha256=Pgov9P7rQMkSdqluG_hwm8loGyYNFOg-3PJUBLN_kjU,1526
35
- ocf_data_sampler/select/fill_time_periods.py,sha256=EIcXG-77aQVOAYNwbDBEv6SGf6DO2p1WMEf96iW4MEM,596
36
- ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=IwPQwvgu4cOiAZ5Gbjflv3fnQCcs0EVK0g4V6yqqSgw,11129
37
- ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
38
- ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
39
- ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
40
- ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56sNWk3BnCnkCgcPI,4725
41
- ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
- ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
- ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=ZgfvVCcEU3dj3RoY0zdBdKGppC7Wm81qecqB17gYTmE,12286
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_uHmqg-VJu-MHgXc5JFDX1noPfH6E8nY4XhQmsrOav4,16325
46
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
- ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
49
- scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
50
- tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
- tests/conftest.py,sha256=k7nM3u2YJmkMupN4SIbJP3BRoxNR1dpIoo2fPFf0abg,8588
52
- tests/config/test_config.py,sha256=CzYVhAUpgT4lvQdIddtVxtJeMqYL_TJolfeIwaaohq4,3969
53
- tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
54
- tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
55
- tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
56
- tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
57
- tests/load/test_load_satellite.py,sha256=IQ8ISRZKCEoi8IsJoPpXZJTolD0mwjnl2E7762RM_PM,524
58
- tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU,405
59
- tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
60
- tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
61
- tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
62
- tests/numpy_sample/test_nwp.py,sha256=Lnd-PMa6gI-fSIJkSZ554QiHFfnwxeXZxLg-rpuBv1U,442
63
- tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
64
- tests/numpy_sample/test_sun_position.py,sha256=_ENYzsNBVPdNXf--FI-UUFqw2u5w7_zqw6LcENU2uZM,2504
65
- tests/select/test_dropout.py,sha256=aQuSSqZF9RxBjN9-ogkQ8O-_zktAM30CrT1Lz7j1hMg,2222
66
- tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
67
- tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
68
- tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
69
- tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
70
- tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
71
- tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
72
- tests/test_sample/test_site_sample.py,sha256=8HNenhIWYouCQu4y389PDQGokSPI5jQ4lS4CG-eA1Y8,5382
73
- tests/test_sample/test_uk_regional_sample.py,sha256=MFibX9-M8mFK7vwMPu58gAG2VoY6y7w7chW5BlZclwk,3962
74
- tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
75
- tests/torch_datasets/test_pvnet_uk.py,sha256=hgD_IDa4D8cgc4cgK1UqKYkT6sFlrTMAvgVn_iwD5_4,5086
76
- tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
77
- tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
- ocf_data_sampler-0.1.11.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
- ocf_data_sampler-0.1.11.dist-info/METADATA,sha256=d8wctSlRyDbP1_yYHFvIGQgEC8DmOkM8h-ITI4XFuPw,12174
80
- ocf_data_sampler-0.1.11.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
- ocf_data_sampler-0.1.11.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
- ocf_data_sampler-0.1.11.dist-info/RECORD,,
tests/__init__.py DELETED
File without changes
@@ -1,113 +0,0 @@
1
- import pytest
2
- from pydantic import ValidationError
3
- from ocf_data_sampler.config import load_yaml_configuration, Configuration
4
-
5
-
6
- def test_default_configuration():
7
- """Test default pydantic class"""
8
- _ = Configuration()
9
-
10
-
11
- def test_extra_field_error():
12
- """
13
- Check an extra parameters in config causes error
14
- """
15
-
16
- configuration = Configuration()
17
- configuration_dict = configuration.model_dump()
18
- configuration_dict["extra_field"] = "extra_value"
19
- with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
20
- _ = Configuration(**configuration_dict)
21
-
22
-
23
- def test_incorrect_interval_start_minutes(test_config_filename):
24
- """
25
- Check a history length not divisible by time resolution causes error
26
- """
27
-
28
- configuration = load_yaml_configuration(test_config_filename)
29
-
30
- configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
31
- with pytest.raises(
32
- ValueError,
33
- match="interval_start_minutes.*must be divisible.*time_resolution_minutes.*"
34
- ):
35
- _ = Configuration(**configuration.model_dump())
36
-
37
-
38
- def test_incorrect_interval_end_minutes(test_config_filename):
39
- """
40
- Check a forecast length not divisible by time resolution causes error
41
- """
42
-
43
- configuration = load_yaml_configuration(test_config_filename)
44
-
45
- configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
46
- with pytest.raises(
47
- ValueError,
48
- match="interval_end_minutes.*must be divisible.*time_resolution_minutes.*"
49
- ):
50
- _ = Configuration(**configuration.model_dump())
51
-
52
-
53
- def test_incorrect_nwp_provider(test_config_filename):
54
- """
55
- Check an unexpected nwp provider causes error
56
- """
57
-
58
- configuration = load_yaml_configuration(test_config_filename)
59
-
60
- configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
61
- with pytest.raises(Exception, match="NWP provider"):
62
- _ = Configuration(**configuration.model_dump())
63
-
64
-
65
- def test_incorrect_dropout(test_config_filename):
66
- """
67
- Check a dropout timedelta over 0 causes error and 0 doesn't
68
- """
69
-
70
- configuration = load_yaml_configuration(test_config_filename)
71
-
72
- # check a positive number is not allowed
73
- configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [120]
74
- with pytest.raises(Exception, match="Dropout timedeltas must be negative"):
75
- _ = Configuration(**configuration.model_dump())
76
-
77
- # check 0 is allowed
78
- configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
79
- _ = Configuration(**configuration.model_dump())
80
-
81
-
82
- def test_incorrect_dropout_fraction(test_config_filename):
83
- """
84
- Check dropout fraction outside of range causes error
85
- """
86
-
87
- configuration = load_yaml_configuration(test_config_filename)
88
-
89
- configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
90
-
91
- with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
92
- _ = Configuration(**configuration.model_dump())
93
-
94
- configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
95
- with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
96
- _ = Configuration(**configuration.model_dump())
97
-
98
-
99
- def test_inconsistent_dropout_use(test_config_filename):
100
- """
101
- Check dropout fraction outside of range causes error
102
- """
103
-
104
- configuration = load_yaml_configuration(test_config_filename)
105
- configuration.input_data.satellite.dropout_fraction= 1.0
106
- configuration.input_data.satellite.dropout_timedeltas_minutes = []
107
-
108
- with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"):
109
- _ = Configuration(**configuration.model_dump())
110
- configuration.input_data.satellite.dropout_fraction= 0.0
111
- configuration.input_data.satellite.dropout_timedeltas_minutes = [-120, -60]
112
- with pytest.raises(ValueError, match="To use dropout timedeltas dropout fraction should be > 0"):
113
- _ = Configuration(**configuration.model_dump())
tests/config/test_load.py DELETED
@@ -1,7 +0,0 @@
1
- from ocf_data_sampler.config import Configuration, load_yaml_configuration
2
-
3
-
4
- def test_load_yaml_configuration(test_config_filename):
5
- loaded_config = load_yaml_configuration(test_config_filename)
6
- assert isinstance(loaded_config, Configuration)
7
-
tests/config/test_save.py DELETED
@@ -1,28 +0,0 @@
1
- """Tests for configuration saving functionality."""
2
- import os
3
- from ocf_data_sampler.config import Configuration, save_yaml_configuration, load_yaml_configuration
4
-
5
-
6
- def test_save_yaml_configuration_basic(tmp_path):
7
- """Save an empty configuration object"""
8
- config = Configuration()
9
-
10
- filepath = f"{tmp_path}/config.yaml"
11
- save_yaml_configuration(config, filepath)
12
-
13
- assert os.path.exists(filepath)
14
-
15
-
16
- def test_save_load_yaml_configuration(tmp_path, test_config_filename):
17
- """Make sure a saved configuration is the same after loading"""
18
-
19
- # Start with this config
20
- initial_config = load_yaml_configuration(test_config_filename)
21
-
22
- # Save it
23
- filepath = f"{tmp_path}/config.yaml"
24
- save_yaml_configuration(initial_config, filepath)
25
-
26
- # Load it and check it is still the same
27
- loaded_config = load_yaml_configuration(filepath)
28
- assert loaded_config == initial_config
tests/conftest.py DELETED
@@ -1,319 +0,0 @@
1
- import pytest
2
-
3
- import os
4
- import numpy as np
5
- import pandas as pd
6
- import xarray as xr
7
- import dask.array
8
-
9
- from ocf_data_sampler.config.model import Site
10
- from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
11
-
12
-
13
- _top_test_directory = os.path.dirname(os.path.realpath(__file__))
14
-
15
- @pytest.fixture()
16
- def test_config_filename():
17
- return f"{_top_test_directory}/test_data/configs/test_config.yaml"
18
-
19
-
20
- @pytest.fixture(scope="session")
21
- def config_filename():
22
- return f"{_top_test_directory}/test_data/configs/pvnet_test_config.yaml"
23
-
24
-
25
- @pytest.fixture(scope="session")
26
- def session_tmp_path(tmp_path_factory):
27
- return tmp_path_factory.mktemp("data")
28
-
29
-
30
- @pytest.fixture(scope="session")
31
- def sat_zarr_path(session_tmp_path):
32
-
33
- # Define coords for satellite-like dataset
34
- variables = [
35
- 'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
36
- 'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073',
37
- ]
38
- x = np.linspace(start=15002, stop=-1824245, num=100)
39
- y = np.linspace(start=4191563, stop=5304712, num=100)
40
- times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min")
41
-
42
- area_string = (
43
- """msg_seviri_rss_3km:
44
- description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution
45
- projection:
46
- proj: geos
47
- lon_0: 9.5
48
- h: 35785831
49
- x_0: 0
50
- y_0: 0
51
- a: 6378169
52
- rf: 295.488065897014
53
- no_defs: null
54
- type: crs
55
- shape:
56
- height: 298
57
- width: 615
58
- area_extent:
59
- lower_left_xy: [28503.830075263977, 5090183.970808983]
60
- upper_right_xy: [-1816744.1169023514, 4196063.827395439]
61
- units: m
62
- """
63
- )
64
-
65
- # Create satellite-like data with some NaNs
66
- data = dask.array.zeros(
67
- shape=(len(variables), len(times), len(y), len(x)),
68
- chunks=(-1, 10, -1, -1),
69
- dtype=np.float32
70
- )
71
- data [:, 10, :, :] = np.nan
72
-
73
- ds = xr.DataArray(
74
- data=data,
75
- coords=dict(
76
- variable=variables,
77
- time=times,
78
- y_geostationary=y,
79
- x_geostationary=x,
80
- ),
81
- attrs=dict(area=area_string),
82
- ).to_dataset(name="data")
83
-
84
- # Save temporarily as a zarr
85
- zarr_path = session_tmp_path / "test_sat.zarr"
86
- ds.to_zarr(zarr_path)
87
-
88
- yield zarr_path
89
-
90
-
91
- @pytest.fixture(scope="session")
92
- def ds_nwp_ukv():
93
- init_times = pd.date_range(start="2023-01-01 00:00", freq="180min", periods=24 * 7)
94
- steps = pd.timedelta_range("0h", "10h", freq="1h")
95
-
96
- x = np.linspace(-239_000, 857_000, 50)
97
- y = np.linspace(-183_000, 1225_000, 100)
98
- variables = ["si10", "dswrf", "t", "prate"]
99
-
100
- coords = (
101
- ("init_time", init_times),
102
- ("variable", variables),
103
- ("step", steps),
104
- ("x", x),
105
- ("y", y),
106
- )
107
-
108
- nwp_array_shape = tuple(len(coord_values) for _, coord_values in coords)
109
-
110
- nwp_data = xr.DataArray(
111
- np.random.uniform(0, 200, size=nwp_array_shape).astype(np.float32),
112
- coords=coords,
113
- )
114
- return nwp_data.to_dataset(name="UKV")
115
-
116
-
117
- @pytest.fixture(scope="session")
118
- def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
119
- ds = ds_nwp_ukv.chunk(
120
- {
121
- "init_time": 1,
122
- "step": -1,
123
- "variable": -1,
124
- "x": 50,
125
- "y": 50,
126
- }
127
- )
128
- zarr_path = session_tmp_path / "ukv_nwp.zarr"
129
- ds.to_zarr(zarr_path)
130
- yield zarr_path
131
-
132
-
133
- @pytest.fixture()
134
- def ds_nwp_ukv_time_sliced():
135
-
136
- t0 = pd.to_datetime("2024-01-02 00:00")
137
-
138
- x = np.arange(-100, 100, 10)
139
- y = np.arange(-100, 100, 10)
140
- steps = pd.timedelta_range("0h", "8h", freq="1h")
141
- target_times = t0 + steps
142
-
143
- channels = ["t", "dswrf"]
144
- init_times = pd.to_datetime([t0]*len(steps))
145
-
146
- # Create dummy time-sliced NWP data
147
- da_nwp = xr.DataArray(
148
- np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
149
- coords=dict(
150
- target_time_utc=(["target_time_utc"], target_times),
151
- channel=(["channel"], channels),
152
- x_osgb=(["x_osgb"], x),
153
- y_osgb=(["y_osgb"], y),
154
- )
155
- )
156
-
157
- # Add extra non-coordinate dimensions
158
- da_nwp = da_nwp.assign_coords(
159
- init_time_utc=("target_time_utc", init_times),
160
- step=("target_time_utc", steps),
161
- )
162
-
163
- return da_nwp
164
-
165
-
166
- @pytest.fixture(scope="session")
167
- def ds_nwp_ecmwf():
168
- init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7)
169
- steps = pd.timedelta_range("0h", "14h", freq="1h")
170
-
171
- lons = np.arange(-12, 3)
172
- lats = np.arange(48, 60)
173
- variables = ["t2m","dswrf", "mcc"]
174
-
175
- coords = (
176
- ("init_time", init_times),
177
- ("variable", variables),
178
- ("step", steps),
179
- ("longitude", lons),
180
- ("latitude", lats),
181
- )
182
-
183
- nwp_array_shape = tuple(len(coord_values) for _, coord_values in coords)
184
-
185
- nwp_data = xr.DataArray(
186
- np.random.uniform(0, 200, size=nwp_array_shape).astype(np.float32),
187
- coords=coords,
188
- )
189
- return nwp_data.to_dataset(name="ECMWF_UK")
190
-
191
-
192
- @pytest.fixture(scope="session")
193
- def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
194
- ds = ds_nwp_ecmwf.chunk(
195
- {
196
- "init_time": 1,
197
- "step": -1,
198
- "variable": -1,
199
- "longitude": 50,
200
- "latitude": 50,
201
- }
202
- )
203
-
204
- zarr_path = session_tmp_path / "ukv_ecmwf.zarr"
205
- ds.to_zarr(zarr_path)
206
- yield zarr_path
207
-
208
-
209
- @pytest.fixture(scope="session")
210
- def ds_uk_gsp():
211
- times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
212
- gsp_ids = np.arange(0, 318)
213
- capacity = np.ones((len(times), len(gsp_ids)))
214
- generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(np.float32)
215
-
216
- coords = (
217
- ("datetime_gmt", times),
218
- ("gsp_id", gsp_ids),
219
- )
220
-
221
- da_cap = xr.DataArray(
222
- capacity,
223
- coords=coords,
224
- )
225
-
226
- da_gen = xr.DataArray(
227
- generation,
228
- coords=coords,
229
- )
230
-
231
- return xr.Dataset({
232
- "capacity_mwp": da_cap,
233
- "installedcapacity_mwp": da_cap,
234
- "generation_mw":da_gen
235
- })
236
-
237
-
238
- @pytest.fixture(scope="session")
239
- def data_sites(session_tmp_path) -> Site:
240
- """
241
- Make fake data for sites
242
- Returns: filename for netcdf file, and csv metadata
243
- """
244
- times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
245
- site_ids = list(range(0,10))
246
- capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4])
247
- # these are quite specific for the fake satellite data
248
- longitude = np.arange(-4, -3, 0.1)
249
- latitude = np.arange(51, 52, 0.1)
250
-
251
- generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32)
252
-
253
- # repeat capacity in new dims len(times) times
254
- capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10)
255
-
256
- coords = (
257
- ("time_utc", times),
258
- ("site_id", site_ids),
259
- )
260
-
261
- da_cap = xr.DataArray(
262
- capacity_kwp,
263
- coords=coords,
264
- )
265
-
266
- da_gen = xr.DataArray(
267
- generation,
268
- coords=coords,
269
- )
270
-
271
- # metadata
272
- meta_df = pd.DataFrame(columns=[], data = [])
273
- meta_df['site_id'] = site_ids
274
- meta_df['capacity_kwp'] = capacity_kwp_1d
275
- meta_df['longitude'] = longitude
276
- meta_df['latitude'] = latitude
277
-
278
- generation = xr.Dataset({
279
- "capacity_kwp": da_cap,
280
- "generation_kw": da_gen,
281
- })
282
-
283
- filename = f"{session_tmp_path}/sites.netcdf"
284
- filename_csv = f"{session_tmp_path}/sites_metadata.csv"
285
- generation.to_netcdf(filename)
286
- meta_df.to_csv(filename_csv)
287
-
288
- site = Site(
289
- file_path=filename,
290
- metadata_file_path=filename_csv,
291
- interval_start_minutes=-30,
292
- interval_end_minutes=60,
293
- time_resolution_minutes=30,
294
- )
295
-
296
- yield site
297
-
298
-
299
- @pytest.fixture(scope="session")
300
- def uk_gsp_zarr_path(session_tmp_path, ds_uk_gsp):
301
- zarr_path = session_tmp_path / "uk_gsp.zarr"
302
- ds_uk_gsp.to_zarr(zarr_path)
303
- yield zarr_path
304
-
305
-
306
- @pytest.fixture()
307
- def pvnet_config_filename(
308
- tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path
309
- ):
310
-
311
- # adjust config to point to the zarr file
312
- config = load_yaml_configuration(config_filename)
313
- config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
314
- config.input_data.satellite.zarr_path = sat_zarr_path
315
- config.input_data.gsp.zarr_path = uk_gsp_zarr_path
316
-
317
- filename = f"{tmp_path}/configuration.yaml"
318
- save_yaml_configuration(config, filename)
319
- return filename
@@ -1,15 +0,0 @@
1
- from ocf_data_sampler.load.gsp import open_gsp
2
- import xarray as xr
3
-
4
-
5
- def test_open_gsp(uk_gsp_zarr_path):
6
- da = open_gsp(uk_gsp_zarr_path)
7
-
8
- assert isinstance(da, xr.DataArray)
9
- assert da.dims == ("time_utc", "gsp_id")
10
-
11
- assert "nominal_capacity_mwp" in da.coords
12
- assert "effective_capacity_mwp" in da.coords
13
- assert "x_osgb" in da.coords
14
- assert "y_osgb" in da.coords
15
- assert da.shape == (49, 318)