ocf-data-sampler 0.5.11__tar.gz → 0.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/PKG-INFO +7 -3
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/README.md +5 -1
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/model.py +27 -41
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/gsp.py +4 -2
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/nwp.py +0 -1
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/utils.py +1 -1
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/open_xarray_tensorstore.py +26 -7
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/satellite.py +1 -1
- ocf_data_sampler-0.6.2/ocf_data_sampler/load/site.py +73 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/nwp.py +1 -1
- ocf_data_sampler-0.6.2/ocf_data_sampler/select/diff_channels.py +25 -0
- ocf_data_sampler-0.6.2/ocf_data_sampler/select/dropout.py +59 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/location.py +2 -2
- ocf_data_sampler-0.6.2/ocf_data_sampler/select/select_spatial_slice.py +110 -0
- ocf_data_sampler-0.6.2/ocf_data_sampler/select/select_time_slice.py +107 -0
- ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/datasets/picklecache.py +33 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +17 -15
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/site.py +27 -38
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/__init__.py +2 -1
- ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +59 -0
- ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/diff_nwp_data.py +20 -0
- ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +50 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +22 -30
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/base.py → ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/torch_batch_utils.py +2 -29
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/utils.py +18 -6
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/PKG-INFO +7 -3
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/SOURCES.txt +4 -4
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/pyproject.toml +1 -1
- ocf_data_sampler-0.5.11/ocf_data_sampler/load/site.py +0 -59
- ocf_data_sampler-0.5.11/ocf_data_sampler/select/dropout.py +0 -61
- ocf_data_sampler-0.5.11/ocf_data_sampler/select/select_spatial_slice.py +0 -216
- ocf_data_sampler-0.5.11/ocf_data_sampler/select/select_time_slice.py +0 -143
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -3
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/site.py +0 -48
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -262
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +0 -57
- ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -29
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/LICENSE +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/collate.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/site.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/fill_time_periods.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/scripts/download_gsp_location_data.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/setup.cfg +0 -0
- {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/tests/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ocf-data-sampler
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.2
|
|
4
4
|
Author: James Fulton, Peter Dudfield
|
|
5
5
|
Author-email: Open Climate Fix team <info@openclimatefix.org>
|
|
6
6
|
License: MIT License
|
|
@@ -28,7 +28,7 @@ License: MIT License
|
|
|
28
28
|
Project-URL: repository, https://github.com/openclimatefix/ocf-data-sampler
|
|
29
29
|
Classifier: Programming Language :: Python :: 3
|
|
30
30
|
Classifier: License :: OSI Approved :: MIT License
|
|
31
|
-
Requires-Python:
|
|
31
|
+
Requires-Python: <3.14,>=3.11
|
|
32
32
|
Description-Content-Type: text/markdown
|
|
33
33
|
Requires-Dist: torch
|
|
34
34
|
Requires-Dist: numpy
|
|
@@ -50,7 +50,7 @@ Requires-Dist: zarr>=3
|
|
|
50
50
|
# ocf-data-sampler
|
|
51
51
|
|
|
52
52
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
53
|
-
[](#contributors-)
|
|
54
54
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
55
55
|
|
|
56
56
|
[](https://github.com/openclimatefix/ocf-data-sampler/tags)
|
|
@@ -137,6 +137,10 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
137
137
|
<td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
|
|
138
138
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
|
|
139
139
|
</tr>
|
|
140
|
+
<tr>
|
|
141
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/utsav-pal"><img src="https://avatars.githubusercontent.com/u/159793156?v=4?s=100" width="100px;" alt="utsav-pal"/><br /><sub><b>utsav-pal</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=utsav-pal" title="Code">💻</a></td>
|
|
142
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=zaryab-ali" title="Code">💻</a></td>
|
|
143
|
+
</tr>
|
|
140
144
|
</tbody>
|
|
141
145
|
</table>
|
|
142
146
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# ocf-data-sampler
|
|
2
2
|
|
|
3
3
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
4
|
-
[](#contributors-)
|
|
5
5
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
6
6
|
|
|
7
7
|
[](https://github.com/openclimatefix/ocf-data-sampler/tags)
|
|
@@ -88,6 +88,10 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
88
88
|
<td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
|
|
89
89
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
|
|
90
90
|
</tr>
|
|
91
|
+
<tr>
|
|
92
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/utsav-pal"><img src="https://avatars.githubusercontent.com/u/159793156?v=4?s=100" width="100px;" alt="utsav-pal"/><br /><sub><b>utsav-pal</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=utsav-pal" title="Code">💻</a></td>
|
|
93
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=zaryab-ali" title="Code">💻</a></td>
|
|
94
|
+
</tr>
|
|
91
95
|
</tbody>
|
|
92
96
|
</table>
|
|
93
97
|
|
|
@@ -7,7 +7,7 @@ Prefix with a protocol like s3:// to read from alternative filesystems.
|
|
|
7
7
|
from collections.abc import Iterator
|
|
8
8
|
from typing import Literal
|
|
9
9
|
|
|
10
|
-
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
NWP_PROVIDERS = [
|
|
@@ -23,10 +23,7 @@ NWP_PROVIDERS = [
|
|
|
23
23
|
class Base(BaseModel):
|
|
24
24
|
"""Pydantic Base model where no extras can be added."""
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
"""Config class."""
|
|
28
|
-
|
|
29
|
-
extra = "forbid" # forbid use of extra kwargs
|
|
26
|
+
model_config = ConfigDict(extra="forbid")
|
|
30
27
|
|
|
31
28
|
|
|
32
29
|
class General(Base):
|
|
@@ -90,12 +87,17 @@ class DropoutMixin(Base):
|
|
|
90
87
|
"negative or zero.",
|
|
91
88
|
)
|
|
92
89
|
|
|
93
|
-
dropout_fraction: float|list[float] = Field(
|
|
94
|
-
default=0,
|
|
90
|
+
dropout_fraction: float | list[float] = Field(
|
|
91
|
+
default=0.0,
|
|
95
92
|
description="Either a float(Chance of dropout being applied to each sample) or a list of "
|
|
96
93
|
"floats (probability that dropout of the corresponding timedelta is applied)",
|
|
97
94
|
)
|
|
98
95
|
|
|
96
|
+
dropout_value: float = Field(
|
|
97
|
+
default=0.0,
|
|
98
|
+
description="The value to use for dropped out values. "
|
|
99
|
+
"Idea is to use -1, but to be backwards comptaible we've put the default as 0")
|
|
100
|
+
|
|
99
101
|
@field_validator("dropout_timedeltas_minutes")
|
|
100
102
|
def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
|
|
101
103
|
"""Validate 'dropout_timedeltas_minutes'."""
|
|
@@ -106,31 +108,22 @@ class DropoutMixin(Base):
|
|
|
106
108
|
|
|
107
109
|
|
|
108
110
|
@field_validator("dropout_fraction")
|
|
109
|
-
def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
|
|
111
|
+
def dropout_fractions(cls, dropout_frac: float | list[float]) -> float | list[float]:
|
|
110
112
|
"""Validate 'dropout_frac'."""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
raise ValueError("Input should be less than or equal to 1")
|
|
115
|
-
elif not (dropout_frac >= 0):
|
|
116
|
-
raise ValueError("Input should be greater than or equal to 0")
|
|
113
|
+
if isinstance(dropout_frac, float | int):
|
|
114
|
+
if not (0<= dropout_frac <= 1):
|
|
115
|
+
raise ValueError("Dropout fractions must be in range [0, 1]")
|
|
117
116
|
|
|
118
117
|
elif isinstance(dropout_frac, list):
|
|
119
118
|
if not dropout_frac:
|
|
120
119
|
raise ValueError("List cannot be empty")
|
|
121
120
|
|
|
122
|
-
if not all(isinstance(i, float) for i in dropout_frac):
|
|
123
|
-
raise ValueError("All elements in the list must be floats")
|
|
124
|
-
|
|
125
121
|
if not all(0 <= i <= 1 for i in dropout_frac):
|
|
126
|
-
raise ValueError("
|
|
127
|
-
|
|
128
|
-
if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
|
|
129
|
-
raise ValueError("Sum of all floats in the list must be 1.0")
|
|
122
|
+
raise ValueError("All dropout fractions must be in range [0, 1]")
|
|
130
123
|
|
|
124
|
+
if not (0 <= sum(dropout_frac) <= 1):
|
|
125
|
+
raise ValueError("The sum of dropout fractions must be in range [0, 1]")
|
|
131
126
|
|
|
132
|
-
else:
|
|
133
|
-
raise TypeError("Must be either a float or a list of floats")
|
|
134
127
|
return dropout_frac
|
|
135
128
|
|
|
136
129
|
|
|
@@ -172,23 +165,6 @@ class NormalisationConstantsMixin(Base):
|
|
|
172
165
|
"""Normalisation constants for multiple channels."""
|
|
173
166
|
normalisation_constants: dict[str, NormalisationValues]
|
|
174
167
|
|
|
175
|
-
@property
|
|
176
|
-
def channel_means(self) -> dict[str, float]:
|
|
177
|
-
"""Return the channel means."""
|
|
178
|
-
return {
|
|
179
|
-
channel: norm_values.mean
|
|
180
|
-
for channel, norm_values in self.normalisation_constants.items()
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
@property
|
|
185
|
-
def channel_stds(self) -> dict[str, float]:
|
|
186
|
-
"""Return the channel standard deviations."""
|
|
187
|
-
return {
|
|
188
|
-
channel: norm_values.std
|
|
189
|
-
for channel, norm_values in self.normalisation_constants.items()
|
|
190
|
-
}
|
|
191
|
-
|
|
192
168
|
|
|
193
169
|
class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
|
|
194
170
|
"""Satellite configuration model."""
|
|
@@ -363,9 +339,19 @@ class InputData(Base):
|
|
|
363
339
|
site: Site | None = None
|
|
364
340
|
solar_position: SolarPosition | None = None
|
|
365
341
|
|
|
342
|
+
@model_validator(mode="after")
|
|
343
|
+
def check_site_or_gsp(self) -> "InputData":
|
|
344
|
+
"""Ensure that either `site` or `gsp` is provided in the input data."""
|
|
345
|
+
if self.site is None and self.gsp is None:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
"You must provide either `site` or `gsp` in the `input_data`",
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return self
|
|
351
|
+
|
|
366
352
|
|
|
367
353
|
class Configuration(Base):
|
|
368
354
|
"""Configuration model for the dataset."""
|
|
369
355
|
|
|
370
356
|
general: General = General()
|
|
371
|
-
input_data: InputData = InputData
|
|
357
|
+
input_data: InputData = Field(default_factory=InputData)
|
|
@@ -32,7 +32,7 @@ def open_gsp(
|
|
|
32
32
|
boundaries_version: str = "20220314",
|
|
33
33
|
public: bool = False,
|
|
34
34
|
) -> xr.DataArray:
|
|
35
|
-
"""Open the GSP data and validates its data types.
|
|
35
|
+
"""Open and eagerly load the GSP data and validates its data types.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
zarr_path: Path to the GSP zarr data
|
|
@@ -93,4 +93,6 @@ def open_gsp(
|
|
|
93
93
|
dtype = gsp_da.coords[coord].dtype
|
|
94
94
|
raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
# Below we load the data eagerly into memory - this makes the dataset faster to sample from, but
|
|
97
|
+
# at the cost of a little extra memory usage
|
|
98
|
+
return gsp_da.compute()
|
{ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/utils.py
RENAMED
|
@@ -75,7 +75,7 @@ def _tensostore_open_zarr_paths(zarr_path: str | list[str], time_dim: str) -> xr
|
|
|
75
75
|
zarr_path = sorted(glob(zarr_path))
|
|
76
76
|
|
|
77
77
|
if isinstance(zarr_path, list | tuple):
|
|
78
|
-
ds = open_zarrs(zarr_path, concat_dim=time_dim).sortby(time_dim)
|
|
78
|
+
ds = open_zarrs(zarr_path, concat_dim=time_dim, data_source="nwp").sortby(time_dim)
|
|
79
79
|
else:
|
|
80
80
|
ds = open_zarr(zarr_path)
|
|
81
81
|
return ds
|
{ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/open_xarray_tensorstore.py
RENAMED
|
@@ -14,6 +14,7 @@ References:
|
|
|
14
14
|
[2] https://www.apache.org/licenses/LICENSE-2.0
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import logging
|
|
17
18
|
import os.path
|
|
18
19
|
import re
|
|
19
20
|
|
|
@@ -26,6 +27,7 @@ from xarray_tensorstore import (
|
|
|
26
27
|
_TensorStoreAdapter,
|
|
27
28
|
)
|
|
28
29
|
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
30
32
|
def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
|
|
31
33
|
if re.match(r"\w+\://", path): # path is a URI
|
|
@@ -127,6 +129,7 @@ def open_zarrs(
|
|
|
127
129
|
concat_dim: str,
|
|
128
130
|
context: ts.Context | None = None,
|
|
129
131
|
mask_and_scale: bool = True,
|
|
132
|
+
data_source: str = "unknown",
|
|
130
133
|
) -> xr.Dataset:
|
|
131
134
|
"""Open multiple zarrs with TensorStore.
|
|
132
135
|
|
|
@@ -135,6 +138,7 @@ def open_zarrs(
|
|
|
135
138
|
concat_dim: Dimension along which to concatenate the data variables.
|
|
136
139
|
context: TensorStore context.
|
|
137
140
|
mask_and_scale: Whether to mask and scale the data.
|
|
141
|
+
data_source: Which data source is being opened. Used for warning context.
|
|
138
142
|
|
|
139
143
|
Returns:
|
|
140
144
|
Concatenated Dataset with all data variables opened via TensorStore.
|
|
@@ -143,13 +147,28 @@ def open_zarrs(
|
|
|
143
147
|
context = ts.Context()
|
|
144
148
|
|
|
145
149
|
ds_list = [xr.open_zarr(p, mask_and_scale=mask_and_scale, decode_timedelta=True) for p in paths]
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
150
|
+
try:
|
|
151
|
+
ds = xr.concat(
|
|
152
|
+
ds_list,
|
|
153
|
+
dim=concat_dim,
|
|
154
|
+
data_vars="minimal",
|
|
155
|
+
compat="equals",
|
|
156
|
+
combine_attrs="drop_conflicts",
|
|
157
|
+
join="exact",
|
|
158
|
+
)
|
|
159
|
+
except ValueError:
|
|
160
|
+
logger.warning(f"Coordinate mismatch found in {data_source} input data. "
|
|
161
|
+
f"The coordinates will be overwritten! "
|
|
162
|
+
f"This might be fine for satellite data. "
|
|
163
|
+
f"Proceed with caution.")
|
|
164
|
+
ds = xr.concat(
|
|
165
|
+
ds_list,
|
|
166
|
+
dim=concat_dim,
|
|
167
|
+
data_vars="minimal",
|
|
168
|
+
compat="equals",
|
|
169
|
+
combine_attrs="drop_conflicts",
|
|
170
|
+
join="override",
|
|
171
|
+
)
|
|
153
172
|
|
|
154
173
|
if mask_and_scale:
|
|
155
174
|
_raise_if_mask_and_scale_used_for_data_vars(ds)
|
|
@@ -19,7 +19,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
19
19
|
"""
|
|
20
20
|
# Open the data
|
|
21
21
|
if isinstance(zarr_path, list | tuple):
|
|
22
|
-
ds = open_zarrs(zarr_path, concat_dim="time")
|
|
22
|
+
ds = open_zarrs(zarr_path, concat_dim="time", data_source="satellite")
|
|
23
23
|
else:
|
|
24
24
|
ds = open_zarr(zarr_path)
|
|
25
25
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Funcitons for loading site data."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
|
|
9
|
+
"""Open a site's generation data and metadata.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
generation_file_path: Path to the site generation netcdf data
|
|
13
|
+
metadata_file_path: Path to the site csv metadata
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
xr.DataArray: The opened site generation data
|
|
17
|
+
"""
|
|
18
|
+
generation_ds = xr.open_dataset(generation_file_path)
|
|
19
|
+
metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
|
|
20
|
+
|
|
21
|
+
if not metadata_df.index.is_unique:
|
|
22
|
+
raise ValueError("site_id is not unique in metadata")
|
|
23
|
+
|
|
24
|
+
# Ensure metadata aligns with the site_id dimension in generation_ds
|
|
25
|
+
metadata_df = metadata_df.reindex(generation_ds.site_id.values)
|
|
26
|
+
|
|
27
|
+
# Assign coordinates to the Dataset using the aligned metadata
|
|
28
|
+
# Check if variable capacity was passed with the generation data
|
|
29
|
+
# If not assign static capacity from metadata
|
|
30
|
+
if hasattr(generation_ds,"capacity_kwp"):
|
|
31
|
+
generation_ds = generation_ds.assign_coords(
|
|
32
|
+
latitude=(metadata_df.latitude.to_xarray()),
|
|
33
|
+
longitude=(metadata_df.longitude.to_xarray()),
|
|
34
|
+
capacity_kwp=generation_ds.capacity_kwp,
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
generation_ds = generation_ds.assign_coords(
|
|
38
|
+
latitude=(metadata_df.latitude.to_xarray()),
|
|
39
|
+
longitude=(metadata_df.longitude.to_xarray()),
|
|
40
|
+
capacity_kwp=(metadata_df.capacity_kwp.to_xarray()),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Sanity checks, to prevent inf or negative values
|
|
44
|
+
# Note NaNs are allowed in generation_kw as can have non overlapping time periods for sites
|
|
45
|
+
if np.isinf(generation_ds.generation_kw.values).all():
|
|
46
|
+
raise ValueError("generation_kw contains infinite (+/- inf) values")
|
|
47
|
+
if not (generation_ds.capacity_kwp.values > 0).all():
|
|
48
|
+
raise ValueError("capacity_kwp contains non-positive values")
|
|
49
|
+
|
|
50
|
+
site_da = generation_ds.generation_kw
|
|
51
|
+
|
|
52
|
+
# Validate data types directly in loading function
|
|
53
|
+
if not np.issubdtype(site_da.dtype, np.floating):
|
|
54
|
+
raise TypeError(f"Generation data should be float, not {site_da.dtype}")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
coord_dtypes = {
|
|
58
|
+
"time_utc": (np.datetime64,),
|
|
59
|
+
"site_id": (np.integer,),
|
|
60
|
+
"capacity_kwp": (np.integer, np.floating),
|
|
61
|
+
"latitude": (np.floating,),
|
|
62
|
+
"longitude": (np.floating,),
|
|
63
|
+
}
|
|
64
|
+
for coord, expected_dtypes in coord_dtypes.items():
|
|
65
|
+
if not any(np.issubdtype(site_da.coords[coord].dtype, dt) for dt in expected_dtypes):
|
|
66
|
+
dtype = site_da.coords[coord].dtype
|
|
67
|
+
allowed = ", ".join(dt.__name__ for dt in expected_dtypes)
|
|
68
|
+
raise TypeError(f"{coord} should be one of ({allowed}), not {dtype}")
|
|
69
|
+
|
|
70
|
+
# Load the data eagerly into memory by calling compute
|
|
71
|
+
# this makes the dataset faster to sample from, but
|
|
72
|
+
# at the cost of a little extra memory usage
|
|
73
|
+
return site_da.compute()
|
|
@@ -28,7 +28,7 @@ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
|
|
|
28
28
|
NWPSampleKey.channel_names: da.channel.values,
|
|
29
29
|
NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
|
|
30
30
|
NWPSampleKey.step: (da.step.values / 3600).astype(int),
|
|
31
|
-
NWPSampleKey.target_time_utc: da.
|
|
31
|
+
NWPSampleKey.target_time_utc: (da.init_time_utc.values + da.step.values).astype(float),
|
|
32
32
|
}
|
|
33
33
|
|
|
34
34
|
if t0_idx is not None:
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Takes the diff along the step axis for a given set of channels."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def diff_channels(da: xr.DataArray, accum_channels: list[str]) -> xr.DataArray:
|
|
8
|
+
"""Perform in-place diff of the given channels of the DataArray in the steps dimension.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
da: The DataArray to slice from
|
|
12
|
+
accum_channels: Channels which are accumulated and need to be differenced
|
|
13
|
+
"""
|
|
14
|
+
if da.dims[:2] != ("step", "channel"):
|
|
15
|
+
raise ValueError("This function assumes the first two dimensions are step then channel")
|
|
16
|
+
|
|
17
|
+
all_channels = da.channel.values
|
|
18
|
+
accum_channel_inds = [i for i, c in enumerate(all_channels) if c in accum_channels]
|
|
19
|
+
|
|
20
|
+
# Make a copy of the values to avoid changing the underlying numpy array
|
|
21
|
+
vals = da.values.copy()
|
|
22
|
+
vals[:-1, accum_channel_inds] = np.diff(vals[:, accum_channel_inds], axis=0)
|
|
23
|
+
da.values = vals
|
|
24
|
+
|
|
25
|
+
return da.isel(step=slice(0, -1))
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Functions for simulating dropout in time series data.
|
|
2
|
+
|
|
3
|
+
This is used for the following types of data: GSP, Satellite and Site
|
|
4
|
+
This is not used for NWP
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import xarray as xr
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_history_dropout(
|
|
13
|
+
t0: pd.Timestamp,
|
|
14
|
+
dropout_timedeltas: list[pd.Timedelta],
|
|
15
|
+
dropout_frac: float | list[float],
|
|
16
|
+
da: xr.DataArray,
|
|
17
|
+
) -> xr.DataArray:
|
|
18
|
+
"""Apply randomly sampled dropout to the historical part of some sequence data.
|
|
19
|
+
|
|
20
|
+
Dropped out data is replaced with NaNs
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
t0: The forecast init-time.
|
|
24
|
+
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
25
|
+
dropout_frac: The probabilit(ies) that each dropout timedelta will be applied. This should
|
|
26
|
+
be between 0 and 1 inclusive.
|
|
27
|
+
da: Xarray DataArray with 'time_utc' coordinate
|
|
28
|
+
"""
|
|
29
|
+
if len(dropout_timedeltas)==0:
|
|
30
|
+
return da
|
|
31
|
+
|
|
32
|
+
if isinstance(dropout_frac, float | int):
|
|
33
|
+
|
|
34
|
+
if not (0<=dropout_frac<=1):
|
|
35
|
+
raise ValueError("`dropout_frac` must be in range [0, 1]")
|
|
36
|
+
|
|
37
|
+
# Create list with equal chance for all dropout timedeltas
|
|
38
|
+
n = len(dropout_timedeltas)
|
|
39
|
+
dropout_frac = [dropout_frac/n for _ in range(n)]
|
|
40
|
+
else:
|
|
41
|
+
if not 0<=sum(dropout_frac)<=1:
|
|
42
|
+
raise ValueError("The sum of `dropout_frac` must be in range [0, 1]")
|
|
43
|
+
if len(dropout_timedeltas)!=len(dropout_frac):
|
|
44
|
+
raise ValueError("`dropout_timedeltas` and `dropout_frac` must have the same length")
|
|
45
|
+
|
|
46
|
+
dropout_frac = [*dropout_frac] # Make copy of the list so we can append to it
|
|
47
|
+
|
|
48
|
+
dropout_timedeltas = [*dropout_timedeltas] # Make copy of the list so we can append to it
|
|
49
|
+
|
|
50
|
+
# Add chance of no dropout
|
|
51
|
+
dropout_frac.append(1-sum(dropout_frac))
|
|
52
|
+
dropout_timedeltas.append(None)
|
|
53
|
+
|
|
54
|
+
timedelta_choice = np.random.choice(dropout_timedeltas, p=dropout_frac)
|
|
55
|
+
|
|
56
|
+
if timedelta_choice is None:
|
|
57
|
+
return da
|
|
58
|
+
else:
|
|
59
|
+
return da.where((da.time_utc <= timedelta_choice + t0) | (da.time_utc> t0))
|
|
@@ -37,9 +37,9 @@ class Location:
|
|
|
37
37
|
return self._projections[coord_system]
|
|
38
38
|
else:
|
|
39
39
|
raise ValueError(
|
|
40
|
-
"Requested the coodinate in {coord_system}. This has not yet been added. "
|
|
40
|
+
f"Requested the coodinate in {coord_system}. This has not yet been added. "
|
|
41
41
|
"The current available coordinate systems are "
|
|
42
|
-
f"{list(self.
|
|
42
|
+
f"{list(self._projections.keys())}",
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
def add_coord_system(self, x: float, y: float, coord_system: int) -> None:
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Select spatial slices."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.select.geospatial import find_coord_system
|
|
7
|
+
from ocf_data_sampler.select.location import Location
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_pixel_index_location(da: xr.DataArray, location: Location) -> tuple[int, int]:
|
|
11
|
+
"""Find pixel index location closest to given Location.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
da: The xarray DataArray.
|
|
15
|
+
location: The Location object representing the point of interest.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
The pixel indices.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If the location is outside the bounds of the DataArray.
|
|
22
|
+
"""
|
|
23
|
+
target_coords, x_dim, y_dim = find_coord_system(da)
|
|
24
|
+
|
|
25
|
+
x, y = location.in_coord_system(target_coords)
|
|
26
|
+
|
|
27
|
+
x_vals = da[x_dim].values
|
|
28
|
+
y_vals = da[y_dim].values
|
|
29
|
+
|
|
30
|
+
# Check that requested point lies within the data
|
|
31
|
+
if not (x_vals[0] < x < x_vals[-1]):
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"{x} is not in the interval {x_vals[0]}: {x_vals[-1]}",
|
|
34
|
+
)
|
|
35
|
+
if not (y_vals[0] < y < y_vals[-1]):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"{y} is not in the interval {y_vals[0]}: {y_vals[-1]}",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
closest_x = np.argmin(np.abs(x_vals - x))
|
|
41
|
+
closest_y = np.argmin(np.abs(y_vals - y))
|
|
42
|
+
|
|
43
|
+
return closest_x, closest_y
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def select_spatial_slice_pixels(
|
|
47
|
+
da: xr.DataArray,
|
|
48
|
+
location: Location,
|
|
49
|
+
width_pixels: int,
|
|
50
|
+
height_pixels: int,
|
|
51
|
+
) -> xr.DataArray:
|
|
52
|
+
"""Select spatial slice based off pixels from location point of interest.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
da: xarray DataArray to slice from
|
|
56
|
+
location: Location of interest that will be the center of the returned slice
|
|
57
|
+
height_pixels: Height of the slice in pixels
|
|
58
|
+
width_pixels: Width of the slice in pixels
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The selected DataArray slice.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If the dimensions are not even or the slice is not allowed
|
|
65
|
+
when padding is required.
|
|
66
|
+
"""
|
|
67
|
+
if (width_pixels % 2) != 0:
|
|
68
|
+
raise ValueError("Width must be an even number")
|
|
69
|
+
if (height_pixels % 2) != 0:
|
|
70
|
+
raise ValueError("Height must be an even number")
|
|
71
|
+
|
|
72
|
+
_, x_dim, y_dim = find_coord_system(da)
|
|
73
|
+
center_idx_x, center_idx_y = _get_pixel_index_location(da, location)
|
|
74
|
+
|
|
75
|
+
half_width = width_pixels // 2
|
|
76
|
+
half_height = height_pixels // 2
|
|
77
|
+
|
|
78
|
+
left_idx = int(center_idx_x - half_width)
|
|
79
|
+
right_idx = int(center_idx_x + half_width)
|
|
80
|
+
bottom_idx = int(center_idx_y - half_height)
|
|
81
|
+
top_idx = int(center_idx_y + half_height)
|
|
82
|
+
|
|
83
|
+
data_width_pixels = len(da[x_dim])
|
|
84
|
+
data_height_pixels = len(da[y_dim])
|
|
85
|
+
|
|
86
|
+
# Padding checks
|
|
87
|
+
slice_unavailable = (
|
|
88
|
+
left_idx < 0
|
|
89
|
+
or right_idx > data_width_pixels
|
|
90
|
+
or bottom_idx < 0
|
|
91
|
+
or top_idx > data_height_pixels
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if slice_unavailable:
|
|
95
|
+
issues = []
|
|
96
|
+
if left_idx < 0:
|
|
97
|
+
issues.append(f"left_idx ({left_idx}) < 0")
|
|
98
|
+
if right_idx > data_width_pixels:
|
|
99
|
+
issues.append(f"right_idx ({right_idx}) > data_width_pixels ({data_width_pixels})")
|
|
100
|
+
if bottom_idx < 0:
|
|
101
|
+
issues.append(f"bottom_idx ({bottom_idx}) < 0")
|
|
102
|
+
if top_idx > data_height_pixels:
|
|
103
|
+
issues.append(f"top_idx ({top_idx}) > data_height_pixels ({data_height_pixels})")
|
|
104
|
+
issue_details = "\n - ".join(issues)
|
|
105
|
+
raise ValueError(f"Window for location {location} not available: \n - {issue_details}")
|
|
106
|
+
|
|
107
|
+
# Standard selection - without padding
|
|
108
|
+
da = da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)})
|
|
109
|
+
|
|
110
|
+
return da
|