ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -16,7 +16,6 @@ from ocf_data_sampler.numpy_sample import (
16
16
  from ocf_data_sampler.sample.uk_regional import UKRegionalSample
17
17
 
18
18
 
19
- # Fixture define
20
19
  @pytest.fixture
21
20
  def pvnet_config_filename(tmp_path):
22
21
  """ Minimal config file - testing """
@@ -50,8 +49,8 @@ def pvnet_config_filename(tmp_path):
50
49
  config_file.write_text(config_content)
51
50
  return str(config_file)
52
51
 
53
-
54
- def create_test_data():
52
+ @pytest.fixture
53
+ def numpy_sample():
55
54
  """ Synthetic data generation """
56
55
 
57
56
  # Field / spatial coordinates
@@ -73,18 +72,8 @@ def create_test_data():
73
72
  }
74
73
 
75
74
 
76
- # UKRegionalSample testing
77
- def test_sample_init():
78
- """ Initialisation """
79
- sample = UKRegionalSample()
80
- assert hasattr(sample, '_data'), "Sample should have _data attribute"
81
- assert isinstance(sample._data, dict)
82
- assert len(sample._data) == 0
83
-
84
-
85
- def test_sample_save_load():
86
- sample = UKRegionalSample()
87
- sample._data = create_test_data()
75
+ def test_sample_save_load(numpy_sample):
76
+ sample = UKRegionalSample(numpy_sample)
88
77
 
89
78
  with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
90
79
  sample.save(tf.name)
@@ -105,24 +94,6 @@ def test_sample_save_load():
105
94
  )
106
95
 
107
96
 
108
- def test_save_unsupported_format():
109
- """ Test saving - unsupported file format """
110
- sample = UKRegionalSample()
111
- sample._data = create_test_data()
112
-
113
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
114
- with pytest.raises(ValueError, match="Only .pt format is supported"):
115
- sample.save(tf.name)
116
-
117
-
118
- def test_load_unsupported_format():
119
- """ Test loading - unsupported file format """
120
-
121
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
122
- with pytest.raises(ValueError, match="Only .pt format is supported"):
123
- UKRegionalSample.load(tf.name)
124
-
125
-
126
97
  def test_load_corrupted_file():
127
98
  """ Test loading - corrupted / empty file """
128
99
 
@@ -136,8 +107,8 @@ def test_load_corrupted_file():
136
107
 
137
108
  def test_to_numpy():
138
109
  """ To numpy conversion check """
139
- sample = UKRegionalSample()
140
- sample._data = {
110
+
111
+ data = {
141
112
  'nwp': {
142
113
  'ukv': {
143
114
  'nwp': np.random.rand(4, 1, 2, 2),
@@ -150,6 +121,8 @@ def test_to_numpy():
150
121
  GSPSampleKey.solar_azimuth: np.random.rand(7),
151
122
  GSPSampleKey.solar_elevation: np.random.rand(7)
152
123
  }
124
+
125
+ sample = UKRegionalSample(data)
153
126
 
154
127
  numpy_data = sample.to_numpy()
155
128
 
@@ -12,24 +12,14 @@ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
12
12
  )
13
13
  from ocf_data_sampler.select.location import Location
14
14
 
15
- def test_process_and_combine_datasets(pvnet_config_filename):
16
15
 
17
- # Load in config for function and define location
16
+ def test_process_and_combine_datasets(pvnet_config_filename, ds_nwp_ukv_time_sliced):
17
+
18
18
  config = load_yaml_configuration(pvnet_config_filename)
19
+
19
20
  t0 = pd.Timestamp("2024-01-01 00:00")
20
21
  location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
21
22
 
22
- nwp_data = xr.DataArray(
23
- np.random.rand(4, 2, 2, 2),
24
- dims=["time_utc", "channel", "y", "x"],
25
- coords={
26
- "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
- "channel": ["t", "dswrf"],
28
- "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
- "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
- }
31
- )
32
-
33
23
  sat_data = xr.DataArray(
34
24
  np.random.rand(7, 1, 2, 2),
35
25
  dims=["time_utc", "channel", "y", "x"],
@@ -41,20 +31,17 @@ def test_process_and_combine_datasets(pvnet_config_filename):
41
31
  }
42
32
  )
43
33
 
44
- # Combine as dict
45
34
  dataset_dict = {
46
- "nwp": {"ukv": nwp_data},
35
+ "nwp": {"ukv": ds_nwp_ukv_time_sliced},
47
36
  "sat": sat_data
48
37
  }
49
38
 
50
- # Call relevant function
51
39
  sample = process_and_combine_datasets(dataset_dict, config, t0, location)
52
40
 
53
- # Assert result is dict - check and validate
54
41
  assert isinstance(sample, dict)
55
42
  assert "nwp" in sample
56
- assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
43
+ assert sample["satellite_actual"].shape == sat_data.shape
44
+ assert sample["nwp"]["ukv"]["nwp"].shape == ds_nwp_ukv_time_sliced.shape
58
45
  assert "gsp_id" in sample
59
46
 
60
47