ocf-data-sampler 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -1,4 +1,3 @@
1
- import ocf_blosc2
2
1
  from ocf_data_sampler.load.gsp import open_gsp
3
2
  from ocf_data_sampler.load.nwp import open_nwp
4
3
  from ocf_data_sampler.load.satellite import open_sat_data
@@ -3,14 +3,18 @@
3
3
  import torch
4
4
  from typing_extensions import override
5
5
 
6
+ from ocf_data_sampler.config import Configuration
6
7
  from ocf_data_sampler.numpy_sample import (
7
8
  GSPSampleKey,
8
9
  NWPSampleKey,
9
10
  SatelliteSampleKey,
10
11
  )
11
12
  from ocf_data_sampler.numpy_sample.common_types import NumpySample
12
-
13
- from .base import SampleBase
13
+ from ocf_data_sampler.torch_datasets.sample.base import SampleBase
14
+ from ocf_data_sampler.torch_datasets.utils.validation_utils import (
15
+ calculate_expected_shapes,
16
+ check_dimensions,
17
+ )
14
18
 
15
19
 
16
20
  class UKRegionalSample(SampleBase):
@@ -22,22 +26,123 @@ class UKRegionalSample(SampleBase):
22
26
 
23
27
  @override
24
28
  def to_numpy(self) -> NumpySample:
29
+ """Returns the data as a NumPy sample."""
25
30
  return self._data
26
31
 
27
32
  @override
28
33
  def save(self, path: str) -> None:
34
+ """Saves sample to the specified path in pickle format."""
29
35
  # Saves to pickle format
30
36
  torch.save(self._data, path)
31
37
 
32
38
  @classmethod
33
39
  @override
34
40
  def load(cls, path: str) -> "UKRegionalSample":
41
+ """Loads sample from the specified path.
42
+
43
+ Args:
44
+ path: Path to the saved sample file.
45
+
46
+ Returns:
47
+ A UKRegionalSample instance with the loaded data.
48
+ """
35
49
  # Loads from .pt format
36
50
  # TODO: We should move away from using torch.load(..., weights_only=False)
37
51
  return cls(torch.load(path, weights_only=False))
38
52
 
53
+ def validate_sample(self, config: Configuration) -> bool:
54
+ """Validates that the sample has the expected structure and data shapes.
55
+
56
+ Args:
57
+ config: Configuration dict with expected shapes and required fields.
58
+
59
+ Returns:
60
+ bool: True if validation passes, otherwise raises an exception.
61
+ """
62
+ if not isinstance(config, Configuration):
63
+ raise TypeError("config must be Configuration object")
64
+
65
+ # Calculate expected shapes from configuration
66
+ expected_shapes = calculate_expected_shapes(config)
67
+
68
+ # Check GSP shape if specified
69
+ gsp_key = GSPSampleKey.gsp
70
+ # Check if GSP data is expected but missing
71
+ if gsp_key in expected_shapes and gsp_key not in self._data:
72
+ raise ValueError(f"Configuration expects GSP data ('{gsp_key}') but is missing.")
73
+
74
+ # Check GSP shape if data exists and is expected
75
+ if gsp_key in self._data and gsp_key in expected_shapes:
76
+ gsp_data = self._data[gsp_key]
77
+ check_dimensions(
78
+ actual_shape=gsp_data.shape,
79
+ expected_shape=expected_shapes[gsp_key],
80
+ name="GSP",
81
+ )
82
+
83
+ # Checks for NWP data - nested structure
84
+ nwp_key = NWPSampleKey.nwp
85
+ if nwp_key in expected_shapes and nwp_key not in self._data:
86
+ raise ValueError(f"Configuration expects NWP data ('{nwp_key}') but is missing.")
87
+
88
+ # Check NWP structure and shapes if data exists
89
+ if nwp_key in self._data:
90
+ nwp_data_all_providers = self._data[nwp_key]
91
+ if not isinstance(nwp_data_all_providers, dict):
92
+ raise ValueError(f"NWP data ('{nwp_key}') should be a dictionary.")
93
+
94
+ # Loop through providers present in actual data
95
+ for provider, provider_data in nwp_data_all_providers.items():
96
+ if "nwp" not in provider_data:
97
+ raise ValueError(f"Missing array key in NWP data for provider '{provider}'.")
98
+
99
+ if nwp_key in expected_shapes and provider in expected_shapes[nwp_key]:
100
+ nwp_array = provider_data["nwp"]
101
+ actual_shape = nwp_array.shape
102
+ expected_shape = expected_shapes[nwp_key][provider]
103
+
104
+ check_dimensions(
105
+ actual_shape=actual_shape,
106
+ expected_shape=expected_shape,
107
+ name=f"NWP data ({provider})",
108
+ )
109
+
110
+ # Validate satellite data
111
+ sat_key = SatelliteSampleKey.satellite_actual
112
+ # Check if Satellite data is expected but missing
113
+ if sat_key in expected_shapes and sat_key not in self._data:
114
+ raise ValueError(f"Configuration expects Satellite data ('{sat_key}') but is missing.")
115
+
116
+ # Check satellite shape if data exists and is expected
117
+ if sat_key in self._data and sat_key in expected_shapes:
118
+ sat_data = self._data[sat_key]
119
+ check_dimensions(
120
+ actual_shape=sat_data.shape,
121
+ expected_shape=expected_shapes[sat_key],
122
+ name="Satellite data",
123
+ )
124
+
125
+ # Validate solar coordinates data
126
+ solar_keys = ["solar_azimuth", "solar_elevation"]
127
+ # Check if solar coordinate is expected but missing
128
+ for solar_key in solar_keys:
129
+ if solar_key in expected_shapes and solar_key not in self._data:
130
+ raise ValueError(f"Configuration expects {solar_key} data but is missing.")
131
+
132
+ # Check solar coordinate shape if data exists and is expected
133
+ if solar_key in self._data and solar_key in expected_shapes:
134
+ solar_data = self._data[solar_key]
135
+ check_dimensions(
136
+ actual_shape=solar_data.shape,
137
+ expected_shape=expected_shapes[solar_key],
138
+ name=f"{solar_key.replace('_', ' ').title()} data",
139
+ )
140
+
141
+ return True
142
+
39
143
  @override
40
144
  def plot(self) -> None:
145
+ """Plots the sample data for visualization."""
41
146
  from matplotlib import pyplot as plt
42
147
 
43
148
  fig, axes = plt.subplots(2, 2, figsize=(12, 8))
@@ -58,10 +163,10 @@ class UKRegionalSample(SampleBase):
58
163
  axes[0, 0].set_title("GSP Generation")
59
164
 
60
165
  if "solar_azimuth" in self._data and "solar_elevation" in self._data:
61
- axes[1, 1].plot(self._data["solar_azimuth"], label="Azimuth")
62
- axes[1, 1].plot(self._data["solar_elevation"], label="Elevation")
63
- axes[1, 1].set_title("Solar Position")
64
- axes[1, 1].legend()
166
+ axes[1, 1].plot(self._data["solar_azimuth"], label="Azimuth")
167
+ axes[1, 1].plot(self._data["solar_elevation"], label="Elevation")
168
+ axes[1, 1].set_title("Solar Position")
169
+ axes[1, 1].legend()
65
170
 
66
171
  if SatelliteSampleKey.satellite_actual in self._data:
67
172
  axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
@@ -0,0 +1,108 @@
1
+ """Validate sample shape against expected shape - utility function."""
2
+
3
+ from ocf_data_sampler.config import Configuration
4
+ from ocf_data_sampler.numpy_sample import GSPSampleKey, NWPSampleKey, SatelliteSampleKey
5
+
6
+
7
+ def check_dimensions(
8
+ actual_shape: tuple[int, ...],
9
+ expected_shape: tuple[int, ...],
10
+ name: str,
11
+ ) -> None:
12
+ """Check if dimensions match between actual and expected shapes.
13
+
14
+ Args:
15
+ actual_shape: The actual shape of the data (e.g., array.shape).
16
+ expected_shape: The expected shape.
17
+ name: Name of the data component for clear error messages.
18
+
19
+ Raises:
20
+ ValueError: If dimensions don't match.
21
+ """
22
+ if actual_shape != expected_shape:
23
+ raise ValueError(
24
+ f"'{name}' shape mismatch: "
25
+ f"Actual shape: {actual_shape}, Expected shape: {expected_shape}",
26
+ )
27
+
28
+
29
+ def calculate_expected_shapes(
30
+ config: Configuration,
31
+ ) -> dict[str, tuple[int, ...]]:
32
+ """Calculate expected shapes from configuration.
33
+
34
+ Args:
35
+ config: Configuration object with shape information.
36
+
37
+ Returns:
38
+ Dictionary mapping data keys to their expected shapes.
39
+ """
40
+ expected_shapes = {}
41
+ input_data = config.input_data
42
+
43
+ # Calculate GSP shape
44
+ gsp_config = input_data.gsp
45
+ expected_shapes[GSPSampleKey.gsp] = (
46
+ _calculate_time_steps(
47
+ gsp_config.interval_start_minutes,
48
+ gsp_config.interval_end_minutes,
49
+ gsp_config.time_resolution_minutes,
50
+ ),
51
+ )
52
+
53
+ # Calculate NWP shape for multiple providers
54
+ expected_shapes[NWPSampleKey.nwp] = {}
55
+ for provider_key, provider_config in input_data.nwp.items():
56
+ expected_shapes[NWPSampleKey.nwp][provider_key] = (
57
+ _calculate_time_steps(
58
+ provider_config.interval_start_minutes,
59
+ provider_config.interval_end_minutes,
60
+ provider_config.time_resolution_minutes,
61
+ ),
62
+ len(provider_config.channels),
63
+ provider_config.image_size_pixels_height,
64
+ provider_config.image_size_pixels_width,
65
+ )
66
+
67
+ # Calculate satellite shape
68
+ sat_config = input_data.satellite
69
+ expected_shapes[SatelliteSampleKey.satellite_actual] = (
70
+ _calculate_time_steps(
71
+ sat_config.interval_start_minutes,
72
+ sat_config.interval_end_minutes,
73
+ sat_config.time_resolution_minutes,
74
+ ),
75
+ len(sat_config.channels),
76
+ sat_config.image_size_pixels_height,
77
+ sat_config.image_size_pixels_width,
78
+ )
79
+
80
+ # Calculate solar coordinates shapes
81
+ solar_config = input_data.solar_position
82
+ # For solar azimuth
83
+ expected_shapes["solar_azimuth"] = (
84
+ _calculate_time_steps(
85
+ solar_config.interval_start_minutes,
86
+ solar_config.interval_end_minutes,
87
+ solar_config.time_resolution_minutes,
88
+ ),
89
+ )
90
+ # For solar elevation
91
+ expected_shapes["solar_elevation"] = expected_shapes["solar_azimuth"]
92
+
93
+ return expected_shapes
94
+
95
+
96
+ def _calculate_time_steps(start_minutes: int, end_minutes: int, resolution_minutes: int) -> int:
97
+ """Calculate number of time steps based on interval and resolution.
98
+
99
+ Args:
100
+ start_minutes: Start of interval in minutes
101
+ end_minutes: End of interval in minutes
102
+ resolution_minutes: Time resolution in minutes
103
+
104
+ Returns:
105
+ Number of time steps
106
+ """
107
+ time_span = end_minutes - start_minutes
108
+ return (time_span // resolution_minutes) + 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.12
3
+ Version: 0.2.14
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -38,11 +38,9 @@ Requires-Dist: zarr==2.18.3
38
38
  Requires-Dist: numcodecs==0.13.1
39
39
  Requires-Dist: dask
40
40
  Requires-Dist: matplotlib
41
- Requires-Dist: ocf_blosc2
42
41
  Requires-Dist: pvlib
43
42
  Requires-Dist: pydantic
44
43
  Requires-Dist: pyproj
45
- Requires-Dist: pathy
46
44
  Requires-Dist: pyaml_env
47
45
  Requires-Dist: pyresample
48
46
  Requires-Dist: h5netcdf
@@ -5,7 +5,7 @@ ocf_data_sampler/config/load.py,sha256=LL-7wemI8o4KPkx35j-wQ3HjsMvDgqXr7G46IcASf
5
5
  ocf_data_sampler/config/model.py,sha256=pb02qtCmWhJhrU3_T_gUzC7i2_JcO8xGwwhKGd8yMuk,10209
6
6
  ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdRaI,1064
7
7
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
- ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
8
+ ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
9
9
  ocf_data_sampler/load/gsp.py,sha256=keB3Nv_CNK1P6pS9Kdfc8PoZXTI1_YFN-spsvEv_Ewc,899
10
10
  ocf_data_sampler/load/load_dataset.py,sha256=0NyDxCDfgE_esKVW3s-rZEe16WB30FQ74ClWlrIo72M,1602
11
11
  ocf_data_sampler/load/satellite.py,sha256=E7Ln7Y60Qr1RTV-_R71YoxXQM-Ca7Y1faIo3oKB2eFk,2292
@@ -43,16 +43,17 @@ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nRUlhXQQGVrTuBmE1QnwXAUs
43
43
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
44
44
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
45
45
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=ZUEgn50g-GmqujOEtezNILF7wjokF80sDAA4OOldcRI,1268
46
- ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=zpCeUw3eljOnoJTSUYW2R4kiWrY6hbuXjK8igJrXgPg,2441
46
+ ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=8hDgaMg5Vb6eYitqYiljpAeTeTemwsYaRpZn7_3_XjI,7013
47
47
  ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_QuiYOuvIyulPrn6xEA4TY,309
48
48
  ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
49
49
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
50
50
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
51
51
  ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=1DN6VsWWdLvkpJxodZtBRDUgC4vJE2td_RP5J3ZqPNw,4268
52
52
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0WrglnX5YUAFjXXlO2GwEBHWyqo8TDuiOA,4714
53
+ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=2fwW-kpsMM2a-FWBG0YBT_r2LDIhhn7WokQ7GWvgx6U,3504
53
54
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
54
55
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
55
- ocf_data_sampler-0.2.12.dist-info/METADATA,sha256=G0rtK1M1DqIZrvtV0dOlbfumxxmOB8P96pXr61QTEfU,11628
56
- ocf_data_sampler-0.2.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
57
- ocf_data_sampler-0.2.12.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
58
- ocf_data_sampler-0.2.12.dist-info/RECORD,,
56
+ ocf_data_sampler-0.2.14.dist-info/METADATA,sha256=fb2tvDSt9FrsJBD8mZvjJ8YwTgp3OfUjavuZda3cblA,11581
57
+ ocf_data_sampler-0.2.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
58
+ ocf_data_sampler-0.2.14.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
59
+ ocf_data_sampler-0.2.14.dist-info/RECORD,,