ocf-data-sampler 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (29) hide show
  1. ocf_data_sampler/config/model.py +25 -23
  2. ocf_data_sampler/load/satellite.py +21 -29
  3. ocf_data_sampler/load/site.py +1 -1
  4. ocf_data_sampler/numpy_sample/gsp.py +6 -2
  5. ocf_data_sampler/numpy_sample/nwp.py +7 -13
  6. ocf_data_sampler/numpy_sample/satellite.py +11 -8
  7. ocf_data_sampler/numpy_sample/site.py +6 -2
  8. ocf_data_sampler/numpy_sample/sun_position.py +9 -10
  9. ocf_data_sampler/sample/__init__.py +0 -7
  10. ocf_data_sampler/sample/base.py +16 -35
  11. ocf_data_sampler/sample/site.py +28 -65
  12. ocf_data_sampler/sample/uk_regional.py +52 -97
  13. ocf_data_sampler/select/dropout.py +38 -25
  14. ocf_data_sampler/select/fill_time_periods.py +3 -1
  15. ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
  16. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
  17. ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
  18. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
  19. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +29 -29
  20. tests/config/test_config.py +3 -3
  21. tests/conftest.py +33 -0
  22. tests/numpy_sample/test_nwp.py +3 -42
  23. tests/select/test_dropout.py +7 -13
  24. tests/test_sample/test_site_sample.py +5 -35
  25. tests/test_sample/test_uk_regional_sample.py +8 -35
  26. tests/torch_datasets/test_pvnet_uk.py +6 -19
  27. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
  28. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
  29. {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
@@ -49,31 +49,34 @@ class TimeWindowMixin(Base):
49
49
  ...,
50
50
  description="Data interval ends at `t0 + interval_end_minutes`",
51
51
  )
52
-
52
+
53
53
  @model_validator(mode='after')
54
- def check_interval_range(cls, values):
55
- if values.interval_start_minutes > values.interval_end_minutes:
56
- raise ValueError('interval_start_minutes must be <= interval_end_minutes')
54
+ def validate_intervals(cls, values):
55
+ start = values.interval_start_minutes
56
+ end = values.interval_end_minutes
57
+ resolution = values.time_resolution_minutes
58
+ if start > end:
59
+ raise ValueError(
60
+ f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
61
+ )
62
+ if (start % resolution != 0):
63
+ raise ValueError(
64
+ f"interval_start_minutes ({start}) must be divisible "
65
+ f"by time_resolution_minutes ({resolution})"
66
+ )
67
+ if (end % resolution != 0):
68
+ raise ValueError(
69
+ f"interval_end_minutes ({end}) must be divisible "
70
+ f"by time_resolution_minutes ({resolution})"
71
+ )
57
72
  return values
58
73
 
59
- @field_validator("interval_start_minutes")
60
- def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
61
- if v % info.data["time_resolution_minutes"] != 0:
62
- raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
63
- return v
64
-
65
- @field_validator("interval_end_minutes")
66
- def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
67
- if v % info.data["time_resolution_minutes"] != 0:
68
- raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
69
- return v
70
-
71
74
 
72
75
  class DropoutMixin(Base):
73
76
  """Mixin class, to add dropout minutes"""
74
77
 
75
- dropout_timedeltas_minutes: Optional[List[int]] = Field(
76
- default=None,
78
+ dropout_timedeltas_minutes: List[int] = Field(
79
+ default=[],
77
80
  description="List of possible minutes before t0 where data availability may start. Must be "
78
81
  "negative or zero.",
79
82
  )
@@ -88,18 +91,17 @@ class DropoutMixin(Base):
88
91
  @field_validator("dropout_timedeltas_minutes")
89
92
  def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
90
93
  """Validate 'dropout_timedeltas_minutes'"""
91
- if v is not None:
92
- for m in v:
93
- assert m <= 0, "Dropout timedeltas must be negative"
94
+ for m in v:
95
+ assert m <= 0, "Dropout timedeltas must be negative"
94
96
  return v
95
97
 
96
98
  @model_validator(mode="after")
97
99
  def dropout_instructions_consistent(self) -> Self:
98
100
  if self.dropout_fraction == 0:
99
- if self.dropout_timedeltas_minutes is not None:
101
+ if self.dropout_timedeltas_minutes != []:
100
102
  raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
101
103
  else:
102
- if self.dropout_timedeltas_minutes is None:
104
+ if self.dropout_timedeltas_minutes == []:
103
105
  raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
104
106
  return self
105
107
 
@@ -1,7 +1,5 @@
1
1
  """Satellite loader"""
2
2
 
3
- import subprocess
4
-
5
3
  import xarray as xr
6
4
  from ocf_data_sampler.load.utils import (
7
5
  check_time_unique_increasing,
@@ -10,63 +8,59 @@ from ocf_data_sampler.load.utils import (
10
8
  )
11
9
 
12
10
 
13
- def _get_single_sat_data(zarr_path: str) -> xr.Dataset:
14
- """Helper function to open a Zarr from either a local or GCP path.
11
+ def get_single_sat_data(zarr_path: str) -> xr.Dataset:
12
+ """Helper function to open a zarr from either a local or GCP path
15
13
 
16
14
  Args:
17
- zarr_path: Path to a Zarr file. Wildcards (*) are supported **only** for local paths.
18
- GCS paths (gs://) **do not support** wildcards.
15
+ zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
16
+ GCS paths (gs://) do not support wildcards
19
17
 
20
18
  Returns:
21
- An xarray Dataset containing satellite data.
19
+ An xarray Dataset containing satellite data
22
20
 
23
21
  Raises:
24
- ValueError: If a wildcard (*) is used in a GCS (gs://) path.
22
+ ValueError: If a wildcard (*) is used in a GCS (gs://) path
25
23
  """
26
24
 
27
- # These kwargs are used if the path contains "*"
28
- openmf_kwargs = dict(
29
- engine="zarr",
30
- concat_dim="time",
31
- combine="nested",
32
- chunks="auto",
33
- join="override",
34
- )
35
-
36
25
  # Raise an error if a wildcard is used in a GCP path
37
26
  if "gs://" in str(zarr_path) and "*" in str(zarr_path):
38
- raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs.")
27
+ raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
39
28
 
40
29
  # Handle multi-file dataset for local paths
41
30
  if "*" in str(zarr_path):
42
- ds = xr.open_mfdataset(zarr_path, **openmf_kwargs)
31
+ ds = xr.open_mfdataset(
32
+ zarr_path,
33
+ engine="zarr",
34
+ concat_dim="time",
35
+ combine="nested",
36
+ chunks="auto",
37
+ join="override",
38
+ )
39
+ check_time_unique_increasing(ds.time)
43
40
  else:
44
41
  ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
45
42
 
46
- # Ensure time is unique and sorted
47
- ds = ds.drop_duplicates("time").sortby("time")
48
-
49
43
  return ds
50
44
 
51
45
 
52
46
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
53
- """Lazily opens the Zarr store.
47
+ """Lazily opens the zarr store
54
48
 
55
49
  Args:
56
- zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL, it must start with
57
- 'gs://'.
50
+ zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
51
+ it must start with 'gs://'
58
52
  """
59
53
 
60
54
  # Open the data
61
55
  if isinstance(zarr_path, (list, tuple)):
62
56
  ds = xr.combine_nested(
63
- [_get_single_sat_data(path) for path in zarr_path],
57
+ [get_single_sat_data(path) for path in zarr_path],
64
58
  concat_dim="time",
65
59
  combine_attrs="override",
66
60
  join="override",
67
61
  )
68
62
  else:
69
- ds = _get_single_sat_data(zarr_path)
63
+ ds = get_single_sat_data(zarr_path)
70
64
 
71
65
  ds = ds.rename(
72
66
  {
@@ -76,9 +70,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
76
70
  )
77
71
 
78
72
  check_time_unique_increasing(ds.time_utc)
79
-
80
73
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
81
-
82
74
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
83
75
 
84
76
  # TODO: should we control the dtype of the DataArray?
@@ -20,7 +20,7 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
20
20
 
21
21
  assert metadata_df.index.is_unique
22
22
 
23
- # Ensure metadata aligns with the site_id dimension in data_ds
23
+ # Ensure metadata aligns with the site_id dimension in generation_ds
24
24
  metadata_df = metadata_df.reindex(generation_ds.site_id.values)
25
25
 
26
26
  # Assign coordinates to the Dataset using the aligned metadata
@@ -18,9 +18,13 @@ class GSPSampleKey:
18
18
 
19
19
 
20
20
  def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
21
+ """Convert from Xarray to NumpySample
22
+
23
+ Args:
24
+ da: Xarray DataArray containing GSP data
25
+ t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
26
+ """
22
27
 
23
- # Extract values from the DataArray
24
28
  sample = {
25
29
  GSPSampleKey.gsp: da.values,
26
30
  GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
@@ -12,30 +12,24 @@ class NWPSampleKey:
12
12
  step = 'nwp_step'
13
13
  target_time_utc = 'nwp_target_time_utc'
14
14
  t0_idx = 'nwp_t0_idx'
15
- y_osgb = 'nwp_y_osgb'
16
- x_osgb = 'nwp_x_osgb'
17
-
18
15
 
19
16
 
20
17
  def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NWP NumpySample"""
18
+ """Convert from Xarray to NWP NumpySample
19
+
20
+ Args:
21
+ da: Xarray DataArray containing NWP data
22
+ t0_idx: Index of the t0 timestamp in the time dimension of the NWP
23
+ """
22
24
 
23
- # Create example and add t if available
24
25
  sample = {
25
26
  NWPSampleKey.nwp: da.values,
26
27
  NWPSampleKey.channel_names: da.channel.values,
27
28
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
28
29
  NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
30
+ NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
29
31
  }
30
32
 
31
- if "target_time_utc" in da.coords:
32
- sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
33
-
34
- # TODO: Do we need this at all? Especially since it is only present in UKV data
35
- for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
36
- if dataset_key in da.coords:
37
- sample[sample_key] = da[dataset_key].values
38
-
39
33
  if t0_idx is not None:
40
34
  sample[NWPSampleKey.t0_idx] = t0_idx
41
35
 
@@ -1,4 +1,5 @@
1
1
  """Convert Satellite to NumpySample"""
2
+
2
3
  import xarray as xr
3
4
 
4
5
 
@@ -12,19 +13,21 @@ class SatelliteSampleKey:
12
13
 
13
14
 
14
15
  def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
15
- """Convert from Xarray to NumpySample"""
16
+ """Convert from Xarray to NumpySample
17
+
18
+ Args:
19
+ da: xarray DataArray containing satellite data
20
+ t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
21
+ """
16
22
  sample = {
17
23
  SatelliteSampleKey.satellite_actual: da.values,
18
24
  SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
25
+ SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
26
+ SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
19
27
  }
20
28
 
21
- for sample_key, dataset_key in (
22
- (SatelliteSampleKey.x_geostationary, "x_geostationary"),
23
- (SatelliteSampleKey.y_geostationary, "y_geostationary"),
24
- ):
25
- sample[sample_key] = da[dataset_key].values
26
-
27
29
  if t0_idx is not None:
28
30
  sample[SatelliteSampleKey.t0_idx] = t0_idx
29
31
 
30
- return sample
32
+ return sample
33
+
@@ -18,9 +18,13 @@ class SiteSampleKey:
18
18
  time_cos = "site_time_cos"
19
19
 
20
20
  def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
21
- """Convert from Xarray to NumpySample"""
21
+ """Convert from Xarray to NumpySample
22
+
23
+ Args:
24
+ da: xarray DataArray containing site data
25
+ t0_idx: Index of the t0 timestamp in the time dimension of the site data
26
+ """
22
27
 
23
- # Extract values from the DataArray
24
28
  sample = {
25
29
  SiteSampleKey.generation: da.values,
26
30
  SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
@@ -27,16 +27,15 @@ def calculate_azimuth_and_elevation(
27
27
  latitude=lat,
28
28
  method='nrel_numpy'
29
29
  )
30
- azimuth = solpos["azimuth"].values
31
- elevation = solpos["elevation"].values
32
- return azimuth, elevation
30
+
31
+ return solpos["azimuth"].values, solpos["elevation"].values
33
32
 
34
33
 
35
34
  def make_sun_position_numpy_sample(
36
- datetimes: pd.DatetimeIndex,
37
- lon: float,
38
- lat: float,
39
- key_prefix: str = "gsp"
35
+ datetimes: pd.DatetimeIndex,
36
+ lon: float,
37
+ lat: float,
38
+ key_prefix: str = "gsp"
40
39
  ) -> dict:
41
40
  """Creates NumpySample with standardized solar coordinates
42
41
 
@@ -44,22 +43,22 @@ def make_sun_position_numpy_sample(
44
43
  datetimes: The datetimes to calculate solar angles for
45
44
  lon: The longitude
46
45
  lat: The latitude
46
+ key_prefix: The prefix to add to the keys in the NumpySample
47
47
  """
48
48
 
49
49
  azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
50
50
 
51
51
  # Normalise
52
-
53
52
  # Azimuth is in range [0, 360] degrees
54
53
  azimuth = azimuth / 360
55
54
 
56
55
  # Elevation is in range [-90, 90] degrees
57
56
  elevation = elevation / 180 + 0.5
58
-
57
+
59
58
  # Make NumpySample
60
59
  sun_numpy_sample = {
61
60
  key_prefix + "_solar_azimuth": azimuth,
62
61
  key_prefix + "_solar_elevation": elevation,
63
62
  }
64
63
 
65
- return sun_numpy_sample
64
+ return sun_numpy_sample
@@ -1,10 +1,3 @@
1
1
  from ocf_data_sampler.sample.base import SampleBase
2
2
  from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
3
  from ocf_data_sampler.sample.site import SiteSample
4
-
5
-
6
- __all__ = [
7
- 'SampleBase',
8
- 'UKRegionalSample',
9
- 'SiteSample'
10
- ]
@@ -1,23 +1,15 @@
1
- """
2
- Base class definition - abstract
3
- Handling of both flat and nested structures - consideration for NWP
4
- """
1
+ """ Base class for handling flat/nested data structures with NWP consideration """
5
2
 
6
- import logging
7
3
  import numpy as np
8
4
  import torch
9
- import xarray as xr
10
5
 
11
- from pathlib import Path
12
- from typing import Any, Dict, Optional, Union, TypeAlias
6
+ from typing import TypeAlias
13
7
  from abc import ABC, abstractmethod
14
8
 
15
9
 
16
- logger = logging.getLogger(__name__)
17
-
18
- NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
19
- NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
20
- TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
10
+ NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
11
+ NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
12
+ TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
21
13
 
22
14
 
23
15
  class SampleBase(ABC):
@@ -26,43 +18,33 @@ class SampleBase(ABC):
26
18
  Provides core data storage functionality
27
19
  """
28
20
 
29
- def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
30
- """ Initialise data container """
31
- logger.debug("Initialising SampleBase instance")
32
- self._data = data
33
-
34
21
  @abstractmethod
35
22
  def to_numpy(self) -> NumpySample:
36
- """ Convert data to a numpy array representation """
23
+ """Convert sample data to numpy format"""
37
24
  raise NotImplementedError
38
25
 
39
26
  @abstractmethod
40
- def plot(self, **kwargs) -> None:
41
- """ Abstract method for plotting """
27
+ def plot(self) -> None:
42
28
  raise NotImplementedError
43
29
 
44
30
  @abstractmethod
45
- def save(self, path: Union[str, Path]) -> None:
46
- """ Abstract method for saving sample data """
31
+ def save(self, path: str) -> None:
47
32
  raise NotImplementedError
48
33
 
49
34
  @classmethod
50
35
  @abstractmethod
51
- def load(cls, path: Union[str, Path]) -> 'SampleBase':
52
- """ Abstract class method for loading sample data """
36
+ def load(cls, path: str) -> 'SampleBase':
53
37
  raise NotImplementedError
54
38
 
55
39
 
56
40
  def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
57
41
  """
58
- Moves ndarrays in a nested dict to torch tensors
42
+ Recursively converts numpy arrays in nested dict to torch tensors
59
43
  Args:
60
44
  batch: NumpyBatch with data in numpy arrays
61
45
  Returns:
62
46
  TensorBatch with data in torch tensors
63
47
  """
64
- if not batch:
65
- raise ValueError("Cannot convert empty batch to tensors")
66
48
 
67
49
  for k, v in batch.items():
68
50
  if isinstance(v, dict):
@@ -75,16 +57,15 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
75
57
  return batch
76
58
 
77
59
 
78
- def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
- """
80
- Moves tensor leaves in a nested dict to a new device.
60
+ def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
61
+ """Recursively copies tensors in nested dict to specified device
81
62
 
82
63
  Args:
83
- batch: Nested dict with tensors to move.
84
- device: Device to move tensors to.
85
-
64
+ batch: Nested dict with tensors to move
65
+ device: Device to move tensors to
66
+
86
67
  Returns:
87
- A dict with tensors moved to the new device.
68
+ A dict with tensors moved to the new device
88
69
  """
89
70
  batch_copy = {}
90
71
 
@@ -1,81 +1,44 @@
1
- """
2
- PVNet - Site sample / dataset implementation
3
- """
1
+ """PVNet Site sample implementation for netCDF data handling and conversion"""
4
2
 
5
- import logging
6
3
  import xarray as xr
7
- import numpy as np
8
4
 
9
- from pathlib import Path
10
- from typing import Dict, Any, Union
5
+ from typing_extensions import override
11
6
 
12
- from ocf_data_sampler.sample.base import SampleBase
7
+ from ocf_data_sampler.sample.base import SampleBase, NumpySample
13
8
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
14
9
 
15
10
 
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
11
  class SiteSample(SampleBase):
20
- """ Sample class specific to Site PVNet """
12
+ """Handles PVNet site specific netCDF operations"""
21
13
 
22
- def __init__(self):
23
- logger.debug("Initialise SiteSample instance")
24
- super().__init__()
25
- self._data = {}
26
-
27
- def to_numpy(self) -> Dict[str, Any]:
28
- """ Convert sample numpy arrays - netCDF conversion """
29
- logger.debug("Converting site sample to numpy format")
14
+ def __init__(self, data: xr.Dataset):
15
+
16
+ if not isinstance(data, xr.Dataset):
17
+ raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
30
18
 
31
- try:
32
- if not isinstance(self._data, xr.Dataset):
33
- raise TypeError("Data must be xarray Dataset")
34
-
35
- numpy_data = convert_netcdf_to_numpy_sample(self._data)
19
+ self._data = data
36
20
 
37
- logger.debug("Successfully converted to numpy format")
38
- return numpy_data
39
-
40
- except Exception as e:
41
- logger.error(f"Error converting to numpy: {str(e)}")
42
- raise
21
+ @override
22
+ def to_numpy(self) -> NumpySample:
23
+ return convert_netcdf_to_numpy_sample(self._data)
43
24
 
44
- def save(self, path: Union[str, Path]) -> None:
45
- """ Save site sample as netCDF - h5netcdf engine """
46
- logger.debug(f"Saving SiteSample to {path}")
47
- path = Path(path)
25
+ def save(self, path: str) -> None:
26
+ """Save site sample data as netCDF
48
27
 
49
- if path.suffix != '.nc':
50
- logger.error(f"Invalid file format - {path.suffix}")
51
- raise ValueError("Only .nc format is supported")
52
-
53
- if not isinstance(self._data, xr.Dataset):
54
- raise TypeError("Data must be xarray Dataset for saving")
55
-
56
- self._data.to_netcdf(
57
- path,
58
- mode="w",
59
- engine="h5netcdf"
60
- )
61
- logger.debug(f"Successfully saved SiteSample - {path}")
28
+ Args:
29
+ path: Path to save the netCDF file
30
+ """
31
+ self._data.to_netcdf(path, mode="w", engine="h5netcdf")
62
32
 
63
33
  @classmethod
64
- def load(cls, path: str) -> None:
65
- """ Load site sample from netCDF """
66
- logger.debug(f"Loading SiteSample from {path}")
67
- path = Path(path)
68
-
69
- if path.suffix != '.nc':
70
- logger.error(f"Invalid file format - {path.suffix}")
71
- raise ValueError("Only .nc format is supported")
34
+ def load(cls, path: str) -> 'SiteSample':
35
+ """Load site sample data from netCDF
72
36
 
73
- instance = cls()
74
- instance._data = xr.open_dataset(path)
75
- logger.debug(f"Loaded SiteSample from {path}")
76
- return instance
77
-
78
- # TO DO - placeholder for now
79
- def plot(self, **kwargs) -> None:
80
- """ Plot sample data - placeholder """
81
- pass
37
+ Args:
38
+ path: Path to load the netCDF file from
39
+ """
40
+ return cls(xr.open_dataset(path))
41
+
42
+ # TODO - placeholder for now
43
+ def plot(self) -> None:
44
+ raise NotImplementedError("Plotting not yet implemented for SiteSample")