ocf-data-sampler 0.1.16__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/PKG-INFO +2 -3
  2. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/README.md +1 -2
  3. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/config/model.py +73 -3
  4. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +16 -15
  5. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/torch_datasets/datasets/site.py +19 -13
  6. ocf_data_sampler-0.2.0/ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  7. ocf_data_sampler-0.2.0/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  8. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler.egg-info/PKG-INFO +2 -3
  9. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler.egg-info/SOURCES.txt +2 -2
  10. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/scripts/refactor_site.py +7 -6
  11. ocf_data_sampler-0.1.16/ocf_data_sampler/constants.py +0 -350
  12. ocf_data_sampler-0.1.16/ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -86
  13. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/LICENSE +0 -0
  14. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/__init__.py +0 -0
  15. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/config/__init__.py +0 -0
  16. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/config/load.py +0 -0
  17. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/config/save.py +0 -0
  18. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  19. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/__init__.py +0 -0
  20. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/gsp.py +0 -0
  21. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/load_dataset.py +0 -0
  22. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  23. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  24. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  25. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  26. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
  27. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
  28. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  29. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  30. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/satellite.py +0 -0
  31. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/site.py +0 -0
  32. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/load/utils.py +0 -0
  33. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  34. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  35. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  36. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  37. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  38. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  39. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/site.py +0 -0
  40. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  41. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/sample/__init__.py +0 -0
  42. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/sample/base.py +0 -0
  43. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/sample/site.py +0 -0
  44. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/sample/uk_regional.py +0 -0
  45. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/__init__.py +0 -0
  46. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/dropout.py +0 -0
  47. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  48. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  49. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/geospatial.py +0 -0
  50. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/location.py +0 -0
  51. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  52. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/select_time_slice.py +0 -0
  53. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  54. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
  55. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  56. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  57. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  58. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler/utils.py +0 -0
  59. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  60. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler.egg-info/requires.txt +0 -0
  61. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  62. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/pyproject.toml +0 -0
  63. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/setup.cfg +0 -0
  64. {ocf_data_sampler-0.1.16 → ocf_data_sampler-0.2.0}/utils/compute_icon_mean_stddev.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf-data-sampler
3
- Version: 0.1.16
3
+ Version: 0.2.0
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -60,8 +60,7 @@ Requires-Dist: h5netcdf
60
60
  We are currently migrating to this repo from [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/), which performs the same functions but is built around `PyTorch DataPipes`, which are quite cumbersome to work with and are no longer maintained by PyTorch. **ocf-data-sampler** uses `PyTorch Datasets`, and we've taken the opportunity to make the code much cleaner and more manageable.
61
61
 
62
62
  > [!Note]
63
- > This repository is in development and is replacing [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/).
64
- > It might not be ready for use out of the box! We would really appreciate any help to let us make the transition faster.
63
+ > This repository is still in early development development and large changes to the user facing functions may still occur.
65
64
 
66
65
  ## Documentation
67
66
 
@@ -12,8 +12,7 @@
12
12
  We are currently migrating to this repo from [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/), which performs the same functions but is built around `PyTorch DataPipes`, which are quite cumbersome to work with and are no longer maintained by PyTorch. **ocf-data-sampler** uses `PyTorch Datasets`, and we've taken the opportunity to make the code much cleaner and more manageable.
13
13
 
14
14
  > [!Note]
15
- > This repository is in development and is replacing [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/).
16
- > It might not be ready for use out of the box! We would really appreciate any help to let us make the transition faster.
15
+ > This repository is still in early development development and large changes to the user facing functions may still occur.
17
16
 
18
17
  ## Documentation
19
18
 
@@ -9,7 +9,12 @@ from collections.abc import Iterator
9
9
  from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
10
10
  from typing_extensions import override
11
11
 
12
- from ocf_data_sampler.constants import NWP_PROVIDERS
12
+ NWP_PROVIDERS = [
13
+ "ukv",
14
+ "ecmwf",
15
+ "gfs",
16
+ "icon_eu",
17
+ ]
13
18
 
14
19
 
15
20
  class Base(BaseModel):
@@ -125,7 +130,35 @@ class SpatialWindowMixin(Base):
125
130
  )
126
131
 
127
132
 
128
- class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
133
+ class NormalisationValues(Base):
134
+ """Normalisation mean and standard deviation."""
135
+ mean: float = Field(..., description="Mean value for normalization")
136
+ std: float = Field(..., gt=0, description="Standard deviation (must be positive)")
137
+
138
+
139
+ class NormalisationConstantsMixin(Base):
140
+ """Normalisation constants for multiple channels."""
141
+ normalisation_constants: dict[str, NormalisationValues]
142
+
143
+ @property
144
+ def channel_means(self) -> dict[str, float]:
145
+ """Return the channel means."""
146
+ return {
147
+ channel: norm_values.mean
148
+ for channel, norm_values in self.normalisation_constants.items()
149
+ }
150
+
151
+
152
+ @property
153
+ def channel_stds(self) -> dict[str, float]:
154
+ """Return the channel standard deviations."""
155
+ return {
156
+ channel: norm_values.std
157
+ for channel, norm_values in self.normalisation_constants.items()
158
+ }
159
+
160
+
161
+ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
129
162
  """Satellite configuration model."""
130
163
 
131
164
  zarr_path: str | tuple[str] | list[str] = Field(
@@ -139,8 +172,20 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
139
172
  description="the satellite channels that are used",
140
173
  )
141
174
 
175
+ @model_validator(mode="after")
176
+ def check_all_channel_have_normalisation_constants(self) -> "Satellite":
177
+ """Check that all the channels have normalisation constants."""
178
+ normalisation_channels = set(self.normalisation_constants.keys())
179
+ missing_norm_values = set(self.channels) - set(normalisation_channels)
180
+ if len(missing_norm_values)>0:
181
+ raise ValueError(
182
+ "Normalsation constants must be provided for all channels. Missing values for "
183
+ f"channels: {missing_norm_values}",
184
+ )
185
+ return self
142
186
 
143
- class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
187
+
188
+ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
144
189
  """NWP configuration model."""
145
190
 
146
191
  zarr_path: str | tuple[str] | list[str] = Field(
@@ -173,6 +218,31 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
173
218
  return v
174
219
 
175
220
 
221
+ @model_validator(mode="after")
222
+ def check_all_channel_have_normalisation_constants(self) -> "NWP":
223
+ """Check that all the channels have normalisation constants."""
224
+ normalisation_channels = set(self.normalisation_constants.keys())
225
+ non_accum_channels = [c for c in self.channels if c not in self.accum_channels]
226
+ accum_channel_names = [f"diff_{c}" for c in self.accum_channels]
227
+
228
+ missing_norm_values = set(non_accum_channels) - set(normalisation_channels)
229
+ if len(missing_norm_values)>0:
230
+ raise ValueError(
231
+ "Normalsation constants must be provided for all channels. Missing values for "
232
+ f"channels: {missing_norm_values}",
233
+ )
234
+
235
+ missing_norm_values = set(accum_channel_names) - set(normalisation_channels)
236
+ if len(missing_norm_values)>0:
237
+ raise ValueError(
238
+ "Normalsation constants must be provided for all channels. Accumulated "
239
+ "channels which will be diffed require normalisation constant names which "
240
+ "start with the prefix 'diff_'. The following channels were missing: "
241
+ f"{missing_norm_values}.",
242
+ )
243
+ return self
244
+
245
+
176
246
  class MultiNWP(RootModel):
177
247
  """Configuration for multiple NWPs."""
178
248
 
@@ -9,7 +9,6 @@ from torch.utils.data import Dataset
9
9
  from typing_extensions import override
10
10
 
11
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
12
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
13
12
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
14
13
  from ocf_data_sampler.numpy_sample import (
15
14
  convert_gsp_to_numpy_sample,
@@ -27,15 +26,11 @@ from ocf_data_sampler.select import (
27
26
  slice_datasets_by_time,
28
27
  )
29
28
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
+ from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
31
  fill_nans_in_arrays,
32
32
  merge_dicts,
33
33
  )
34
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
35
- from ocf_data_sampler.torch_datasets.utils.validate_channels import (
36
- validate_nwp_channels,
37
- validate_satellite_channels,
38
- )
39
34
  from ocf_data_sampler.utils import minutes
40
35
 
41
36
  xr.set_options(keep_attrs=True)
@@ -54,10 +49,18 @@ def process_and_combine_datasets(
54
49
  nwp_numpy_modalities = {}
55
50
 
56
51
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
57
- provider = config.input_data.nwp[nwp_key].provider
58
52
 
59
53
  # Standardise and convert to NumpyBatch
60
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
54
+
55
+ da_channel_means = channel_dict_to_dataarray(
56
+ config.input_data.nwp[nwp_key].channel_means,
57
+ )
58
+ da_channel_stds = channel_dict_to_dataarray(
59
+ config.input_data.nwp[nwp_key].channel_stds,
60
+ )
61
+
62
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
63
+
61
64
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
62
65
 
63
66
  # Combine the NWPs into NumpyBatch
@@ -67,7 +70,11 @@ def process_and_combine_datasets(
67
70
  da_sat = dataset_dict["sat"]
68
71
 
69
72
  # Standardise and convert to NumpyBatch
70
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
73
+ da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
74
+ da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
75
+
76
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
77
+
71
78
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
72
79
 
73
80
  if "gsp" in dataset_dict:
@@ -194,8 +201,6 @@ class PVNetUKRegionalDataset(Dataset):
194
201
  """
195
202
  # config = load_yaml_configuration(config_filename)
196
203
  config: Configuration = load_yaml_configuration(config_filename)
197
- validate_nwp_channels(config)
198
- validate_satellite_channels(config)
199
204
 
200
205
  datasets_dict = get_dataset_dict(config.input_data)
201
206
 
@@ -305,10 +310,6 @@ class PVNetUKConcurrentDataset(Dataset):
305
310
  """
306
311
  config = load_yaml_configuration(config_filename)
307
312
 
308
- # Validate channels for NWP and satellite data
309
- validate_nwp_channels(config)
310
- validate_satellite_channels(config)
311
-
312
313
  datasets_dict = get_dataset_dict(config.input_data)
313
314
 
314
315
  # Get t0 times where all input data is available
@@ -9,7 +9,6 @@ from torch.utils.data import Dataset
9
9
  from typing_extensions import override
10
10
 
11
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
12
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
13
12
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
14
13
  from ocf_data_sampler.numpy_sample import (
15
14
  NWPSampleKey,
@@ -27,15 +26,11 @@ from ocf_data_sampler.select import (
27
26
  slice_datasets_by_space,
28
27
  slice_datasets_by_time,
29
28
  )
29
+ from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
31
  fill_nans_in_arrays,
32
32
  merge_dicts,
33
33
  )
34
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
35
- from ocf_data_sampler.torch_datasets.utils.validate_channels import (
36
- validate_nwp_channels,
37
- validate_satellite_channels,
38
- )
39
34
  from ocf_data_sampler.utils import minutes
40
35
 
41
36
  xr.set_options(keep_attrs=True)
@@ -58,9 +53,6 @@ class SitesDataset(Dataset):
58
53
  end_time: Limit the init-times to be before this
59
54
  """
60
55
  config: Configuration = load_yaml_configuration(config_filename)
61
- validate_nwp_channels(config)
62
- validate_satellite_channels(config)
63
-
64
56
  datasets_dict = get_dataset_dict(config.input_data)
65
57
 
66
58
  # Assign config and input data to self
@@ -224,7 +216,6 @@ class SitesDataset(Dataset):
224
216
 
225
217
  Args:
226
218
  dataset_dict: dict containing sliced xr DataArrays
227
- config: Configuration for the model
228
219
  t0: The initial timestamp of the sample
229
220
 
230
221
  Returns:
@@ -238,14 +229,29 @@ class SitesDataset(Dataset):
238
229
  provider = self.config.input_data.nwp[nwp_key].provider
239
230
 
240
231
  # Standardise
241
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
232
+ da_channel_means = channel_dict_to_dataarray(
233
+ self.config.input_data.nwp[nwp_key].channel_means,
234
+ )
235
+ da_channel_stds = channel_dict_to_dataarray(
236
+ self.config.input_data.nwp[nwp_key].channel_stds,
237
+ )
238
+
239
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
240
+
242
241
  data_arrays.append((f"nwp-{provider}", da_nwp))
243
242
 
244
243
  if "sat" in dataset_dict:
245
244
  da_sat = dataset_dict["sat"]
246
245
 
247
- # Standardise
248
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
246
+ da_channel_means = channel_dict_to_dataarray(
247
+ self.config.input_data.satellite.channel_means,
248
+ )
249
+ da_channel_stds = channel_dict_to_dataarray(
250
+ self.config.input_data.satellite.channel_stds,
251
+ )
252
+
253
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
254
+
249
255
  data_arrays.append(("satellite", da_sat))
250
256
 
251
257
  if "site" in dataset_dict:
@@ -0,0 +1,3 @@
1
+ from .channel_dict_to_dataarray import channel_dict_to_dataarray
2
+ from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
+ from .valid_time_periods import find_valid_time_periods
@@ -0,0 +1,11 @@
1
+ """Converts a dictionary of channel values to a DataArray."""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
+ """Converts a dictionary of channel values to a DataArray."""
8
+ return xr.DataArray(
9
+ list(channel_dict.values()),
10
+ coords={"channel": list(channel_dict.keys())},
11
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf-data-sampler
3
- Version: 0.1.16
3
+ Version: 0.2.0
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -60,8 +60,7 @@ Requires-Dist: h5netcdf
60
60
  We are currently migrating to this repo from [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/), which performs the same functions but is built around `PyTorch DataPipes`, which are quite cumbersome to work with and are no longer maintained by PyTorch. **ocf-data-sampler** uses `PyTorch Datasets`, and we've taken the opportunity to make the code much cleaner and more manageable.
61
61
 
62
62
  > [!Note]
63
- > This repository is in development and is replacing [ocf_datapipes](https://github.com/openclimatefix/ocf_datapipes/).
64
- > It might not be ready for use out of the box! We would really appreciate any help to let us make the transition faster.
63
+ > This repository is still in early development development and large changes to the user facing functions may still occur.
65
64
 
66
65
  ## Documentation
67
66
 
@@ -2,7 +2,6 @@ LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
4
  ocf_data_sampler/__init__.py
5
- ocf_data_sampler/constants.py
6
5
  ocf_data_sampler/utils.py
7
6
  ocf_data_sampler.egg-info/PKG-INFO
8
7
  ocf_data_sampler.egg-info/SOURCES.txt
@@ -53,8 +52,9 @@ ocf_data_sampler/select/time_slice_for_dataset.py
53
52
  ocf_data_sampler/torch_datasets/datasets/__init__.py
54
53
  ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
55
54
  ocf_data_sampler/torch_datasets/datasets/site.py
55
+ ocf_data_sampler/torch_datasets/utils/__init__.py
56
+ ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py
56
57
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
57
58
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
58
- ocf_data_sampler/torch_datasets/utils/validate_channels.py
59
59
  scripts/refactor_site.py
60
60
  utils/compute_icon_mean_stddev.py
@@ -1,16 +1,18 @@
1
- import xarray as xr
1
+ """Refactor legacy site data into a more structured format."""
2
+
2
3
  import pandas as pd
4
+ import xarray as xr
5
+
3
6
 
4
7
  def legacy_format(data_ds: xr.Dataset, metadata_df: pd.DataFrame) -> xr.Dataset:
5
- """
6
- Converts old legacy site data into a more structured format.
8
+ """Converts old legacy site data into a more structured format.
7
9
 
8
10
  This function does three main things:
9
11
  1. Renames some columns in the metadata to keep things consistent.
10
12
  2. Reshapes site data so that instead of having separate variables for each site,
11
13
  we use a `site_id` dimension—makes life easier for analysis.
12
14
  3. Adds `capacity_kwp` as a time series so that each site has its capacity info.
13
-
15
+
14
16
  Parameters:
15
17
  data_ds (xr.Dataset): The dataset containing legacy site data.
16
18
  metadata_df (pd.DataFrame): A DataFrame with metadata about the sites.
@@ -18,11 +20,10 @@ def legacy_format(data_ds: xr.Dataset, metadata_df: pd.DataFrame) -> xr.Dataset:
18
20
  Returns:
19
21
  xr.Dataset: Reformatted dataset with `generation_kw` and `capacity_kwp`.
20
22
  """
21
-
22
23
  # Step 1: Rename metadata columns to match the new expected format
23
24
  if "system_id" in metadata_df.columns:
24
25
  metadata_df = metadata_df.rename(columns={"system_id": "site_id"})
25
-
26
+
26
27
  # Convert capacity from megawatts to kilowatts if needed
27
28
  if "capacity_megawatts" in metadata_df.columns:
28
29
  metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000
@@ -1,350 +0,0 @@
1
- """Constants for the package."""
2
-
3
- import numpy as np
4
- import xarray as xr
5
- from typing_extensions import override
6
-
7
- NWP_PROVIDERS = [
8
- "ukv",
9
- "ecmwf",
10
- "gfs",
11
- "icon_eu",
12
- ]
13
-
14
-
15
- def _to_data_array(d: dict) -> xr.DataArray:
16
- """Convert a dictionary to a DataArray."""
17
- return xr.DataArray(
18
- [d[k] for k in d],
19
- coords={"channel": list(d.keys())},
20
- ).astype(np.float32)
21
-
22
-
23
- class NWPStatDict(dict):
24
- """Custom dictionary class to hold NWP normalization stats."""
25
-
26
- @override
27
- def __getitem__(self, key: str) -> xr.DataArray:
28
- if key not in NWP_PROVIDERS:
29
- raise KeyError(f"{key} is not a supported NWP provider - {NWP_PROVIDERS}")
30
- elif key in self.keys():
31
- return super().__getitem__(key)
32
- else:
33
- raise KeyError(
34
- f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}",
35
- )
36
-
37
-
38
- # ------ UKV
39
- # Means and std computed WITH version_7 and higher, MetOffice values
40
- UKV_STD = {
41
- "cdcb": 2126.99350113,
42
- "lcc": 39.33210726,
43
- "mcc": 41.91144559,
44
- "hcc": 38.07184418,
45
- "sde": 0.1029753,
46
- "hcct": 18382.63958991,
47
- "dswrf": 190.47216887,
48
- "dlwrf": 39.45988077,
49
- "h": 1075.77812282,
50
- "t": 4.38818501,
51
- "r": 11.45012499,
52
- "dpt": 4.57250482,
53
- "vis": 21578.97975625,
54
- "si10": 3.94718813,
55
- "wdir10": 94.08407495,
56
- "prmsl": 1252.71790539,
57
- "prate": 0.00021497,
58
- }
59
-
60
- UKV_MEAN = {
61
- "cdcb": 1412.26599062,
62
- "lcc": 50.08362643,
63
- "mcc": 40.88984494,
64
- "hcc": 29.11949682,
65
- "sde": 0.00289545,
66
- "hcct": -18345.97478167,
67
- "dswrf": 111.28265039,
68
- "dlwrf": 325.03130139,
69
- "h": 2096.51991356,
70
- "t": 283.64913206,
71
- "r": 81.79229501,
72
- "dpt": 280.54379901,
73
- "vis": 32262.03285118,
74
- "si10": 6.88348448,
75
- "wdir10": 199.41891636,
76
- "prmsl": 101321.61574029,
77
- "prate": 3.45793433e-05,
78
- }
79
-
80
- UKV_STD = _to_data_array(UKV_STD)
81
- UKV_MEAN = _to_data_array(UKV_MEAN)
82
-
83
- # ------ ECMWF
84
- # These were calculated from 100 random init times of UK data from 2020-2023
85
- ECMWF_STD = {
86
- "dlwrf": 15855867.0,
87
- "dswrf": 13025427.0,
88
- "duvrs": 1445635.25,
89
- "hcc": 0.42244860529899597,
90
- "lcc": 0.3791404366493225,
91
- "mcc": 0.38039860129356384,
92
- "prate": 9.81039775069803e-05,
93
- "sd": 0.000913831521756947,
94
- "sr": 16294988.0,
95
- "t2m": 3.692270040512085,
96
- "tcc": 0.37487083673477173,
97
- "u10": 5.531515598297119,
98
- "u100": 7.2320556640625,
99
- "u200": 8.049470901489258,
100
- "v10": 5.411230564117432,
101
- "v100": 6.944501876831055,
102
- "v200": 7.561611652374268,
103
- "diff_dlwrf": 131942.03125,
104
- "diff_dswrf": 715366.3125,
105
- "diff_duvrs": 81605.25,
106
- "diff_sr": 818950.6875,
107
- }
108
-
109
- ECMWF_MEAN = {
110
- "dlwrf": 27187026.0,
111
- "dswrf": 11458988.0,
112
- "duvrs": 1305651.25,
113
- "hcc": 0.3961029052734375,
114
- "lcc": 0.44901806116104126,
115
- "mcc": 0.3288780450820923,
116
- "prate": 3.108070450252853e-05,
117
- "sd": 8.107526082312688e-05,
118
- "sr": 12905302.0,
119
- "t2m": 283.48333740234375,
120
- "tcc": 0.7049227356910706,
121
- "u10": 1.7677178382873535,
122
- "u100": 2.393547296524048,
123
- "u200": 2.7963004112243652,
124
- "v10": 0.985887885093689,
125
- "v100": 1.4244288206100464,
126
- "v200": 1.6010299921035767,
127
- "diff_dlwrf": 1136464.0,
128
- "diff_dswrf": 420584.6875,
129
- "diff_duvrs": 48265.4765625,
130
- "diff_sr": 469169.5,
131
- }
132
-
133
- ECMWF_STD = _to_data_array(ECMWF_STD)
134
- ECMWF_MEAN = _to_data_array(ECMWF_MEAN)
135
-
136
- # ------ GFS
137
- GFS_STD = {
138
- "dlwrf": 96.305916,
139
- "dswrf": 246.18533,
140
- "hcc": 42.525383,
141
- "lcc": 44.3732,
142
- "mcc": 43.150745,
143
- "prate": 0.00010159573,
144
- "r": 25.440672,
145
- "sde": 0.43345627,
146
- "t": 22.825893,
147
- "tcc": 41.030598,
148
- "u10": 5.470838,
149
- "u100": 6.8899174,
150
- "v10": 4.7401133,
151
- "v100": 6.076132,
152
- "vis": 8294.022,
153
- "u": 10.614556,
154
- "v": 7.176398,
155
- }
156
-
157
- GFS_MEAN = {
158
- "dlwrf": 298.342,
159
- "dswrf": 168.12321,
160
- "hcc": 35.272,
161
- "lcc": 43.578342,
162
- "mcc": 33.738823,
163
- "prate": 2.8190969e-05,
164
- "r": 18.359747,
165
- "sde": 0.36937004,
166
- "t": 278.5223,
167
- "tcc": 66.841606,
168
- "u10": -0.0022310058,
169
- "u100": 0.0823025,
170
- "v10": 0.06219831,
171
- "v100": 0.0797807,
172
- "vis": 19628.32,
173
- "u": 11.645444,
174
- "v": 0.12330122,
175
- }
176
-
177
- GFS_STD = _to_data_array(GFS_STD)
178
- GFS_MEAN = _to_data_array(GFS_MEAN)
179
-
180
- # ------ ICON-EU
181
- # Statistics for ICON-EU variables
182
- ICON_EU_STD = {
183
- "alb_rad": 13.7881,
184
- "alhfl_s": 73.7198,
185
- "ashfl_s": 54.8027,
186
- "asob_s": 55.8319,
187
- "asob_t": 74.9360,
188
- "aswdifd_s": 21.4940,
189
- "aswdifu_s": 18.7688,
190
- "aswdir_s": 54.4683,
191
- "athb_s": 34.8575,
192
- "athb_t": 42.9108,
193
- "aumfl_s": 0.1460,
194
- "avmfl_s": 0.1892,
195
- "cape_con": 32.2570,
196
- "cape_ml": 106.3998,
197
- "clch": 39.9324,
198
- "clcl": 36.3961,
199
- "clcm": 41.1690,
200
- "clct": 34.7696,
201
- "clct_mod": 0.4227,
202
- "cldepth": 0.1739,
203
- "h_snow": 0.9012,
204
- "hbas_con": 1306.6632,
205
- "htop_con": 1810.5665,
206
- "htop_dc": 459.0422,
207
- "hzerocl": 1144.6469,
208
- "pmsl": 1103.3301,
209
- "ps": 4761.3184,
210
- "qv_2m": 0.0024,
211
- "qv_s": 0.0038,
212
- "rain_con": 1.7097,
213
- "rain_gsp": 4.2654,
214
- "relhum_2m": 15.3779,
215
- "rho_snow": 120.2461,
216
- "runoff_g": 0.7410,
217
- "runoff_s": 2.1930,
218
- "snow_con": 1.1432,
219
- "snow_gsp": 1.8154,
220
- "snowlmt": 656.0699,
221
- "synmsg_bt_cl_ir10.8": 17.9438,
222
- "t_2m": 7.7973,
223
- "t_g": 8.7053,
224
- "t_snow": 134.6874,
225
- "tch": 0.0052,
226
- "tcm": 0.0133,
227
- "td_2m": 7.1460,
228
- "tmax_2m": 7.8218,
229
- "tmin_2m": 7.8346,
230
- "tot_prec": 5.6312,
231
- "tqc": 0.0976,
232
- "tqi": 0.0247,
233
- "u_10m": 3.8351,
234
- "v_10m": 5.0083,
235
- "vmax_10m": 5.5037,
236
- "w_snow": 286.1510,
237
- "ww": 27.2974,
238
- "z0": 0.3901,
239
- }
240
-
241
- ICON_EU_MEAN = {
242
- "alb_rad": 15.4437,
243
- "alhfl_s": -54.9398,
244
- "ashfl_s": -19.4684,
245
- "asob_s": 40.9305,
246
- "asob_t": 61.9244,
247
- "aswdifd_s": 19.7813,
248
- "aswdifu_s": 8.8328,
249
- "aswdir_s": 29.9820,
250
- "athb_s": -53.9873,
251
- "athb_t": -212.8088,
252
- "aumfl_s": 0.0558,
253
- "avmfl_s": 0.0078,
254
- "cape_con": 16.7397,
255
- "cape_ml": 21.2189,
256
- "clch": 26.4262,
257
- "clcl": 57.1591,
258
- "clcm": 36.1702,
259
- "clct": 72.9254,
260
- "clct_mod": 0.5561,
261
- "cldepth": 0.1356,
262
- "h_snow": 0.0494,
263
- "hbas_con": 108.4975,
264
- "htop_con": 433.0623,
265
- "htop_dc": 454.0859,
266
- "hzerocl": 1696.6272,
267
- "pmsl": 101778.8281,
268
- "ps": 99114.4766,
269
- "qv_2m": 0.0049,
270
- "qv_s": 0.0065,
271
- "rain_con": 0.4869,
272
- "rain_gsp": 0.9783,
273
- "relhum_2m": 78.2258,
274
- "rho_snow": 62.5032,
275
- "runoff_g": 0.1301,
276
- "runoff_s": 0.4119,
277
- "snow_con": 0.2188,
278
- "snow_gsp": 0.4317,
279
- "snowlmt": 1450.3241,
280
- "synmsg_bt_cl_ir10.8": 265.0639,
281
- "t_2m": 278.8212,
282
- "t_g": 279.9216,
283
- "t_snow": 162.5582,
284
- "tch": 0.0047,
285
- "tcm": 0.0091,
286
- "td_2m": 274.9544,
287
- "tmax_2m": 279.3550,
288
- "tmin_2m": 278.2519,
289
- "tot_prec": 2.1158,
290
- "tqc": 0.0424,
291
- "tqi": 0.0108,
292
- "u_10m": 1.1902,
293
- "v_10m": -0.4733,
294
- "vmax_10m": 8.4152,
295
- "w_snow": 14.5936,
296
- "ww": 15.3570,
297
- "z0": 0.2386,
298
- }
299
-
300
- ICON_EU_STD = _to_data_array(ICON_EU_STD)
301
- ICON_EU_MEAN = _to_data_array(ICON_EU_MEAN)
302
-
303
- NWP_STDS = NWPStatDict(
304
- ukv=UKV_STD,
305
- ecmwf=ECMWF_STD,
306
- gfs=GFS_STD,
307
- icon_eu=ICON_EU_STD,
308
- )
309
- NWP_MEANS = NWPStatDict(
310
- ukv=UKV_MEAN,
311
- ecmwf=ECMWF_MEAN,
312
- gfs=GFS_MEAN,
313
- icon_eu=ICON_EU_MEAN,
314
- )
315
-
316
- # ------ Satellite
317
- # RSS Mean and std values from randomised 20% of 2020 imagery
318
-
319
- RSS_STD = {
320
- "HRV": 0.11405209,
321
- "IR_016": 0.21462157,
322
- "IR_039": 0.04618041,
323
- "IR_087": 0.06687243,
324
- "IR_097": 0.0468558,
325
- "IR_108": 0.17482725,
326
- "IR_120": 0.06115861,
327
- "IR_134": 0.04492306,
328
- "VIS006": 0.12184761,
329
- "VIS008": 0.13090034,
330
- "WV_062": 0.16111417,
331
- "WV_073": 0.12924142,
332
- }
333
-
334
- RSS_MEAN = {
335
- "HRV": 0.09298719,
336
- "IR_016": 0.17594202,
337
- "IR_039": 0.86167645,
338
- "IR_087": 0.7719318,
339
- "IR_097": 0.8014212,
340
- "IR_108": 0.71254843,
341
- "IR_120": 0.89058584,
342
- "IR_134": 0.944365,
343
- "VIS006": 0.09633306,
344
- "VIS008": 0.11426069,
345
- "WV_062": 0.7359355,
346
- "WV_073": 0.62479186,
347
- }
348
-
349
- RSS_STD = _to_data_array(RSS_STD)
350
- RSS_MEAN = _to_data_array(RSS_MEAN)
@@ -1,86 +0,0 @@
1
- """Functions for checking that normalisation statistics exist for the data channels requested."""
2
-
3
- from ocf_data_sampler.config import Configuration
4
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
5
-
6
-
7
- def validate_channels(
8
- data_channels: list,
9
- means_channels: list,
10
- stds_channels: list,
11
- source_name: str | None = None,
12
- ) -> None:
13
- """Validates that all channels in data have corresponding normalisation constants.
14
-
15
- Args:
16
- data_channels: Set of channels from the data
17
- means_channels: Set of channels from means constants
18
- stds_channels: Set of channels from stds constants
19
- source_name: Name of data source (e.g., 'ecmwf', 'satellite') for error messages
20
-
21
- Raises:
22
- ValueError: If there's a mismatch between data channels and normalisation constants
23
- """
24
- data_set = set(data_channels)
25
- means_set = set(means_channels)
26
- stds_set = set(stds_channels)
27
-
28
- # Find missing channels in means
29
- missing_in_means = data_set - means_set
30
- if missing_in_means:
31
- raise ValueError(
32
- f"The following channels for {source_name} are missing in normalisation means: "
33
- f"{missing_in_means}",
34
- )
35
-
36
- # Find missing channels in stds
37
- missing_in_stds = data_set - stds_set
38
- if missing_in_stds:
39
- raise ValueError(
40
- f"The following channels for {source_name} are missing in normalisation stds: "
41
- f"{missing_in_stds}",
42
- )
43
-
44
-
45
- def validate_nwp_channels(config: Configuration) -> None:
46
- """Validate that NWP channels in config have corresponding normalisation constants.
47
-
48
- Args:
49
- config: Configuration object containing NWP channel information
50
-
51
- Raises:
52
- ValueError: If there's a mismatch between configured NWP channels
53
- and normalisation constants
54
- """
55
- if hasattr(config.input_data, "nwp") and (
56
- config.input_data.nwp is not None
57
- ):
58
- for _, nwp_config in config.input_data.nwp.items():
59
- provider = nwp_config.provider
60
- validate_channels(
61
- data_channels=nwp_config.channels,
62
- means_channels=NWP_MEANS[provider].channel.values,
63
- stds_channels=NWP_STDS[provider].channel.values,
64
- source_name=provider,
65
- )
66
-
67
-
68
- def validate_satellite_channels(config: Configuration) -> None:
69
- """Validate that satellite channels in config have corresponding normalisation constants.
70
-
71
- Args:
72
- config: Configuration object containing satellite channel information
73
-
74
- Raises:
75
- ValueError: If there's a mismatch between configured satellite channels
76
- and normalisation constants
77
- """
78
- if hasattr(config.input_data, "satellite") and (
79
- config.input_data.satellite is not None
80
- ):
81
- validate_channels(
82
- data_channels=config.input_data.satellite.channels,
83
- means_channels=RSS_MEAN.channel.values,
84
- stds_channels=RSS_STD.channel.values,
85
- source_name="satellite",
86
- )