ocf-data-sampler 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -14,7 +14,7 @@ import logging
14
14
  from typing import Dict, List, Optional
15
15
  from typing_extensions import Self
16
16
 
17
- from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
17
+ from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
18
18
  from ocf_data_sampler.constants import NWP_PROVIDERS
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -34,27 +34,12 @@ class Base(BaseModel):
34
34
  class General(Base):
35
35
  """General pydantic model"""
36
36
 
37
- name: str = Field("example", description="The name of this configuration file.")
37
+ name: str = Field("example", description="The name of this configuration file")
38
38
  description: str = Field(
39
39
  "example configuration", description="Description of this configuration file"
40
40
  )
41
41
 
42
42
 
43
- class DataSourceMixin(Base):
44
- """Mixin class, to add forecast and history minutes"""
45
-
46
- forecast_minutes: int = Field(
47
- ...,
48
- ge=0,
49
- description="how many minutes to forecast in the future. ",
50
- )
51
- history_minutes: int = Field(
52
- ...,
53
- ge=0,
54
- description="how many historic minutes to use. ",
55
- )
56
-
57
-
58
43
  # noinspection PyMethodParameters
59
44
  class DropoutMixin(Base):
60
45
  """Mixin class, to add dropout minutes"""
@@ -65,7 +50,12 @@ class DropoutMixin(Base):
65
50
  "negative or zero.",
66
51
  )
67
52
 
68
- dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")
53
+ dropout_fraction: float = Field(
54
+ default=0,
55
+ description="Chance of dropout being applied to each sample",
56
+ ge=0,
57
+ le=1,
58
+ )
69
59
 
70
60
  @field_validator("dropout_timedeltas_minutes")
71
61
  def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
@@ -75,12 +65,6 @@ class DropoutMixin(Base):
75
65
  assert m <= 0, "Dropout timedeltas must be negative"
76
66
  return v
77
67
 
78
- @field_validator("dropout_fraction")
79
- def dropout_fraction_valid(cls, v: float) -> float:
80
- """Validate 'dropout_fraction'"""
81
- assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
82
- return v
83
-
84
68
  @model_validator(mode="after")
85
69
  def dropout_instructions_consistent(self) -> Self:
86
70
  if self.dropout_fraction == 0:
@@ -93,69 +77,67 @@ class DropoutMixin(Base):
93
77
 
94
78
 
95
79
  # noinspection PyMethodParameters
96
- class TimeResolutionMixin(Base):
80
+ class TimeWindowMixin(Base):
97
81
  """Time resolution mix in"""
98
82
 
99
83
  time_resolution_minutes: int = Field(
100
84
  ...,
85
+ gt=0,
101
86
  description="The temporal resolution of the data in minutes",
102
87
  )
103
88
 
104
-
105
- class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
106
- """Site configuration model"""
107
-
108
- file_path: str = Field(
89
+ forecast_minutes: int = Field(
109
90
  ...,
110
- description="The NetCDF files holding the power timeseries.",
91
+ ge=0,
92
+ description="how many minutes to forecast in the future",
111
93
  )
112
- metadata_file_path: str = Field(
94
+ history_minutes: int = Field(
113
95
  ...,
114
- description="The CSV files describing power system",
96
+ ge=0,
97
+ description="how many historic minutes to use",
115
98
  )
116
99
 
117
100
  @field_validator("forecast_minutes")
118
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
119
- """Check forecast length requested will give stable number of timesteps"""
120
- if v % info.data["time_resolution_minutes"] != 0:
101
+ def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
102
+ if v % values.data["time_resolution_minutes"] != 0:
121
103
  message = "Forecast duration must be divisible by time resolution"
122
104
  logger.error(message)
123
105
  raise Exception(message)
124
106
  return v
125
107
 
126
108
  @field_validator("history_minutes")
127
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
128
- """Check history length requested will give stable number of timesteps"""
129
- if v % info.data["time_resolution_minutes"] != 0:
109
+ def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
110
+ if v % values.data["time_resolution_minutes"] != 0:
130
111
  message = "History duration must be divisible by time resolution"
131
112
  logger.error(message)
132
113
  raise Exception(message)
133
114
  return v
134
115
 
135
- # TODO validate the netcdf for sites
136
- # TODO validate the csv for metadata
137
116
 
138
- class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
139
- """Satellite configuration model"""
117
+ class SpatialWindowMixin(Base):
118
+ """Mixin class, to add path and image size"""
140
119
 
141
- # Todo: remove 'satellite' from names
142
- satellite_zarr_path: str | tuple[str] | list[str] = Field(
120
+ image_size_pixels_height: int = Field(
143
121
  ...,
144
- description="The path or list of paths which hold the satellite zarr",
122
+ description="The number of pixels of the height of the region of interest",
145
123
  )
146
- satellite_channels: list[str] = Field(
147
- ..., description="the satellite channels that are used"
148
- )
149
- satellite_image_size_pixels_height: int = Field(
124
+
125
+ image_size_pixels_width: int = Field(
150
126
  ...,
151
- description="The number of pixels of the height of the region of interest"
152
- " for non-HRV satellite channels.",
127
+ description="The number of pixels of the width of the region of interest",
153
128
  )
154
129
 
155
- satellite_image_size_pixels_width: int = Field(
130
+
131
+ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
132
+ """Satellite configuration model"""
133
+
134
+ zarr_path: str | tuple[str] | list[str] = Field(
156
135
  ...,
157
- description="The number of pixels of the width of the region "
158
- "of interest for non-HRV satellite channels.",
136
+ description="The path or list of paths which hold the data zarr",
137
+ )
138
+
139
+ channels: list[str] = Field(
140
+ ..., description="the satellite channels that are used"
159
141
  )
160
142
 
161
143
  live_delay_minutes: int = Field(
@@ -164,21 +146,21 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
164
146
 
165
147
 
166
148
  # noinspection PyMethodParameters
167
- class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
149
+ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
168
150
  """NWP configuration model"""
169
-
170
- nwp_zarr_path: str | tuple[str] | list[str] = Field(
151
+
152
+ zarr_path: str | tuple[str] | list[str] = Field(
171
153
  ...,
172
- description="The path which holds the NWP zarr",
154
+ description="The path or list of paths which hold the data zarr",
173
155
  )
174
- nwp_channels: list[str] = Field(
156
+
157
+ channels: list[str] = Field(
175
158
  ..., description="the channels used in the nwp data"
176
159
  )
177
- nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
178
- nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
179
- nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")
180
160
 
181
- nwp_provider: str = Field(..., description="The provider of the NWP data")
161
+ provider: str = Field(..., description="The provider of the NWP data")
162
+
163
+ accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
182
164
 
183
165
  max_staleness_minutes: Optional[int] = Field(
184
166
  None,
@@ -187,33 +169,15 @@ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
187
169
  " the maximum forecast horizon of the NWP and the requested forecast length.",
188
170
  )
189
171
 
190
-
191
- @field_validator("nwp_provider")
192
- def validate_nwp_provider(cls, v: str) -> str:
193
- """Validate 'nwp_provider'"""
172
+ @field_validator("provider")
173
+ def validate_provider(cls, v: str) -> str:
174
+ """Validate 'provider'"""
194
175
  if v.lower() not in NWP_PROVIDERS:
195
176
  message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
196
177
  logger.warning(message)
197
178
  raise Exception(message)
198
179
  return v
199
180
 
200
- # Todo: put into time mixin when moving intervals there
201
- @field_validator("forecast_minutes")
202
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
203
- if v % info.data["time_resolution_minutes"] != 0:
204
- message = "Forecast duration must be divisible by time resolution"
205
- logger.error(message)
206
- raise Exception(message)
207
- return v
208
-
209
- @field_validator("history_minutes")
210
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
211
- if v % info.data["time_resolution_minutes"] != 0:
212
- message = "History duration must be divisible by time resolution"
213
- logger.error(message)
214
- raise Exception(message)
215
- return v
216
-
217
181
 
218
182
  class MultiNWP(RootModel):
219
183
  """Configuration for multiple NWPs"""
@@ -241,27 +205,26 @@ class MultiNWP(RootModel):
241
205
  return self.root.items()
242
206
 
243
207
 
244
- # noinspection PyMethodParameters
245
- class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
208
+ class GSP(TimeWindowMixin, DropoutMixin):
246
209
  """GSP configuration model"""
247
210
 
248
- gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr")
211
+ zarr_path: str = Field(..., description="The path which holds the GSP zarr")
249
212
 
250
- @field_validator("forecast_minutes")
251
- def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
252
- if v % info.data["time_resolution_minutes"] != 0:
253
- message = "Forecast duration must be divisible by time resolution"
254
- logger.error(message)
255
- raise Exception(message)
256
- return v
257
213
 
258
- @field_validator("history_minutes")
259
- def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
260
- if v % info.data["time_resolution_minutes"] != 0:
261
- message = "History duration must be divisible by time resolution"
262
- logger.error(message)
263
- raise Exception(message)
264
- return v
214
+ class Site(TimeWindowMixin, DropoutMixin):
215
+ """Site configuration model"""
216
+
217
+ file_path: str = Field(
218
+ ...,
219
+ description="The NetCDF files holding the power timeseries.",
220
+ )
221
+ metadata_file_path: str = Field(
222
+ ...,
223
+ description="The CSV files describing power system",
224
+ )
225
+
226
+ # TODO validate the netcdf for sites
227
+ # TODO validate the csv for metadata
265
228
 
266
229
 
267
230
  # noinspection PyPep8Naming
@@ -280,4 +243,4 @@ class Configuration(Base):
280
243
  """Configuration model for the dataset"""
281
244
 
282
245
  general: General = General()
283
- input_data: InputData = InputData()
246
+ input_data: InputData = InputData()
@@ -20,8 +20,8 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
20
20
  datasets_dict = {}
21
21
 
22
22
  # Load GSP data unless the path is None
23
- if in_config.gsp and in_config.gsp.gsp_zarr_path:
24
- da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
23
+ if in_config.gsp and in_config.gsp.zarr_path:
24
+ da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()
25
25
 
26
26
  # Remove national GSP
27
27
  datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
@@ -32,9 +32,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
32
32
  datasets_dict["nwp"] = {}
33
33
  for nwp_source, nwp_config in in_config.nwp.items():
34
34
 
35
- da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
35
+ da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
36
36
 
37
- da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
37
+ da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
38
38
 
39
39
  datasets_dict["nwp"][nwp_source] = da_nwp
40
40
 
@@ -42,9 +42,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
42
42
  if in_config.satellite:
43
43
  sat_config = config.input_data.satellite
44
44
 
45
- da_sat = open_sat_data(sat_config.satellite_zarr_path)
45
+ da_sat = open_sat_data(sat_config.zarr_path)
46
46
 
47
- da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
47
+ da_sat = da_sat.sel(channel=list(sat_config.channels))
48
48
 
49
49
  datasets_dict["sat"] = da_sat
50
50
 
@@ -30,8 +30,8 @@ def slice_datasets_by_space(
30
30
  sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
31
31
  datasets_dict["nwp"][nwp_key],
32
32
  location,
33
- height_pixels=nwp_config.nwp_image_size_pixels_height,
34
- width_pixels=nwp_config.nwp_image_size_pixels_width,
33
+ height_pixels=nwp_config.image_size_pixels_height,
34
+ width_pixels=nwp_config.image_size_pixels_width,
35
35
  )
36
36
 
37
37
  if "sat" in datasets_dict:
@@ -40,8 +40,8 @@ def slice_datasets_by_space(
40
40
  sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
41
41
  datasets_dict["sat"],
42
42
  location,
43
- height_pixels=sat_config.satellite_image_size_pixels_height,
44
- width_pixels=sat_config.satellite_image_size_pixels_width,
43
+ height_pixels=sat_config.image_size_pixels_height,
44
+ width_pixels=sat_config.image_size_pixels_width,
45
45
  )
46
46
 
47
47
  if "gsp" in datasets_dict:
@@ -38,7 +38,7 @@ def slice_datasets_by_time(
38
38
  forecast_duration=minutes(nwp_config.forecast_minutes),
39
39
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
40
40
  dropout_frac=nwp_config.dropout_fraction,
41
- accum_channels=nwp_config.nwp_accum_channels,
41
+ accum_channels=nwp_config.accum_channels,
42
42
  )
43
43
 
44
44
  if "sat" in datasets_dict:
@@ -35,7 +35,7 @@ def process_and_combine_datasets(
35
35
 
36
36
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
37
37
  # Standardise
38
- provider = config.input_data.nwp[nwp_key].nwp_provider
38
+ provider = config.input_data.nwp[nwp_key].provider
39
39
  da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
40
40
  # Convert to NumpyBatch
41
41
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
@@ -38,7 +38,7 @@ def find_valid_time_periods(
38
38
  max_staleness = minutes(nwp_config.max_staleness_minutes)
39
39
 
40
40
  # The last step of the forecast is lost if we have to diff channels
41
- if len(nwp_config.nwp_accum_channels) > 0:
41
+ if len(nwp_config.accum_channels) > 0:
42
42
  end_buffer = minutes(nwp_config.time_resolution_minutes)
43
43
  else:
44
44
  end_buffer = minutes(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.25
3
+ Version: 0.0.26
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -3,12 +3,12 @@ ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ
3
3
  ocf_data_sampler/time_functions.py,sha256=R6ZlVEe6h4UlJeUW7paZYAMWveOv9MTjMsoISCwnsiE,284
4
4
  ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
5
5
  ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
6
- ocf_data_sampler/config/model.py,sha256=5GO8SF_4iOZhCAyIJyENSl0dnDRIWrURgqwslrVWke8,9462
6
+ ocf_data_sampler/config/model.py,sha256=YnGOzt6T835h6bozWqrlMnUIHPo26U8o-DTKAKvv_24,7121
7
7
  ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
8
8
  ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
9
9
  ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
10
10
  ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
- ocf_data_sampler/load/load_dataset.py,sha256=R4RAIVLVx6CHA6Qs61kD9sx834I_GMGAn6G7ZgwFMUA,1627
11
+ ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
12
12
  ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
13
13
  ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
14
14
  ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
@@ -32,17 +32,17 @@ ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGs
32
32
  ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
33
33
  ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
34
34
  ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
35
- ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=Nrc3j8DR5MM4BPPp9IQwaIMpoyOkc6AADMnfOjg-170,1791
36
- ocf_data_sampler/select/time_slice_for_dataset.py,sha256=A9fxvurbM0JSRkrjyg5Lr70_Mj6t5OO7HFqHUZel9q4,4220
35
+ ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
36
+ ocf_data_sampler/select/time_slice_for_dataset.py,sha256=5gcTGgQ1D524OhullNRWq3hxCwl2SoliGR210G-62JA,4216
37
37
  ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
38
- ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=Lovc2UM3-HgUy2BoQEIr0gQTz3USW6ACRWo-iTgxjHs,4993
38
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=GA-tGZLEMNAqX5Zun_7tPcTWVxlVtwejC9zfXPECwSk,4989
39
39
  ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=TpHALGU7hpo3iLbvD0nkoY6zu94Vq99W1V1qSGEcIW8,5552
40
40
  ocf_data_sampler/torch_datasets/site.py,sha256=1k0fWXYwAAIWG5DX_j3tgNfY8gglfPGLNzNlZd8EnJs,6631
41
- ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=dNJkBH5wdsFUjoFSmthU3yTqar6OPE77WsRQUebm-PY,4163
41
+ ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=vP25e7DpWAu4dACTFMJZm0bi304iUFdi1XySAmxi_c0,4159
42
42
  scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
43
43
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  tests/conftest.py,sha256=ZRktySCynj3NBbFRR4EFNLRLFMErkQsC-qQlmQzhbRg,7360
45
- tests/config/test_config.py,sha256=G_PD_pXib0zdRBPUIn0jjwJ9VyoKaO_TanLN1Mh5Ca4,5055
45
+ tests/config/test_config.py,sha256=C8NppoEVCMKxTTUf3o_z1Jb_I2DDH75XKpQ9x45U3Hw,5090
46
46
  tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
47
47
  tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
48
48
  tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
@@ -57,10 +57,10 @@ tests/select/test_find_contiguous_time_periods.py,sha256=G6tJRJd0DMfH9EdfzlKWsmf
57
57
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
58
58
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
59
59
  tests/select/test_select_time_slice.py,sha256=XC1J3DBBDnt81jcba5u-Hnd0yKv8GIQErLm-OECV6rs,10147
60
- tests/torch_datasets/test_pvnet_uk_regional.py,sha256=u3taw6p3oozM0_7cEEhCYbImAQPRldRhpruqSyV08Vg,2675
61
- tests/torch_datasets/test_site.py,sha256=5hdUP64neCDWEo2NMSd-MhbpuQjQvD6NOvhZ1DlMmo8,2733
62
- ocf_data_sampler-0.0.25.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
- ocf_data_sampler-0.0.25.dist-info/METADATA,sha256=p3SKEM4gRy0Z4LTcRWlgTrpjQ-QV89ar69tM9EwhudU,5269
64
- ocf_data_sampler-0.0.25.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- ocf_data_sampler-0.0.25.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
- ocf_data_sampler-0.0.25.dist-info/RECORD,,
60
+ tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
61
+ tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
62
+ ocf_data_sampler-0.0.26.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
+ ocf_data_sampler-0.0.26.dist-info/METADATA,sha256=VRnSRX4dgDbz4k9bwSM66uqaHI4P97xC97_NsEIt5qU,5269
64
+ ocf_data_sampler-0.0.26.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ ocf_data_sampler-0.0.26.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
+ ocf_data_sampler-0.0.26.dist-info/RECORD,,
@@ -10,13 +10,13 @@ from ocf_data_sampler.config import (
10
10
  )
11
11
 
12
12
 
13
- def test_default():
13
+ def test_default_configuration():
14
14
  """Test default pydantic class"""
15
15
 
16
16
  _ = Configuration()
17
17
 
18
18
 
19
- def test_yaml_load_test_config(test_config_filename):
19
+ def test_load_yaml_configuration(test_config_filename):
20
20
  """
21
21
  Test that yaml loading works for 'test_config.yaml'
22
22
  and fails for an empty .yaml file
@@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
56
56
  assert test_config == tmp_config
57
57
 
58
58
 
59
- def test_extra_field():
59
+ def test_extra_field_error():
60
60
  """
61
61
  Check an extra parameters in config causes error
62
62
  """
@@ -99,10 +99,11 @@ def test_incorrect_nwp_provider(test_config_filename):
99
99
 
100
100
  configuration = load_yaml_configuration(test_config_filename)
101
101
 
102
- configuration.input_data.nwp['ukv'].nwp_provider = "unexpected_provider"
102
+ configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
103
103
  with pytest.raises(Exception, match="NWP provider"):
104
104
  _ = Configuration(**configuration.model_dump())
105
105
 
106
+
106
107
  def test_incorrect_dropout(test_config_filename):
107
108
  """
108
109
  Check a dropout timedelta over 0 causes error and 0 doesn't
@@ -119,6 +120,7 @@ def test_incorrect_dropout(test_config_filename):
119
120
  configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
120
121
  _ = Configuration(**configuration.model_dump())
121
122
 
123
+
122
124
  def test_incorrect_dropout_fraction(test_config_filename):
123
125
  """
124
126
  Check dropout fraction outside of range causes error
@@ -127,11 +129,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
127
129
  configuration = load_yaml_configuration(test_config_filename)
128
130
 
129
131
  configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
130
- with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
132
+
133
+ with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
131
134
  _ = Configuration(**configuration.model_dump())
132
135
 
133
136
  configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
134
- with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
137
+ with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
135
138
  _ = Configuration(**configuration.model_dump())
136
139
 
137
140
 
@@ -11,9 +11,9 @@ def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_z
11
11
 
12
12
  # adjust config to point to the zarr file
13
13
  config = load_yaml_configuration(config_filename)
14
- config.input_data.nwp['ukv'].nwp_zarr_path = nwp_ukv_zarr_path
15
- config.input_data.satellite.satellite_zarr_path = sat_zarr_path
16
- config.input_data.gsp.gsp_zarr_path = uk_gsp_zarr_path
14
+ config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
15
+ config.input_data.satellite.zarr_path = sat_zarr_path
16
+ config.input_data.gsp.zarr_path = uk_gsp_zarr_path
17
17
 
18
18
  filename = f"{tmp_path}/configuration.yaml"
19
19
  save_yaml_configuration(config, filename)
@@ -60,7 +60,7 @@ def test_pvnet_no_gsp(pvnet_config_filename):
60
60
  # load config
61
61
  config = load_yaml_configuration(pvnet_config_filename)
62
62
  # remove gsp
63
- config.input_data.gsp.gsp_zarr_path = ''
63
+ config.input_data.gsp.zarr_path = ''
64
64
 
65
65
  # save temp config file
66
66
  with tempfile.NamedTemporaryFile() as temp_config_file:
@@ -13,8 +13,8 @@ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_
13
13
 
14
14
  # adjust config to point to the zarr file
15
15
  config = load_yaml_configuration(config_filename)
16
- config.input_data.nwp["ukv"].nwp_zarr_path = nwp_ukv_zarr_path
17
- config.input_data.satellite.satellite_zarr_path = sat_zarr_path
16
+ config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
17
+ config.input_data.satellite.zarr_path = sat_zarr_path
18
18
  config.input_data.site = data_sites
19
19
  config.input_data.gsp = None
20
20