ocf-data-sampler 0.0.44__tar.gz → 0.0.46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (81) hide show
  1. {ocf_data_sampler-0.0.44/ocf_data_sampler.egg-info → ocf_data_sampler-0.0.46}/PKG-INFO +3 -2
  2. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/README.md +2 -1
  3. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/config/save.py +22 -11
  4. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/__init__.py +1 -0
  5. ocf_data_sampler-0.0.46/ocf_data_sampler/numpy_sample/datetime_features.py +46 -0
  6. {ocf_data_sampler-0.0.44/ocf_data_sampler/torch_datasets → ocf_data_sampler-0.0.46/ocf_data_sampler/torch_datasets/datasets}/pvnet_uk_regional.py +102 -4
  7. {ocf_data_sampler-0.0.44/ocf_data_sampler/torch_datasets → ocf_data_sampler-0.0.46/ocf_data_sampler/torch_datasets/datasets}/site.py +23 -5
  8. ocf_data_sampler-0.0.46/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +25 -0
  9. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46/ocf_data_sampler.egg-info}/PKG-INFO +3 -2
  10. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler.egg-info/SOURCES.txt +8 -6
  11. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/pyproject.toml +1 -1
  12. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/config/test_config.py +25 -27
  13. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/conftest.py +2 -2
  14. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/numpy_sample/test_collate.py +1 -1
  15. ocf_data_sampler-0.0.46/tests/numpy_sample/test_datetime_features.py +47 -0
  16. ocf_data_sampler-0.0.46/tests/torch_datasets/test_merge_and_fill_utils.py +42 -0
  17. ocf_data_sampler-0.0.44/tests/torch_datasets/test_process_and_combine.py → ocf_data_sampler-0.0.46/tests/torch_datasets/test_pvnet_uk_regional.py +57 -47
  18. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/torch_datasets/test_site.py +4 -4
  19. ocf_data_sampler-0.0.44/ocf_data_sampler/torch_datasets/process_and_combine.py +0 -131
  20. ocf_data_sampler-0.0.44/tests/torch_datasets/test_pvnet_uk_regional.py +0 -59
  21. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/LICENSE +0 -0
  22. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/MANIFEST.in +0 -0
  23. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/__init__.py +0 -0
  24. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/config/__init__.py +0 -0
  25. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/config/load.py +0 -0
  26. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/config/model.py +0 -0
  27. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/constants.py +0 -0
  28. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  29. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/__init__.py +0 -0
  30. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/gsp.py +0 -0
  31. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/load_dataset.py +0 -0
  32. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  33. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  34. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  35. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  36. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  37. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  38. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/satellite.py +0 -0
  39. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/site.py +0 -0
  40. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/load/utils.py +0 -0
  41. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  42. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  43. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  44. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  45. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/site.py +0 -0
  46. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  47. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/__init__.py +0 -0
  48. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/dropout.py +0 -0
  49. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  50. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  51. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/geospatial.py +0 -0
  52. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/location.py +0 -0
  53. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  54. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/select_time_slice.py +0 -0
  55. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  56. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
  57. {ocf_data_sampler-0.0.44/ocf_data_sampler/torch_datasets → ocf_data_sampler-0.0.46/ocf_data_sampler/torch_datasets/datasets}/__init__.py +0 -0
  58. {ocf_data_sampler-0.0.44/ocf_data_sampler/torch_datasets → ocf_data_sampler-0.0.46/ocf_data_sampler/torch_datasets/utils}/valid_time_periods.py +0 -0
  59. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler/utils.py +0 -0
  60. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  61. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler.egg-info/requires.txt +0 -0
  62. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  63. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/scripts/refactor_site.py +0 -0
  64. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/setup.cfg +0 -0
  65. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/__init__.py +0 -0
  66. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/config/test_save.py +0 -0
  67. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/load/test_load_gsp.py +0 -0
  68. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/load/test_load_nwp.py +0 -0
  69. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/load/test_load_satellite.py +0 -0
  70. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/load/test_load_sites.py +0 -0
  71. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/numpy_sample/test_gsp.py +0 -0
  72. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/numpy_sample/test_nwp.py +0 -0
  73. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/numpy_sample/test_satellite.py +0 -0
  74. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/numpy_sample/test_sun_position.py +0 -0
  75. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_dropout.py +0 -0
  76. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_fill_time_periods.py +0 -0
  77. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_find_contiguous_time_periods.py +0 -0
  78. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_location.py +0 -0
  79. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_select_spatial_slice.py +0 -0
  80. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/select/test_select_time_slice.py +0 -0
  81. {ocf_data_sampler-0.0.44 → ocf_data_sampler-0.0.46}/tests/torch_datasets/conftest.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.44
3
+ Version: 0.0.46
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -56,7 +56,7 @@ Requires-Dist: mkdocs-material>=8.0; extra == "docs"
56
56
  # ocf-data-sampler
57
57
 
58
58
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
59
- [![All Contributors](https://img.shields.io/badge/all_contributors-10-orange.svg?style=flat-square)](#contributors-)
59
+ [![All Contributors](https://img.shields.io/badge/all_contributors-11-orange.svg?style=flat-square)](#contributors-)
60
60
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
61
61
 
62
62
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -135,6 +135,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
135
135
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/felix-e-h-p"><img src="https://avatars.githubusercontent.com/u/137530077?v=4?s=100" width="100px;" alt="Felix"/><br /><sub><b>Felix</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=felix-e-h-p" title="Code">💻</a></td>
136
136
  <td align="center" valign="top" width="14.28%"><a href="https://timothyajaniportfolio-b6v3zq29k-timthegreat.vercel.app/"><img src="https://avatars.githubusercontent.com/u/60073728?v=4?s=100" width="100px;" alt="Ajani Timothy"/><br /><sub><b>Ajani Timothy</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Tim1119" title="Code">💻</a></td>
137
137
  <td align="center" valign="top" width="14.28%"><a href="https://rupeshmangalam.vercel.app/"><img src="https://avatars.githubusercontent.com/u/91172425?v=4?s=100" width="100px;" alt="Rupesh Mangalam"/><br /><sub><b>Rupesh Mangalam</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=RupeshMangalam21" title="Code">💻</a></td>
138
+ <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
138
139
  </tr>
139
140
  </tbody>
140
141
  </table>
@@ -1,7 +1,7 @@
1
1
  # ocf-data-sampler
2
2
 
3
3
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
4
- [![All Contributors](https://img.shields.io/badge/all_contributors-10-orange.svg?style=flat-square)](#contributors-)
4
+ [![All Contributors](https://img.shields.io/badge/all_contributors-11-orange.svg?style=flat-square)](#contributors-)
5
5
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
6
6
 
7
7
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -80,6 +80,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
80
80
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/felix-e-h-p"><img src="https://avatars.githubusercontent.com/u/137530077?v=4?s=100" width="100px;" alt="Felix"/><br /><sub><b>Felix</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=felix-e-h-p" title="Code">💻</a></td>
81
81
  <td align="center" valign="top" width="14.28%"><a href="https://timothyajaniportfolio-b6v3zq29k-timthegreat.vercel.app/"><img src="https://avatars.githubusercontent.com/u/60073728?v=4?s=100" width="100px;" alt="Ajani Timothy"/><br /><sub><b>Ajani Timothy</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Tim1119" title="Code">💻</a></td>
82
82
  <td align="center" valign="top" width="14.28%"><a href="https://rupeshmangalam.vercel.app/"><img src="https://avatars.githubusercontent.com/u/91172425?v=4?s=100" width="100px;" alt="Rupesh Mangalam"/><br /><sub><b>Rupesh Mangalam</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=RupeshMangalam21" title="Code">💻</a></td>
83
+ <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
83
84
  </tr>
84
85
  </tbody>
85
86
  </table>
@@ -9,7 +9,6 @@ Example:
9
9
  """
10
10
 
11
11
  import json
12
-
13
12
  from pathlib import Path
14
13
  from typing import Union
15
14
 
@@ -18,7 +17,6 @@ import yaml
18
17
 
19
18
  from ocf_data_sampler.config import Configuration
20
19
 
21
-
22
20
  def save_yaml_configuration(
23
21
  configuration: Configuration,
24
22
  filename: Union[str, Path],
@@ -35,7 +33,7 @@ def save_yaml_configuration(
35
33
  Path: The path where the configuration was saved
36
34
 
37
35
  Raises:
38
- ValueError: If filename is None or if writing to the specified path fails
36
+ ValueError: If filename is None, directory doesn't exist, or if writing to the specified path fails
39
37
  TypeError: If the configuration cannot be serialized
40
38
  """
41
39
  if filename is None:
@@ -50,24 +48,37 @@ def save_yaml_configuration(
50
48
 
51
49
  filepath = Path(filename)
52
50
 
53
- # For local files, check if directory exists before proceeding
51
+ # For local paths, check if parent directory exists before attempting to create
54
52
  if filepath.is_absolute():
55
- directory = filepath.parent
56
- if not directory.exists():
53
+ if not filepath.parent.exists():
57
54
  raise ValueError("Directory does not exist")
55
+
56
+ # Only try to create directory if it's in a writable location
57
+ try:
58
+ filepath.parent.mkdir(parents=True, exist_ok=True)
59
+ except PermissionError:
60
+ raise ValueError(f"Permission denied when accessing directory {filepath.parent}")
58
61
 
59
62
  # Serialize configuration to JSON-compatible dictionary
60
63
  config_dict = json.loads(configuration.model_dump_json())
61
64
 
62
- # Save to YAML file using fsspec
63
- with fsspec.open(str(filepath), mode='w') as yaml_file:
64
- yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
65
+ # Write to file directly for local paths
66
+ if filepath.is_absolute():
67
+ try:
68
+ with open(filepath, 'w') as f:
69
+ yaml.safe_dump(config_dict, f, default_flow_style=False)
70
+ except PermissionError:
71
+ raise ValueError(f"Permission denied when writing to {filename}")
72
+ else:
73
+ # Use fsspec for cloud storage
74
+ with fsspec.open(str(filepath), mode='w') as yaml_file:
75
+ yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
65
76
 
66
77
  return filepath
67
78
 
68
79
  except json.JSONDecodeError as e:
69
80
  raise TypeError(f"Failed to serialize configuration: {str(e)}") from e
70
- except PermissionError as e:
71
- raise ValueError(f"Permission denied when writing to {filename}") from e
72
81
  except (IOError, OSError) as e:
82
+ if "Permission denied" in str(e):
83
+ raise ValueError(f"Permission denied when writing to {filename}") from e
73
84
  raise ValueError(f"Failed to write configuration to {filename}: {str(e)}") from e
@@ -1,5 +1,6 @@
1
1
  """Conversion from Xarray to NumpySample"""
2
2
 
3
+ from .datetime_features import make_datetime_numpy_dict
3
4
  from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
4
5
  from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
5
6
  from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
@@ -0,0 +1,46 @@
1
+ """Functions to create trigonometric date and time inputs"""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from numpy.typing import NDArray
6
+
7
+
8
+ def _get_date_time_in_pi(
9
+ dt: pd.DatetimeIndex,
10
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
11
+ """
12
+ Change the datetimes, into time and date scaled in radians
13
+ """
14
+
15
+ day_of_year = dt.dayofyear
16
+ minute_of_day = dt.minute + dt.hour * 60
17
+
18
+ # converting into positions on sin-cos circle
19
+ time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 60))
20
+ date_in_pi = (2 * np.pi) * (day_of_year / 365)
21
+
22
+ return date_in_pi, time_in_pi
23
+
24
+
25
+ def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
26
+ """ Make dictionary of datetime features"""
27
+
28
+ if datetimes.empty:
29
+ raise ValueError("Input datetimes is empty for 'make_datetime_numpy_dict' function")
30
+
31
+ time_numpy_sample = {}
32
+
33
+ date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
34
+
35
+ # Store
36
+ date_sin_batch_key = key_prefix + "_date_sin"
37
+ date_cos_batch_key = key_prefix + "_date_cos"
38
+ time_sin_batch_key = key_prefix + "_time_sin"
39
+ time_cos_batch_key = key_prefix + "_time_cos"
40
+
41
+ time_numpy_sample[date_sin_batch_key] = np.sin(date_in_pi)
42
+ time_numpy_sample[date_cos_batch_key] = np.cos(date_in_pi)
43
+ time_numpy_sample[time_sin_batch_key] = np.sin(time_in_pi)
44
+ time_numpy_sample[time_cos_batch_key] = np.cos(time_in_pi)
45
+
46
+ return time_numpy_sample
@@ -5,16 +5,114 @@ import pandas as pd
5
5
  import pkg_resources
6
6
  import xarray as xr
7
7
  from torch.utils.data import Dataset
8
-
9
8
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
10
9
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
10
  from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
12
11
  from ocf_data_sampler.utils import minutes
13
- from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
14
- from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
12
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
13
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
14
+ from ocf_data_sampler.numpy_sample import (
15
+ convert_nwp_to_numpy_sample,
16
+ convert_satellite_to_numpy_sample,
17
+ convert_gsp_to_numpy_sample,
18
+ make_sun_position_numpy_sample,
19
+ )
20
+ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
21
+ merge_dicts,
22
+ fill_nans_in_arrays,
23
+ )
24
+ from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
25
+ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
26
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
15
27
 
16
28
  xr.set_options(keep_attrs=True)
17
29
 
30
+ def process_and_combine_datasets(
31
+ dataset_dict: dict,
32
+ config: Configuration,
33
+ t0: pd.Timestamp,
34
+ location: Location,
35
+ target_key: str = 'gsp'
36
+ ) -> dict:
37
+
38
+ """Normalise and convert data to numpy arrays"""
39
+ numpy_modalities = []
40
+
41
+ if "nwp" in dataset_dict:
42
+
43
+ nwp_numpy_modalities = dict()
44
+
45
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
46
+ # Standardise
47
+ provider = config.input_data.nwp[nwp_key].provider
48
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
49
+
50
+ # Convert to NumpyBatch
51
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
52
+
53
+ # Combine the NWPs into NumpyBatch
54
+ numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
55
+
56
+
57
+ if "sat" in dataset_dict:
58
+ # Standardise
59
+ da_sat = dataset_dict["sat"]
60
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
61
+
62
+ # Convert to NumpyBatch
63
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
64
+
65
+ gsp_config = config.input_data.gsp
66
+
67
+ if "gsp" in dataset_dict:
68
+ da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
69
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
70
+
71
+ numpy_modalities.append(
72
+ convert_gsp_to_numpy_sample(
73
+ da_gsp,
74
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
75
+ )
76
+ )
77
+
78
+ # Add coordinate data
79
+ # TODO: Do we need all of these?
80
+ numpy_modalities.append(
81
+ {
82
+ GSPSampleKey.gsp_id: location.id,
83
+ GSPSampleKey.x_osgb: location.x,
84
+ GSPSampleKey.y_osgb: location.y,
85
+ }
86
+ )
87
+
88
+ if target_key == 'gsp':
89
+ # Make sun coords NumpySample
90
+ datetimes = pd.date_range(
91
+ t0+minutes(gsp_config.interval_start_minutes),
92
+ t0+minutes(gsp_config.interval_end_minutes),
93
+ freq=minutes(gsp_config.time_resolution_minutes),
94
+ )
95
+
96
+ lon, lat = osgb_to_lon_lat(location.x, location.y)
97
+
98
+ numpy_modalities.append(
99
+ make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
100
+ )
101
+
102
+ # Combine all the modalities and fill NaNs
103
+ combined_sample = merge_dicts(numpy_modalities)
104
+ combined_sample = fill_nans_in_arrays(combined_sample)
105
+
106
+ return combined_sample
107
+
108
+ def compute(xarray_dict: dict) -> dict:
109
+ """Eagerly load a nested dictionary of xarray DataArrays"""
110
+ for k, v in xarray_dict.items():
111
+ if isinstance(v, dict):
112
+ xarray_dict[k] = compute(v)
113
+ else:
114
+ xarray_dict[k] = v.compute(scheduler="single-threaded")
115
+ return xarray_dict
18
116
 
19
117
  def find_valid_t0_times(
20
118
  datasets_dict: dict,
@@ -48,7 +146,7 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
48
146
 
49
147
  # Load UK GSP locations
50
148
  df_gsp_loc = pd.read_csv(
51
- pkg_resources.resource_filename(__name__, "../data/uk_gsp_locations.csv"),
149
+ pkg_resources.resource_filename(__name__, "../../data/uk_gsp_locations.csv"),
52
150
  index_col="gsp_id",
53
151
  )
54
152
 
@@ -17,12 +17,14 @@ from ocf_data_sampler.select import (
17
17
  slice_datasets_by_time, slice_datasets_by_space
18
18
  )
19
19
  from ocf_data_sampler.utils import minutes
20
- from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
21
- from ocf_data_sampler.torch_datasets.process_and_combine import merge_dicts, fill_nans_in_arrays
20
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
21
+ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
22
22
  from ocf_data_sampler.numpy_sample import (
23
23
  convert_site_to_numpy_sample,
24
24
  convert_satellite_to_numpy_sample,
25
- convert_nwp_to_numpy_sample
25
+ convert_nwp_to_numpy_sample,
26
+ make_datetime_numpy_dict,
27
+ make_sun_position_numpy_sample,
26
28
  )
27
29
  from ocf_data_sampler.numpy_sample import NWPSampleKey
28
30
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
@@ -234,10 +236,26 @@ class SitesDataset(Dataset):
234
236
  da_sites = dataset_dict["site"]
235
237
  da_sites = da_sites / da_sites.capacity_kwp
236
238
  data_arrays.append(("site", da_sites))
237
-
239
+
238
240
  combined_sample_dataset = self.merge_data_arrays(data_arrays)
239
241
 
240
- # TODO add solar + time features for sites
242
+ # add datetime features
243
+ datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
244
+ datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site")
245
+ datetime_features_xr = xr.Dataset(datetime_features, coords={"site__time_utc": datetimes})
246
+ combined_sample_dataset = xr.merge([combined_sample_dataset, datetime_features_xr])
247
+
248
+ # add sun features
249
+ sun_position_features = make_sun_position_numpy_sample(
250
+ datetimes=datetimes,
251
+ lon=combined_sample_dataset.site__longitude.values,
252
+ lat=combined_sample_dataset.site__latitude.values,
253
+ key_prefix="site",
254
+ )
255
+ sun_position_features_xr = xr.Dataset(
256
+ sun_position_features, coords={"site__time_utc": datetimes}
257
+ )
258
+ combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr])
241
259
 
242
260
  # Fill any nan values
243
261
  return combined_sample_dataset.fillna(0.0)
@@ -0,0 +1,25 @@
1
+ import numpy as np
2
+
3
+ def merge_dicts(list_of_dicts: list[dict]) -> dict:
4
+ """Merge a list of dictionaries into a single dictionary"""
5
+ # TODO: This doesn't account for duplicate keys, which will be overwritten
6
+ combined_dict = {}
7
+ for d in list_of_dicts:
8
+ combined_dict.update(d)
9
+ return combined_dict
10
+
11
+ def fill_nans_in_arrays(sample: dict) -> dict:
12
+ """Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
13
+
14
+ Operation is performed in-place on the sample.
15
+ """
16
+ for k, v in sample.items():
17
+ if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
18
+ if np.isnan(v).any():
19
+ sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
20
+
21
+ # Recursion is included to reach NWP arrays in subdict
22
+ elif isinstance(v, dict):
23
+ fill_nans_in_arrays(v)
24
+
25
+ return sample
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.44
3
+ Version: 0.0.46
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -56,7 +56,7 @@ Requires-Dist: mkdocs-material>=8.0; extra == "docs"
56
56
  # ocf-data-sampler
57
57
 
58
58
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
59
- [![All Contributors](https://img.shields.io/badge/all_contributors-10-orange.svg?style=flat-square)](#contributors-)
59
+ [![All Contributors](https://img.shields.io/badge/all_contributors-11-orange.svg?style=flat-square)](#contributors-)
60
60
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
61
61
 
62
62
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -135,6 +135,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
135
135
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/felix-e-h-p"><img src="https://avatars.githubusercontent.com/u/137530077?v=4?s=100" width="100px;" alt="Felix"/><br /><sub><b>Felix</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=felix-e-h-p" title="Code">💻</a></td>
136
136
  <td align="center" valign="top" width="14.28%"><a href="https://timothyajaniportfolio-b6v3zq29k-timthegreat.vercel.app/"><img src="https://avatars.githubusercontent.com/u/60073728?v=4?s=100" width="100px;" alt="Ajani Timothy"/><br /><sub><b>Ajani Timothy</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Tim1119" title="Code">💻</a></td>
137
137
  <td align="center" valign="top" width="14.28%"><a href="https://rupeshmangalam.vercel.app/"><img src="https://avatars.githubusercontent.com/u/91172425?v=4?s=100" width="100px;" alt="Rupesh Mangalam"/><br /><sub><b>Rupesh Mangalam</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=RupeshMangalam21" title="Code">💻</a></td>
138
+ <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
138
139
  </tr>
139
140
  </tbody>
140
141
  </table>
@@ -29,6 +29,7 @@ ocf_data_sampler/load/nwp/providers/ukv.py
29
29
  ocf_data_sampler/load/nwp/providers/utils.py
30
30
  ocf_data_sampler/numpy_sample/__init__.py
31
31
  ocf_data_sampler/numpy_sample/collate.py
32
+ ocf_data_sampler/numpy_sample/datetime_features.py
32
33
  ocf_data_sampler/numpy_sample/gsp.py
33
34
  ocf_data_sampler/numpy_sample/nwp.py
34
35
  ocf_data_sampler/numpy_sample/satellite.py
@@ -44,11 +45,11 @@ ocf_data_sampler/select/select_spatial_slice.py
44
45
  ocf_data_sampler/select/select_time_slice.py
45
46
  ocf_data_sampler/select/spatial_slice_for_dataset.py
46
47
  ocf_data_sampler/select/time_slice_for_dataset.py
47
- ocf_data_sampler/torch_datasets/__init__.py
48
- ocf_data_sampler/torch_datasets/process_and_combine.py
49
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
50
- ocf_data_sampler/torch_datasets/site.py
51
- ocf_data_sampler/torch_datasets/valid_time_periods.py
48
+ ocf_data_sampler/torch_datasets/datasets/__init__.py
49
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py
50
+ ocf_data_sampler/torch_datasets/datasets/site.py
51
+ ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
52
+ ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
52
53
  scripts/refactor_site.py
53
54
  tests/__init__.py
54
55
  tests/conftest.py
@@ -59,6 +60,7 @@ tests/load/test_load_nwp.py
59
60
  tests/load/test_load_satellite.py
60
61
  tests/load/test_load_sites.py
61
62
  tests/numpy_sample/test_collate.py
63
+ tests/numpy_sample/test_datetime_features.py
62
64
  tests/numpy_sample/test_gsp.py
63
65
  tests/numpy_sample/test_nwp.py
64
66
  tests/numpy_sample/test_satellite.py
@@ -70,6 +72,6 @@ tests/select/test_location.py
70
72
  tests/select/test_select_spatial_slice.py
71
73
  tests/select/test_select_time_slice.py
72
74
  tests/torch_datasets/conftest.py
73
- tests/torch_datasets/test_process_and_combine.py
75
+ tests/torch_datasets/test_merge_and_fill_utils.py
74
76
  tests/torch_datasets/test_pvnet_uk_regional.py
75
77
  tests/torch_datasets/test_site.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ocf_data_sampler"
7
- version = "0.0.44"
7
+ version = "0.0.46"
8
8
  license = { file = "LICENSE" }
9
9
  readme = "README.md"
10
10
  description = "Sample from weather data for renewable energy prediction"
@@ -2,7 +2,7 @@ import tempfile
2
2
 
3
3
  import pytest
4
4
  from pydantic import ValidationError
5
-
5
+ from pathlib import Path
6
6
  from ocf_data_sampler.config import (
7
7
  load_yaml_configuration,
8
8
  Configuration,
@@ -21,39 +21,37 @@ def test_load_yaml_configuration(test_config_filename):
21
21
  Test that yaml loading works for 'test_config.yaml'
22
22
  and fails for an empty .yaml file
23
23
  """
24
-
25
- # check we get an error if loading a file with no config
26
- with tempfile.NamedTemporaryFile(suffix=".yaml") as fp:
27
- filename = fp.name
28
-
29
- # check that temp file can't be loaded
24
+ # Create temporary directory instead of file
25
+ with tempfile.TemporaryDirectory() as temp_dir:
26
+ # Create path for empty file
27
+ empty_file = Path(temp_dir) / "empty.yaml"
28
+
29
+ # Create an empty file
30
+ empty_file.touch()
31
+
32
+ # Test loading empty file
30
33
  with pytest.raises(TypeError):
31
- _ = load_yaml_configuration(filename)
32
-
33
- # test can load test_config.yaml
34
- config = load_yaml_configuration(test_config_filename)
35
-
36
- assert isinstance(config, Configuration)
37
-
34
+ _ = load_yaml_configuration(str(empty_file))
38
35
 
39
36
  def test_yaml_save(test_config_filename):
40
37
  """
41
38
  Check configuration can be saved to a .yaml file
42
39
  """
43
-
44
40
  test_config = load_yaml_configuration(test_config_filename)
45
-
46
- with tempfile.NamedTemporaryFile(suffix=".yaml") as fp:
47
- filename = fp.name
48
-
49
- # save default config to file
50
- save_yaml_configuration(test_config, filename)
51
-
52
- # check the file can be loaded back
53
- tmp_config = load_yaml_configuration(filename)
54
-
55
- # check loaded configuration is the same as the one passed to save
56
- assert test_config == tmp_config
41
+
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ # Create path for config file
44
+ config_path = Path(temp_dir) / "test_config.yaml"
45
+
46
+ # Save configuration
47
+ saved_path = save_yaml_configuration(test_config, config_path)
48
+
49
+ # Verify file exists
50
+ assert saved_path.exists()
51
+
52
+ # Test loading saved configuration
53
+ loaded_config = load_yaml_configuration(str(saved_path))
54
+ assert loaded_config == test_config
57
55
 
58
56
 
59
57
  def test_extra_field_error():
@@ -1,10 +1,10 @@
1
1
  import os
2
-
3
2
  import numpy as np
4
3
  import pandas as pd
5
4
  import pytest
6
5
  import xarray as xr
7
6
  import tempfile
7
+ from typing import Generator
8
8
 
9
9
  from ocf_data_sampler.config.model import Site
10
10
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
@@ -201,7 +201,7 @@ def ds_uk_gsp():
201
201
 
202
202
 
203
203
  @pytest.fixture(scope="session")
204
- def data_sites() -> Site:
204
+ def data_sites() -> Generator[Site, None, None]:
205
205
  """
206
206
  Make fake data for sites
207
207
  Returns: filename for netcdf file, and csv metadata
@@ -1,6 +1,6 @@
1
1
  from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
2
2
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
3
- from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
3
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
4
4
 
5
5
 
6
6
  def test_pvnet(pvnet_config_filename):
@@ -0,0 +1,47 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pytest
4
+
5
+ from ocf_data_sampler.numpy_sample.datetime_features import make_datetime_numpy_dict
6
+
7
+
8
+ def test_calculate_azimuth_and_elevation():
9
+
10
+ # Pick the day of the summer solstice
11
+ datetimes = pd.to_datetime(["2024-06-20 12:00", "2024-06-20 12:30", "2024-06-20 13:00"])
12
+
13
+ # Calculate sun angles
14
+ datetime_features = make_datetime_numpy_dict(datetimes)
15
+
16
+ assert len(datetime_features) == 4
17
+
18
+ assert len(datetime_features["wind_date_sin"]) == len(datetimes)
19
+ assert (datetime_features["wind_date_cos"] != datetime_features["wind_date_sin"]).all()
20
+
21
+ # assert all values are between -1 and 1
22
+ assert all(np.abs(datetime_features["wind_date_sin"]) <= 1)
23
+ assert all(np.abs(datetime_features["wind_date_cos"]) <= 1)
24
+ assert all(np.abs(datetime_features["wind_time_sin"]) <= 1)
25
+ assert all(np.abs(datetime_features["wind_time_cos"]) <= 1)
26
+
27
+
28
+ def test_make_datetime_numpy_batch_custom_key_prefix():
29
+ # Test function correctly applies custom prefix to dict keys
30
+ datetimes = pd.to_datetime(["2024-06-20 12:00", "2024-06-20 12:30", "2024-06-20 13:00"])
31
+ key_prefix = "solar"
32
+
33
+ datetime_features = make_datetime_numpy_dict(datetimes, key_prefix=key_prefix)
34
+
35
+ # Assert dict contains expected quantity of keys and verify starting with custom prefix
36
+ assert len(datetime_features) == 4
37
+ assert all(key.startswith(key_prefix) for key in datetime_features.keys())
38
+
39
+
40
+ def test_make_datetime_numpy_batch_empty_input():
41
+ # Verification that function raises error for empty input
42
+ datetimes = pd.DatetimeIndex([])
43
+
44
+ with pytest.raises(
45
+ ValueError, match="Input datetimes is empty for 'make_datetime_numpy_dict' function"
46
+ ):
47
+ make_datetime_numpy_dict(datetimes)
@@ -0,0 +1,42 @@
1
+ import numpy as np
2
+
3
+ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
4
+ merge_dicts,
5
+ fill_nans_in_arrays,
6
+ )
7
+
8
+ def test_merge_dicts():
9
+ """Test merge_dicts function"""
10
+ dict1 = {"a": 1, "b": 2}
11
+ dict2 = {"c": 3, "d": 4}
12
+ dict3 = {"e": 5}
13
+
14
+ result = merge_dicts([dict1, dict2, dict3])
15
+ assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
16
+
17
+ # Test key overwriting
18
+ dict4 = {"a": 10, "f": 6}
19
+ result = merge_dicts([dict1, dict4])
20
+ assert result["a"] == 10
21
+
22
+
23
+ def test_fill_nans_in_arrays():
24
+ """Test the fill_nans_in_arrays function"""
25
+ array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
26
+ nested_dict = {
27
+ "array1": array_with_nans,
28
+ "nested": {
29
+ "array2": np.array([np.nan, 2.0, np.nan, 4.0])
30
+ },
31
+ "string_key": "not_an_array"
32
+ }
33
+
34
+ result = fill_nans_in_arrays(nested_dict)
35
+
36
+ assert not np.isnan(result["array1"]).any()
37
+ assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
38
+ assert not np.isnan(result["nested"]["array2"]).any()
39
+ assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
40
+ assert result["string_key"] == "not_an_array"
41
+
42
+
@@ -2,19 +2,14 @@ import numpy as np
2
2
  import pandas as pd
3
3
  import xarray as xr
4
4
  import dask.array as da
5
+ import tempfile
5
6
 
6
- from ocf_data_sampler.config import load_yaml_configuration
7
- from ocf_data_sampler.select.location import Location
7
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
8
+ from ocf_data_sampler.config.save import save_yaml_configuration
9
+ from ocf_data_sampler.config.load import load_yaml_configuration
8
10
  from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
9
- from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
10
-
11
- from ocf_data_sampler.torch_datasets.process_and_combine import (
12
- process_and_combine_datasets,
13
- merge_dicts,
14
- fill_nans_in_arrays,
15
- compute,
16
- )
17
-
11
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import process_and_combine_datasets, compute
12
+ from ocf_data_sampler.select.location import Location
18
13
 
19
14
  def test_process_and_combine_datasets(pvnet_config_filename):
20
15
 
@@ -60,42 +55,6 @@ def test_process_and_combine_datasets(pvnet_config_filename):
60
55
  assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
61
56
  assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
62
57
 
63
-
64
- def test_merge_dicts():
65
- """Test merge_dicts function"""
66
- dict1 = {"a": 1, "b": 2}
67
- dict2 = {"c": 3, "d": 4}
68
- dict3 = {"e": 5}
69
-
70
- result = merge_dicts([dict1, dict2, dict3])
71
- assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
72
-
73
- # Test key overwriting
74
- dict4 = {"a": 10, "f": 6}
75
- result = merge_dicts([dict1, dict4])
76
- assert result["a"] == 10
77
-
78
-
79
- def test_fill_nans_in_arrays():
80
- """Test the fill_nans_in_arrays function"""
81
- array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
82
- nested_dict = {
83
- "array1": array_with_nans,
84
- "nested": {
85
- "array2": np.array([np.nan, 2.0, np.nan, 4.0])
86
- },
87
- "string_key": "not_an_array"
88
- }
89
-
90
- result = fill_nans_in_arrays(nested_dict)
91
-
92
- assert not np.isnan(result["array1"]).any()
93
- assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
94
- assert not np.isnan(result["nested"]["array2"]).any()
95
- assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
96
- assert result["string_key"] == "not_an_array"
97
-
98
-
99
58
  def test_compute():
100
59
  """Test compute function with dask array"""
101
60
  da_dask = xr.DataArray(da.random.random((5, 5)))
@@ -124,3 +83,54 @@ def test_compute():
124
83
  # Ensure there no NaN values in computed data
125
84
  assert not np.isnan(result["array1"].data).any()
126
85
  assert not np.isnan(result["nested"]["array2"].data).any()
86
+
87
+ def test_pvnet(pvnet_config_filename):
88
+
89
+ # Create dataset object
90
+ dataset = PVNetUKRegionalDataset(pvnet_config_filename)
91
+
92
+ assert len(dataset.locations) == 317 # no of GSPs not including the National level
93
+ # NB. I have not checked this value is in fact correct, but it does seem to stay constant
94
+ assert len(dataset.valid_t0_times) == 39
95
+ assert len(dataset) == 317*39
96
+
97
+ # Generate a sample
98
+ sample = dataset[0]
99
+
100
+ assert isinstance(sample, dict)
101
+
102
+ for key in [
103
+ NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
104
+ GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
105
+ ]:
106
+ assert key in sample
107
+
108
+ for nwp_source in ["ukv"]:
109
+ assert nwp_source in sample[NWPSampleKey.nwp]
110
+
111
+ # check the shape of the data is correct
112
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
113
+ assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
114
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
115
+ assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
116
+ # 3 hours of 30 minute data (inclusive)
117
+ assert sample[GSPSampleKey.gsp].shape == (7,)
118
+ # Solar angles have same shape as GSP data
119
+ assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
120
+ assert sample[GSPSampleKey.solar_elevation].shape == (7,)
121
+
122
+ def test_pvnet_no_gsp(pvnet_config_filename):
123
+
124
+ # load config
125
+ config = load_yaml_configuration(pvnet_config_filename)
126
+ # remove gsp
127
+ config.input_data.gsp.zarr_path = ''
128
+
129
+ # save temp config file
130
+ with tempfile.NamedTemporaryFile() as temp_config_file:
131
+ save_yaml_configuration(config, temp_config_file.name)
132
+ # Create dataset object
133
+ dataset = PVNetUKRegionalDataset(temp_config_file.name)
134
+
135
+ # Generate a sample
136
+ _ = dataset[0]
@@ -1,8 +1,6 @@
1
1
  import pandas as pd
2
-
3
- from ocf_data_sampler.torch_datasets import SitesDataset
4
- from ocf_data_sampler.torch_datasets.site import convert_from_dataset_to_dict_datasets
5
2
  import numpy as np
3
+ from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
6
4
  from xarray import Dataset, DataArray
7
5
 
8
6
 
@@ -22,7 +20,9 @@ def test_site(site_config_filename):
22
20
  # Expected dimensions and data variables
23
21
  expected_dims = {'satellite__x_geostationary', 'site__time_utc', 'nwp-ukv__target_time_utc',
24
22
  'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
25
- 'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'}
23
+ 'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb', 'site_solar_azimuth',
24
+ 'site_solar_elevation', 'site_date_cos', 'site_time_cos', 'site_time_sin', 'site_date_sin'}
25
+
26
26
  expected_data_vars = {"nwp-ukv", "satellite", "site"}
27
27
 
28
28
  # Check dimensions
@@ -1,131 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import xarray as xr
4
- from typing import Optional
5
-
6
- from ocf_data_sampler.config import Configuration
7
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS,RSS_MEAN,RSS_STD
8
- from ocf_data_sampler.numpy_sample import (
9
- convert_nwp_to_numpy_sample,
10
- convert_satellite_to_numpy_sample,
11
- convert_gsp_to_numpy_sample,
12
- make_sun_position_numpy_sample,
13
- )
14
- from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
15
- from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
16
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
17
- from ocf_data_sampler.select.location import Location
18
- from ocf_data_sampler.utils import minutes
19
-
20
-
21
- def process_and_combine_datasets(
22
- dataset_dict: dict,
23
- config: Configuration,
24
- t0: Optional[pd.Timestamp] = None,
25
- location: Optional[Location] = None,
26
- target_key: str = 'gsp'
27
- ) -> dict:
28
-
29
- """Normalise and convert data to numpy arrays"""
30
- numpy_modalities = []
31
-
32
- if "nwp" in dataset_dict:
33
-
34
- nwp_numpy_modalities = dict()
35
-
36
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
37
- # Standardise
38
- provider = config.input_data.nwp[nwp_key].provider
39
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
40
- # Convert to NumpySample
41
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
42
-
43
- # Combine the NWPs into NumpySample
44
- numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
45
-
46
-
47
- if "sat" in dataset_dict:
48
- # Standardise
49
- da_sat = dataset_dict["sat"]
50
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
51
-
52
- # Convert to NumpySample
53
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
54
-
55
-
56
- gsp_config = config.input_data.gsp
57
-
58
- if "gsp" in dataset_dict:
59
- da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
60
- da_gsp = da_gsp / da_gsp.effective_capacity_mwp
61
-
62
- numpy_modalities.append(
63
- convert_gsp_to_numpy_sample(
64
- da_gsp,
65
- t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
66
- )
67
- )
68
-
69
- # Add coordinate data
70
- # TODO: Do we need all of these?
71
- numpy_modalities.append(
72
- {
73
- GSPSampleKey.gsp_id: location.id,
74
- GSPSampleKey.x_osgb: location.x,
75
- GSPSampleKey.y_osgb: location.y,
76
- }
77
- )
78
-
79
- if target_key == 'gsp':
80
- # Make sun coords NumpySample
81
- datetimes = pd.date_range(
82
- t0+minutes(gsp_config.interval_start_minutes),
83
- t0+minutes(gsp_config.interval_end_minutes),
84
- freq=minutes(gsp_config.time_resolution_minutes),
85
- )
86
-
87
- lon, lat = osgb_to_lon_lat(location.x, location.y)
88
-
89
- numpy_modalities.append(
90
- make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
91
- )
92
-
93
- # Combine all the modalities and fill NaNs
94
- combined_sample = merge_dicts(numpy_modalities)
95
- combined_sample = fill_nans_in_arrays(combined_sample)
96
-
97
- return combined_sample
98
-
99
- def merge_dicts(list_of_dicts: list[dict]) -> dict:
100
- """Merge a list of dictionaries into a single dictionary"""
101
- # TODO: This doesn't account for duplicate keys, which will be overwritten
102
- combined_dict = {}
103
- for d in list_of_dicts:
104
- combined_dict.update(d)
105
- return combined_dict
106
-
107
- def fill_nans_in_arrays(sample: dict) -> dict:
108
- """Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
109
-
110
- Operation is performed in-place on the sample.
111
- """
112
- for k, v in sample.items():
113
- if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
114
- if np.isnan(v).any():
115
- sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
116
-
117
- # Recursion is included to reach NWP arrays in subdict
118
- elif isinstance(v, dict):
119
- fill_nans_in_arrays(v)
120
-
121
- return sample
122
-
123
-
124
- def compute(xarray_dict: dict) -> dict:
125
- """Eagerly load a nested dictionary of xarray DataArrays"""
126
- for k, v in xarray_dict.items():
127
- if isinstance(v, dict):
128
- xarray_dict[k] = compute(v)
129
- else:
130
- xarray_dict[k] = v.compute(scheduler="single-threaded")
131
- return xarray_dict
@@ -1,59 +0,0 @@
1
- import pytest
2
- import tempfile
3
-
4
- from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
5
- from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
- from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
7
-
8
-
9
-
10
- def test_pvnet(pvnet_config_filename):
11
-
12
- # Create dataset object
13
- dataset = PVNetUKRegionalDataset(pvnet_config_filename)
14
-
15
- assert len(dataset.locations) == 317 # no of GSPs not including the National level
16
- # NB. I have not checked this value is in fact correct, but it does seem to stay constant
17
- assert len(dataset.valid_t0_times) == 39
18
- assert len(dataset) == 317*39
19
-
20
- # Generate a sample
21
- sample = dataset[0]
22
-
23
- assert isinstance(sample, dict)
24
-
25
- for key in [
26
- NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
27
- GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
28
- ]:
29
- assert key in sample
30
-
31
- for nwp_source in ["ukv"]:
32
- assert nwp_source in sample[NWPSampleKey.nwp]
33
-
34
- # check the shape of the data is correct
35
- # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
36
- assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
37
- # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
38
- assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
39
- # 3 hours of 30 minute data (inclusive)
40
- assert sample[GSPSampleKey.gsp].shape == (7,)
41
- # Solar angles have same shape as GSP data
42
- assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
43
- assert sample[GSPSampleKey.solar_elevation].shape == (7,)
44
-
45
- def test_pvnet_no_gsp(pvnet_config_filename):
46
-
47
- # load config
48
- config = load_yaml_configuration(pvnet_config_filename)
49
- # remove gsp
50
- config.input_data.gsp.zarr_path = ''
51
-
52
- # save temp config file
53
- with tempfile.NamedTemporaryFile() as temp_config_file:
54
- save_yaml_configuration(config, temp_config_file.name)
55
- # Create dataset object
56
- dataset = PVNetUKRegionalDataset(temp_config_file.name)
57
-
58
- # Generate a sample
59
- _ = dataset[0]