ocf-data-sampler 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (29) hide show
  1. ocf_data_sampler/config/model.py +25 -23
  2. ocf_data_sampler/load/satellite.py +21 -29
  3. ocf_data_sampler/load/site.py +1 -1
  4. ocf_data_sampler/numpy_sample/gsp.py +6 -2
  5. ocf_data_sampler/numpy_sample/nwp.py +7 -13
  6. ocf_data_sampler/numpy_sample/satellite.py +11 -8
  7. ocf_data_sampler/numpy_sample/site.py +6 -2
  8. ocf_data_sampler/numpy_sample/sun_position.py +9 -10
  9. ocf_data_sampler/sample/__init__.py +0 -7
  10. ocf_data_sampler/sample/base.py +16 -35
  11. ocf_data_sampler/sample/site.py +28 -65
  12. ocf_data_sampler/sample/uk_regional.py +52 -97
  13. ocf_data_sampler/select/dropout.py +38 -25
  14. ocf_data_sampler/select/fill_time_periods.py +3 -1
  15. ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
  16. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
  17. ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
  18. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
  19. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +29 -29
  20. tests/config/test_config.py +3 -3
  21. tests/conftest.py +33 -0
  22. tests/numpy_sample/test_nwp.py +3 -42
  23. tests/select/test_dropout.py +7 -13
  24. tests/test_sample/test_site_sample.py +5 -35
  25. tests/test_sample/test_uk_regional_sample.py +8 -35
  26. tests/torch_datasets/test_pvnet_uk.py +6 -19
  27. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
  28. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
  29. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
@@ -74,17 +74,9 @@ def sample_data():
74
74
  )
75
75
 
76
76
 
77
- def test_site_sample_init():
78
- """ Test initialisation """
79
- sample = SiteSample()
80
- assert isinstance(sample._data, dict)
81
- assert len(sample._data) == 0
82
-
83
-
84
77
  def test_site_sample_with_data(sample_data):
85
78
  """ Testing of defined sample with actual data """
86
- sample = SiteSample()
87
- sample._data = sample_data
79
+ sample = SiteSample(sample_data)
88
80
 
89
81
  # Assert data structure
90
82
  assert isinstance(sample._data, Dataset)
@@ -109,8 +101,7 @@ def test_site_sample_with_data(sample_data):
109
101
 
110
102
  def test_save_load(tmp_path, sample_data):
111
103
  """ Save and load functionality """
112
- sample = SiteSample()
113
- sample._data = sample_data
104
+ sample = SiteSample(sample_data)
114
105
  filepath = tmp_path / "test_sample.nc"
115
106
  sample.save(filepath)
116
107
 
@@ -127,36 +118,16 @@ def test_save_load(tmp_path, sample_data):
127
118
  xr.testing.assert_identical(sample._data, loaded._data)
128
119
 
129
120
 
130
- def test_invalid_save_format(sample_data):
131
- """ Saving with invalid format """
132
- sample = SiteSample()
133
- sample._data = sample_data
134
- with pytest.raises(ValueError, match="Only .nc format is supported"):
135
- sample.save("invalid.txt")
136
-
137
-
138
- def test_invalid_load_format():
139
- """ Loading with invalid format """
140
- with pytest.raises(ValueError, match="Only .nc format is supported"):
141
- SiteSample.load("invalid.txt")
142
-
143
-
144
121
  def test_invalid_data_type():
145
122
  """ Handling of invalid data types """
146
- sample = SiteSample()
147
- sample._data = {"invalid": "data"}
148
123
 
149
124
  with pytest.raises(TypeError, match="Data must be xarray Dataset"):
150
- sample.to_numpy()
151
-
152
- with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
153
- sample.save("test.nc")
125
+ _ = SiteSample({"invalid": "data"})
154
126
 
155
127
 
156
128
  def test_to_numpy(sample_data):
157
129
  """ To numpy conversion """
158
- sample = SiteSample()
159
- sample._data = sample_data
130
+ sample = SiteSample(sample_data)
160
131
  numpy_data = sample.to_numpy()
161
132
 
162
133
  # Assert structure
@@ -180,8 +151,7 @@ def test_to_numpy(sample_data):
180
151
 
181
152
  def test_data_consistency(sample_data):
182
153
  """ Consistency of data across operations """
183
- sample = SiteSample()
184
- sample._data = sample_data
154
+ sample = SiteSample(sample_data)
185
155
  numpy_data = sample.to_numpy()
186
156
 
187
157
  # Assert components remain consistent after conversion above
@@ -16,7 +16,6 @@ from ocf_data_sampler.numpy_sample import (
16
16
  from ocf_data_sampler.sample.uk_regional import UKRegionalSample
17
17
 
18
18
 
19
- # Fixture define
20
19
  @pytest.fixture
21
20
  def pvnet_config_filename(tmp_path):
22
21
  """ Minimal config file - testing """
@@ -50,8 +49,8 @@ def pvnet_config_filename(tmp_path):
50
49
  config_file.write_text(config_content)
51
50
  return str(config_file)
52
51
 
53
-
54
- def create_test_data():
52
+ @pytest.fixture
53
+ def numpy_sample():
55
54
  """ Synthetic data generation """
56
55
 
57
56
  # Field / spatial coordinates
@@ -73,18 +72,8 @@ def create_test_data():
73
72
  }
74
73
 
75
74
 
76
- # UKRegionalSample testing
77
- def test_sample_init():
78
- """ Initialisation """
79
- sample = UKRegionalSample()
80
- assert hasattr(sample, '_data'), "Sample should have _data attribute"
81
- assert isinstance(sample._data, dict)
82
- assert len(sample._data) == 0
83
-
84
-
85
- def test_sample_save_load():
86
- sample = UKRegionalSample()
87
- sample._data = create_test_data()
75
+ def test_sample_save_load(numpy_sample):
76
+ sample = UKRegionalSample(numpy_sample)
88
77
 
89
78
  with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
90
79
  sample.save(tf.name)
@@ -105,24 +94,6 @@ def test_sample_save_load():
105
94
  )
106
95
 
107
96
 
108
- def test_save_unsupported_format():
109
- """ Test saving - unsupported file format """
110
- sample = UKRegionalSample()
111
- sample._data = create_test_data()
112
-
113
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
114
- with pytest.raises(ValueError, match="Only .pt format is supported"):
115
- sample.save(tf.name)
116
-
117
-
118
- def test_load_unsupported_format():
119
- """ Test loading - unsupported file format """
120
-
121
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
122
- with pytest.raises(ValueError, match="Only .pt format is supported"):
123
- UKRegionalSample.load(tf.name)
124
-
125
-
126
97
  def test_load_corrupted_file():
127
98
  """ Test loading - corrupted / empty file """
128
99
 
@@ -136,8 +107,8 @@ def test_load_corrupted_file():
136
107
 
137
108
  def test_to_numpy():
138
109
  """ To numpy conversion check """
139
- sample = UKRegionalSample()
140
- sample._data = {
110
+
111
+ data = {
141
112
  'nwp': {
142
113
  'ukv': {
143
114
  'nwp': np.random.rand(4, 1, 2, 2),
@@ -150,6 +121,8 @@ def test_to_numpy():
150
121
  GSPSampleKey.solar_azimuth: np.random.rand(7),
151
122
  GSPSampleKey.solar_elevation: np.random.rand(7)
152
123
  }
124
+
125
+ sample = UKRegionalSample(data)
153
126
 
154
127
  numpy_data = sample.to_numpy()
155
128
 
@@ -12,24 +12,14 @@ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
12
12
  )
13
13
  from ocf_data_sampler.select.location import Location
14
14
 
15
- def test_process_and_combine_datasets(pvnet_config_filename):
16
15
 
17
- # Load in config for function and define location
16
+ def test_process_and_combine_datasets(pvnet_config_filename, ds_nwp_ukv_time_sliced):
17
+
18
18
  config = load_yaml_configuration(pvnet_config_filename)
19
+
19
20
  t0 = pd.Timestamp("2024-01-01 00:00")
20
21
  location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
21
22
 
22
- nwp_data = xr.DataArray(
23
- np.random.rand(4, 2, 2, 2),
24
- dims=["time_utc", "channel", "y", "x"],
25
- coords={
26
- "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
- "channel": ["t", "dswrf"],
28
- "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
- "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
- }
31
- )
32
-
33
23
  sat_data = xr.DataArray(
34
24
  np.random.rand(7, 1, 2, 2),
35
25
  dims=["time_utc", "channel", "y", "x"],
@@ -41,20 +31,17 @@ def test_process_and_combine_datasets(pvnet_config_filename):
41
31
  }
42
32
  )
43
33
 
44
- # Combine as dict
45
34
  dataset_dict = {
46
- "nwp": {"ukv": nwp_data},
35
+ "nwp": {"ukv": ds_nwp_ukv_time_sliced},
47
36
  "sat": sat_data
48
37
  }
49
38
 
50
- # Call relevant function
51
39
  sample = process_and_combine_datasets(dataset_dict, config, t0, location)
52
40
 
53
- # Assert result is dict - check and validate
54
41
  assert isinstance(sample, dict)
55
42
  assert "nwp" in sample
56
- assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
43
+ assert sample["satellite_actual"].shape == sat_data.shape
44
+ assert sample["nwp"]["ukv"]["nwp"].shape == ds_nwp_ukv_time_sliced.shape
58
45
  assert "gsp_id" in sample
59
46
 
60
47