ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (37) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +249 -0
  4. ocf_data_sampler/config/save.py +36 -0
  5. ocf_data_sampler/constants.py +135 -0
  6. ocf_data_sampler/numpy_batch/gsp.py +21 -8
  7. ocf_data_sampler/numpy_batch/nwp.py +13 -3
  8. ocf_data_sampler/numpy_batch/satellite.py +15 -8
  9. ocf_data_sampler/numpy_batch/sun_position.py +5 -6
  10. ocf_data_sampler/select/dropout.py +2 -2
  11. ocf_data_sampler/select/geospatial.py +118 -0
  12. ocf_data_sampler/select/location.py +62 -0
  13. ocf_data_sampler/select/select_spatial_slice.py +5 -14
  14. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +16 -20
  15. ocf_data_sampler-0.0.22.dist-info/METADATA +88 -0
  16. ocf_data_sampler-0.0.22.dist-info/RECORD +54 -0
  17. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.22.dist-info}/WHEEL +1 -1
  18. tests/config/test_config.py +152 -0
  19. tests/conftest.py +6 -1
  20. tests/load/test_load_gsp.py +15 -0
  21. tests/load/test_load_nwp.py +21 -0
  22. tests/load/test_load_satellite.py +17 -0
  23. tests/numpy_batch/test_gsp.py +22 -0
  24. tests/numpy_batch/test_nwp.py +54 -0
  25. tests/numpy_batch/test_satellite.py +42 -0
  26. tests/numpy_batch/test_sun_position.py +81 -0
  27. tests/select/test_dropout.py +75 -0
  28. tests/select/test_fill_time_periods.py +28 -0
  29. tests/select/test_find_contiguous_time_periods.py +202 -0
  30. tests/select/test_location.py +67 -0
  31. tests/select/test_select_spatial_slice.py +154 -0
  32. tests/select/test_select_time_slice.py +284 -0
  33. tests/torch_datasets/test_pvnet_uk_regional.py +74 -0
  34. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  35. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  36. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.22.dist-info}/LICENSE +0 -0
  37. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,118 @@
1
+ """Geospatial functions"""
2
+
3
+ from numbers import Number
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import pyproj
8
+ import xarray as xr
9
+
10
+ # OSGB is also called "OSGB 1936 / British National Grid -- United
11
+ # Kingdom Ordnance Survey". OSGB is used in many UK electricity
12
+ # system maps, and is used by the UK Met Office UKV model. OSGB is a
13
+ # Transverse Mercator projection, using 'easting' and 'northing'
14
+ # coordinates which are in meters. See https://epsg.io/27700
15
+ OSGB36 = 27700
16
+
17
+ # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
18
+ # latitude and longitude.
19
+ WGS84 = 4326
20
+
21
+
22
+ _osgb_to_lon_lat = pyproj.Transformer.from_crs(
23
+ crs_from=OSGB36, crs_to=WGS84, always_xy=True
24
+ ).transform
25
+ _lon_lat_to_osgb = pyproj.Transformer.from_crs(
26
+ crs_from=WGS84, crs_to=OSGB36, always_xy=True
27
+ ).transform
28
+
29
+
30
+ def osgb_to_lon_lat(
31
+ x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
32
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
33
+ """Change OSGB coordinates to lon, lat.
34
+
35
+ Args:
36
+ x: osgb east-west
37
+ y: osgb north-south
38
+ Return: 2-tuple of longitude (east-west), latitude (north-south)
39
+ """
40
+ return _osgb_to_lon_lat(xx=x, yy=y)
41
+
42
+
43
+ def lon_lat_to_osgb(
44
+ x: Union[Number, np.ndarray],
45
+ y: Union[Number, np.ndarray],
46
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
47
+ """Change lon-lat coordinates to OSGB.
48
+
49
+ Args:
50
+ x: longitude east-west
51
+ y: latitude north-south
52
+
53
+ Return: 2-tuple of OSGB x, y
54
+ """
55
+ return _lon_lat_to_osgb(xx=x, yy=y)
56
+
57
+
58
+ def osgb_to_geostationary_area_coords(
59
+ x: Union[Number, np.ndarray],
60
+ y: Union[Number, np.ndarray],
61
+ xr_data: xr.DataArray,
62
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
+ """Loads geostationary area and transformation from OSGB to geostationary coords
64
+
65
+ Args:
66
+ x: osgb east-west
67
+ y: osgb north-south
68
+ xr_data: xarray object with geostationary area
69
+
70
+ Returns:
71
+ Geostationary coords: x, y
72
+ """
73
+ # Only load these if using geostationary projection
74
+ import pyresample
75
+
76
+ area_definition_yaml = xr_data.attrs["area"]
77
+
78
+ geostationary_area_definition = pyresample.area_config.load_area_from_string(
79
+ area_definition_yaml
80
+ )
81
+ geostationary_crs = geostationary_area_definition.crs
82
+ osgb_to_geostationary = pyproj.Transformer.from_crs(
83
+ crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
84
+ ).transform
85
+ return osgb_to_geostationary(xx=x, yy=y)
86
+
87
+
88
+ def _coord_priority(available_coords):
89
+ if "longitude" in available_coords:
90
+ return "lon_lat", "longitude", "latitude"
91
+ elif "x_geostationary" in available_coords:
92
+ return "geostationary", "x_geostationary", "y_geostationary"
93
+ elif "x_osgb" in available_coords:
94
+ return "osgb", "x_osgb", "y_osgb"
95
+ else:
96
+ raise ValueError(f"Unrecognized coordinate system: {available_coords}")
97
+
98
+
99
+ def spatial_coord_type(ds: xr.DataArray):
100
+ """Searches the data array to determine the kind of spatial coordinates present.
101
+
102
+ This search has a preference for the dimension coordinates of the xarray object.
103
+
104
+ Args:
105
+ ds: Dataset with spatial coords
106
+
107
+ Returns:
108
+ str: The kind of the coordinate system
109
+ x_coord: Name of the x-coordinate
110
+ y_coord: Name of the y-coordinate
111
+ """
112
+ if isinstance(ds, xr.DataArray):
113
+ # Search dimension coords of dataarray
114
+ coords = _coord_priority(ds.xindexes)
115
+ else:
116
+ raise ValueError(f"Unrecognized input type: {type(ds)}")
117
+
118
+ return coords
@@ -0,0 +1,62 @@
1
+ """location"""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel, Field, model_validator
7
+
8
+
9
+ allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
10
+
11
+ class Location(BaseModel):
12
+ """Represent a spatial location."""
13
+
14
+ coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
15
+ x: float
16
+ y: float
17
+ id: Optional[int] = Field(None)
18
+
19
+ @model_validator(mode='after')
20
+ def validate_coordinate_system(self):
21
+ """Validate 'coordinate_system'"""
22
+ if self.coordinate_system not in allowed_coordinate_systems:
23
+ raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
24
+ return self
25
+
26
+ @model_validator(mode='after')
27
+ def validate_x(self):
28
+ """Validate 'x'"""
29
+ min_x: float
30
+ max_x: float
31
+
32
+ co = self.coordinate_system
33
+ if co == "osgb":
34
+ min_x, max_x = -103976.3, 652897.98
35
+ if co == "lon_lat":
36
+ min_x, max_x = -180, 180
37
+ if co == "geostationary":
38
+ min_x, max_x = -5568748.275756836, 5567248.074173927
39
+ if co == "idx":
40
+ min_x, max_x = 0, np.inf
41
+ if self.x < min_x or self.x > max_x:
42
+ raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
43
+ return self
44
+
45
+ @model_validator(mode='after')
46
+ def validate_y(self):
47
+ """Validate 'y'"""
48
+ min_y: float
49
+ max_y: float
50
+
51
+ co = self.coordinate_system
52
+ if co == "osgb":
53
+ min_y, max_y = -16703.87, 1199851.44
54
+ if co == "lon_lat":
55
+ min_y, max_y = -90, 90
56
+ if co == "geostationary":
57
+ min_y, max_y = 1393687.2151494026, 5570748.323202133
58
+ if co == "idx":
59
+ min_y, max_y = 0, np.inf
60
+ if self.y < min_y or self.y > max_y:
61
+ raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
62
+ return self
@@ -5,15 +5,14 @@ import logging
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from ocf_datapipes.utils import Location
9
- from ocf_datapipes.utils.geospatial import (
10
- lon_lat_to_geostationary_area_coords,
8
+ from ocf_data_sampler.select.location import Location
9
+ from ocf_data_sampler.select.geospatial import (
11
10
  lon_lat_to_osgb,
12
11
  osgb_to_geostationary_area_coords,
13
12
  osgb_to_lon_lat,
14
13
  spatial_coord_type,
15
14
  )
16
- from ocf_datapipes.utils.utils import searchsorted
15
+
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
@@ -45,9 +44,6 @@ def convert_coords_to_match_xarray(
45
44
  if from_coords == "osgb":
46
45
  x, y = osgb_to_geostationary_area_coords(x, y, da)
47
46
 
48
- elif from_coords == "lon_lat":
49
- x, y = lon_lat_to_geostationary_area_coords(x, y, da)
50
-
51
47
  elif target_coords == "lon_lat":
52
48
  if from_coords == "osgb":
53
49
  x, y = osgb_to_lon_lat(x, y)
@@ -130,13 +126,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
130
126
  f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
131
127
 
132
128
  # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
133
- x_index_at_center = searchsorted(
134
- da[x_dim].values, center_geostationary.x, assume_ascending=True
135
- )
136
-
137
- y_index_at_center = searchsorted(
138
- da[y_dim].values, center_geostationary.y, assume_ascending=True
139
- )
129
+ x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
130
+ y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
140
131
 
141
132
  return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
142
133
 
@@ -27,19 +27,15 @@ from ocf_data_sampler.numpy_batch import (
27
27
  )
28
28
 
29
29
 
30
- from ocf_datapipes.config.model import Configuration
31
- from ocf_datapipes.config.load import load_yaml_configuration
32
- from ocf_datapipes.batch import BatchKey, NumpyBatch
30
+ from ocf_data_sampler.config import Configuration, load_yaml_configuration
31
+ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
32
+ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
33
33
 
34
- from ocf_datapipes.utils.location import Location
35
- from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
34
+ from ocf_data_sampler.select.location import Location
35
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
36
36
 
37
- from ocf_datapipes.utils.consts import (
38
- NWP_MEANS,
39
- NWP_STDS,
40
- )
37
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
41
38
 
42
- from ocf_datapipes.training.common import concat_xr_time_utc, normalize_gsp
43
39
 
44
40
 
45
41
 
@@ -344,7 +340,7 @@ def slice_datasets_by_time(
344
340
  return sliced_datasets_dict
345
341
 
346
342
 
347
- def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch:
343
+ def fill_nans_in_arrays(batch: dict) -> dict:
348
344
  """Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
349
345
 
350
346
  Operation is performed in-place on the batch.
@@ -376,7 +372,7 @@ def process_and_combine_datasets(
376
372
  config: Configuration,
377
373
  t0: pd.Timedelta,
378
374
  location: Location,
379
- ) -> NumpyBatch:
375
+ ) -> dict:
380
376
  """Normalize and convert data to numpy arrays"""
381
377
 
382
378
  numpy_modalities = []
@@ -393,7 +389,7 @@ def process_and_combine_datasets(
393
389
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
394
390
 
395
391
  # Combine the NWPs into NumpyBatch
396
- numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})
392
+ numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
397
393
 
398
394
  if "sat" in dataset_dict:
399
395
  # Satellite is already in the range [0-1] so no need to standardise
@@ -405,8 +401,8 @@ def process_and_combine_datasets(
405
401
  gsp_config = config.input_data.gsp
406
402
 
407
403
  if "gsp" in dataset_dict:
408
- da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
409
- da_gsp = normalize_gsp(da_gsp)
404
+ da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
405
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
410
406
 
411
407
  numpy_modalities.append(
412
408
  convert_gsp_to_numpy_batch(
@@ -429,9 +425,9 @@ def process_and_combine_datasets(
429
425
  # Add coordinate data
430
426
  # TODO: Do we need all of these?
431
427
  numpy_modalities.append({
432
- BatchKey.gsp_id: location.id,
433
- BatchKey.gsp_x_osgb: location.x,
434
- BatchKey.gsp_y_osgb: location.y,
428
+ GSPBatchKey.gsp_id: location.id,
429
+ GSPBatchKey.gsp_x_osgb: location.x,
430
+ GSPBatchKey.gsp_y_osgb: location.y,
435
431
  })
436
432
 
437
433
  # Combine all the modalities and fill NaNs
@@ -539,7 +535,7 @@ class PVNetUKRegionalDataset(Dataset):
539
535
  return len(self.index_pairs)
540
536
 
541
537
 
542
- def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch:
538
+ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
543
539
  """Generate the PVNet sample for given coordinates
544
540
 
545
541
  Args:
@@ -566,7 +562,7 @@ class PVNetUKRegionalDataset(Dataset):
566
562
  return self._get_sample(t0, location)
567
563
 
568
564
 
569
- def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
565
+ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
570
566
  """Generate a sample for the given coordinates.
571
567
 
572
568
  Useful for users to generate samples by GSP ID.
@@ -0,0 +1,88 @@
1
+ Metadata-Version: 2.1
2
+ Name: ocf_data_sampler
3
+ Version: 0.0.22
4
+ Summary: Sample from weather data for renewable energy prediction
5
+ Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
+ Author-email: info@openclimatefix.org
7
+ Maintainer: Open Climate Fix Ltd
8
+ License: MIT License
9
+
10
+ Copyright (c) 2023 Open Climate Fix
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+
30
+ Project-URL: homepage, https://github.com/openclimatefix
31
+ Project-URL: repository, https://github.com/openclimatefix/ocf-data-sampler
32
+ Keywords: weather data,renewable energy prediction,sample weather data
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.8
35
+ Classifier: Operating System :: POSIX :: Linux
36
+ Requires-Python: >=3.8
37
+ Description-Content-Type: text/markdown
38
+ License-File: LICENSE
39
+ Requires-Dist: torch
40
+ Requires-Dist: numpy
41
+ Requires-Dist: pandas
42
+ Requires-Dist: xarray
43
+ Requires-Dist: zarr
44
+ Requires-Dist: dask
45
+ Requires-Dist: ocf-blosc2
46
+ Requires-Dist: pvlib
47
+ Requires-Dist: pydantic
48
+ Requires-Dist: pyproj
49
+ Requires-Dist: pathy
50
+ Requires-Dist: pyaml-env
51
+ Requires-Dist: pyresample
52
+ Provides-Extra: docs
53
+ Requires-Dist: mkdocs>=1.2; extra == "docs"
54
+ Requires-Dist: mkdocs-material>=8.0; extra == "docs"
55
+
56
+ # OCF Data Sampler
57
+ <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
58
+ [![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
59
+ <!-- ALL-CONTRIBUTORS-BADGE:END -->
60
+ [![ease of contribution: easy](https://img.shields.io/badge/ease%20of%20contribution:%20easy-32bd50)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
61
+
62
+ A repo for sampling from weather data for renewable energy prediction
63
+
64
+ ## Contributors ✨
65
+
66
+ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
67
+
68
+ <!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
69
+ <!-- prettier-ignore-start -->
70
+ <!-- markdownlint-disable -->
71
+ <table>
72
+ <tbody>
73
+ <tr>
74
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/dfulu"><img src="https://avatars.githubusercontent.com/u/41546094?v=4?s=100" width="100px;" alt="James Fulton"/><br /><sub><b>James Fulton</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=dfulu" title="Code">💻</a></td>
75
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/AUdaltsova"><img src="https://avatars.githubusercontent.com/u/43303448?v=4?s=100" width="100px;" alt="Alexandra Udaltsova"/><br /><sub><b>Alexandra Udaltsova</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=AUdaltsova" title="Code">💻</a></td>
76
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/Sukh-P"><img src="https://avatars.githubusercontent.com/u/42407101?v=4?s=100" width="100px;" alt="Sukhil Patel"/><br /><sub><b>Sukhil Patel</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Sukh-P" title="Code">💻</a></td>
77
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/peterdudfield"><img src="https://avatars.githubusercontent.com/u/34686298?v=4?s=100" width="100px;" alt="Peter Dudfield"/><br /><sub><b>Peter Dudfield</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=peterdudfield" title="Code">💻</a></td>
78
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/VikramsDataScience"><img src="https://avatars.githubusercontent.com/u/45002417?v=4?s=100" width="100px;" alt="Vikram Pande"/><br /><sub><b>Vikram Pande</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=VikramsDataScience" title="Code">💻</a></td>
79
+ </tr>
80
+ </tbody>
81
+ </table>
82
+
83
+ <!-- markdownlint-restore -->
84
+ <!-- prettier-ignore-end -->
85
+
86
+ <!-- ALL-CONTRIBUTORS-LIST:END -->
87
+
88
+ This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
@@ -0,0 +1,54 @@
1
+ ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
+ ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
3
+ ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
4
+ ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
5
+ ocf_data_sampler/config/model.py,sha256=bvU3BEMtcUh-N17fMVLTYtN-J2GcTM9Qq-CI5AfbE4Q,8128
6
+ ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
7
+ ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
+ ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
9
+ ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
10
+ ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
11
+ ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
12
+ ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
13
+ ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
14
+ ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=vW-p3vCyQ-CofKo555-gE7VDi5hlpjtjTLfHqWF0HEE,1175
16
+ ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
17
+ ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
18
+ ocf_data_sampler/numpy_batch/__init__.py,sha256=mrtqwbGik5Zc9MYP5byfCTBm08wMtS2XnTsypC4fPMo,245
19
+ ocf_data_sampler/numpy_batch/gsp.py,sha256=FQg5infqOZ8QAsRCXq00HPplN8XnsbKoWUeV3N7TGK8,1008
20
+ ocf_data_sampler/numpy_batch/nwp.py,sha256=dt4gXZonrdkyoVmMnwkKyJ_xFnJ4M5xscr-pdNoRsc4,1311
21
+ ocf_data_sampler/numpy_batch/satellite.py,sha256=yWYmq6hEvX4LTJ95K_e0BRJJkBmRurQT2mI2KibmwCI,983
22
+ ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
23
+ ocf_data_sampler/select/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
+ ocf_data_sampler/select/dropout.py,sha256=zDpVLMjGb70RRyYKN-WI2Kp3x9SznstT4cMcZ4dsvJg,1066
25
+ ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
26
+ ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
27
+ ocf_data_sampler/select/geospatial.py,sha256=oHJoKEKubn3v3yKCVeuiPxuGroVA4RyrpNi6ARq5woE,3558
28
+ ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
29
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=hWIJe4_VzuQ2iiiQh7V17AXwTILT5kIkUvzG458J_Gw,11220
30
+ ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
31
+ ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
32
+ ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=5AV_eFtQ842-E7kNeQPLvEvcmV4LDkJKYbV2rvU_wkE,19113
33
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ tests/conftest.py,sha256=O77dmow8mGpGPbZ6Pz7ma7cLaiV1k8mxW1eYg37Avrw,5585
35
+ tests/config/test_config.py,sha256=G_PD_pXib0zdRBPUIn0jjwJ9VyoKaO_TanLN1Mh5Ca4,5055
36
+ tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
37
+ tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
38
+ tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
39
+ tests/numpy_batch/test_gsp.py,sha256=ke4CsFn9ZKRgrciT5-OHS8jmd8nu9gGKKr0T2WzZK6M,592
40
+ tests/numpy_batch/test_nwp.py,sha256=yk77jZfba-dd5ImjOTWNTMNnjJR8SF0kSpZ-513vPTw,1490
41
+ tests/numpy_batch/test_satellite.py,sha256=DIVnVq7JYgoC6Y6xMtLCNUnk3ADzakCi0bZ44QTdPgQ,1230
42
+ tests/numpy_batch/test_sun_position.py,sha256=zw7ErTKARkW8NrpXJ9MeGp-dkNBJsCscxQx0dnZHg2c,2513
43
+ tests/select/test_dropout.py,sha256=kiycl7RxAQYMCZJlokmx6Da5h_oBpSs8Is8pmSW4gOU,2413
44
+ tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
45
+ tests/select/test_find_contiguous_time_periods.py,sha256=G6tJRJd0DMfH9EdfzlKWsmfTbtMwOf3w-2filjJzuIQ,5998
46
+ tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
47
+ tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
48
+ tests/select/test_select_time_slice.py,sha256=rH4h90HdQCoWE7vV7ivMEKhiCStQDEcMBCPamiDuO0k,10147
49
+ tests/torch_datasets/test_pvnet_uk_regional.py,sha256=r1SHtwaXQrOYU3EOH1OEp_Bamo338IMv-9Q_gKsUOa4,2789
50
+ ocf_data_sampler-0.0.22.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
51
+ ocf_data_sampler-0.0.22.dist-info/METADATA,sha256=iB2KIy-7AQjWsOQzQoINItbAlKFWv_3TIRcOJ1-f_Uw,5269
52
+ ocf_data_sampler-0.0.22.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
53
+ ocf_data_sampler-0.0.22.dist-info/top_level.txt,sha256=KaQn5qzkJGJP6hKWqsVAc9t0cMLjVvSTk8-kTrW79SA,23
54
+ ocf_data_sampler-0.0.22.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,152 @@
1
+ import tempfile
2
+
3
+ import pytest
4
+ from pydantic import ValidationError
5
+
6
+ from ocf_data_sampler.config import (
7
+ load_yaml_configuration,
8
+ Configuration,
9
+ save_yaml_configuration
10
+ )
11
+
12
+
13
+ def test_default():
14
+ """Test default pydantic class"""
15
+
16
+ _ = Configuration()
17
+
18
+
19
+ def test_yaml_load_test_config(test_config_filename):
20
+ """
21
+ Test that yaml loading works for 'test_config.yaml'
22
+ and fails for an empty .yaml file
23
+ """
24
+
25
+ # check we get an error if loading a file with no config
26
+ with tempfile.NamedTemporaryFile(suffix=".yaml") as fp:
27
+ filename = fp.name
28
+
29
+ # check that temp file can't be loaded
30
+ with pytest.raises(TypeError):
31
+ _ = load_yaml_configuration(filename)
32
+
33
+ # test can load test_config.yaml
34
+ config = load_yaml_configuration(test_config_filename)
35
+
36
+ assert isinstance(config, Configuration)
37
+
38
+
39
+ def test_yaml_save(test_config_filename):
40
+ """
41
+ Check configuration can be saved to a .yaml file
42
+ """
43
+
44
+ test_config = load_yaml_configuration(test_config_filename)
45
+
46
+ with tempfile.NamedTemporaryFile(suffix=".yaml") as fp:
47
+ filename = fp.name
48
+
49
+ # save default config to file
50
+ save_yaml_configuration(test_config, filename)
51
+
52
+ # check the file can be loaded back
53
+ tmp_config = load_yaml_configuration(filename)
54
+
55
+ # check loaded configuration is the same as the one passed to save
56
+ assert test_config == tmp_config
57
+
58
+
59
+ def test_extra_field():
60
+ """
61
+ Check an extra parameters in config causes error
62
+ """
63
+
64
+ configuration = Configuration()
65
+ configuration_dict = configuration.model_dump()
66
+ configuration_dict["extra_field"] = "extra_value"
67
+ with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
68
+ _ = Configuration(**configuration_dict)
69
+
70
+
71
+ def test_incorrect_forecast_minutes(test_config_filename):
72
+ """
73
+ Check a forecast length not divisible by time resolution causes error
74
+ """
75
+
76
+ configuration = load_yaml_configuration(test_config_filename)
77
+
78
+ configuration.input_data.nwp['ukv'].forecast_minutes = 1111
79
+ with pytest.raises(Exception, match="duration must be divisible by time resolution"):
80
+ _ = Configuration(**configuration.model_dump())
81
+
82
+
83
+ def test_incorrect_history_minutes(test_config_filename):
84
+ """
85
+ Check a history length not divisible by time resolution causes error
86
+ """
87
+
88
+ configuration = load_yaml_configuration(test_config_filename)
89
+
90
+ configuration.input_data.nwp['ukv'].history_minutes = 1111
91
+ with pytest.raises(Exception, match="duration must be divisible by time resolution"):
92
+ _ = Configuration(**configuration.model_dump())
93
+
94
+
95
+ def test_incorrect_nwp_provider(test_config_filename):
96
+ """
97
+ Check an unexpected nwp provider causes error
98
+ """
99
+
100
+ configuration = load_yaml_configuration(test_config_filename)
101
+
102
+ configuration.input_data.nwp['ukv'].nwp_provider = "unexpected_provider"
103
+ with pytest.raises(Exception, match="NWP provider"):
104
+ _ = Configuration(**configuration.model_dump())
105
+
106
+ def test_incorrect_dropout(test_config_filename):
107
+ """
108
+ Check a dropout timedelta over 0 causes error and 0 doesn't
109
+ """
110
+
111
+ configuration = load_yaml_configuration(test_config_filename)
112
+
113
+ # check a positive number is not allowed
114
+ configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [120]
115
+ with pytest.raises(Exception, match="Dropout timedeltas must be negative"):
116
+ _ = Configuration(**configuration.model_dump())
117
+
118
+ # check 0 is allowed
119
+ configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
120
+ _ = Configuration(**configuration.model_dump())
121
+
122
+ def test_incorrect_dropout_fraction(test_config_filename):
123
+ """
124
+ Check dropout fraction outside of range causes error
125
+ """
126
+
127
+ configuration = load_yaml_configuration(test_config_filename)
128
+
129
+ configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
130
+ with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
131
+ _ = Configuration(**configuration.model_dump())
132
+
133
+ configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
134
+ with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
135
+ _ = Configuration(**configuration.model_dump())
136
+
137
+
138
+ def test_inconsistent_dropout_use(test_config_filename):
139
+ """
140
+ Check dropout fraction outside of range causes error
141
+ """
142
+
143
+ configuration = load_yaml_configuration(test_config_filename)
144
+ configuration.input_data.satellite.dropout_fraction= 1.0
145
+ configuration.input_data.satellite.dropout_timedeltas_minutes = None
146
+
147
+ with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"):
148
+ _ = Configuration(**configuration.model_dump())
149
+ configuration.input_data.satellite.dropout_fraction= 0.0
150
+ configuration.input_data.satellite.dropout_timedeltas_minutes = [-120, -60]
151
+ with pytest.raises(ValueError, match="To use dropout timedeltas dropout fraction should be > 0"):
152
+ _ = Configuration(**configuration.model_dump())
tests/conftest.py CHANGED
@@ -6,11 +6,16 @@ import pytest
6
6
  import xarray as xr
7
7
  import tempfile
8
8
 
9
+ _top_test_directory = os.path.dirname(os.path.realpath(__file__))
10
+
11
+ @pytest.fixture()
12
+ def test_config_filename():
13
+ return f"{_top_test_directory}/test_data/configs/test_config.yaml"
9
14
 
10
15
 
11
16
  @pytest.fixture(scope="session")
12
17
  def config_filename():
13
- return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/pvnet_test_config.yaml"
18
+ return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/configs/pvnet_test_config.yaml"
14
19
 
15
20
 
16
21
  @pytest.fixture(scope="session")
@@ -0,0 +1,15 @@
1
+ from ocf_data_sampler.load.gsp import open_gsp
2
+ import xarray as xr
3
+
4
+
5
+ def test_open_gsp(uk_gsp_zarr_path):
6
+ da = open_gsp(uk_gsp_zarr_path)
7
+
8
+ assert isinstance(da, xr.DataArray)
9
+ assert da.dims == ("time_utc", "gsp_id")
10
+
11
+ assert "nominal_capacity_mwp" in da.coords
12
+ assert "effective_capacity_mwp" in da.coords
13
+ assert "x_osgb" in da.coords
14
+ assert "y_osgb" in da.coords
15
+ assert da.shape == (49, 318)