ocf-data-sampler 0.0.25__tar.gz → 0.0.26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.0.25/ocf_data_sampler.egg-info → ocf_data_sampler-0.0.26}/PKG-INFO +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/config/model.py +66 -103
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/load_dataset.py +6 -6
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/spatial_slice_for_dataset.py +4 -4
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/time_slice_for_dataset.py +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/process_and_combine.py +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/valid_time_periods.py +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/pyproject.toml +1 -1
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/config/test_config.py +9 -6
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/torch_datasets/test_pvnet_uk_regional.py +4 -4
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/torch_datasets/test_site.py +2 -2
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/LICENSE +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/MANIFEST.in +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/README.md +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/constants.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/satellite.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/site.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/gsp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/nwp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/satellite.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/site.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/sun_position.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/dropout.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/fill_time_periods.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/time_functions.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/site.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler.egg-info/SOURCES.txt +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/setup.cfg +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/__init__.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/conftest.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/load/test_load_gsp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/load/test_load_nwp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/load/test_load_satellite.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/load/test_load_sites.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/numpy_batch/test_gsp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/numpy_batch/test_nwp.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/numpy_batch/test_satellite.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/numpy_batch/test_sun_position.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_dropout.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_fill_time_periods.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_location.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_select_time_slice.py +0 -0
|
@@ -14,7 +14,7 @@ import logging
|
|
|
14
14
|
from typing import Dict, List, Optional
|
|
15
15
|
from typing_extensions import Self
|
|
16
16
|
|
|
17
|
-
from pydantic import BaseModel, Field, RootModel, field_validator,
|
|
17
|
+
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
|
18
18
|
from ocf_data_sampler.constants import NWP_PROVIDERS
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
@@ -34,27 +34,12 @@ class Base(BaseModel):
|
|
|
34
34
|
class General(Base):
|
|
35
35
|
"""General pydantic model"""
|
|
36
36
|
|
|
37
|
-
name: str = Field("example", description="The name of this configuration file
|
|
37
|
+
name: str = Field("example", description="The name of this configuration file")
|
|
38
38
|
description: str = Field(
|
|
39
39
|
"example configuration", description="Description of this configuration file"
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
class DataSourceMixin(Base):
|
|
44
|
-
"""Mixin class, to add forecast and history minutes"""
|
|
45
|
-
|
|
46
|
-
forecast_minutes: int = Field(
|
|
47
|
-
...,
|
|
48
|
-
ge=0,
|
|
49
|
-
description="how many minutes to forecast in the future. ",
|
|
50
|
-
)
|
|
51
|
-
history_minutes: int = Field(
|
|
52
|
-
...,
|
|
53
|
-
ge=0,
|
|
54
|
-
description="how many historic minutes to use. ",
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
43
|
# noinspection PyMethodParameters
|
|
59
44
|
class DropoutMixin(Base):
|
|
60
45
|
"""Mixin class, to add dropout minutes"""
|
|
@@ -65,7 +50,12 @@ class DropoutMixin(Base):
|
|
|
65
50
|
"negative or zero.",
|
|
66
51
|
)
|
|
67
52
|
|
|
68
|
-
dropout_fraction: float = Field(
|
|
53
|
+
dropout_fraction: float = Field(
|
|
54
|
+
default=0,
|
|
55
|
+
description="Chance of dropout being applied to each sample",
|
|
56
|
+
ge=0,
|
|
57
|
+
le=1,
|
|
58
|
+
)
|
|
69
59
|
|
|
70
60
|
@field_validator("dropout_timedeltas_minutes")
|
|
71
61
|
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
|
|
@@ -75,12 +65,6 @@ class DropoutMixin(Base):
|
|
|
75
65
|
assert m <= 0, "Dropout timedeltas must be negative"
|
|
76
66
|
return v
|
|
77
67
|
|
|
78
|
-
@field_validator("dropout_fraction")
|
|
79
|
-
def dropout_fraction_valid(cls, v: float) -> float:
|
|
80
|
-
"""Validate 'dropout_fraction'"""
|
|
81
|
-
assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
|
|
82
|
-
return v
|
|
83
|
-
|
|
84
68
|
@model_validator(mode="after")
|
|
85
69
|
def dropout_instructions_consistent(self) -> Self:
|
|
86
70
|
if self.dropout_fraction == 0:
|
|
@@ -93,69 +77,67 @@ class DropoutMixin(Base):
|
|
|
93
77
|
|
|
94
78
|
|
|
95
79
|
# noinspection PyMethodParameters
|
|
96
|
-
class
|
|
80
|
+
class TimeWindowMixin(Base):
|
|
97
81
|
"""Time resolution mix in"""
|
|
98
82
|
|
|
99
83
|
time_resolution_minutes: int = Field(
|
|
100
84
|
...,
|
|
85
|
+
gt=0,
|
|
101
86
|
description="The temporal resolution of the data in minutes",
|
|
102
87
|
)
|
|
103
88
|
|
|
104
|
-
|
|
105
|
-
class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
|
|
106
|
-
"""Site configuration model"""
|
|
107
|
-
|
|
108
|
-
file_path: str = Field(
|
|
89
|
+
forecast_minutes: int = Field(
|
|
109
90
|
...,
|
|
110
|
-
|
|
91
|
+
ge=0,
|
|
92
|
+
description="how many minutes to forecast in the future",
|
|
111
93
|
)
|
|
112
|
-
|
|
94
|
+
history_minutes: int = Field(
|
|
113
95
|
...,
|
|
114
|
-
|
|
96
|
+
ge=0,
|
|
97
|
+
description="how many historic minutes to use",
|
|
115
98
|
)
|
|
116
99
|
|
|
117
100
|
@field_validator("forecast_minutes")
|
|
118
|
-
def forecast_minutes_divide_by_time_resolution(cls, v
|
|
119
|
-
|
|
120
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
101
|
+
def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
|
|
102
|
+
if v % values.data["time_resolution_minutes"] != 0:
|
|
121
103
|
message = "Forecast duration must be divisible by time resolution"
|
|
122
104
|
logger.error(message)
|
|
123
105
|
raise Exception(message)
|
|
124
106
|
return v
|
|
125
107
|
|
|
126
108
|
@field_validator("history_minutes")
|
|
127
|
-
def history_minutes_divide_by_time_resolution(cls, v
|
|
128
|
-
|
|
129
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
109
|
+
def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
|
|
110
|
+
if v % values.data["time_resolution_minutes"] != 0:
|
|
130
111
|
message = "History duration must be divisible by time resolution"
|
|
131
112
|
logger.error(message)
|
|
132
113
|
raise Exception(message)
|
|
133
114
|
return v
|
|
134
115
|
|
|
135
|
-
# TODO validate the netcdf for sites
|
|
136
|
-
# TODO validate the csv for metadata
|
|
137
116
|
|
|
138
|
-
class
|
|
139
|
-
"""
|
|
117
|
+
class SpatialWindowMixin(Base):
|
|
118
|
+
"""Mixin class, to add path and image size"""
|
|
140
119
|
|
|
141
|
-
|
|
142
|
-
satellite_zarr_path: str | tuple[str] | list[str] = Field(
|
|
120
|
+
image_size_pixels_height: int = Field(
|
|
143
121
|
...,
|
|
144
|
-
description="The
|
|
122
|
+
description="The number of pixels of the height of the region of interest",
|
|
145
123
|
)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
)
|
|
149
|
-
satellite_image_size_pixels_height: int = Field(
|
|
124
|
+
|
|
125
|
+
image_size_pixels_width: int = Field(
|
|
150
126
|
...,
|
|
151
|
-
description="The number of pixels of the
|
|
152
|
-
" for non-HRV satellite channels.",
|
|
127
|
+
description="The number of pixels of the width of the region of interest",
|
|
153
128
|
)
|
|
154
129
|
|
|
155
|
-
|
|
130
|
+
|
|
131
|
+
class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
132
|
+
"""Satellite configuration model"""
|
|
133
|
+
|
|
134
|
+
zarr_path: str | tuple[str] | list[str] = Field(
|
|
156
135
|
...,
|
|
157
|
-
description="The
|
|
158
|
-
|
|
136
|
+
description="The path or list of paths which hold the data zarr",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
channels: list[str] = Field(
|
|
140
|
+
..., description="the satellite channels that are used"
|
|
159
141
|
)
|
|
160
142
|
|
|
161
143
|
live_delay_minutes: int = Field(
|
|
@@ -164,21 +146,21 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
|
|
|
164
146
|
|
|
165
147
|
|
|
166
148
|
# noinspection PyMethodParameters
|
|
167
|
-
class NWP(
|
|
149
|
+
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
|
|
168
150
|
"""NWP configuration model"""
|
|
169
|
-
|
|
170
|
-
|
|
151
|
+
|
|
152
|
+
zarr_path: str | tuple[str] | list[str] = Field(
|
|
171
153
|
...,
|
|
172
|
-
description="The path which
|
|
154
|
+
description="The path or list of paths which hold the data zarr",
|
|
173
155
|
)
|
|
174
|
-
|
|
156
|
+
|
|
157
|
+
channels: list[str] = Field(
|
|
175
158
|
..., description="the channels used in the nwp data"
|
|
176
159
|
)
|
|
177
|
-
nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
|
|
178
|
-
nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
|
|
179
|
-
nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")
|
|
180
160
|
|
|
181
|
-
|
|
161
|
+
provider: str = Field(..., description="The provider of the NWP data")
|
|
162
|
+
|
|
163
|
+
accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
|
|
182
164
|
|
|
183
165
|
max_staleness_minutes: Optional[int] = Field(
|
|
184
166
|
None,
|
|
@@ -187,33 +169,15 @@ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
|
|
|
187
169
|
" the maximum forecast horizon of the NWP and the requested forecast length.",
|
|
188
170
|
)
|
|
189
171
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
"""Validate 'nwp_provider'"""
|
|
172
|
+
@field_validator("provider")
|
|
173
|
+
def validate_provider(cls, v: str) -> str:
|
|
174
|
+
"""Validate 'provider'"""
|
|
194
175
|
if v.lower() not in NWP_PROVIDERS:
|
|
195
176
|
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
|
|
196
177
|
logger.warning(message)
|
|
197
178
|
raise Exception(message)
|
|
198
179
|
return v
|
|
199
180
|
|
|
200
|
-
# Todo: put into time mixin when moving intervals there
|
|
201
|
-
@field_validator("forecast_minutes")
|
|
202
|
-
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
203
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
204
|
-
message = "Forecast duration must be divisible by time resolution"
|
|
205
|
-
logger.error(message)
|
|
206
|
-
raise Exception(message)
|
|
207
|
-
return v
|
|
208
|
-
|
|
209
|
-
@field_validator("history_minutes")
|
|
210
|
-
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
211
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
212
|
-
message = "History duration must be divisible by time resolution"
|
|
213
|
-
logger.error(message)
|
|
214
|
-
raise Exception(message)
|
|
215
|
-
return v
|
|
216
|
-
|
|
217
181
|
|
|
218
182
|
class MultiNWP(RootModel):
|
|
219
183
|
"""Configuration for multiple NWPs"""
|
|
@@ -241,27 +205,26 @@ class MultiNWP(RootModel):
|
|
|
241
205
|
return self.root.items()
|
|
242
206
|
|
|
243
207
|
|
|
244
|
-
|
|
245
|
-
class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
|
|
208
|
+
class GSP(TimeWindowMixin, DropoutMixin):
|
|
246
209
|
"""GSP configuration model"""
|
|
247
210
|
|
|
248
|
-
|
|
211
|
+
zarr_path: str = Field(..., description="The path which holds the GSP zarr")
|
|
249
212
|
|
|
250
|
-
@field_validator("forecast_minutes")
|
|
251
|
-
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
252
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
253
|
-
message = "Forecast duration must be divisible by time resolution"
|
|
254
|
-
logger.error(message)
|
|
255
|
-
raise Exception(message)
|
|
256
|
-
return v
|
|
257
213
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
214
|
+
class Site(TimeWindowMixin, DropoutMixin):
|
|
215
|
+
"""Site configuration model"""
|
|
216
|
+
|
|
217
|
+
file_path: str = Field(
|
|
218
|
+
...,
|
|
219
|
+
description="The NetCDF files holding the power timeseries.",
|
|
220
|
+
)
|
|
221
|
+
metadata_file_path: str = Field(
|
|
222
|
+
...,
|
|
223
|
+
description="The CSV files describing power system",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# TODO validate the netcdf for sites
|
|
227
|
+
# TODO validate the csv for metadata
|
|
265
228
|
|
|
266
229
|
|
|
267
230
|
# noinspection PyPep8Naming
|
|
@@ -280,4 +243,4 @@ class Configuration(Base):
|
|
|
280
243
|
"""Configuration model for the dataset"""
|
|
281
244
|
|
|
282
245
|
general: General = General()
|
|
283
|
-
input_data: InputData = InputData()
|
|
246
|
+
input_data: InputData = InputData()
|
|
@@ -20,8 +20,8 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
|
|
|
20
20
|
datasets_dict = {}
|
|
21
21
|
|
|
22
22
|
# Load GSP data unless the path is None
|
|
23
|
-
if in_config.gsp and in_config.gsp.
|
|
24
|
-
da_gsp = open_gsp(zarr_path=in_config.gsp.
|
|
23
|
+
if in_config.gsp and in_config.gsp.zarr_path:
|
|
24
|
+
da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()
|
|
25
25
|
|
|
26
26
|
# Remove national GSP
|
|
27
27
|
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
|
|
@@ -32,9 +32,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
|
|
|
32
32
|
datasets_dict["nwp"] = {}
|
|
33
33
|
for nwp_source, nwp_config in in_config.nwp.items():
|
|
34
34
|
|
|
35
|
-
da_nwp = open_nwp(nwp_config.
|
|
35
|
+
da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)
|
|
36
36
|
|
|
37
|
-
da_nwp = da_nwp.sel(channel=list(nwp_config.
|
|
37
|
+
da_nwp = da_nwp.sel(channel=list(nwp_config.channels))
|
|
38
38
|
|
|
39
39
|
datasets_dict["nwp"][nwp_source] = da_nwp
|
|
40
40
|
|
|
@@ -42,9 +42,9 @@ def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
|
|
|
42
42
|
if in_config.satellite:
|
|
43
43
|
sat_config = config.input_data.satellite
|
|
44
44
|
|
|
45
|
-
da_sat = open_sat_data(sat_config.
|
|
45
|
+
da_sat = open_sat_data(sat_config.zarr_path)
|
|
46
46
|
|
|
47
|
-
da_sat = da_sat.sel(channel=list(sat_config.
|
|
47
|
+
da_sat = da_sat.sel(channel=list(sat_config.channels))
|
|
48
48
|
|
|
49
49
|
datasets_dict["sat"] = da_sat
|
|
50
50
|
|
|
@@ -30,8 +30,8 @@ def slice_datasets_by_space(
|
|
|
30
30
|
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
31
31
|
datasets_dict["nwp"][nwp_key],
|
|
32
32
|
location,
|
|
33
|
-
height_pixels=nwp_config.
|
|
34
|
-
width_pixels=nwp_config.
|
|
33
|
+
height_pixels=nwp_config.image_size_pixels_height,
|
|
34
|
+
width_pixels=nwp_config.image_size_pixels_width,
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
if "sat" in datasets_dict:
|
|
@@ -40,8 +40,8 @@ def slice_datasets_by_space(
|
|
|
40
40
|
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
|
|
41
41
|
datasets_dict["sat"],
|
|
42
42
|
location,
|
|
43
|
-
height_pixels=sat_config.
|
|
44
|
-
width_pixels=sat_config.
|
|
43
|
+
height_pixels=sat_config.image_size_pixels_height,
|
|
44
|
+
width_pixels=sat_config.image_size_pixels_width,
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
if "gsp" in datasets_dict:
|
|
@@ -38,7 +38,7 @@ def slice_datasets_by_time(
|
|
|
38
38
|
forecast_duration=minutes(nwp_config.forecast_minutes),
|
|
39
39
|
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
40
40
|
dropout_frac=nwp_config.dropout_fraction,
|
|
41
|
-
accum_channels=nwp_config.
|
|
41
|
+
accum_channels=nwp_config.accum_channels,
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
if "sat" in datasets_dict:
|
|
@@ -35,7 +35,7 @@ def process_and_combine_datasets(
|
|
|
35
35
|
|
|
36
36
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
37
|
# Standardise
|
|
38
|
-
provider = config.input_data.nwp[nwp_key].
|
|
38
|
+
provider = config.input_data.nwp[nwp_key].provider
|
|
39
39
|
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
40
|
# Convert to NumpyBatch
|
|
41
41
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
@@ -38,7 +38,7 @@ def find_valid_time_periods(
|
|
|
38
38
|
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
39
39
|
|
|
40
40
|
# The last step of the forecast is lost if we have to diff channels
|
|
41
|
-
if len(nwp_config.
|
|
41
|
+
if len(nwp_config.accum_channels) > 0:
|
|
42
42
|
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
43
43
|
else:
|
|
44
44
|
end_buffer = minutes(0)
|
|
@@ -10,13 +10,13 @@ from ocf_data_sampler.config import (
|
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def test_default_configuration():
|
|
14
14
|
"""Test default pydantic class"""
|
|
15
15
|
|
|
16
16
|
_ = Configuration()
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def test_load_yaml_configuration(test_config_filename):
|
|
20
20
|
"""
|
|
21
21
|
Test that yaml loading works for 'test_config.yaml'
|
|
22
22
|
and fails for an empty .yaml file
|
|
@@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
|
|
|
56
56
|
assert test_config == tmp_config
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def
|
|
59
|
+
def test_extra_field_error():
|
|
60
60
|
"""
|
|
61
61
|
Check an extra parameters in config causes error
|
|
62
62
|
"""
|
|
@@ -99,10 +99,11 @@ def test_incorrect_nwp_provider(test_config_filename):
|
|
|
99
99
|
|
|
100
100
|
configuration = load_yaml_configuration(test_config_filename)
|
|
101
101
|
|
|
102
|
-
configuration.input_data.nwp['ukv'].
|
|
102
|
+
configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
|
|
103
103
|
with pytest.raises(Exception, match="NWP provider"):
|
|
104
104
|
_ = Configuration(**configuration.model_dump())
|
|
105
105
|
|
|
106
|
+
|
|
106
107
|
def test_incorrect_dropout(test_config_filename):
|
|
107
108
|
"""
|
|
108
109
|
Check a dropout timedelta over 0 causes error and 0 doesn't
|
|
@@ -119,6 +120,7 @@ def test_incorrect_dropout(test_config_filename):
|
|
|
119
120
|
configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
|
|
120
121
|
_ = Configuration(**configuration.model_dump())
|
|
121
122
|
|
|
123
|
+
|
|
122
124
|
def test_incorrect_dropout_fraction(test_config_filename):
|
|
123
125
|
"""
|
|
124
126
|
Check dropout fraction outside of range causes error
|
|
@@ -127,11 +129,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
|
|
|
127
129
|
configuration = load_yaml_configuration(test_config_filename)
|
|
128
130
|
|
|
129
131
|
configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
|
|
130
|
-
|
|
132
|
+
|
|
133
|
+
with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
|
|
131
134
|
_ = Configuration(**configuration.model_dump())
|
|
132
135
|
|
|
133
136
|
configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
|
|
134
|
-
with pytest.raises(
|
|
137
|
+
with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
|
|
135
138
|
_ = Configuration(**configuration.model_dump())
|
|
136
139
|
|
|
137
140
|
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/torch_datasets/test_pvnet_uk_regional.py
RENAMED
|
@@ -11,9 +11,9 @@ def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_z
|
|
|
11
11
|
|
|
12
12
|
# adjust config to point to the zarr file
|
|
13
13
|
config = load_yaml_configuration(config_filename)
|
|
14
|
-
config.input_data.nwp['ukv'].
|
|
15
|
-
config.input_data.satellite.
|
|
16
|
-
config.input_data.gsp.
|
|
14
|
+
config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
|
|
15
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
16
|
+
config.input_data.gsp.zarr_path = uk_gsp_zarr_path
|
|
17
17
|
|
|
18
18
|
filename = f"{tmp_path}/configuration.yaml"
|
|
19
19
|
save_yaml_configuration(config, filename)
|
|
@@ -60,7 +60,7 @@ def test_pvnet_no_gsp(pvnet_config_filename):
|
|
|
60
60
|
# load config
|
|
61
61
|
config = load_yaml_configuration(pvnet_config_filename)
|
|
62
62
|
# remove gsp
|
|
63
|
-
config.input_data.gsp.
|
|
63
|
+
config.input_data.gsp.zarr_path = ''
|
|
64
64
|
|
|
65
65
|
# save temp config file
|
|
66
66
|
with tempfile.NamedTemporaryFile() as temp_config_file:
|
|
@@ -13,8 +13,8 @@ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_
|
|
|
13
13
|
|
|
14
14
|
# adjust config to point to the zarr file
|
|
15
15
|
config = load_yaml_configuration(config_filename)
|
|
16
|
-
config.input_data.nwp["ukv"].
|
|
17
|
-
config.input_data.satellite.
|
|
16
|
+
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
17
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
18
18
|
config.input_data.site = data_sites
|
|
19
19
|
config.input_data.gsp = None
|
|
20
20
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/data/uk_gsp_locations.csv
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/__init__.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/ecmwf.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/ukv.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/load/nwp/providers/utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/satellite.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/numpy_batch/sun_position.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/select_spatial_slice.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/select/select_time_slice.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler/torch_datasets/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/ocf_data_sampler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.0.25 → ocf_data_sampler-0.0.26}/tests/select/test_select_spatial_slice.py
RENAMED
|
File without changes
|
|
File without changes
|