ocf-data-sampler 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (32) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +249 -0
  4. ocf_data_sampler/config/save.py +36 -0
  5. ocf_data_sampler/select/dropout.py +2 -2
  6. ocf_data_sampler/select/geospatial.py +118 -0
  7. ocf_data_sampler/select/location.py +62 -0
  8. ocf_data_sampler/select/select_spatial_slice.py +5 -14
  9. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +10 -5
  10. ocf_data_sampler-0.0.21.dist-info/METADATA +83 -0
  11. ocf_data_sampler-0.0.21.dist-info/RECORD +53 -0
  12. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.21.dist-info}/WHEEL +1 -1
  13. tests/config/test_config.py +152 -0
  14. tests/conftest.py +6 -1
  15. tests/load/test_load_gsp.py +15 -0
  16. tests/load/test_load_nwp.py +21 -0
  17. tests/load/test_load_satellite.py +17 -0
  18. tests/numpy_batch/test_gsp.py +23 -0
  19. tests/numpy_batch/test_nwp.py +54 -0
  20. tests/numpy_batch/test_satellite.py +42 -0
  21. tests/numpy_batch/test_sun_position.py +81 -0
  22. tests/select/test_dropout.py +75 -0
  23. tests/select/test_fill_time_periods.py +28 -0
  24. tests/select/test_find_contiguous_time_periods.py +202 -0
  25. tests/select/test_location.py +67 -0
  26. tests/select/test_select_spatial_slice.py +154 -0
  27. tests/select/test_select_time_slice.py +284 -0
  28. tests/torch_datasets/test_pvnet_uk_regional.py +72 -0
  29. ocf_data_sampler-0.0.18.dist-info/METADATA +0 -22
  30. ocf_data_sampler-0.0.18.dist-info/RECORD +0 -32
  31. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.21.dist-info}/LICENSE +0 -0
  32. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ """Configuration model"""
2
+
3
+ from ocf_data_sampler.config.model import Configuration
4
+ from ocf_data_sampler.config.save import save_yaml_configuration
5
+ from ocf_data_sampler.config.load import load_yaml_configuration
@@ -0,0 +1,33 @@
1
+ """Loading configuration functions.
2
+
3
+ Example:
4
+
5
+ from ocf_data_sampler.config import load_yaml_configuration
6
+ configuration = load_yaml_configuration(filename)
7
+ """
8
+
9
+ import fsspec
10
+ from pathy import Pathy
11
+ from pyaml_env import parse_config
12
+
13
+ from ocf_data_sampler.config import Configuration
14
+
15
+
16
+ def load_yaml_configuration(filename: str | Pathy) -> Configuration:
17
+ """
18
+ Load a yaml file which has a configuration in it
19
+
20
+ Args:
21
+ filename: the file name that you want to load. Will load from local, AWS, or GCP
22
+ depending on the protocol suffix (e.g. 's3://bucket/config.yaml').
23
+
24
+ Returns:pydantic class
25
+
26
+ """
27
+ # load the file to a dictionary
28
+ with fsspec.open(filename, mode="r") as stream:
29
+ configuration = parse_config(data=stream)
30
+ # this means we can load ENVs in the yaml file
31
+ # turn into pydantic class
32
+ configuration = Configuration(**configuration)
33
+ return configuration
@@ -0,0 +1,249 @@
1
+ """Configuration model for the dataset.
2
+
3
+ All paths must include the protocol prefix. For local files,
4
+ it's sufficient to just start with a '/'. For aws, start with 's3://',
5
+ for gcp start with 'gs://'.
6
+
7
+ Example:
8
+
9
+ from ocf_data_sampler.config import Configuration
10
+ config = Configuration(**config_dict)
11
+ """
12
+
13
+ import logging
14
+ from typing import Dict, List, Optional
15
+ from typing_extensions import Self
16
+
17
+ from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
18
+ from ocf_datapipes.utils.consts import NWP_PROVIDERS
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ providers = ["pvoutput.org", "solar_sheffield_passiv"]
23
+
24
+
25
+ class Base(BaseModel):
26
+ """Pydantic Base model where no extras can be added"""
27
+
28
+ class Config:
29
+ """config class"""
30
+
31
+ extra = "forbid" # forbid use of extra kwargs
32
+
33
+
34
+ class General(Base):
35
+ """General pydantic model"""
36
+
37
+ name: str = Field("example", description="The name of this configuration file.")
38
+ description: str = Field(
39
+ "example configuration", description="Description of this configuration file"
40
+ )
41
+
42
+
43
+ class DataSourceMixin(Base):
44
+ """Mixin class, to add forecast and history minutes"""
45
+
46
+ forecast_minutes: int = Field(
47
+ ...,
48
+ ge=0,
49
+ description="how many minutes to forecast in the future. ",
50
+ )
51
+ history_minutes: int = Field(
52
+ ...,
53
+ ge=0,
54
+ description="how many historic minutes to use. ",
55
+ )
56
+
57
+
58
+ # noinspection PyMethodParameters
59
+ class DropoutMixin(Base):
60
+ """Mixin class, to add dropout minutes"""
61
+
62
+ dropout_timedeltas_minutes: Optional[List[int]] = Field(
63
+ default=None,
64
+ description="List of possible minutes before t0 where data availability may start. Must be "
65
+ "negative or zero.",
66
+ )
67
+
68
+ dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")
69
+
70
+ @field_validator("dropout_timedeltas_minutes")
71
+ def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
72
+ """Validate 'dropout_timedeltas_minutes'"""
73
+ if v is not None:
74
+ for m in v:
75
+ assert m <= 0, "Dropout timedeltas must be negative"
76
+ return v
77
+
78
+ @field_validator("dropout_fraction")
79
+ def dropout_fraction_valid(cls, v: float) -> float:
80
+ """Validate 'dropout_fraction'"""
81
+ assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
82
+ return v
83
+
84
+ @model_validator(mode="after")
85
+ def dropout_instructions_consistent(self) -> Self:
86
+ if self.dropout_fraction == 0:
87
+ if self.dropout_timedeltas_minutes is not None:
88
+ raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
89
+ else:
90
+ if self.dropout_timedeltas_minutes is None:
91
+ raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
92
+ return self
93
+
94
+
95
+ # noinspection PyMethodParameters
96
+ class TimeResolutionMixin(Base):
97
+ """Time resolution mix in"""
98
+
99
+ time_resolution_minutes: int = Field(
100
+ ...,
101
+ description="The temporal resolution of the data in minutes",
102
+ )
103
+
104
+
105
+ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
106
+ """Satellite configuration model"""
107
+
108
+ # Todo: remove 'satellite' from names
109
+ satellite_zarr_path: str | tuple[str] | list[str] = Field(
110
+ ...,
111
+ description="The path or list of paths which hold the satellite zarr",
112
+ )
113
+ satellite_channels: list[str] = Field(
114
+ ..., description="the satellite channels that are used"
115
+ )
116
+ satellite_image_size_pixels_height: int = Field(
117
+ ...,
118
+ description="The number of pixels of the height of the region of interest"
119
+ " for non-HRV satellite channels.",
120
+ )
121
+
122
+ satellite_image_size_pixels_width: int = Field(
123
+ ...,
124
+ description="The number of pixels of the width of the region "
125
+ "of interest for non-HRV satellite channels.",
126
+ )
127
+
128
+ live_delay_minutes: int = Field(
129
+ ..., description="The expected delay in minutes of the satellite data"
130
+ )
131
+
132
+
133
+ # noinspection PyMethodParameters
134
+ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
135
+ """NWP configuration model"""
136
+
137
+ nwp_zarr_path: str | tuple[str] | list[str] = Field(
138
+ ...,
139
+ description="The path which holds the NWP zarr",
140
+ )
141
+ nwp_channels: list[str] = Field(
142
+ ..., description="the channels used in the nwp data"
143
+ )
144
+ nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
145
+ nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
146
+ nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")
147
+
148
+ nwp_provider: str = Field(..., description="The provider of the NWP data")
149
+
150
+ max_staleness_minutes: Optional[int] = Field(
151
+ None,
152
+ description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
153
+ " used to construct an example. If set to None, then the max staleness is set according to"
154
+ " the maximum forecast horizon of the NWP and the requested forecast length.",
155
+ )
156
+
157
+
158
+ @field_validator("nwp_provider")
159
+ def validate_nwp_provider(cls, v: str) -> str:
160
+ """Validate 'nwp_provider'"""
161
+ if v.lower() not in NWP_PROVIDERS:
162
+ message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
163
+ logger.warning(message)
164
+ raise Exception(message)
165
+ return v
166
+
167
+ # Todo: put into time mixin when moving intervals there
168
+ @field_validator("forecast_minutes")
169
+ def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
170
+ if v % info.data["time_resolution_minutes"] != 0:
171
+ message = "Forecast duration must be divisible by time resolution"
172
+ logger.error(message)
173
+ raise Exception(message)
174
+ return v
175
+
176
+ @field_validator("history_minutes")
177
+ def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
178
+ if v % info.data["time_resolution_minutes"] != 0:
179
+ message = "History duration must be divisible by time resolution"
180
+ logger.error(message)
181
+ raise Exception(message)
182
+ return v
183
+
184
+
185
+ class MultiNWP(RootModel):
186
+ """Configuration for multiple NWPs"""
187
+
188
+ root: Dict[str, NWP]
189
+
190
+ def __getattr__(self, item):
191
+ return self.root[item]
192
+
193
+ def __getitem__(self, item):
194
+ return self.root[item]
195
+
196
+ def __len__(self):
197
+ return len(self.root)
198
+
199
+ def __iter__(self):
200
+ return iter(self.root)
201
+
202
+ def keys(self):
203
+ """Returns dictionary-like keys"""
204
+ return self.root.keys()
205
+
206
+ def items(self):
207
+ """Returns dictionary-like items"""
208
+ return self.root.items()
209
+
210
+
211
+ # noinspection PyMethodParameters
212
+ class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
213
+ """GSP configuration model"""
214
+
215
+ gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr")
216
+
217
+ @field_validator("forecast_minutes")
218
+ def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
219
+ if v % info.data["time_resolution_minutes"] != 0:
220
+ message = "Forecast duration must be divisible by time resolution"
221
+ logger.error(message)
222
+ raise Exception(message)
223
+ return v
224
+
225
+ @field_validator("history_minutes")
226
+ def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
227
+ if v % info.data["time_resolution_minutes"] != 0:
228
+ message = "History duration must be divisible by time resolution"
229
+ logger.error(message)
230
+ raise Exception(message)
231
+ return v
232
+
233
+
234
+ # noinspection PyPep8Naming
235
+ class InputData(Base):
236
+ """
237
+ Input data model.
238
+ """
239
+
240
+ satellite: Optional[Satellite] = None
241
+ nwp: Optional[MultiNWP] = None
242
+ gsp: Optional[GSP] = None
243
+
244
+
245
+ class Configuration(Base):
246
+ """Configuration model for the dataset"""
247
+
248
+ general: General = General()
249
+ input_data: InputData = InputData()
@@ -0,0 +1,36 @@
1
+ """Save functions for the configuration model.
2
+
3
+ Example:
4
+
5
+ from ocf_data_sampler.config import save_yaml_configuration
6
+ configuration = save_yaml_configuration(config, filename)
7
+ """
8
+
9
+ import json
10
+
11
+ import fsspec
12
+ import yaml
13
+ from pathy import Pathy
14
+
15
+ from ocf_data_sampler.config import Configuration
16
+
17
+
18
+ def save_yaml_configuration(
19
+ configuration: Configuration, filename: str | Pathy
20
+ ):
21
+ """
22
+ Save a local yaml file which has the configuration in it.
23
+
24
+ If `filename` is None then saves to configuration.output_data.filepath / configuration.yaml.
25
+
26
+ Will save to GCP, AWS, or local, depending on the protocol suffix of filepath.
27
+ """
28
+ # make a dictionary from the configuration,
29
+ # Note that we make the object json'able first, so that it can be saved to a yaml file
30
+ d = json.loads(configuration.model_dump_json())
31
+ if filename is None:
32
+ filename = Pathy(configuration.output_data.filepath) / "configuration.yaml"
33
+
34
+ # save to a yaml file
35
+ with fsspec.open(filename, "w") as yaml_file:
36
+ yaml.safe_dump(d, yaml_file, default_flow_style=False)
@@ -12,7 +12,7 @@ def draw_dropout_time(
12
12
  if dropout_timedeltas is not None:
13
13
  assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas"
14
14
  assert all(
15
- [t < pd.Timedelta("0min") for t in dropout_timedeltas]
15
+ [t <= pd.Timedelta("0min") for t in dropout_timedeltas]
16
16
  ), "dropout timedeltas must be negative"
17
17
  assert 0 <= dropout_frac <= 1
18
18
 
@@ -35,4 +35,4 @@ def apply_dropout_time(
35
35
  return ds
36
36
  else:
37
37
  # This replaces the times after the dropout with NaNs
38
- return ds.where(ds.time_utc <= dropout_time)
38
+ return ds.where(ds.time_utc <= dropout_time)
@@ -0,0 +1,118 @@
1
+ """Geospatial functions"""
2
+
3
+ from numbers import Number
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import pyproj
8
+ import xarray as xr
9
+
10
+ # OSGB is also called "OSGB 1936 / British National Grid -- United
11
+ # Kingdom Ordnance Survey". OSGB is used in many UK electricity
12
+ # system maps, and is used by the UK Met Office UKV model. OSGB is a
13
+ # Transverse Mercator projection, using 'easting' and 'northing'
14
+ # coordinates which are in meters. See https://epsg.io/27700
15
+ OSGB36 = 27700
16
+
17
+ # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
18
+ # latitude and longitude.
19
+ WGS84 = 4326
20
+
21
+
22
+ _osgb_to_lon_lat = pyproj.Transformer.from_crs(
23
+ crs_from=OSGB36, crs_to=WGS84, always_xy=True
24
+ ).transform
25
+ _lon_lat_to_osgb = pyproj.Transformer.from_crs(
26
+ crs_from=WGS84, crs_to=OSGB36, always_xy=True
27
+ ).transform
28
+
29
+
30
+ def osgb_to_lon_lat(
31
+ x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
32
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
33
+ """Change OSGB coordinates to lon, lat.
34
+
35
+ Args:
36
+ x: osgb east-west
37
+ y: osgb north-south
38
+ Return: 2-tuple of longitude (east-west), latitude (north-south)
39
+ """
40
+ return _osgb_to_lon_lat(xx=x, yy=y)
41
+
42
+
43
+ def lon_lat_to_osgb(
44
+ x: Union[Number, np.ndarray],
45
+ y: Union[Number, np.ndarray],
46
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
47
+ """Change lon-lat coordinates to OSGB.
48
+
49
+ Args:
50
+ x: longitude east-west
51
+ y: latitude north-south
52
+
53
+ Return: 2-tuple of OSGB x, y
54
+ """
55
+ return _lon_lat_to_osgb(xx=x, yy=y)
56
+
57
+
58
+ def osgb_to_geostationary_area_coords(
59
+ x: Union[Number, np.ndarray],
60
+ y: Union[Number, np.ndarray],
61
+ xr_data: xr.DataArray,
62
+ ) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
63
+ """Loads geostationary area and transformation from OSGB to geostationary coords
64
+
65
+ Args:
66
+ x: osgb east-west
67
+ y: osgb north-south
68
+ xr_data: xarray object with geostationary area
69
+
70
+ Returns:
71
+ Geostationary coords: x, y
72
+ """
73
+ # Only load these if using geostationary projection
74
+ import pyresample
75
+
76
+ area_definition_yaml = xr_data.attrs["area"]
77
+
78
+ geostationary_area_definition = pyresample.area_config.load_area_from_string(
79
+ area_definition_yaml
80
+ )
81
+ geostationary_crs = geostationary_area_definition.crs
82
+ osgb_to_geostationary = pyproj.Transformer.from_crs(
83
+ crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
84
+ ).transform
85
+ return osgb_to_geostationary(xx=x, yy=y)
86
+
87
+
88
+ def _coord_priority(available_coords):
89
+ if "longitude" in available_coords:
90
+ return "lon_lat", "longitude", "latitude"
91
+ elif "x_geostationary" in available_coords:
92
+ return "geostationary", "x_geostationary", "y_geostationary"
93
+ elif "x_osgb" in available_coords:
94
+ return "osgb", "x_osgb", "y_osgb"
95
+ else:
96
+ raise ValueError(f"Unrecognized coordinate system: {available_coords}")
97
+
98
+
99
+ def spatial_coord_type(ds: xr.DataArray):
100
+ """Searches the data array to determine the kind of spatial coordinates present.
101
+
102
+ This search has a preference for the dimension coordinates of the xarray object.
103
+
104
+ Args:
105
+ ds: Dataset with spatial coords
106
+
107
+ Returns:
108
+ str: The kind of the coordinate system
109
+ x_coord: Name of the x-coordinate
110
+ y_coord: Name of the y-coordinate
111
+ """
112
+ if isinstance(ds, xr.DataArray):
113
+ # Search dimension coords of dataarray
114
+ coords = _coord_priority(ds.xindexes)
115
+ else:
116
+ raise ValueError(f"Unrecognized input type: {type(ds)}")
117
+
118
+ return coords
@@ -0,0 +1,62 @@
1
+ """location"""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel, Field, model_validator
7
+
8
+
9
+ allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
10
+
11
+ class Location(BaseModel):
12
+ """Represent a spatial location."""
13
+
14
+ coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
15
+ x: float
16
+ y: float
17
+ id: Optional[int] = Field(None)
18
+
19
+ @model_validator(mode='after')
20
+ def validate_coordinate_system(self):
21
+ """Validate 'coordinate_system'"""
22
+ if self.coordinate_system not in allowed_coordinate_systems:
23
+ raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
24
+ return self
25
+
26
+ @model_validator(mode='after')
27
+ def validate_x(self):
28
+ """Validate 'x'"""
29
+ min_x: float
30
+ max_x: float
31
+
32
+ co = self.coordinate_system
33
+ if co == "osgb":
34
+ min_x, max_x = -103976.3, 652897.98
35
+ if co == "lon_lat":
36
+ min_x, max_x = -180, 180
37
+ if co == "geostationary":
38
+ min_x, max_x = -5568748.275756836, 5567248.074173927
39
+ if co == "idx":
40
+ min_x, max_x = 0, np.inf
41
+ if self.x < min_x or self.x > max_x:
42
+ raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
43
+ return self
44
+
45
+ @model_validator(mode='after')
46
+ def validate_y(self):
47
+ """Validate 'y'"""
48
+ min_y: float
49
+ max_y: float
50
+
51
+ co = self.coordinate_system
52
+ if co == "osgb":
53
+ min_y, max_y = -16703.87, 1199851.44
54
+ if co == "lon_lat":
55
+ min_y, max_y = -90, 90
56
+ if co == "geostationary":
57
+ min_y, max_y = 1393687.2151494026, 5570748.323202133
58
+ if co == "idx":
59
+ min_y, max_y = 0, np.inf
60
+ if self.y < min_y or self.y > max_y:
61
+ raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
62
+ return self
@@ -5,15 +5,14 @@ import logging
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from ocf_datapipes.utils import Location
9
- from ocf_datapipes.utils.geospatial import (
10
- lon_lat_to_geostationary_area_coords,
8
+ from ocf_data_sampler.select.location import Location
9
+ from ocf_data_sampler.select.geospatial import (
11
10
  lon_lat_to_osgb,
12
11
  osgb_to_geostationary_area_coords,
13
12
  osgb_to_lon_lat,
14
13
  spatial_coord_type,
15
14
  )
16
- from ocf_datapipes.utils.utils import searchsorted
15
+
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
@@ -45,9 +44,6 @@ def convert_coords_to_match_xarray(
45
44
  if from_coords == "osgb":
46
45
  x, y = osgb_to_geostationary_area_coords(x, y, da)
47
46
 
48
- elif from_coords == "lon_lat":
49
- x, y = lon_lat_to_geostationary_area_coords(x, y, da)
50
-
51
47
  elif target_coords == "lon_lat":
52
48
  if from_coords == "osgb":
53
49
  x, y = osgb_to_lon_lat(x, y)
@@ -130,13 +126,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
130
126
  f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
131
127
 
132
128
  # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
133
- x_index_at_center = searchsorted(
134
- da[x_dim].values, center_geostationary.x, assume_ascending=True
135
- )
136
-
137
- y_index_at_center = searchsorted(
138
- da[y_dim].values, center_geostationary.y, assume_ascending=True
139
- )
129
+ x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
130
+ y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
140
131
 
141
132
  return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
142
133
 
@@ -27,8 +27,7 @@ from ocf_data_sampler.numpy_batch import (
27
27
  )
28
28
 
29
29
 
30
- from ocf_datapipes.config.model import Configuration
31
- from ocf_datapipes.config.load import load_yaml_configuration
30
+ from ocf_data_sampler.config import Configuration, load_yaml_configuration
32
31
  from ocf_datapipes.batch import BatchKey, NumpyBatch
33
32
 
34
33
  from ocf_datapipes.utils.location import Location
@@ -451,8 +450,12 @@ def compute(xarray_dict: dict) -> dict:
451
450
  return xarray_dict
452
451
 
453
452
 
454
- def get_gsp_locations() -> list[Location]:
453
+ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
455
454
  """Get list of locations of all GSPs"""
455
+
456
+ if gsp_ids is None:
457
+ gsp_ids = [i for i in range(1, 318)]
458
+
456
459
  locations = []
457
460
 
458
461
  # Load UK GSP locations
@@ -461,7 +464,7 @@ def get_gsp_locations() -> list[Location]:
461
464
  index_col="gsp_id",
462
465
  )
463
466
 
464
- for gsp_id in np.arange(1, 318):
467
+ for gsp_id in gsp_ids:
465
468
  locations.append(
466
469
  Location(
467
470
  coordinate_system = "osgb",
@@ -480,6 +483,7 @@ class PVNetUKRegionalDataset(Dataset):
480
483
  config_filename: str,
481
484
  start_time: str | None = None,
482
485
  end_time: str| None = None,
486
+ gsp_ids: list[int] | None = None,
483
487
  ):
484
488
  """A torch Dataset for creating PVNet UK GSP samples
485
489
 
@@ -487,6 +491,7 @@ class PVNetUKRegionalDataset(Dataset):
487
491
  config_filename: Path to the configuration file
488
492
  start_time: Limit the init-times to be after this
489
493
  end_time: Limit the init-times to be before this
494
+ gsp_ids: List of GSP IDs to create samples for. Defaults to all
490
495
  """
491
496
 
492
497
  config = load_yaml_configuration(config_filename)
@@ -504,7 +509,7 @@ class PVNetUKRegionalDataset(Dataset):
504
509
  valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
505
510
 
506
511
  # Construct list of locations to sample from
507
- locations = get_gsp_locations()
512
+ locations = get_gsp_locations(gsp_ids)
508
513
 
509
514
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
510
515
  location_lookup = {loc.id: loc for loc in locations}