ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +75 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.43.dist-info/METADATA +154 -0
  32. ocf_data_sampler-0.0.43.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/LICENSE +0 -0
@@ -0,0 +1,14 @@
1
+ from ocf_data_sampler.load.site import open_site
2
+ import xarray as xr
3
+
4
+
5
+ def test_open_site(data_sites):
6
+ da = open_site(data_sites)
7
+
8
+ assert isinstance(da, xr.DataArray)
9
+ assert da.dims == ("time_utc", "site_id")
10
+
11
+ assert "capacity_kwp" in da.coords
12
+ assert "latitude" in da.coords
13
+ assert "longitude" in da.coords
14
+ assert da.shape == (49, 10)
@@ -0,0 +1,26 @@
1
+ from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
2
+ from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
3
+ from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
4
+
5
+
6
+ def test_pvnet(pvnet_config_filename):
7
+
8
+ # Create dataset object
9
+ dataset = PVNetUKRegionalDataset(pvnet_config_filename)
10
+
11
+ assert len(dataset.locations) == 317
12
+ assert len(dataset.valid_t0_times) == 39
13
+ assert len(dataset) == 317 * 39
14
+
15
+ # Generate 2 samples
16
+ sample1 = dataset[0]
17
+ sample2 = dataset[1]
18
+
19
+ batch = stack_np_samples_into_batch([sample1, sample2])
20
+
21
+ assert isinstance(batch, dict)
22
+ assert "nwp" in batch
23
+ assert isinstance(batch["nwp"], dict)
24
+ assert "ukv" in batch["nwp"]
25
+ assert GSPSampleKey.gsp in batch
26
+ assert SatelliteSampleKey.satellite_actual in batch
@@ -0,0 +1,38 @@
1
+ from ocf_data_sampler.load.gsp import open_gsp
2
+ import numpy as np
3
+
4
+ from ocf_data_sampler.numpy_sample import convert_gsp_to_numpy_sample, GSPSampleKey
5
+
6
+ def test_convert_gsp_to_numpy_sample(uk_gsp_zarr_path):
7
+
8
+ da = (
9
+ open_gsp(uk_gsp_zarr_path)
10
+ .isel(time_utc=slice(0, 10))
11
+ .sel(gsp_id=1)
12
+ )
13
+
14
+ numpy_sample = convert_gsp_to_numpy_sample(da)
15
+
16
+ # Test data structure
17
+ assert isinstance(numpy_sample, dict), "Should be dict"
18
+ assert set(numpy_sample.keys()).issubset({
19
+ GSPSampleKey.gsp,
20
+ GSPSampleKey.nominal_capacity_mwp,
21
+ GSPSampleKey.effective_capacity_mwp,
22
+ GSPSampleKey.time_utc,
23
+ }), "Unexpected keys"
24
+
25
+ # Assert data content and capacity values
26
+ assert np.array_equal(numpy_sample[GSPSampleKey.gsp], da.values), "GSP values mismatch"
27
+ assert isinstance(numpy_sample[GSPSampleKey.time_utc], np.ndarray), "Time UTC should be numpy array"
28
+ assert numpy_sample[GSPSampleKey.time_utc].dtype == float, "Time UTC should be float type"
29
+ assert numpy_sample[GSPSampleKey.nominal_capacity_mwp] == da.isel(time_utc=0)["nominal_capacity_mwp"].values
30
+ assert numpy_sample[GSPSampleKey.effective_capacity_mwp] == da.isel(time_utc=0)["effective_capacity_mwp"].values
31
+
32
+ # Test with t0_idx
33
+ t0_idx = 5
34
+ numpy_sample_with_t0 = convert_gsp_to_numpy_sample(da, t0_idx=t0_idx)
35
+ assert numpy_sample_with_t0[GSPSampleKey.t0_idx] == t0_idx, "t0_idx not correctly set"
36
+
37
+
38
+
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xarray as xr
4
+
5
+ import pytest
6
+
7
+ from ocf_data_sampler.numpy_sample import convert_nwp_to_numpy_sample, NWPSampleKey
8
+
9
+ @pytest.fixture(scope="module")
10
+ def da_nwp_like():
11
+ """Create dummy data which looks like time-sliced NWP data"""
12
+
13
+ t0 = pd.to_datetime("2024-01-02 00:00")
14
+
15
+ x = np.arange(-100, 100, 10)
16
+ y = np.arange(-100, 100, 10)
17
+ steps = pd.timedelta_range("0h", "8h", freq="1h")
18
+ target_times = t0 + steps
19
+
20
+ channels = ["t", "dswrf"]
21
+ init_times = pd.to_datetime([t0]*len(steps))
22
+
23
+ # Create dummy time-sliced NWP data
24
+ da_nwp = xr.DataArray(
25
+ np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
26
+ coords=dict(
27
+ target_times_utc=(["target_times_utc"], target_times),
28
+ channel=(["channel"], channels),
29
+ x_osgb=(["x_osgb"], x),
30
+ y_osgb=(["y_osgb"], y),
31
+ )
32
+ )
33
+
34
+ # Add extra non-coordinate dimensions
35
+ da_nwp = da_nwp.assign_coords(
36
+ init_time_utc=("target_times_utc", init_times),
37
+ step=("target_times_utc", steps),
38
+ )
39
+
40
+ return da_nwp
41
+
42
+
43
+ def test_convert_nwp_to_numpy_sample(da_nwp_like):
44
+
45
+ # Call the function
46
+ numpy_sample = convert_nwp_to_numpy_sample(da_nwp_like)
47
+
48
+ # Assert the output type
49
+ assert isinstance(numpy_sample, dict)
50
+
51
+ # Assert the shape of the numpy sample
52
+ assert (numpy_sample[NWPSampleKey.nwp] == da_nwp_like.values).all()
@@ -0,0 +1,40 @@
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import xarray as xr
5
+
6
+ import pytest
7
+
8
+ from ocf_data_sampler.numpy_sample import convert_satellite_to_numpy_sample, SatelliteSampleKey
9
+
10
+
11
+ @pytest.fixture(scope="module")
12
+ def da_sat_like():
13
+ """Create dummy data which looks like satellite data"""
14
+ x = np.arange(-100, 100, 10)
15
+ y = np.arange(-100, 100, 10)
16
+ datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 12:30", freq="5min")
17
+ channels = ["VIS008", "IR016"]
18
+
19
+ da_sat = xr.DataArray(
20
+ np.random.normal(size=(len(datetimes), len(channels), len(x), len(y))),
21
+ coords=dict(
22
+ time_utc=(["time_utc"], datetimes),
23
+ channel=(["channel"], channels),
24
+ x_geostationary=(["x_geostationary"], x),
25
+ y_geostationary=(["y_geostationary"], y),
26
+ )
27
+ )
28
+ return da_sat
29
+
30
+
31
+ def test_convert_satellite_to_numpy_sample(da_sat_like):
32
+
33
+ # Call the function
34
+ numpy_sample = convert_satellite_to_numpy_sample(da_sat_like)
35
+
36
+ # Assert the output type
37
+ assert isinstance(numpy_sample, dict)
38
+
39
+ # Assert the shape of the numpy sample
40
+ assert (numpy_sample[SatelliteSampleKey.satellite_actual] == da_sat_like.values).all()
@@ -0,0 +1,81 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pytest
4
+
5
+ from ocf_data_sampler.numpy_sample.sun_position import (
6
+ calculate_azimuth_and_elevation, make_sun_position_numpy_sample
7
+ )
8
+
9
+ from ocf_data_sampler.numpy_sample import GSPSampleKey
10
+
11
+
12
+ @pytest.mark.parametrize("lat", [0, 5, 10, 23.5])
13
+ def test_calculate_azimuth_and_elevation(lat):
14
+
15
+ # Pick the day of the summer solstice
16
+ datetimes = pd.to_datetime(["2024-06-20 12:00"])
17
+
18
+ # Calculate sun angles
19
+ azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon=0, lat=lat)
20
+
21
+ assert len(azimuth)==len(datetimes)
22
+ assert len(elevation)==len(datetimes)
23
+
24
+ # elevation should be close to (90 - (23.5-lat) degrees
25
+ assert np.abs(elevation - (90-23.5+lat)) < 1
26
+
27
+
28
+ def test_calculate_azimuth_and_elevation_random():
29
+ """Test that the function produces the expected range of azimuths and elevations"""
30
+
31
+ # Set seed so we know the test should pass
32
+ np.random.seed(0)
33
+
34
+ # Pick the day of the summer solstice
35
+ datetimes = pd.to_datetime(["2024-06-20 12:00"])
36
+
37
+ # Pick 100 random locations and measure their azimuth and elevations
38
+ azimuths = []
39
+ elevations = []
40
+
41
+ for _ in range(100):
42
+
43
+ lon = np.random.uniform(low=0, high=360)
44
+ lat = np.random.uniform(low=-90, high=90)
45
+
46
+ # Calculate sun angles
47
+ azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon=lon, lat=lat)
48
+
49
+ azimuths.append(azimuth.item())
50
+ elevations.append(elevation.item())
51
+
52
+ azimuths = np.array(azimuths)
53
+ elevations = np.array(elevations)
54
+
55
+ assert (0<=azimuths).all() and (azimuths<=360).all()
56
+ assert (-90<=elevations).all() and (elevations<=90).all()
57
+
58
+ # Azimuth range is [0, 360]
59
+ assert azimuths.min() < 30
60
+ assert azimuths.max() > 330
61
+
62
+ # Elevation range is [-90, 90]
63
+ assert elevations.min() < -70
64
+ assert elevations.max() > 70
65
+
66
+
67
+ def test_make_sun_position_numpy_sample():
68
+
69
+ datetimes = pd.date_range("2024-06-20 12:00", "2024-06-20 16:00", freq="30min")
70
+ lon, lat = 0, 51.5
71
+
72
+ sample = make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix="gsp")
73
+
74
+ assert GSPSampleKey.solar_elevation in sample
75
+ assert GSPSampleKey.solar_azimuth in sample
76
+
77
+ # The solar coords are normalised in the function
78
+ assert (sample[GSPSampleKey.solar_elevation]>=0).all()
79
+ assert (sample[GSPSampleKey.solar_elevation]<=1).all()
80
+ assert (sample[GSPSampleKey.solar_azimuth]>=0).all()
81
+ assert (sample[GSPSampleKey.solar_azimuth]<=1).all()
@@ -0,0 +1,75 @@
1
+ from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+
7
+ import pytest
8
+
9
+
10
+ @pytest.fixture(scope="module")
11
+ def da_sample():
12
+ """Create dummy data which looks like satellite data"""
13
+
14
+ datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 13:00", freq="5min")
15
+
16
+ da_sat = xr.DataArray(
17
+ np.random.normal(size=(len(datetimes),)),
18
+ coords=dict(
19
+ time_utc=(["time_utc"], datetimes),
20
+ )
21
+ )
22
+ return da_sat
23
+
24
+
25
+ def test_draw_dropout_time():
26
+ t0 = pd.Timestamp("2021-01-01 04:00:00")
27
+
28
+ dropout_timedeltas = pd.to_timedelta([-30, -60], unit="min")
29
+ dropout_time = draw_dropout_time(t0, dropout_timedeltas, dropout_frac=1)
30
+
31
+ assert isinstance(dropout_time, pd.Timestamp)
32
+ assert dropout_time-t0 in dropout_timedeltas
33
+
34
+
35
+ def test_draw_dropout_time_partial():
36
+ t0 = pd.Timestamp("2021-01-01 04:00:00")
37
+
38
+ dropout_timedeltas = pd.to_timedelta([-30, -60], unit="min")
39
+
40
+ dropouts = set()
41
+
42
+ # Loop over 1000 to have very high probability of seeing all dropouts
43
+ # The chances of this failing by chance are approx ((2/3)^100)*3 = 7e-18
44
+ for _ in range(100):
45
+ dropouts.add(draw_dropout_time(t0, dropout_timedeltas, dropout_frac=2/3))
46
+
47
+ # Check all expected dropouts are present
48
+ dropouts == {None} | set(t0 + dt for dt in dropout_timedeltas)
49
+
50
+
51
+ def test_draw_dropout_time_none():
52
+ t0 = pd.Timestamp("2021-01-01 04:00:00")
53
+
54
+ # No dropout timedeltas
55
+ dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=1)
56
+ assert dropout_time is None
57
+
58
+ # Dropout fraction is 0
59
+ dropout_timedeltas = [pd.Timedelta(-30, "min")]
60
+ dropout_time = draw_dropout_time(t0, dropout_timedeltas=dropout_timedeltas, dropout_frac=0)
61
+ assert dropout_time is None
62
+
63
+ # No dropout timedeltas and dropout fraction is 0
64
+ dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=0)
65
+ assert dropout_time is None
66
+
67
+
68
+ @pytest.mark.parametrize("t0_str", ["12:00", "12:30", "13:00"])
69
+ def test_apply_dropout_time(da_sample, t0_str):
70
+ dropout_time = pd.Timestamp(f"2024-01-01 {t0_str}")
71
+
72
+ da_dropout = apply_dropout_time(da_sample, dropout_time)
73
+
74
+ assert da_dropout.sel(time_utc=slice(None, dropout_time)).notnull().all()
75
+ assert da_dropout.sel(time_utc=slice(dropout_time+pd.Timedelta(5, "min"), None)).isnull().all()
@@ -0,0 +1,28 @@
1
+ import pandas as pd
2
+
3
+ from ocf_data_sampler.select.fill_time_periods import fill_time_periods
4
+
5
+ def test_fill_time_periods():
6
+ time_periods = pd.DataFrame(
7
+ {
8
+ "start_dt": [
9
+ "2021-01-01 04:10:00", "2021-01-01 09:00:00",
10
+ "2021-01-01 09:15:00", "2021-01-01 12:00:00"
11
+ ],
12
+ "end_dt": [
13
+ "2021-01-01 06:00:00", "2021-01-01 09:00:00",
14
+ "2021-01-01 09:20:00", "2021-01-01 14:45:00"
15
+ ],
16
+ }
17
+ )
18
+ freq = pd.Timedelta("30min")
19
+ filled_time_periods = fill_time_periods(time_periods, freq)
20
+
21
+ expected_times = [
22
+ "04:30", "05:00", "05:30", "06:00", "09:00", "12:00",
23
+ "12:30", "13:00", "13:30", "14:00", "14:30"
24
+ ]
25
+
26
+ expected_times = pd.DatetimeIndex([f"2021-01-01 {t}" for t in expected_times])
27
+
28
+ pd.testing.assert_index_equal(filled_time_periods, expected_times)
@@ -0,0 +1,202 @@
1
+ import pandas as pd
2
+
3
+ from ocf_data_sampler.select.find_contiguous_time_periods import (
4
+ find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
5
+ intersection_of_multiple_dataframes_of_periods,
6
+ )
7
+
8
+
9
+
10
+ def test_find_contiguous_t0_periods():
11
+
12
+ # Create 5-minutely data timestamps
13
+ freq = pd.Timedelta(5, "min")
14
+ interval_start = pd.Timedelta(-60, "min")
15
+ interval_end = pd.Timedelta(15, "min")
16
+
17
+ datetimes = (
18
+ pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq)
19
+ .delete([5, 6, 30])
20
+ )
21
+
22
+ periods = find_contiguous_t0_periods(
23
+ datetimes=datetimes,
24
+ interval_start=interval_start,
25
+ interval_end=interval_end,
26
+ sample_period_duration=freq,
27
+ )
28
+
29
+ expected_results = pd.DataFrame(
30
+ {
31
+ "start_dt": pd.to_datetime(
32
+ [
33
+ "2023-01-01 13:35",
34
+ "2023-01-01 15:35",
35
+ ]
36
+ ),
37
+ "end_dt": pd.to_datetime(
38
+ [
39
+ "2023-01-01 14:10",
40
+ "2023-01-01 16:45",
41
+ ]
42
+ ),
43
+ },
44
+ )
45
+
46
+ assert periods.equals(expected_results)
47
+
48
+
49
+ def test_find_contiguous_t0_periods_nwp():
50
+
51
+ # These are the expected results of the test
52
+ expected_results = [
53
+ pd.DataFrame(
54
+ {
55
+ "start_dt": pd.to_datetime(["2023-01-01 03:00", "2023-01-02 03:00"]),
56
+ "end_dt": pd.to_datetime(["2023-01-01 21:00", "2023-01-03 06:00"]),
57
+ },
58
+ ),
59
+ pd.DataFrame(
60
+ {
61
+ "start_dt": pd.to_datetime(
62
+ [
63
+ "2023-01-01 05:00",
64
+ "2023-01-02 05:00",
65
+ ]
66
+ ),
67
+ "end_dt": pd.to_datetime(
68
+ [
69
+ "2023-01-01 21:00",
70
+ "2023-01-03 06:00",
71
+ ]
72
+ ),
73
+ },
74
+ ),
75
+ pd.DataFrame(
76
+ {
77
+ "start_dt": pd.to_datetime(
78
+ [
79
+ "2023-01-01 05:00",
80
+ "2023-01-02 05:00",
81
+ "2023-01-02 14:00",
82
+ ]
83
+ ),
84
+ "end_dt": pd.to_datetime(
85
+ [
86
+ "2023-01-01 18:00",
87
+ "2023-01-02 09:00",
88
+ "2023-01-03 03:00",
89
+ ]
90
+ ),
91
+ },
92
+ ),
93
+ pd.DataFrame(
94
+ {
95
+ "start_dt": pd.to_datetime(
96
+ [
97
+ "2023-01-01 05:00",
98
+ "2023-01-01 11:00",
99
+ "2023-01-02 05:00",
100
+ "2023-01-02 14:00",
101
+ ]
102
+ ),
103
+ "end_dt": pd.to_datetime(
104
+ [
105
+ "2023-01-01 06:00",
106
+ "2023-01-01 15:00",
107
+ "2023-01-02 06:00",
108
+ "2023-01-03 00:00",
109
+ ]
110
+ ),
111
+ },
112
+ ),
113
+ pd.DataFrame(
114
+ {
115
+ "start_dt": pd.to_datetime(
116
+ [
117
+ "2023-01-01 06:00",
118
+ "2023-01-01 12:00",
119
+ "2023-01-02 06:00",
120
+ "2023-01-02 15:00",
121
+ ]
122
+ ),
123
+ "end_dt": pd.to_datetime(
124
+ [
125
+ "2023-01-01 09:00",
126
+ "2023-01-01 18:00",
127
+ "2023-01-02 09:00",
128
+ "2023-01-03 03:00",
129
+ ]
130
+ ),
131
+ },
132
+ ),
133
+ ]
134
+
135
+ # Create 3-hourly init times with a few time stamps missing
136
+ freq = pd.Timedelta(3, "h")
137
+
138
+ init_times = (
139
+ pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq)
140
+ .delete([1, 4, 5, 6, 7, 9, 10])
141
+ )
142
+
143
+ # Choose some history durations and max stalenesses
144
+ history_durations_hr = [0, 2, 2, 2, 2]
145
+ max_stalenesses_hr = [9, 9, 6, 3, 6]
146
+ max_dropouts_hr = [0, 0, 0, 0, 3]
147
+
148
+ for i in range(len(expected_results)):
149
+ interval_start = pd.Timedelta(-history_durations_hr[i], "h")
150
+ max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h")
151
+ max_dropout = pd.Timedelta(max_dropouts_hr[i], "h")
152
+
153
+ time_periods = find_contiguous_t0_periods_nwp(
154
+ init_times=init_times,
155
+ interval_start=interval_start,
156
+ max_staleness=max_staleness,
157
+ max_dropout=max_dropout,
158
+ )
159
+
160
+ # Check if results are as expected
161
+ assert time_periods.equals(expected_results[i])
162
+
163
+
164
+ def test_intersection_of_multiple_dataframes_of_periods():
165
+ periods_1 = pd.DataFrame(
166
+ {
167
+ "start_dt": pd.to_datetime(["2023-01-01 05:00", "2023-01-01 14:10"]),
168
+ "end_dt": pd.to_datetime(["2023-01-01 13:35", "2023-01-01 18:00"]),
169
+ },
170
+ )
171
+
172
+ periods_2 = pd.DataFrame(
173
+ {
174
+ "start_dt": pd.to_datetime(["2023-01-01 12:00"]),
175
+ "end_dt": pd.to_datetime(["2023-01-02 00:00"]),
176
+ },
177
+ )
178
+
179
+ periods_3 = pd.DataFrame(
180
+ {
181
+ "start_dt": pd.to_datetime(["2023-01-01 00:00", "2023-01-01 13:00"]),
182
+ "end_dt": pd.to_datetime(["2023-01-01 12:30", "2023-01-01 23:00"]),
183
+ },
184
+ )
185
+
186
+ expected_result = pd.DataFrame(
187
+ {
188
+ "start_dt": pd.to_datetime(
189
+ ["2023-01-01 12:00", "2023-01-01 13:00", "2023-01-01 14:10"]
190
+ ),
191
+ "end_dt": pd.to_datetime([
192
+ "2023-01-01 12:30", "2023-01-01 13:35", "2023-01-01 18:00"]
193
+ ),
194
+ },
195
+ )
196
+
197
+ overlaping_periods = intersection_of_multiple_dataframes_of_periods(
198
+ [periods_1, periods_2, periods_3]
199
+ )
200
+
201
+ # Check if results are as expected
202
+ assert overlaping_periods.equals(expected_result)
@@ -0,0 +1,67 @@
1
+ from ocf_data_sampler.select.location import Location
2
+ import pytest
3
+
4
+
5
+ def test_make_valid_location_object_with_default_coordinate_system():
6
+ x, y = -1000.5, 50000
7
+ location = Location(x=x, y=y)
8
+ assert location.x == x, "location.x value not set correctly"
9
+ assert location.y == y, "location.x value not set correctly"
10
+ assert (
11
+ location.coordinate_system == "osgb"
12
+ ), "location.coordinate_system value not set correctly"
13
+
14
+
15
+ def test_make_valid_location_object_with_osgb_coordinate_system():
16
+ x, y, coordinate_system = 1.2, 22.9, "osgb"
17
+ location = Location(x=x, y=y, coordinate_system=coordinate_system)
18
+ assert location.x == x, "location.x value not set correctly"
19
+ assert location.y == y, "location.x value not set correctly"
20
+ assert (
21
+ location.coordinate_system == coordinate_system
22
+ ), "location.coordinate_system value not set correctly"
23
+
24
+
25
+ def test_make_valid_location_object_with_lon_lat_coordinate_system():
26
+ x, y, coordinate_system = 1.2, 1.2, "lon_lat"
27
+ location = Location(x=x, y=y, coordinate_system=coordinate_system)
28
+ assert location.x == x, "location.x value not set correctly"
29
+ assert location.y == y, "location.x value not set correctly"
30
+ assert (
31
+ location.coordinate_system == coordinate_system
32
+ ), "location.coordinate_system value not set correctly"
33
+
34
+
35
+ def test_make_invalid_location_object_with_invalid_osgb_x():
36
+ x, y, coordinate_system = 10000000, 1.2, "osgb"
37
+ with pytest.raises(ValueError) as err:
38
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
39
+ assert err.typename == "ValidationError"
40
+
41
+
42
+ def test_make_invalid_location_object_with_invalid_osgb_y():
43
+ x, y, coordinate_system = 2.5, 10000000, "osgb"
44
+ with pytest.raises(ValueError) as err:
45
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
46
+ assert err.typename == "ValidationError"
47
+
48
+
49
+ def test_make_invalid_location_object_with_invalid_lon_lat_x():
50
+ x, y, coordinate_system = 200, 1.2, "lon_lat"
51
+ with pytest.raises(ValueError) as err:
52
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
53
+ assert err.typename == "ValidationError"
54
+
55
+
56
+ def test_make_invalid_location_object_with_invalid_lon_lat_y():
57
+ x, y, coordinate_system = 2.5, -200, "lon_lat"
58
+ with pytest.raises(ValueError) as err:
59
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
60
+ assert err.typename == "ValidationError"
61
+
62
+
63
+ def test_make_invalid_location_object_with_invalid_coordinate_system():
64
+ x, y, coordinate_system = 2.5, 1000, "abcd"
65
+ with pytest.raises(ValueError) as err:
66
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
67
+ assert err.typename == "ValidationError"