ocf-data-sampler 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +25 -23
- ocf_data_sampler/load/satellite.py +21 -29
- ocf_data_sampler/load/site.py +1 -1
- ocf_data_sampler/numpy_sample/gsp.py +6 -2
- ocf_data_sampler/numpy_sample/nwp.py +7 -13
- ocf_data_sampler/numpy_sample/satellite.py +11 -8
- ocf_data_sampler/numpy_sample/site.py +6 -2
- ocf_data_sampler/numpy_sample/sun_position.py +9 -10
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +16 -35
- ocf_data_sampler/sample/site.py +28 -65
- ocf_data_sampler/sample/uk_regional.py +52 -97
- ocf_data_sampler/select/dropout.py +38 -25
- ocf_data_sampler/select/fill_time_periods.py +3 -1
- ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
- ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +29 -29
- tests/config/test_config.py +3 -3
- tests/conftest.py +33 -0
- tests/numpy_sample/test_nwp.py +3 -42
- tests/select/test_dropout.py +7 -13
- tests/test_sample/test_site_sample.py +5 -35
- tests/test_sample/test_uk_regional_sample.py +8 -35
- tests/torch_datasets/test_pvnet_uk.py +6 -19
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -74,17 +74,9 @@ def sample_data():
|
|
|
74
74
|
)
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def test_site_sample_init():
|
|
78
|
-
""" Test initialisation """
|
|
79
|
-
sample = SiteSample()
|
|
80
|
-
assert isinstance(sample._data, dict)
|
|
81
|
-
assert len(sample._data) == 0
|
|
82
|
-
|
|
83
|
-
|
|
84
77
|
def test_site_sample_with_data(sample_data):
|
|
85
78
|
""" Testing of defined sample with actual data """
|
|
86
|
-
sample = SiteSample()
|
|
87
|
-
sample._data = sample_data
|
|
79
|
+
sample = SiteSample(sample_data)
|
|
88
80
|
|
|
89
81
|
# Assert data structure
|
|
90
82
|
assert isinstance(sample._data, Dataset)
|
|
@@ -109,8 +101,7 @@ def test_site_sample_with_data(sample_data):
|
|
|
109
101
|
|
|
110
102
|
def test_save_load(tmp_path, sample_data):
|
|
111
103
|
""" Save and load functionality """
|
|
112
|
-
sample = SiteSample()
|
|
113
|
-
sample._data = sample_data
|
|
104
|
+
sample = SiteSample(sample_data)
|
|
114
105
|
filepath = tmp_path / "test_sample.nc"
|
|
115
106
|
sample.save(filepath)
|
|
116
107
|
|
|
@@ -127,36 +118,16 @@ def test_save_load(tmp_path, sample_data):
|
|
|
127
118
|
xr.testing.assert_identical(sample._data, loaded._data)
|
|
128
119
|
|
|
129
120
|
|
|
130
|
-
def test_invalid_save_format(sample_data):
|
|
131
|
-
""" Saving with invalid format """
|
|
132
|
-
sample = SiteSample()
|
|
133
|
-
sample._data = sample_data
|
|
134
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
135
|
-
sample.save("invalid.txt")
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def test_invalid_load_format():
|
|
139
|
-
""" Loading with invalid format """
|
|
140
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
141
|
-
SiteSample.load("invalid.txt")
|
|
142
|
-
|
|
143
|
-
|
|
144
121
|
def test_invalid_data_type():
|
|
145
122
|
""" Handling of invalid data types """
|
|
146
|
-
sample = SiteSample()
|
|
147
|
-
sample._data = {"invalid": "data"}
|
|
148
123
|
|
|
149
124
|
with pytest.raises(TypeError, match="Data must be xarray Dataset"):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
|
|
153
|
-
sample.save("test.nc")
|
|
125
|
+
_ = SiteSample({"invalid": "data"})
|
|
154
126
|
|
|
155
127
|
|
|
156
128
|
def test_to_numpy(sample_data):
|
|
157
129
|
""" To numpy conversion """
|
|
158
|
-
sample = SiteSample()
|
|
159
|
-
sample._data = sample_data
|
|
130
|
+
sample = SiteSample(sample_data)
|
|
160
131
|
numpy_data = sample.to_numpy()
|
|
161
132
|
|
|
162
133
|
# Assert structure
|
|
@@ -180,8 +151,7 @@ def test_to_numpy(sample_data):
|
|
|
180
151
|
|
|
181
152
|
def test_data_consistency(sample_data):
|
|
182
153
|
""" Consistency of data across operations """
|
|
183
|
-
sample = SiteSample()
|
|
184
|
-
sample._data = sample_data
|
|
154
|
+
sample = SiteSample(sample_data)
|
|
185
155
|
numpy_data = sample.to_numpy()
|
|
186
156
|
|
|
187
157
|
# Assert components remain consistent after conversion above
|
|
@@ -16,7 +16,6 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
16
16
|
from ocf_data_sampler.sample.uk_regional import UKRegionalSample
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
# Fixture define
|
|
20
19
|
@pytest.fixture
|
|
21
20
|
def pvnet_config_filename(tmp_path):
|
|
22
21
|
""" Minimal config file - testing """
|
|
@@ -50,8 +49,8 @@ def pvnet_config_filename(tmp_path):
|
|
|
50
49
|
config_file.write_text(config_content)
|
|
51
50
|
return str(config_file)
|
|
52
51
|
|
|
53
|
-
|
|
54
|
-
def
|
|
52
|
+
@pytest.fixture
|
|
53
|
+
def numpy_sample():
|
|
55
54
|
""" Synthetic data generation """
|
|
56
55
|
|
|
57
56
|
# Field / spatial coordinates
|
|
@@ -73,18 +72,8 @@ def create_test_data():
|
|
|
73
72
|
}
|
|
74
73
|
|
|
75
74
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
""" Initialisation """
|
|
79
|
-
sample = UKRegionalSample()
|
|
80
|
-
assert hasattr(sample, '_data'), "Sample should have _data attribute"
|
|
81
|
-
assert isinstance(sample._data, dict)
|
|
82
|
-
assert len(sample._data) == 0
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def test_sample_save_load():
|
|
86
|
-
sample = UKRegionalSample()
|
|
87
|
-
sample._data = create_test_data()
|
|
75
|
+
def test_sample_save_load(numpy_sample):
|
|
76
|
+
sample = UKRegionalSample(numpy_sample)
|
|
88
77
|
|
|
89
78
|
with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
|
|
90
79
|
sample.save(tf.name)
|
|
@@ -105,24 +94,6 @@ def test_sample_save_load():
|
|
|
105
94
|
)
|
|
106
95
|
|
|
107
96
|
|
|
108
|
-
def test_save_unsupported_format():
|
|
109
|
-
""" Test saving - unsupported file format """
|
|
110
|
-
sample = UKRegionalSample()
|
|
111
|
-
sample._data = create_test_data()
|
|
112
|
-
|
|
113
|
-
with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
|
|
114
|
-
with pytest.raises(ValueError, match="Only .pt format is supported"):
|
|
115
|
-
sample.save(tf.name)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def test_load_unsupported_format():
|
|
119
|
-
""" Test loading - unsupported file format """
|
|
120
|
-
|
|
121
|
-
with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
|
|
122
|
-
with pytest.raises(ValueError, match="Only .pt format is supported"):
|
|
123
|
-
UKRegionalSample.load(tf.name)
|
|
124
|
-
|
|
125
|
-
|
|
126
97
|
def test_load_corrupted_file():
|
|
127
98
|
""" Test loading - corrupted / empty file """
|
|
128
99
|
|
|
@@ -136,8 +107,8 @@ def test_load_corrupted_file():
|
|
|
136
107
|
|
|
137
108
|
def test_to_numpy():
|
|
138
109
|
""" To numpy conversion check """
|
|
139
|
-
|
|
140
|
-
|
|
110
|
+
|
|
111
|
+
data = {
|
|
141
112
|
'nwp': {
|
|
142
113
|
'ukv': {
|
|
143
114
|
'nwp': np.random.rand(4, 1, 2, 2),
|
|
@@ -150,6 +121,8 @@ def test_to_numpy():
|
|
|
150
121
|
GSPSampleKey.solar_azimuth: np.random.rand(7),
|
|
151
122
|
GSPSampleKey.solar_elevation: np.random.rand(7)
|
|
152
123
|
}
|
|
124
|
+
|
|
125
|
+
sample = UKRegionalSample(data)
|
|
153
126
|
|
|
154
127
|
numpy_data = sample.to_numpy()
|
|
155
128
|
|
|
@@ -12,24 +12,14 @@ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
|
|
|
12
12
|
)
|
|
13
13
|
from ocf_data_sampler.select.location import Location
|
|
14
14
|
|
|
15
|
-
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
16
15
|
|
|
17
|
-
|
|
16
|
+
def test_process_and_combine_datasets(pvnet_config_filename, ds_nwp_ukv_time_sliced):
|
|
17
|
+
|
|
18
18
|
config = load_yaml_configuration(pvnet_config_filename)
|
|
19
|
+
|
|
19
20
|
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
20
21
|
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
21
22
|
|
|
22
|
-
nwp_data = xr.DataArray(
|
|
23
|
-
np.random.rand(4, 2, 2, 2),
|
|
24
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
25
|
-
coords={
|
|
26
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
27
|
-
"channel": ["t", "dswrf"],
|
|
28
|
-
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
29
|
-
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
30
|
-
}
|
|
31
|
-
)
|
|
32
|
-
|
|
33
23
|
sat_data = xr.DataArray(
|
|
34
24
|
np.random.rand(7, 1, 2, 2),
|
|
35
25
|
dims=["time_utc", "channel", "y", "x"],
|
|
@@ -41,20 +31,17 @@ def test_process_and_combine_datasets(pvnet_config_filename):
|
|
|
41
31
|
}
|
|
42
32
|
)
|
|
43
33
|
|
|
44
|
-
# Combine as dict
|
|
45
34
|
dataset_dict = {
|
|
46
|
-
"nwp": {"ukv":
|
|
35
|
+
"nwp": {"ukv": ds_nwp_ukv_time_sliced},
|
|
47
36
|
"sat": sat_data
|
|
48
37
|
}
|
|
49
38
|
|
|
50
|
-
# Call relevant function
|
|
51
39
|
sample = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
52
40
|
|
|
53
|
-
# Assert result is dict - check and validate
|
|
54
41
|
assert isinstance(sample, dict)
|
|
55
42
|
assert "nwp" in sample
|
|
56
|
-
assert sample["satellite_actual"].shape ==
|
|
57
|
-
assert sample["nwp"]["ukv"]["nwp"].shape ==
|
|
43
|
+
assert sample["satellite_actual"].shape == sat_data.shape
|
|
44
|
+
assert sample["nwp"]["ukv"]["nwp"].shape == ds_nwp_ukv_time_sliced.shape
|
|
58
45
|
assert "gsp_id" in sample
|
|
59
46
|
|
|
60
47
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|