ocf-data-sampler 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (31) hide show
  1. ocf_data_sampler/config/__init__.py +1 -1
  2. ocf_data_sampler/config/load.py +6 -17
  3. ocf_data_sampler/config/model.py +10 -20
  4. ocf_data_sampler/config/save.py +9 -62
  5. ocf_data_sampler/load/__init__.py +5 -1
  6. ocf_data_sampler/load/gsp.py +10 -6
  7. ocf_data_sampler/load/load_dataset.py +15 -17
  8. ocf_data_sampler/load/nwp/nwp.py +3 -4
  9. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -17
  10. ocf_data_sampler/load/nwp/providers/ukv.py +1 -9
  11. ocf_data_sampler/load/nwp/providers/utils.py +1 -5
  12. ocf_data_sampler/load/satellite.py +4 -8
  13. ocf_data_sampler/load/site.py +20 -13
  14. ocf_data_sampler/numpy_sample/collate.py +3 -4
  15. ocf_data_sampler/numpy_sample/datetime_features.py +14 -22
  16. ocf_data_sampler/sample/base.py +34 -3
  17. ocf_data_sampler/select/find_contiguous_time_periods.py +2 -2
  18. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -2
  19. ocf_data_sampler/torch_datasets/datasets/site.py +1 -1
  20. {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.7.dist-info}/METADATA +1 -1
  21. {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.7.dist-info}/RECORD +31 -30
  22. tests/config/test_config.py +1 -47
  23. tests/config/test_load.py +7 -0
  24. tests/config/test_save.py +21 -30
  25. tests/load/test_load_sites.py +1 -1
  26. tests/numpy_sample/test_datetime_features.py +0 -10
  27. tests/test_sample/test_base.py +63 -2
  28. tests/torch_datasets/test_site.py +3 -3
  29. {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.7.dist-info}/LICENSE +0 -0
  30. {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.7.dist-info}/WHEEL +0 -0
  31. {ocf_data_sampler-0.1.5.dist-info → ocf_data_sampler-0.1.7.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  """Configuration model"""
2
2
 
3
- from ocf_data_sampler.config.model import Configuration
3
+ from ocf_data_sampler.config.model import Configuration, InputData
4
4
  from ocf_data_sampler.config.save import save_yaml_configuration
5
5
  from ocf_data_sampler.config.load import load_yaml_configuration
@@ -1,33 +1,22 @@
1
- """Loading configuration functions.
2
-
3
- Example:
4
-
5
- from ocf_data_sampler.config import load_yaml_configuration
6
- configuration = load_yaml_configuration(filename)
7
- """
1
+ """Load configuration from a yaml file"""
8
2
 
9
3
  import fsspec
10
- from pathy import Pathy
11
4
  from pyaml_env import parse_config
12
-
13
5
  from ocf_data_sampler.config import Configuration
14
6
 
15
7
 
16
- def load_yaml_configuration(filename: str | Pathy) -> Configuration:
8
+ def load_yaml_configuration(filename: str) -> Configuration:
17
9
  """
18
10
  Load a yaml file which has a configuration in it
19
11
 
20
12
  Args:
21
- filename: the file name that you want to load. Will load from local, AWS, or GCP
13
+ filename: the yaml file name that you want to load. Will load from local, AWS, or GCP
22
14
  depending on the protocol suffix (e.g. 's3://bucket/config.yaml').
23
15
 
24
- Returns:pydantic class
16
+ Returns: pydantic class
25
17
 
26
18
  """
27
- # load the file to a dictionary
28
19
  with fsspec.open(filename, mode="r") as stream:
29
20
  configuration = parse_config(data=stream)
30
- # this means we can load ENVs in the yaml file
31
- # turn into pydantic class
32
- configuration = Configuration(**configuration)
33
- return configuration
21
+
22
+ return Configuration(**configuration)
@@ -1,16 +1,10 @@
1
1
  """Configuration model for the dataset.
2
2
 
3
- All paths must include the protocol prefix. For local files,
4
- it's sufficient to just start with a '/'. For aws, start with 's3://',
5
- for gcp start with 'gs://'.
6
3
 
7
- Example:
4
+ Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// to read from alternative filesystems.
8
5
 
9
- from ocf_data_sampler.config import Configuration
10
- config = Configuration(**config_dict)
11
6
  """
12
7
 
13
- import logging
14
8
  from typing import Dict, List, Optional
15
9
  from typing_extensions import Self
16
10
 
@@ -18,10 +12,6 @@ from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInf
18
12
 
19
13
  from ocf_data_sampler.constants import NWP_PROVIDERS
20
14
 
21
- logger = logging.getLogger(__name__)
22
-
23
- providers = ["pvoutput.org", "solar_sheffield_passiv"]
24
-
25
15
 
26
16
  class Base(BaseModel):
27
17
  """Pydantic Base model where no extras can be added"""
@@ -79,8 +69,6 @@ class TimeWindowMixin(Base):
79
69
  return v
80
70
 
81
71
 
82
-
83
- # noinspection PyMethodParameters
84
72
  class DropoutMixin(Base):
85
73
  """Mixin class, to add dropout minutes"""
86
74
 
@@ -137,7 +125,8 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
137
125
 
138
126
  zarr_path: str | tuple[str] | list[str] = Field(
139
127
  ...,
140
- description="The path or list of paths which hold the data zarr",
128
+ description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
129
+ "to read from alternative filesystems.",
141
130
  )
142
131
 
143
132
  channels: list[str] = Field(
@@ -145,13 +134,13 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
145
134
  )
146
135
 
147
136
 
148
- # noinspection PyMethodParameters
149
137
  class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
150
138
  """NWP configuration model"""
151
139
 
152
140
  zarr_path: str | tuple[str] | list[str] = Field(
153
141
  ...,
154
- description="The path or list of paths which hold the data zarr",
142
+ description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
143
+ "to read from alternative filesystems.",
155
144
  )
156
145
 
157
146
  channels: list[str] = Field(
@@ -175,7 +164,6 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
175
164
  """Validate 'provider'"""
176
165
  if v.lower() not in NWP_PROVIDERS:
177
166
  message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
178
- logger.warning(message)
179
167
  raise Exception(message)
180
168
  return v
181
169
 
@@ -209,7 +197,11 @@ class MultiNWP(RootModel):
209
197
  class GSP(TimeWindowMixin, DropoutMixin):
210
198
  """GSP configuration model"""
211
199
 
212
- zarr_path: str = Field(..., description="The path which holds the GSP zarr")
200
+ zarr_path: str = Field(
201
+ ...,
202
+ description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
203
+ "to read from alternative filesystems.",
204
+ )
213
205
 
214
206
 
215
207
  class Site(TimeWindowMixin, DropoutMixin):
@@ -228,8 +220,6 @@ class Site(TimeWindowMixin, DropoutMixin):
228
220
  # TODO validate the csv for metadata
229
221
 
230
222
 
231
-
232
- # noinspection PyPep8Naming
233
223
  class InputData(Base):
234
224
  """Input data model"""
235
225
 
@@ -2,25 +2,16 @@
2
2
 
3
3
  This module provides functionality to save configuration objects to YAML files,
4
4
  supporting local and cloud storage locations.
5
-
6
- Example:
7
- from ocf_data_sampler.config import save_yaml_configuration
8
- saved_path = save_yaml_configuration(config, "config.yaml")
9
5
  """
10
6
 
11
7
  import json
12
- from pathlib import Path
13
- from typing import Union
14
-
15
8
  import fsspec
16
9
  import yaml
10
+ import os
17
11
 
18
12
  from ocf_data_sampler.config import Configuration
19
13
 
20
- def save_yaml_configuration(
21
- configuration: Configuration,
22
- filename: Union[str, Path],
23
- ) -> Path:
14
+ def save_yaml_configuration(configuration: Configuration, filename: str) -> None:
24
15
  """Save a configuration object to a YAML file.
25
16
 
26
17
  Args:
@@ -28,57 +19,13 @@ def save_yaml_configuration(
28
19
  filename: Destination path for the YAML file. Can be a local path or
29
20
  cloud storage URL (e.g., 'gs://', 's3://'). For local paths,
30
21
  absolute paths are recommended.
31
-
32
- Returns:
33
- Path: The path where the configuration was saved
34
-
35
- Raises:
36
- ValueError: If filename is None, directory doesn't exist, or if writing to the specified path fails
37
- TypeError: If the configuration cannot be serialized
38
22
  """
39
- if filename is None:
40
- raise ValueError("filename cannot be None")
41
-
42
- try:
43
- # Convert to absolute path if it's a relative path
44
- if isinstance(filename, (str, Path)) and not any(
45
- str(filename).startswith(prefix) for prefix in ('gs://', 's3://', '/')
46
- ):
47
- filename = Path.cwd() / filename
48
-
49
- filepath = Path(filename)
50
-
51
- # For local paths, check if parent directory exists before attempting to create
52
- if filepath.is_absolute():
53
- if not filepath.parent.exists():
54
- raise ValueError("Directory does not exist")
55
-
56
- # Only try to create directory if it's in a writable location
57
- try:
58
- filepath.parent.mkdir(parents=True, exist_ok=True)
59
- except PermissionError:
60
- raise ValueError(f"Permission denied when accessing directory {filepath.parent}")
61
-
62
- # Serialize configuration to JSON-compatible dictionary
63
- config_dict = json.loads(configuration.model_dump_json())
64
-
65
- # Write to file directly for local paths
66
- if filepath.is_absolute():
67
- try:
68
- with open(filepath, 'w') as f:
69
- yaml.safe_dump(config_dict, f, default_flow_style=False)
70
- except PermissionError:
71
- raise ValueError(f"Permission denied when writing to {filename}")
72
- else:
73
- # Use fsspec for cloud storage
74
- with fsspec.open(str(filepath), mode='w') as yaml_file:
75
- yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
23
+
24
+ if os.path.exists(filename):
25
+ raise FileExistsError(f"File already exists: {filename}")
76
26
 
77
- return filepath
27
+ # Serialize configuration to JSON-compatible dictionary
28
+ config_dict = json.loads(configuration.model_dump_json())
78
29
 
79
- except json.JSONDecodeError as e:
80
- raise TypeError(f"Failed to serialize configuration: {str(e)}") from e
81
- except (IOError, OSError) as e:
82
- if "Permission denied" in str(e):
83
- raise ValueError(f"Permission denied when writing to {filename}") from e
84
- raise ValueError(f"Failed to write configuration to {filename}: {str(e)}") from e
30
+ with fsspec.open(filename, mode='w') as yaml_file:
31
+ yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
@@ -1 +1,5 @@
1
- from ocf_blosc2 import Blosc2 # noqa: F401
1
+ import ocf_blosc2
2
+ from ocf_data_sampler.load.gsp import open_gsp
3
+ from ocf_data_sampler.load.nwp import open_nwp
4
+ from ocf_data_sampler.load.satellite import open_sat_data
5
+ from ocf_data_sampler.load.site import open_site
@@ -1,16 +1,21 @@
1
- from pathlib import Path
2
1
  import pkg_resources
3
2
 
4
3
  import pandas as pd
5
4
  import xarray as xr
6
5
 
7
6
 
8
- def open_gsp(zarr_path: str | Path) -> xr.DataArray:
7
+ def open_gsp(zarr_path: str) -> xr.DataArray:
8
+ """Open the GSP data
9
+
10
+ Args:
11
+ zarr_path: Path to the GSP zarr data
12
+
13
+ Returns:
14
+ xr.DataArray: The opened GSP data
15
+ """
9
16
 
10
- # Load GSP generation xr.Dataset
11
17
  ds = xr.open_zarr(zarr_path)
12
18
 
13
- # Rename to standard time name
14
19
  ds = ds.rename({"datetime_gmt": "time_utc"})
15
20
 
16
21
  # Load UK GSP locations
@@ -19,13 +24,12 @@ def open_gsp(zarr_path: str | Path) -> xr.DataArray:
19
24
  index_col="gsp_id",
20
25
  )
21
26
 
22
- # Add coordinates
27
+ # Add locations and capacities as coordinates for each GSP and datetime
23
28
  ds = ds.assign_coords(
24
29
  x_osgb=(df_gsp_loc.x_osgb.to_xarray()),
25
30
  y_osgb=(df_gsp_loc.y_osgb.to_xarray()),
26
31
  nominal_capacity_mwp=ds.installedcapacity_mwp,
27
32
  effective_capacity_mwp=ds.capacity_mwp,
28
-
29
33
  )
30
34
 
31
35
  return ds.generation_mw
@@ -1,36 +1,31 @@
1
1
  """ Loads all data sources """
2
2
  import xarray as xr
3
3
 
4
- from ocf_data_sampler.config import Configuration
5
- from ocf_data_sampler.load.gsp import open_gsp
6
- from ocf_data_sampler.load.nwp import open_nwp
7
- from ocf_data_sampler.load.satellite import open_sat_data
8
- from ocf_data_sampler.load.site import open_site
4
+ from ocf_data_sampler.config import InputData
5
+ from ocf_data_sampler.load import open_nwp, open_gsp, open_sat_data, open_site
9
6
 
10
7
 
11
- def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
8
+ def get_dataset_dict(input_config: InputData) -> dict[str, dict[xr.DataArray] | xr.DataArray]:
12
9
  """Construct dictionary of all of the input data sources
13
10
 
14
11
  Args:
15
- config: Configuration file
12
+ input_config: InputData configuration object
16
13
  """
17
14
 
18
- in_config = config.input_data
19
-
20
15
  datasets_dict = {}
21
16
 
22
17
  # Load GSP data unless the path is None
23
- if in_config.gsp and in_config.gsp.zarr_path:
24
- da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()
18
+ if input_config.gsp and input_config.gsp.zarr_path:
19
+ da_gsp = open_gsp(zarr_path=input_config.gsp.zarr_path).compute()
25
20
 
26
21
  # Remove national GSP
27
22
  datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
28
23
 
29
24
  # Load NWP data if in config
30
- if in_config.nwp:
25
+ if input_config.nwp:
31
26
 
32
27
  datasets_dict["nwp"] = {}
33
- for nwp_source, nwp_config in in_config.nwp.items():
28
+ for nwp_source, nwp_config in input_config.nwp.items():
34
29
 
35
30
  da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
36
31
 
@@ -39,8 +34,8 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
39
34
  datasets_dict["nwp"][nwp_source] = da_nwp
40
35
 
41
36
  # Load satellite data if in config
42
- if in_config.satellite:
43
- sat_config = config.input_data.satellite
37
+ if input_config.satellite:
38
+ sat_config = input_config.satellite
44
39
 
45
40
  da_sat = open_sat_data(sat_config.zarr_path)
46
41
 
@@ -48,8 +43,11 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
48
43
 
49
44
  datasets_dict["sat"] = da_sat
50
45
 
51
- if in_config.site:
52
- da_sites = open_site(in_config.site)
46
+ if input_config.site:
47
+ da_sites = open_site(
48
+ generation_file_path=input_config.site.file_path,
49
+ metadata_file_path=input_config.site.metadata_file_path,
50
+ )
53
51
  datasets_dict["site"] = da_sites
54
52
 
55
53
  return datasets_dict
@@ -1,15 +1,14 @@
1
- from pathlib import Path
2
1
  import xarray as xr
3
2
 
4
3
  from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
5
4
  from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs
6
5
 
7
6
 
8
- def open_nwp(zarr_path: Path | str | list[Path] | list[str], provider: str) -> xr.DataArray:
9
- """Opens NWP Zarr
7
+ def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
8
+ """Opens NWP zarr
10
9
 
11
10
  Args:
12
- zarr_path: Path to the Zarr file
11
+ zarr_path: path to the zarr file
13
12
  provider: NWP provider
14
13
  """
15
14
 
@@ -1,5 +1,5 @@
1
1
  """ECMWF provider loaders"""
2
- from pathlib import Path
2
+
3
3
  import xarray as xr
4
4
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
5
5
  from ocf_data_sampler.load.utils import (
@@ -9,7 +9,7 @@ from ocf_data_sampler.load.utils import (
9
9
  )
10
10
 
11
11
 
12
- def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
12
+ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
13
13
  """
14
14
  Opens the ECMWF IFS NWP data
15
15
 
@@ -19,25 +19,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
19
19
  Returns:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
- # Open the data
23
- ds = open_zarr_paths(zarr_path)
24
-
25
- # Rename
26
- ds = ds.rename(
27
- {
28
- "init_time": "init_time_utc",
29
- }
30
- )
31
22
 
32
- # LEGACY SUPPORT
33
- # rename variable to channel if it exists
34
- if "variable" in ds:
35
- ds = ds.rename({"variable": "channel"})
23
+ ds = open_zarr_paths(zarr_path)
24
+
25
+ # LEGACY SUPPORT - rename variable to channel if it exists
26
+ ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
36
27
 
37
- # Check the timestamps are unique and increasing
38
28
  check_time_unique_increasing(ds.init_time_utc)
39
29
 
40
- # Make sure the spatial coords are in increasing order
41
30
  ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
42
31
 
43
32
  ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
@@ -2,8 +2,6 @@
2
2
 
3
3
  import xarray as xr
4
4
 
5
- from pathlib import Path
6
-
7
5
  from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
8
6
  from ocf_data_sampler.load.utils import (
9
7
  check_time_unique_increasing,
@@ -12,7 +10,7 @@ from ocf_data_sampler.load.utils import (
12
10
  )
13
11
 
14
12
 
15
- def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
13
+ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
16
14
  """
17
15
  Opens the NWP data
18
16
 
@@ -22,10 +20,8 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
22
20
  Returns:
23
21
  Xarray DataArray of the NWP data
24
22
  """
25
- # Open the data
26
23
  ds = open_zarr_paths(zarr_path)
27
24
 
28
- # Rename
29
25
  ds = ds.rename(
30
26
  {
31
27
  "init_time": "init_time_utc",
@@ -35,15 +31,11 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
35
31
  }
36
32
  )
37
33
 
38
- # Check the timestamps are unique and increasing
39
34
  check_time_unique_increasing(ds.init_time_utc)
40
35
 
41
- # Make sure the spatial coords are in increasing order
42
36
  ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
43
37
 
44
38
  ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
45
39
 
46
40
  # TODO: should we control the dtype of the DataArray?
47
41
  return get_xr_data_array_from_xr_dataset(ds)
48
-
49
-
@@ -1,11 +1,7 @@
1
- from pathlib import Path
2
1
  import xarray as xr
3
2
 
4
3
 
5
- def open_zarr_paths(
6
- zarr_path: Path | str | list[Path] | list[str],
7
- time_dim: str = "init_time"
8
- ) -> xr.Dataset:
4
+ def open_zarr_paths(zarr_path: str | list[str], time_dim: str = "init_time") -> xr.Dataset:
9
5
  """Opens the NWP data
10
6
 
11
7
  Args:
@@ -1,7 +1,6 @@
1
1
  """Satellite loader"""
2
2
 
3
3
  import subprocess
4
- from pathlib import Path
5
4
 
6
5
  import xarray as xr
7
6
  from ocf_data_sampler.load.utils import (
@@ -11,7 +10,7 @@ from ocf_data_sampler.load.utils import (
11
10
  )
12
11
 
13
12
 
14
- def _get_single_sat_data(zarr_path: Path | str) -> xr.Dataset:
13
+ def _get_single_sat_data(zarr_path: str) -> xr.Dataset:
15
14
  """Helper function to open a Zarr from either a local or GCP path.
16
15
 
17
16
  Args:
@@ -50,7 +49,7 @@ def _get_single_sat_data(zarr_path: Path | str) -> xr.Dataset:
50
49
  return ds
51
50
 
52
51
 
53
- def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
52
+ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
54
53
  """Lazily opens the Zarr store.
55
54
 
56
55
  Args:
@@ -69,7 +68,6 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra
69
68
  else:
70
69
  ds = _get_single_sat_data(zarr_path)
71
70
 
72
- # Rename dimensions
73
71
  ds = ds.rename(
74
72
  {
75
73
  "variable": "channel",
@@ -77,13 +75,11 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra
77
75
  }
78
76
  )
79
77
 
80
- # Check timestamps
81
78
  check_time_unique_increasing(ds.time_utc)
82
79
 
83
- # Ensure spatial coordinates are sorted
84
80
  ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
85
-
81
+
86
82
  ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
83
+
87
84
  # TODO: should we control the dtype of the DataArray?
88
-
89
85
  return get_xr_data_array_from_xr_dataset(ds)
@@ -1,30 +1,37 @@
1
+ import numpy as np
1
2
  import pandas as pd
2
3
  import xarray as xr
3
- import numpy as np
4
4
 
5
- from ocf_data_sampler.config.model import Site
6
5
 
6
+ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
7
+ """Open a site's generation data and metadata.
8
+
9
+ Args:
10
+ generation_file_path: Path to the site generation netcdf data
11
+ metadata_file_path: Path to the site csv metadata
7
12
 
8
- def open_site(sites_config: Site) -> xr.DataArray:
13
+ Returns:
14
+ xr.DataArray: The opened site generation data
15
+ """
9
16
 
10
- # Load site generation xr.Dataset
11
- site_generation_ds = xr.open_dataset(sites_config.file_path)
17
+ generation_ds = xr.open_dataset(generation_file_path)
12
18
 
13
- # Load site generation data
14
- metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")
19
+ metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
20
+
21
+ assert metadata_df.index.is_unique
15
22
 
16
23
  # Ensure metadata aligns with the site_id dimension in data_ds
17
- metadata_df = metadata_df.reindex(site_generation_ds.site_id.values)
24
+ metadata_df = metadata_df.reindex(generation_ds.site_id.values)
18
25
 
19
26
  # Assign coordinates to the Dataset using the aligned metadata
20
- site_generation_ds = site_generation_ds.assign_coords(
27
+ generation_ds = generation_ds.assign_coords(
21
28
  latitude=("site_id", metadata_df["latitude"].values),
22
29
  longitude=("site_id", metadata_df["longitude"].values),
23
30
  capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
24
31
  )
25
32
 
26
33
  # Sanity checks
27
- assert np.isfinite(site_generation_ds.capacity_kwp.values).all()
28
- assert (site_generation_ds.capacity_kwp.values > 0).all()
29
- assert metadata_df.index.is_unique
30
- return site_generation_ds.generation_kw
34
+ assert np.isfinite(generation_ds.capacity_kwp.values).all()
35
+ assert (generation_ds.capacity_kwp.values > 0).all()
36
+
37
+ return generation_ds.generation_kw
@@ -45,11 +45,12 @@ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
45
45
  return batch
46
46
 
47
47
 
48
- def _key_is_constant(key: str):
48
+ def _key_is_constant(key: str) -> bool:
49
+ """Check if a key is for value which is constant for all samples"""
49
50
  return key.endswith("t0_idx") or key.endswith("channel_names")
50
51
 
51
52
 
52
- def stack_data_list(data_list: list, key: str):
53
+ def stack_data_list(data_list: list, key: str) -> np.ndarray:
53
54
  """Stack a sequence of data elements along a new axis
54
55
 
55
56
  Args:
@@ -57,8 +58,6 @@ def stack_data_list(data_list: list, key: str):
57
58
  key: string identifying the data type
58
59
  """
59
60
  if _key_is_constant(key):
60
- # These are always the same for all examples.
61
61
  return data_list[0]
62
62
  else:
63
63
  return np.stack(data_list)
64
-
@@ -2,20 +2,21 @@
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from numpy.typing import NDArray
6
5
 
7
6
 
8
- def _get_date_time_in_pi(
9
- dt: pd.DatetimeIndex,
10
- ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
11
- """
12
- Change the datetimes, into time and date scaled in radians
7
+ def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
8
+ """Create positional embeddings for the datetimes in radians
9
+
10
+ Args:
11
+ dt: DatetimeIndex to create radian embeddings for
12
+
13
+ Returns:
14
+ Tuple of numpy arrays containing radian coordinates for date and time
13
15
  """
14
16
 
15
17
  day_of_year = dt.dayofyear
16
18
  minute_of_day = dt.minute + dt.hour * 60
17
19
 
18
- # converting into positions on sin-cos circle
19
20
  time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 60))
20
21
  date_in_pi = (2 * np.pi) * (day_of_year / 365)
21
22
 
@@ -23,24 +24,15 @@ def _get_date_time_in_pi(
23
24
 
24
25
 
25
26
  def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
26
- """ Make dictionary of datetime features"""
27
-
28
- if datetimes.empty:
29
- raise ValueError("Input datetimes is empty for 'make_datetime_numpy_dict' function")
30
-
31
- time_numpy_sample = {}
27
+ """ Creates dictionary of cyclical datetime features - encoded """
32
28
 
33
29
  date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
34
30
 
35
- # Store
36
- date_sin_batch_key = key_prefix + "_date_sin"
37
- date_cos_batch_key = key_prefix + "_date_cos"
38
- time_sin_batch_key = key_prefix + "_time_sin"
39
- time_cos_batch_key = key_prefix + "_time_cos"
31
+ time_numpy_sample = {}
40
32
 
41
- time_numpy_sample[date_sin_batch_key] = np.sin(date_in_pi)
42
- time_numpy_sample[date_cos_batch_key] = np.cos(date_in_pi)
43
- time_numpy_sample[time_sin_batch_key] = np.sin(time_in_pi)
44
- time_numpy_sample[time_cos_batch_key] = np.cos(time_in_pi)
33
+ time_numpy_sample[key_prefix + "_date_sin"] = np.sin(date_in_pi)
34
+ time_numpy_sample[key_prefix + "_date_cos"] = np.cos(date_in_pi)
35
+ time_numpy_sample[key_prefix + "_time_sin"] = np.sin(time_in_pi)
36
+ time_numpy_sample[key_prefix + "_time_cos"] = np.cos(time_in_pi)
45
37
 
46
38
  return time_numpy_sample
@@ -5,25 +5,34 @@ Handling of both flat and nested structures - consideration for NWP
5
5
 
6
6
  import logging
7
7
  import numpy as np
8
+ import torch
9
+ import xarray as xr
8
10
 
9
11
  from pathlib import Path
10
- from typing import Any, Dict, Optional, Union
12
+ from typing import Any, Dict, Optional, Union, TypeAlias
11
13
  from abc import ABC, abstractmethod
12
14
 
15
+
13
16
  logger = logging.getLogger(__name__)
14
17
 
18
+ NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
19
+ NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
20
+ TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
21
+
22
+
15
23
  class SampleBase(ABC):
16
24
  """
17
25
  Abstract base class for all sample types
18
26
  Provides core data storage functionality
19
27
  """
20
28
 
21
- def __init__(self):
29
+ def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
22
30
  """ Initialise data container """
23
31
  logger.debug("Initialising SampleBase instance")
32
+ self._data = data
24
33
 
25
34
  @abstractmethod
26
- def to_numpy(self) -> Dict[str, Any]:
35
+ def to_numpy(self) -> NumpySample:
27
36
  """ Convert data to a numpy array representation """
28
37
  raise NotImplementedError
29
38
 
@@ -42,3 +51,25 @@ class SampleBase(ABC):
42
51
  def load(cls, path: Union[str, Path]) -> 'SampleBase':
43
52
  """ Abstract class method for loading sample data """
44
53
  raise NotImplementedError
54
+
55
+
56
+ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
57
+ """
58
+ Moves ndarrays in a nested dict to torch tensors
59
+ Args:
60
+ batch: NumpyBatch with data in numpy arrays
61
+ Returns:
62
+ TensorBatch with data in torch tensors
63
+ """
64
+ if not batch:
65
+ raise ValueError("Cannot convert empty batch to tensors")
66
+
67
+ for k, v in batch.items():
68
+ if isinstance(v, dict):
69
+ batch[k] = batch_to_tensor(v)
70
+ elif isinstance(v, np.ndarray):
71
+ if v.dtype == np.bool_:
72
+ batch[k] = torch.tensor(v, dtype=torch.bool)
73
+ elif np.issubdtype(v.dtype, np.number):
74
+ batch[k] = torch.as_tensor(v)
75
+ return batch
@@ -2,6 +2,7 @@
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ from ocf_data_sampler.load.utils import check_time_unique_increasing
5
6
 
6
7
 
7
8
 
@@ -28,8 +29,7 @@ def find_contiguous_time_periods(
28
29
  # Sanity checks.
29
30
  assert len(datetimes) > 0
30
31
  assert min_seq_length > 1
31
- assert datetimes.is_monotonic_increasing
32
- assert datetimes.is_unique
32
+ check_time_unique_increasing(datetimes)
33
33
 
34
34
  # Find indices of gaps larger than max_gap:
35
35
  gap_mask = pd.TimedeltaIndex(np.diff(datetimes)) > max_gap_duration
@@ -187,7 +187,7 @@ class PVNetUKRegionalDataset(Dataset):
187
187
 
188
188
  config = load_yaml_configuration(config_filename)
189
189
 
190
- datasets_dict = get_dataset_dict(config)
190
+ datasets_dict = get_dataset_dict(config.input_data)
191
191
 
192
192
  # Get t0 times where all input data is available
193
193
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
@@ -295,7 +295,7 @@ class PVNetUKConcurrentDataset(Dataset):
295
295
 
296
296
  config = load_yaml_configuration(config_filename)
297
297
 
298
- datasets_dict = get_dataset_dict(config)
298
+ datasets_dict = get_dataset_dict(config.input_data)
299
299
 
300
300
  # Get t0 times where all input data is available
301
301
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
@@ -47,7 +47,7 @@ class SitesDataset(Dataset):
47
47
  """
48
48
 
49
49
  config: Configuration = load_yaml_configuration(config_filename)
50
- datasets_dict = get_dataset_dict(config)
50
+ datasets_dict = get_dataset_dict(config.input_data)
51
51
 
52
52
  # Assign config and input data to self
53
53
  self.datasets_dict = datasets_dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,39 +1,39 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
2
  ocf_data_sampler/constants.py,sha256=ClteRIgp7EPlUPqIbkel83BfIaD7_VIDjUeHzUfyhnM,5079
3
3
  ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
- ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
5
- ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
6
- ocf_data_sampler/config/model.py,sha256=sXmh7IadwXDT-7lxEl5_b3vjovZgZYR77EXy4GHaf4w,7276
7
- ocf_data_sampler/config/save.py,sha256=gB44isAZWUlCe3L6VBkLkngWC9GFpcCfAM57gy-0dkg,3156
4
+ ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
5
+ ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
6
+ ocf_data_sampler/config/model.py,sha256=IMJhsjL_oGh2c50q8pBnCnArY4qHQcBc_M8jqlEeD0c,7129
7
+ ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
8
8
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
9
- ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
10
- ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
- ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
12
- ocf_data_sampler/load/satellite.py,sha256=f2Q7FSyySOf7DeHxcigHd-vk-J-U4S2pXg_CnhnhuwU,2571
13
- ocf_data_sampler/load/site.py,sha256=P83uz01WBDzoZajdOH0m8FQt4-buKDlUk19N548KqhA,1086
9
+ ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
10
+ ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
11
+ ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
12
+ ocf_data_sampler/load/satellite.py,sha256=4MRJBFDHxx5WXu_6X71wEBznJTIuldEVnu9d6DVoLPI,2436
13
+ ocf_data_sampler/load/site.py,sha256=74M_7RYwEc1bU4idjs3ZmQrx9I8mJXm6H4lwEL-h9n0,1226
14
14
  ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
15
15
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
- ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
16
+ ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
17
17
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
19
- ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
20
- ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
18
+ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=8rYZKdV62AdczVNSOJ2G0BM4-fRFRV0_y5zkHgNYkQs,1004
19
+ ocf_data_sampler/load/nwp/providers/ukv.py,sha256=dM_kvUI0xk9xEdslXqZGjOPP96PEw3qAci5mPUgUvxA,1014
20
+ ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s5ebRx3_6DO-33FQ,780
21
21
  ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
22
- ocf_data_sampler/numpy_sample/collate.py,sha256=Onl_aKhsZ4pbFJsh70orjsHk523GHxrpRirH2vJq_GA,1911
23
- ocf_data_sampler/numpy_sample/datetime_features.py,sha256=U-9uRplfZ7VYFA4qBduI8OkG2x_65RYIP8wrLG4i-Nw,1441
22
+ ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
23
+ ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
24
24
  ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
25
25
  ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
26
26
  ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBunbhDN4Pjo0Grxs,910
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
30
+ ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
34
34
  ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
35
35
  ocf_data_sampler/select/fill_time_periods.py,sha256=h0XD1Ds_wUUoy-7bILxmN8AIbjlQ6YdXRKuCk_Is5jo,460
36
- ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=q7IaNfX95A3z9XHqbhgtkZ4Js1gn5K9Qyp6DVLbsL-Q,11093
36
+ ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=Nvz4gLCbbKzAe3sQXfxgExL9NtZVk1WNORvHs94DQ_k,11130
37
37
  ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
38
38
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
39
39
  ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
@@ -41,21 +41,22 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
42
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=4lqniFbUNt1qWSct4ISavXg9C7FM5cdVu48JHd7A9Pk,11873
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=5T8nkTMUHHFidZRuFOunYeKAqNuyZ8V7sikBoBOBwwA,16033
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=xuNJyCXZ4dZ9UldX1lqOoRSRNP39Vcy0DR77Vr7dxlk,11895
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=ZjvJS0mWUyQE7ZcrhS1TdMHaPrEZXVbBAv2vDwBvQwA,16044
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
48
  scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
49
49
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
50
  tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
51
- tests/config/test_config.py,sha256=Vq_kTL5tJcwEP-hXD_Nah5O6cgafo99iX6Fw1AN5NDY,5288
52
- tests/config/test_save.py,sha256=rA_XVxP1pOxB--5Ebujz4T5o-VbcrCbg2VSlSq2iI0o,1318
51
+ tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
52
+ tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
53
+ tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
53
54
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
54
55
  tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
55
56
  tests/load/test_load_satellite.py,sha256=IQ8ISRZKCEoi8IsJoPpXZJTolD0mwjnl2E7762RM_PM,524
56
- tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
57
+ tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU,405
57
58
  tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
58
- tests/numpy_sample/test_datetime_features.py,sha256=o4t3KeKFdGrOBQ77rNFcDuDMQSD23ileCS5T5AP3wG4,1769
59
+ tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
59
60
  tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
60
61
  tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
61
62
  tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
@@ -66,14 +67,14 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
66
67
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
67
68
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
68
69
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
69
- tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
70
+ tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
70
71
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
71
72
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
72
73
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
73
74
  tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
74
- tests/torch_datasets/test_site.py,sha256=5MH5zkHFJXekwpnV6nHuSxt_sRNu9_mxiUjfWqmEhr0,6966
75
- ocf_data_sampler-0.1.5.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
76
- ocf_data_sampler-0.1.5.dist-info/METADATA,sha256=PetECVCNM6jys05FuPsOVmntGurbxTuW3n1_j7CYCLE,12173
77
- ocf_data_sampler-0.1.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
78
- ocf_data_sampler-0.1.5.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
79
- ocf_data_sampler-0.1.5.dist-info/RECORD,,
75
+ tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
76
+ ocf_data_sampler-0.1.7.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
+ ocf_data_sampler-0.1.7.dist-info/METADATA,sha256=8SbL1qjkmeFDYdv1_hHBL9jxbSpt4aFCpx70rEEPeb0,12173
78
+ ocf_data_sampler-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
+ ocf_data_sampler-0.1.7.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
+ ocf_data_sampler-0.1.7.dist-info/RECORD,,
@@ -1,59 +1,13 @@
1
- import tempfile
2
-
3
1
  import pytest
4
2
  from pydantic import ValidationError
5
- from pathlib import Path
6
- from ocf_data_sampler.config import (
7
- load_yaml_configuration,
8
- Configuration,
9
- save_yaml_configuration
10
- )
3
+ from ocf_data_sampler.config import load_yaml_configuration, Configuration
11
4
 
12
5
 
13
6
  def test_default_configuration():
14
7
  """Test default pydantic class"""
15
-
16
8
  _ = Configuration()
17
9
 
18
10
 
19
- def test_load_yaml_configuration(test_config_filename):
20
- """
21
- Test that yaml loading works for 'test_config.yaml'
22
- and fails for an empty .yaml file
23
- """
24
- # Create temporary directory instead of file
25
- with tempfile.TemporaryDirectory() as temp_dir:
26
- # Create path for empty file
27
- empty_file = Path(temp_dir) / "empty.yaml"
28
-
29
- # Create an empty file
30
- empty_file.touch()
31
-
32
- # Test loading empty file
33
- with pytest.raises(TypeError):
34
- _ = load_yaml_configuration(str(empty_file))
35
-
36
- def test_yaml_save(test_config_filename):
37
- """
38
- Check configuration can be saved to a .yaml file
39
- """
40
- test_config = load_yaml_configuration(test_config_filename)
41
-
42
- with tempfile.TemporaryDirectory() as temp_dir:
43
- # Create path for config file
44
- config_path = Path(temp_dir) / "test_config.yaml"
45
-
46
- # Save configuration
47
- saved_path = save_yaml_configuration(test_config, config_path)
48
-
49
- # Verify file exists
50
- assert saved_path.exists()
51
-
52
- # Test loading saved configuration
53
- loaded_config = load_yaml_configuration(str(saved_path))
54
- assert loaded_config == test_config
55
-
56
-
57
11
  def test_extra_field_error():
58
12
  """
59
13
  Check an extra parameters in config causes error
@@ -0,0 +1,7 @@
1
+ from ocf_data_sampler.config import Configuration, load_yaml_configuration
2
+
3
+
4
+ def test_load_yaml_configuration(test_config_filename):
5
+ loaded_config = load_yaml_configuration(test_config_filename)
6
+ assert isinstance(loaded_config, Configuration)
7
+
tests/config/test_save.py CHANGED
@@ -1,37 +1,28 @@
1
1
  """Tests for configuration saving functionality."""
2
- import pytest
3
- from pathlib import Path
4
- import tempfile
5
- import yaml
2
+ import os
3
+ from ocf_data_sampler.config import Configuration, save_yaml_configuration, load_yaml_configuration
6
4
 
7
- from ocf_data_sampler.config import Configuration, save_yaml_configuration
8
5
 
9
- @pytest.fixture
10
- def temp_dir():
11
- """Create a temporary directory."""
12
- with tempfile.TemporaryDirectory() as tmpdirname:
13
- yield Path(tmpdirname)
14
-
15
- def test_save_yaml_configuration_basic(temp_dir):
16
- """Test basic configuration saving functionality."""
6
+ def test_save_yaml_configuration_basic(tmp_path):
7
+ """Save an empty configuration object"""
17
8
  config = Configuration()
18
- filepath = temp_dir / "config.yaml"
19
- result = save_yaml_configuration(config, filepath)
9
+
10
+ filepath = f"{tmp_path}/config.yaml"
11
+ save_yaml_configuration(config, filepath)
20
12
 
21
- assert filepath.exists()
22
- with open(filepath) as f:
23
- loaded_yaml = yaml.safe_load(f)
24
- assert isinstance(loaded_yaml, dict)
13
+ assert os.path.exists(filepath)
25
14
 
26
- def test_save_yaml_configuration_none_filename():
27
- """Test that None filename raises ValueError."""
28
- config = Configuration()
29
- with pytest.raises(ValueError, match="filename cannot be None"):
30
- save_yaml_configuration(config, None)
31
15
 
32
- def test_save_yaml_configuration_invalid_directory(temp_dir):
33
- """Test handling of invalid directory paths."""
34
- config = Configuration()
35
- invalid_path = (temp_dir / "nonexistent" / "config.yaml").resolve()
36
- with pytest.raises(ValueError, match="Directory does not exist"):
37
- save_yaml_configuration(config, invalid_path)
16
+ def test_save_load_yaml_configuration(tmp_path, test_config_filename):
17
+ """Make sure a saved configuration is the same after loading"""
18
+
19
+ # Start with this config
20
+ initial_config = load_yaml_configuration(test_config_filename)
21
+
22
+ # Save it
23
+ filepath = f"{tmp_path}/config.yaml"
24
+ save_yaml_configuration(initial_config, filepath)
25
+
26
+ # Load it and check it is still the same
27
+ loaded_config = load_yaml_configuration(filepath)
28
+ assert loaded_config == initial_config
@@ -3,7 +3,7 @@ import xarray as xr
3
3
 
4
4
 
5
5
  def test_open_site(data_sites):
6
- da = open_site(data_sites)
6
+ da = open_site(data_sites.file_path, data_sites.metadata_file_path)
7
7
 
8
8
  assert isinstance(da, xr.DataArray)
9
9
  assert da.dims == ("time_utc", "site_id")
@@ -35,13 +35,3 @@ def test_make_datetime_numpy_batch_custom_key_prefix():
35
35
  # Assert dict contains expected quantity of keys and verify starting with custom prefix
36
36
  assert len(datetime_features) == 4
37
37
  assert all(key.startswith(key_prefix) for key in datetime_features.keys())
38
-
39
-
40
- def test_make_datetime_numpy_batch_empty_input():
41
- # Verification that function raises error for empty input
42
- datetimes = pd.DatetimeIndex([])
43
-
44
- with pytest.raises(
45
- ValueError, match="Input datetimes is empty for 'make_datetime_numpy_dict' function"
46
- ):
47
- make_datetime_numpy_dict(datetimes)
@@ -3,11 +3,14 @@ Base class testing - SampleBase
3
3
  """
4
4
 
5
5
  import pytest
6
+ import torch
6
7
  import numpy as np
7
8
 
8
9
  from pathlib import Path
9
- from ocf_data_sampler.sample.base import SampleBase
10
-
10
+ from ocf_data_sampler.sample.base import (
11
+ SampleBase,
12
+ batch_to_tensor
13
+ )
11
14
 
12
15
  class TestSample(SampleBase):
13
16
  """
@@ -84,3 +87,61 @@ def test_sample_base_to_numpy():
84
87
  assert isinstance(numpy_data, dict)
85
88
  assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
86
89
  assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
90
+
91
+
92
+ def test_batch_to_tensor_nested():
93
+ """ Test nested dictionary conversion """
94
+ batch = {
95
+ 'outer': {
96
+ 'inner': np.array([1, 2, 3])
97
+ }
98
+ }
99
+ tensor_batch = batch_to_tensor(batch)
100
+
101
+ assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
102
+
103
+
104
+ def test_batch_to_tensor_mixed_types():
105
+ """ Test handling of mixed data types """
106
+ batch = {
107
+ 'tensor_data': np.array([1, 2, 3]),
108
+ 'string_data': 'not_a_tensor',
109
+ 'nested': {
110
+ 'numbers': np.array([4, 5, 6]),
111
+ 'text': 'still_not_a_tensor'
112
+ }
113
+ }
114
+ tensor_batch = batch_to_tensor(batch)
115
+
116
+ assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
117
+ assert isinstance(tensor_batch['string_data'], str)
118
+ assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
119
+ assert isinstance(tensor_batch['nested']['text'], str)
120
+
121
+
122
+ def test_batch_to_tensor_different_dtypes():
123
+ """ Test conversion of arrays with different dtypes """
124
+ batch = {
125
+ 'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
126
+ 'int_data': np.array([1, 2, 3], dtype=np.int64),
127
+ 'bool_data': np.array([True, False, True], dtype=np.bool_)
128
+ }
129
+ tensor_batch = batch_to_tensor(batch)
130
+
131
+ assert isinstance(tensor_batch['bool_data'], torch.Tensor)
132
+ assert tensor_batch['float_data'].dtype == torch.float32
133
+ assert tensor_batch['int_data'].dtype == torch.int64
134
+ assert tensor_batch['bool_data'].dtype == torch.bool
135
+
136
+
137
+ def test_batch_to_tensor_multidimensional():
138
+ """ Test conversion of multidimensional arrays """
139
+ batch = {
140
+ 'matrix': np.array([[1, 2], [3, 4]]),
141
+ 'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
142
+ }
143
+ tensor_batch = batch_to_tensor(batch)
144
+
145
+ assert tensor_batch['matrix'].shape == (2, 2)
146
+ assert tensor_batch['tensor'].shape == (2, 2, 2)
147
+ assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
@@ -33,7 +33,7 @@ def sites_dataset(site_config_filename):
33
33
  return SitesDataset(site_config_filename)
34
34
 
35
35
 
36
- def test_site(site_config_filename):
36
+ def test_site(tmp_path, site_config_filename):
37
37
 
38
38
  # Create dataset object
39
39
  dataset = SitesDataset(site_config_filename)
@@ -71,8 +71,8 @@ def test_site(site_config_filename):
71
71
  expected_data_vars = {"nwp-ukv", "satellite", "site"}
72
72
 
73
73
 
74
- sample.to_netcdf("sample.nc")
75
- sample = xr.open_dataset("sample.nc")
74
+ sample.to_netcdf(f"{tmp_path}/sample.nc")
75
+ sample = xr.open_dataset(f"{tmp_path}/sample.nc")
76
76
 
77
77
  # Check dimensions
78
78
  assert (