ocf-data-sampler 0.3.1__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (71) hide show
  1. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/PKG-INFO +3 -2
  2. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/README.md +2 -1
  3. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/__init__.py +2 -3
  4. ocf_data_sampler-0.5.0/ocf_data_sampler/numpy_sample/datetime_features.py +29 -0
  5. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/site.py +1 -8
  6. ocf_data_sampler-0.5.0/ocf_data_sampler/torch_datasets/datasets/__init__.py +2 -0
  7. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/datasets/site.py +95 -366
  8. ocf_data_sampler-0.5.0/ocf_data_sampler/torch_datasets/sample/site.py +48 -0
  9. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/PKG-INFO +3 -2
  10. ocf_data_sampler-0.3.1/ocf_data_sampler/numpy_sample/datetime_features.py +0 -38
  11. ocf_data_sampler-0.3.1/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -6
  12. ocf_data_sampler-0.3.1/ocf_data_sampler/torch_datasets/sample/site.py +0 -39
  13. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/LICENSE +0 -0
  14. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/__init__.py +0 -0
  15. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/__init__.py +0 -0
  16. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/load.py +0 -0
  17. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/model.py +0 -0
  18. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/config/save.py +0 -0
  19. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
  20. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
  21. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/__init__.py +0 -0
  22. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/gsp.py +0 -0
  23. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/load_dataset.py +0 -0
  24. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  25. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  26. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  27. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
  28. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  29. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
  30. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
  31. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  32. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  33. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/open_tensorstore_zarrs.py +0 -0
  34. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/satellite.py +0 -0
  35. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/site.py +0 -0
  36. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/load/utils.py +0 -0
  37. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  38. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
  39. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  40. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  41. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  42. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  43. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/__init__.py +0 -0
  44. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/dropout.py +0 -0
  45. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  46. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  47. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/geospatial.py +0 -0
  48. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/location.py +0 -0
  49. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  50. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/select/select_time_slice.py +0 -0
  51. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +0 -0
  52. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -0
  53. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/base.py +0 -0
  54. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -0
  55. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/__init__.py +0 -0
  56. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -0
  57. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  58. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
  59. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +0 -0
  60. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  61. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
  62. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler/utils.py +0 -0
  63. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
  64. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  65. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/requires.txt +0 -0
  66. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  67. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/pyproject.toml +0 -0
  68. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/scripts/download_gsp_location_data.py +0 -0
  69. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/scripts/refactor_site.py +0 -0
  70. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/setup.cfg +0 -0
  71. {ocf_data_sampler-0.3.1 → ocf_data_sampler-0.5.0}/utils/compute_icon_mean_stddev.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.3.1
3
+ Version: 0.5.0
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -49,7 +49,7 @@ Requires-Dist: xarray-tensorstore==0.1.5
49
49
  # ocf-data-sampler
50
50
 
51
51
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
52
- [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors-)
52
+ [![All Contributors](https://img.shields.io/badge/all_contributors-14-orange.svg?style=flat-square)](#contributors-)
53
53
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
54
54
 
55
55
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -128,6 +128,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
128
128
  <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
129
129
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/Sachin-G13"><img src="https://avatars.githubusercontent.com/u/190184500?v=4?s=100" width="100px;" alt="Sachin-G13"/><br /><sub><b>Sachin-G13</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Sachin-G13" title="Code">💻</a></td>
130
130
  <td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
131
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
131
132
  </tr>
132
133
  </tbody>
133
134
  </table>
@@ -1,7 +1,7 @@
1
1
  # ocf-data-sampler
2
2
 
3
3
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
4
- [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors-)
4
+ [![All Contributors](https://img.shields.io/badge/all_contributors-14-orange.svg?style=flat-square)](#contributors-)
5
5
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
6
6
 
7
7
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -80,6 +80,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
80
80
  <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
81
81
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/Sachin-G13"><img src="https://avatars.githubusercontent.com/u/190184500?v=4?s=100" width="100px;" alt="Sachin-G13"/><br /><sub><b>Sachin-G13</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Sachin-G13" title="Code">💻</a></td>
82
82
  <td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
83
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
83
84
  </tr>
84
85
  </tbody>
85
86
  </table>
@@ -1,9 +1,8 @@
1
1
  """Conversion from Xarray to NumpySample"""
2
2
 
3
- from .datetime_features import make_datetime_numpy_dict
3
+ from .datetime_features import encode_datetimes
4
4
  from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
5
5
  from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
6
6
  from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
7
7
  from .sun_position import make_sun_position_numpy_sample
8
- from .site import convert_site_to_numpy_sample
9
-
8
+ from .site import convert_site_to_numpy_sample, SiteSampleKey
@@ -0,0 +1,29 @@
1
+ """Functions to create trigonometric date and time inputs."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
+
8
+
9
+ def encode_datetimes(datetimes: pd.DatetimeIndex) -> NumpySample:
10
+ """Creates dictionary of sin and cos datetime embeddings.
11
+
12
+ Args:
13
+ datetimes: DatetimeIndex to create radian embeddings for
14
+
15
+ Returns:
16
+ Dictionary of datetime encodings
17
+ """
18
+ day_of_year = datetimes.dayofyear
19
+ minute_of_day = datetimes.minute + datetimes.hour * 60
20
+
21
+ time_in_radians = (2 * np.pi) * (minute_of_day / (24 * 60))
22
+ date_in_radians = (2 * np.pi) * (day_of_year / 365)
23
+
24
+ return {
25
+ "date_sin": np.sin(date_in_radians),
26
+ "date_cos": np.cos(date_in_radians),
27
+ "time_sin": np.sin(time_in_radians),
28
+ "time_cos": np.cos(time_in_radians),
29
+ }
@@ -13,10 +13,7 @@ class SiteSampleKey:
13
13
  time_utc = "site_time_utc"
14
14
  t0_idx = "site_t0_idx"
15
15
  id = "site_id"
16
- date_sin = "site_date_sin"
17
- date_cos = "site_date_cos"
18
- time_sin = "site_time_sin"
19
- time_cos = "site_time_cos"
16
+
20
17
 
21
18
 
22
19
  def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> NumpySample:
@@ -31,10 +28,6 @@ def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
31
28
  SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
32
29
  SiteSampleKey.time_utc: da["time_utc"].values.astype(float),
33
30
  SiteSampleKey.id: da["site_id"].values,
34
- SiteSampleKey.date_sin: da["date_sin"].values,
35
- SiteSampleKey.date_cos: da["date_cos"].values,
36
- SiteSampleKey.time_sin: da["time_sin"].values,
37
- SiteSampleKey.time_cos: da["time_cos"].values,
38
31
  }
39
32
 
40
33
  if t0_idx is not None:
@@ -0,0 +1,2 @@
1
+ from .pvnet_uk import PVNetUKRegionalDataset, PVNetUKConcurrentDataset
2
+ from .site import SitesDataset
@@ -13,7 +13,7 @@ from ocf_data_sampler.numpy_sample import (
13
13
  convert_nwp_to_numpy_sample,
14
14
  convert_satellite_to_numpy_sample,
15
15
  convert_site_to_numpy_sample,
16
- make_datetime_numpy_dict,
16
+ encode_datetimes,
17
17
  make_sun_position_numpy_sample,
18
18
  )
19
19
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
@@ -58,6 +58,96 @@ def get_locations(site_xr: xr.Dataset) -> list[Location]:
58
58
 
59
59
  return locations
60
60
 
61
+ def process_and_combine_datasets(
62
+ dataset_dict: dict,
63
+ config: Configuration,
64
+ t0: pd.Timestamp,
65
+ ) -> NumpySample:
66
+ """Normalise and convert data to numpy arrays.
67
+
68
+ Args:
69
+ dataset_dict: Dictionary of xarray datasets
70
+ config: Configuration object
71
+ t0: init-time for sample
72
+ """
73
+ numpy_modalities = []
74
+
75
+ if "nwp" in dataset_dict:
76
+ nwp_numpy_modalities = {}
77
+
78
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
79
+
80
+ # Standardise and convert to NumpyBatch
81
+
82
+ da_channel_means = channel_dict_to_dataarray(
83
+ config.input_data.nwp[nwp_key].channel_means,
84
+ )
85
+ da_channel_stds = channel_dict_to_dataarray(
86
+ config.input_data.nwp[nwp_key].channel_stds,
87
+ )
88
+
89
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
90
+
91
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
92
+
93
+ # Combine the NWPs into NumpyBatch
94
+ numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
95
+
96
+ if "sat" in dataset_dict:
97
+ da_sat = dataset_dict["sat"]
98
+
99
+ # Standardise and convert to NumpyBatch
100
+ da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
101
+ da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
102
+
103
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
104
+
105
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
106
+
107
+ if "site" in dataset_dict:
108
+ da_sites = dataset_dict["site"]
109
+ da_sites = da_sites / da_sites.capacity_kwp
110
+
111
+ # Convert to NumpyBatch
112
+ numpy_modalities.append(
113
+ convert_site_to_numpy_sample(
114
+ da_sites,
115
+ ),
116
+ )
117
+
118
+ # add datetime features
119
+ datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
120
+ datetime_features = encode_datetimes(datetimes=datetimes)
121
+
122
+ numpy_modalities.append(datetime_features)
123
+
124
+ # Only add solar position if explicitly configured
125
+ if config.input_data.solar_position is not None:
126
+ solar_config = config.input_data.solar_position
127
+
128
+ # Create datetime range for solar position calculation
129
+ datetimes = pd.date_range(
130
+ t0 + minutes(solar_config.interval_start_minutes),
131
+ t0 + minutes(solar_config.interval_end_minutes),
132
+ freq=minutes(solar_config.time_resolution_minutes),
133
+ )
134
+
135
+
136
+ # Calculate solar positions and add to modalities
137
+ numpy_modalities.append(
138
+ make_sun_position_numpy_sample(
139
+ datetimes,
140
+ da_sites.longitude.values,
141
+ da_sites.latitude.values,
142
+ ),
143
+ )
144
+
145
+ # Combine all the modalities and fill NaNs
146
+ combined_sample = merge_dicts(numpy_modalities)
147
+ combined_sample = fill_nans_in_arrays(combined_sample)
148
+
149
+ return combined_sample
150
+
61
151
 
62
152
  class SitesDataset(Dataset):
63
153
  """A torch Dataset for creating PVNet Site samples."""
@@ -181,8 +271,9 @@ class SitesDataset(Dataset):
181
271
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
182
272
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
183
273
 
184
- sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
185
- return sample.compute()
274
+ sample_dict = compute(sample_dict)
275
+
276
+ return process_and_combine_datasets(sample_dict, self.config, t0)
186
277
 
187
278
  def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
188
279
  """Generate a sample for a given site id and t0.
@@ -197,159 +288,6 @@ class SitesDataset(Dataset):
197
288
 
198
289
  return self._get_sample(t0, location)
199
290
 
200
- def process_and_combine_site_sample_dict(
201
- self,
202
- dataset_dict: dict,
203
- t0: pd.Timestamp,
204
- ) -> xr.Dataset:
205
- """Normalize and combine data into a single xr Dataset.
206
-
207
- Args:
208
- dataset_dict: dict containing sliced xr DataArrays
209
- t0: The initial timestamp of the sample
210
-
211
- Returns:
212
- xr.Dataset: A merged Dataset with nans filled in.
213
- """
214
- data_arrays = []
215
-
216
- if "nwp" in dataset_dict:
217
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
218
- provider = self.config.input_data.nwp[nwp_key].provider
219
-
220
- da_channel_means = channel_dict_to_dataarray(
221
- self.config.input_data.nwp[nwp_key].channel_means,
222
- )
223
- da_channel_stds = channel_dict_to_dataarray(
224
- self.config.input_data.nwp[nwp_key].channel_stds,
225
- )
226
-
227
- da_nwp = (da_nwp - da_channel_means) / da_channel_stds
228
- data_arrays.append((f"nwp-{provider}", da_nwp))
229
-
230
- if "sat" in dataset_dict:
231
- da_sat = dataset_dict["sat"]
232
-
233
- da_channel_means = channel_dict_to_dataarray(
234
- self.config.input_data.satellite.channel_means,
235
- )
236
- da_channel_stds = channel_dict_to_dataarray(
237
- self.config.input_data.satellite.channel_stds,
238
- )
239
-
240
- da_sat = (da_sat - da_channel_means) / da_channel_stds
241
- data_arrays.append(("satellite", da_sat))
242
-
243
- if "site" in dataset_dict:
244
- da_sites = dataset_dict["site"]
245
- da_sites = da_sites / da_sites.capacity_kwp
246
- data_arrays.append(("site", da_sites))
247
-
248
- combined_sample_dataset = self.merge_data_arrays(data_arrays)
249
-
250
- # add datetime features
251
- datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
252
- datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
253
- combined_sample_dataset = combined_sample_dataset.assign_coords(
254
- {k: ("site__time_utc", v) for k, v in datetime_features.items()},
255
- )
256
-
257
- # Only add solar position if explicitly configured
258
- has_solar_config = (
259
- hasattr(self.config.input_data, "solar_position")
260
- and self.config.input_data.solar_position is not None
261
- )
262
-
263
- if has_solar_config:
264
- solar_config = self.config.input_data.solar_position
265
-
266
- # Datetime range - solar config params
267
- solar_datetimes = pd.date_range(
268
- t0 + minutes(solar_config.interval_start_minutes),
269
- t0 + minutes(solar_config.interval_end_minutes),
270
- freq=minutes(solar_config.time_resolution_minutes),
271
- )
272
-
273
- # Calculate sun position features
274
- sun_position_features = make_sun_position_numpy_sample(
275
- datetimes=solar_datetimes,
276
- lon=combined_sample_dataset.site__longitude.values,
277
- lat=combined_sample_dataset.site__latitude.values,
278
- )
279
-
280
- # Use existing dimension for solar positions
281
- # TODO decouple this as a separate data varaible
282
- solar_dim_name = "site__time_utc"
283
-
284
- # Assign solar position values
285
- for key, values in sun_position_features.items():
286
- combined_sample_dataset = combined_sample_dataset.assign_coords(
287
- {key: (solar_dim_name, values)},
288
- )
289
-
290
- # TODO include t0_index in xr dataset?
291
-
292
- # Fill any nan values
293
- return combined_sample_dataset.fillna(0.0)
294
-
295
- def merge_data_arrays(
296
- self,
297
- normalised_data_arrays: list[tuple[str, xr.DataArray]],
298
- ) -> xr.Dataset:
299
- """Combine a list of DataArrays into a single Dataset with unique naming conventions.
300
-
301
- Args:
302
- normalised_data_arrays: List of tuples where each tuple contains:
303
- - A string (key name).
304
- - An xarray.DataArray.
305
-
306
- Returns:
307
- xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
308
- """
309
- datasets = []
310
-
311
- for key, data_array in normalised_data_arrays:
312
- # Ensure all attributes are strings for consistency
313
- data_array = data_array.assign_attrs(
314
- {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
315
- )
316
-
317
- # Convert DataArray to Dataset with the variable name as the key
318
- dataset = data_array.to_dataset(name=key)
319
-
320
- # Prepend key name to all dimension and coordinate names for uniqueness
321
- dataset = dataset.rename(
322
- {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
323
- )
324
- dataset = dataset.rename(
325
- {coord: f"{key}__{coord}" for coord in dataset.coords},
326
- )
327
-
328
- # Handle concatenation dimension if applicable
329
- concat_dim = (
330
- f"{key}__target_time_utc"
331
- if f"{key}__target_time_utc" in dataset.coords
332
- else f"{key}__time_utc"
333
- )
334
-
335
- if f"{key}__init_time_utc" in dataset.coords:
336
- init_coord = f"{key}__init_time_utc"
337
- if dataset[init_coord].ndim == 0: # Check if scalar
338
- expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
339
- dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})
340
-
341
- datasets.append(dataset)
342
-
343
- # Ensure all datasets are valid xarray.Dataset objects
344
- for ds in datasets:
345
- if not isinstance(ds, xr.Dataset):
346
- raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
347
-
348
- # Merge all prepared datasets
349
- combined_dataset = xr.merge(datasets)
350
-
351
- return combined_dataset
352
-
353
291
 
354
292
  class SitesDatasetConcurrent(Dataset):
355
293
  """A torch Dataset for creating PVNet Site batches with samples for all sites."""
@@ -394,93 +332,6 @@ class SitesDatasetConcurrent(Dataset):
394
332
  # Assign coords and indices to self
395
333
  self.valid_t0s = valid_t0s
396
334
 
397
- @staticmethod
398
- def process_and_combine_datasets(
399
- dataset_dict: dict,
400
- config: Configuration,
401
- t0: pd.Timestamp,
402
- ) -> NumpySample:
403
- """Normalise and convert data to numpy arrays.
404
-
405
- Args:
406
- dataset_dict: Dictionary of xarray datasets
407
- config: Configuration object
408
- t0: init-time for sample
409
- """
410
- numpy_modalities = []
411
-
412
- if "nwp" in dataset_dict:
413
- nwp_numpy_modalities = {}
414
-
415
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
416
- # Standardise and convert to NumpyBatch
417
-
418
- da_channel_means = channel_dict_to_dataarray(
419
- config.input_data.nwp[nwp_key].channel_means,
420
- )
421
- da_channel_stds = channel_dict_to_dataarray(
422
- config.input_data.nwp[nwp_key].channel_stds,
423
- )
424
-
425
- da_nwp = (da_nwp - da_channel_means) / da_channel_stds
426
-
427
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
428
-
429
- # Combine the NWPs into NumpyBatch
430
- numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
431
-
432
- if "sat" in dataset_dict:
433
- da_sat = dataset_dict["sat"]
434
-
435
- # Standardise and convert to NumpyBatch
436
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
437
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
438
-
439
- da_sat = (da_sat - da_channel_means) / da_channel_stds
440
-
441
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
442
-
443
- if "site" in dataset_dict:
444
- da_sites = dataset_dict["site"]
445
- da_sites = da_sites / da_sites.capacity_kwp
446
-
447
- # Convert to NumpyBatch
448
- numpy_modalities.append(
449
- convert_site_to_numpy_sample(
450
- da_sites,
451
- ),
452
- )
453
-
454
- # Only add solar position if explicitly configured
455
- has_solar_config = (
456
- hasattr(config.input_data, "solar_position")
457
- and config.input_data.solar_position is not None
458
- )
459
-
460
- if has_solar_config:
461
- solar_config = config.input_data.solar_position
462
-
463
- # Create datetime range for solar position calculation
464
- datetimes = pd.date_range(
465
- t0 + minutes(solar_config.interval_start_minutes),
466
- t0 + minutes(solar_config.interval_end_minutes),
467
- freq=minutes(solar_config.time_resolution_minutes),
468
- )
469
-
470
- # Calculate solar positions and add to modalities
471
- numpy_modalities.append(
472
- make_sun_position_numpy_sample(
473
- datetimes, da_sites.longitude.values, da_sites.latitude.values,
474
- ),
475
- )
476
-
477
- # Combine all the modalities and fill NaNs
478
- combined_sample = merge_dicts(numpy_modalities)
479
- combined_sample = fill_nans_in_arrays(combined_sample)
480
-
481
- return combined_sample
482
-
483
-
484
335
  def find_valid_t0s(
485
336
  self,
486
337
  datasets_dict: dict,
@@ -551,7 +402,7 @@ class SitesDatasetConcurrent(Dataset):
551
402
 
552
403
  for location in self.locations:
553
404
  site_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
554
- site_numpy_sample = self.process_and_combine_datasets(
405
+ site_numpy_sample = process_and_combine_datasets(
555
406
  site_sample_dict,
556
407
  self.config,
557
408
  t0,
@@ -561,128 +412,6 @@ class SitesDatasetConcurrent(Dataset):
561
412
  return stack_np_samples_into_batch(site_samples)
562
413
 
563
414
 
564
- # ----- functions to load presaved samples ------
565
-
566
-
567
- def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
568
- """Convert a netcdf dataset to a numpy sample.
569
-
570
- Args:
571
- ds: xarray Dataset
572
- """
573
- # convert the single dataset to a dict of arrays
574
- sample_dict = convert_from_dataset_to_dict_datasets(ds)
575
-
576
- if "satellite" in sample_dict:
577
- # rename satellite to sat # TODO this could be improved
578
- sample_dict["sat"] = sample_dict.pop("satellite")
579
-
580
- # process and combine the datasets
581
- sample = convert_to_numpy_and_combine(
582
- dataset_dict=sample_dict,
583
- )
584
-
585
- # Extraction of solar position coords
586
- solar_keys = ["solar_azimuth", "solar_elevation"]
587
- for key in solar_keys:
588
- if key in ds.coords:
589
- sample[key] = ds.coords[key].values
590
-
591
- # TODO think about normalization:
592
- # * maybe its done not in sample creation, maybe its done afterwards,
593
- # to allow it to be flexible
594
-
595
- return sample
596
-
597
-
598
- def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
599
- """Convert a combined sample dataset to a dict of datasets for each input.
600
-
601
- Args:
602
- combined_dataset: The combined NetCDF dataset
603
-
604
- Returns:
605
- The uncombined datasets as a dict of xr.Datasets
606
- """
607
- # Split into datasets by splitting by the prefix added in combine_to_netcdf
608
- datasets: dict[str, xr.DataArray] = {}
609
-
610
- # Go through each data variable and split it into a dataset
611
- for key, dataset in combined_dataset.items():
612
- # If 'key__' doesn't exist in a dim or coordinate, remove it
613
- for dim in list(dataset.coords):
614
- if f"{key}__" not in dim:
615
- dataset = dataset.drop_vars(dim)
616
- dataset = dataset.rename(
617
- {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
618
- )
619
- dataset = dataset.rename(
620
- {coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
621
- )
622
- # Split the dataset by the prefix
623
- datasets[key] = dataset
624
-
625
- # Unflatten any NWP data
626
- return nest_nwp_source_dict(datasets, sep="-")
627
-
628
-
629
- def nest_nwp_source_dict(
630
- dataset_dict: dict[xr.Dataset],
631
- sep: str = "-",
632
- ) -> dict[str, xr.Dataset | dict[xr.Dataset]]:
633
- """Re-nest a dictionary where the NWP values are nested under keys 'nwp-<key>'.
634
-
635
- Args:
636
- dataset_dict: Dictionary of datasets
637
- sep: Separator to use to nest NWP keys
638
- """
639
- nwp_prefix = f"nwp{sep}"
640
- new_dict = {k: v for k, v in dataset_dict.items() if not k.startswith(nwp_prefix)}
641
- nwp_keys = [k for k in dataset_dict if k.startswith(nwp_prefix)]
642
- if len(nwp_keys) > 0:
643
- nwp_subdict = {k.removeprefix(nwp_prefix): dataset_dict[k] for k in nwp_keys}
644
- new_dict["nwp"] = nwp_subdict
645
- return new_dict
646
-
647
-
648
- def convert_to_numpy_and_combine(dataset_dict: dict[xr.Dataset]) -> NumpySample:
649
- """Convert input data in a dict to numpy arrays.
650
-
651
- Args:
652
- dataset_dict: Dictionary of xarray Datasets
653
- """
654
- numpy_modalities = []
655
-
656
- if "nwp" in dataset_dict:
657
- nwp_numpy_modalities = {}
658
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
659
- # Convert to NumpySample
660
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
661
-
662
- # Combine the NWPs into NumpySample
663
- numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
664
-
665
- if "sat" in dataset_dict:
666
- # Satellite is already in the range [0-1] so no need to standardise
667
- da_sat = dataset_dict["sat"]
668
-
669
- # Convert to NumpySample
670
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
671
-
672
- if "site" in dataset_dict:
673
- da_sites = dataset_dict["site"]
674
-
675
- numpy_modalities.append(
676
- convert_site_to_numpy_sample(
677
- da_sites,
678
- ),
679
- )
680
-
681
- # Combine all the modalities and fill NaNs
682
- combined_sample = merge_dicts(numpy_modalities)
683
- return fill_nans_in_arrays(combined_sample)
684
-
685
-
686
415
  def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
687
416
  """Coarsen the data to a specified resolution in degrees.
688
417
 
@@ -0,0 +1,48 @@
1
+ """PVNet Site sample implementation for netCDF data handling and conversion."""
2
+
3
+ import torch
4
+ from typing_extensions import override
5
+
6
+ from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
+
8
+ from .base import SampleBase
9
+
10
+
11
+ # TODO this is now similar to the UKRegionalSample
12
+ # We should consider just having one Sample class for all datasets
13
+ class SiteSample(SampleBase):
14
+ """Handles SiteSample specific operations."""
15
+
16
+ def __init__(self, data: NumpySample) -> None:
17
+ """Initializes the SiteSample object with the given NumpySample."""
18
+ self._data = data
19
+
20
+ @override
21
+ def to_numpy(self) -> NumpySample:
22
+ return self._data
23
+
24
+ @override
25
+ def save(self, path: str) -> None:
26
+ """Saves sample to the specified path in pickle format."""
27
+ # Saves to pickle format
28
+ torch.save(self._data, path)
29
+
30
+ @classmethod
31
+ @override
32
+ def load(cls, path: str) -> "SiteSample":
33
+ """Loads sample from the specified path.
34
+
35
+ Args:
36
+ path: Path to the saved sample file.
37
+
38
+ Returns:
39
+ A SiteSample instance with the loaded data.
40
+ """
41
+ # Loads from .pt format
42
+ # TODO: We should move away from using torch.load(..., weights_only=False)
43
+ return cls(torch.load(path, weights_only=False))
44
+
45
+ @override
46
+ def plot(self) -> None:
47
+ # TODO - placeholder for now
48
+ raise NotImplementedError("Plotting not yet implemented for SiteSample")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.3.1
3
+ Version: 0.5.0
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -49,7 +49,7 @@ Requires-Dist: xarray-tensorstore==0.1.5
49
49
  # ocf-data-sampler
50
50
 
51
51
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
52
- [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors-)
52
+ [![All Contributors](https://img.shields.io/badge/all_contributors-14-orange.svg?style=flat-square)](#contributors-)
53
53
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
54
54
 
55
55
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -128,6 +128,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
128
128
  <td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
129
129
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/Sachin-G13"><img src="https://avatars.githubusercontent.com/u/190184500?v=4?s=100" width="100px;" alt="Sachin-G13"/><br /><sub><b>Sachin-G13</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Sachin-G13" title="Code">💻</a></td>
130
130
  <td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
131
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
131
132
  </tr>
132
133
  </tbody>
133
134
  </table>
@@ -1,38 +0,0 @@
1
- """Functions to create trigonometric date and time inputs."""
2
-
3
- import numpy as np
4
- import pandas as pd
5
-
6
- from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
-
8
-
9
- def _get_date_time_in_pi(dt: pd.DatetimeIndex) -> tuple[np.ndarray, np.ndarray]:
10
- """Create positional embeddings for the datetimes in radians.
11
-
12
- Args:
13
- dt: DatetimeIndex to create radian embeddings for
14
-
15
- Returns:
16
- Tuple of numpy arrays containing radian coordinates for date and time
17
- """
18
- day_of_year = dt.dayofyear
19
- minute_of_day = dt.minute + dt.hour * 60
20
-
21
- time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 60))
22
- date_in_pi = (2 * np.pi) * (day_of_year / 365)
23
-
24
- return date_in_pi, time_in_pi
25
-
26
-
27
- def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> NumpySample:
28
- """Creates dictionary of cyclical datetime features - encoded."""
29
- date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
30
-
31
- time_numpy_sample = {}
32
-
33
- time_numpy_sample[key_prefix + "_date_sin"] = np.sin(date_in_pi)
34
- time_numpy_sample[key_prefix + "_date_cos"] = np.cos(date_in_pi)
35
- time_numpy_sample[key_prefix + "_time_sin"] = np.sin(time_in_pi)
36
- time_numpy_sample[key_prefix + "_time_cos"] = np.cos(time_in_pi)
37
-
38
- return time_numpy_sample
@@ -1,6 +0,0 @@
1
- from .pvnet_uk import PVNetUKRegionalDataset, PVNetUKConcurrentDataset
2
-
3
- from .site import (
4
- convert_netcdf_to_numpy_sample,
5
- SitesDataset
6
- )
@@ -1,39 +0,0 @@
1
- """PVNet Site sample implementation for netCDF data handling and conversion."""
2
-
3
- import xarray as xr
4
- from typing_extensions import override
5
-
6
- from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
- from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
8
-
9
- from .base import SampleBase
10
-
11
-
12
- class SiteSample(SampleBase):
13
- """Handles PVNet site specific netCDF operations."""
14
-
15
- def __init__(self, data: xr.Dataset) -> None:
16
- """Initializes the SiteSample object with the given xarray Dataset."""
17
- if not isinstance(data, xr.Dataset):
18
- raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
19
- self._data = data
20
-
21
- @override
22
- def to_numpy(self) -> NumpySample:
23
- return convert_netcdf_to_numpy_sample(self._data)
24
-
25
- @override
26
- def save(self, path: str) -> None:
27
- # Saves as NetCDF
28
- self._data.to_netcdf(path, mode="w", engine="h5netcdf")
29
-
30
- @classmethod
31
- @override
32
- def load(cls, path: str) -> "SiteSample":
33
- # Loads from NetCDF
34
- return cls(xr.open_dataset(path, decode_timedelta=False))
35
-
36
- @override
37
- def plot(self) -> None:
38
- # TODO - placeholder for now
39
- raise NotImplementedError("Plotting not yet implemented for SiteSample")