ocf-data-sampler 0.5.2__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/load/nwp/providers/utils.py +1 -1
- ocf_data_sampler/load/open_tensorstore_zarrs.py +2 -1
- ocf_data_sampler/load/satellite.py +1 -1
- ocf_data_sampler/load/xarray_tensorstore.py +299 -0
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +14 -23
- ocf_data_sampler/torch_datasets/datasets/site.py +29 -15
- ocf_data_sampler/torch_datasets/utils/__init__.py +1 -1
- ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +57 -0
- {ocf_data_sampler-0.5.2.dist-info → ocf_data_sampler-0.5.5.dist-info}/METADATA +5 -4
- {ocf_data_sampler-0.5.2.dist-info → ocf_data_sampler-0.5.5.dist-info}/RECORD +12 -11
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -18
- {ocf_data_sampler-0.5.2.dist-info → ocf_data_sampler-0.5.5.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.5.2.dist-info → ocf_data_sampler-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Satellite loader."""
|
|
2
2
|
import numpy as np
|
|
3
3
|
import xarray as xr
|
|
4
|
-
from xarray_tensorstore import open_zarr
|
|
5
4
|
|
|
6
5
|
from ocf_data_sampler.load.utils import (
|
|
7
6
|
check_time_unique_increasing,
|
|
8
7
|
get_xr_data_array_from_xr_dataset,
|
|
9
8
|
make_spatial_coords_increasing,
|
|
10
9
|
)
|
|
10
|
+
from ocf_data_sampler.load.xarray_tensorstore import open_zarr
|
|
11
11
|
|
|
12
12
|
from .open_tensorstore_zarrs import open_zarrs
|
|
13
13
|
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# Copyright 2023 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utilities for loading TensorStore data into Xarray.
|
|
15
|
+
|
|
16
|
+
Copied from https://github.com/google-research/tensorstore/blob/main/tensorstore/xarray.py
|
|
17
|
+
But we added small changes so that it works for zarr3
|
|
18
|
+
https://github.com/google/xarray-tensorstore/pull/22
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import dataclasses
|
|
23
|
+
import math
|
|
24
|
+
import os.path
|
|
25
|
+
import re
|
|
26
|
+
from typing import TypeVar
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import tensorstore
|
|
30
|
+
import xarray
|
|
31
|
+
import zarr
|
|
32
|
+
from xarray.core import indexing
|
|
33
|
+
|
|
34
|
+
__version__ = "0.1.5" # keep in sync with setup.py
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
Index = TypeVar("Index", int, slice, np.ndarray, None)
|
|
38
|
+
XarrayData = TypeVar("XarrayData", xarray.Dataset, xarray.DataArray)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _numpy_to_tensorstore_index(index: Index, size: int) -> Index:
|
|
42
|
+
"""Switch from NumPy to TensorStore indexing conventions."""
|
|
43
|
+
# https://google.github.io/tensorstore/python/indexing.html#differences-compared-to-numpy-indexing
|
|
44
|
+
if index is None:
|
|
45
|
+
return None
|
|
46
|
+
elif isinstance(index, int):
|
|
47
|
+
# Negative integers do not count from the end in TensorStore
|
|
48
|
+
return index + size if index < 0 else index
|
|
49
|
+
elif isinstance(index, slice):
|
|
50
|
+
start = _numpy_to_tensorstore_index(index.start, size)
|
|
51
|
+
stop = _numpy_to_tensorstore_index(index.stop, size)
|
|
52
|
+
if stop is not None:
|
|
53
|
+
# TensorStore does not allow out of bounds slicing
|
|
54
|
+
stop = min(stop, size)
|
|
55
|
+
return slice(start, stop, index.step)
|
|
56
|
+
else:
|
|
57
|
+
assert isinstance(index, np.ndarray) # noqa S101
|
|
58
|
+
return np.where(index < 0, index + size, index)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclasses.dataclass(frozen=True)
|
|
62
|
+
class _TensorStoreAdapter(indexing.ExplicitlyIndexed):
|
|
63
|
+
"""TensorStore array that can be wrapped by xarray.Variable.
|
|
64
|
+
|
|
65
|
+
We use Xarray's semi-internal ExplicitlyIndexed API so that Xarray will not
|
|
66
|
+
attempt to load our array into memory as a NumPy array. In the future, this
|
|
67
|
+
should be supported by public Xarray APIs, as part of the refactor discussed
|
|
68
|
+
in: https://github.com/pydata/xarray/issues/3981
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
array: tensorstore.TensorStore
|
|
72
|
+
future: tensorstore.Future | None = None
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def shape(self) -> tuple[int, ...]:
|
|
76
|
+
return self.array.shape
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def dtype(self) -> np.dtype:
|
|
80
|
+
return self.array.dtype.numpy_dtype
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def ndim(self) -> int:
|
|
84
|
+
return len(self.shape)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def size(self) -> int:
|
|
88
|
+
return math.prod(self.shape)
|
|
89
|
+
|
|
90
|
+
def __getitem__(self, key: indexing.ExplicitIndexer) -> _TensorStoreAdapter:
|
|
91
|
+
index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
|
|
92
|
+
if isinstance(key, indexing.OuterIndexer):
|
|
93
|
+
# TODO(shoyer): fix this for newer versions of Xarray.
|
|
94
|
+
# We get the error message:
|
|
95
|
+
# AttributeError: '_TensorStoreAdapter' object has no attribute 'oindex'
|
|
96
|
+
indexed = self.array.oindex[index_tuple]
|
|
97
|
+
elif isinstance(key, indexing.VectorizedIndexer):
|
|
98
|
+
indexed = self.array.vindex[index_tuple]
|
|
99
|
+
else:
|
|
100
|
+
assert isinstance(key, indexing.BasicIndexer) # noqa S101
|
|
101
|
+
indexed = self.array[index_tuple]
|
|
102
|
+
# Translate to the origin so repeated indexing is relative to the new bounds
|
|
103
|
+
# like NumPy, not absolute like TensorStore
|
|
104
|
+
translated = indexed[tensorstore.d[:].translate_to[0]]
|
|
105
|
+
return type(self)(translated)
|
|
106
|
+
|
|
107
|
+
def __setitem__(self, key: indexing.ExplicitIndexer, value) -> None: # noqa ANN001
|
|
108
|
+
index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
|
|
109
|
+
if isinstance(key, indexing.OuterIndexer):
|
|
110
|
+
self.array.oindex[index_tuple] = value
|
|
111
|
+
elif isinstance(key, indexing.VectorizedIndexer):
|
|
112
|
+
self.array.vindex[index_tuple] = value
|
|
113
|
+
else:
|
|
114
|
+
assert isinstance(key, indexing.BasicIndexer) # noqa S101
|
|
115
|
+
self.array[index_tuple] = value
|
|
116
|
+
# Invalidate the future so that the next read will pick up the new value
|
|
117
|
+
object.__setattr__(self, "future", None)
|
|
118
|
+
|
|
119
|
+
# xarray>2024.02.0 uses oindex and vindex properties, which are expected to
|
|
120
|
+
# return objects whose __getitem__ method supports the appropriate form of
|
|
121
|
+
# indexing.
|
|
122
|
+
@property
|
|
123
|
+
def oindex(self) -> _TensorStoreAdapter:
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def vindex(self) -> _TensorStoreAdapter:
|
|
128
|
+
return self
|
|
129
|
+
|
|
130
|
+
def transpose(self, order: tuple[int, ...]) -> _TensorStoreAdapter:
|
|
131
|
+
transposed = self.array[tensorstore.d[order].transpose[:]]
|
|
132
|
+
return type(self)(transposed)
|
|
133
|
+
|
|
134
|
+
def read(self) -> _TensorStoreAdapter:
|
|
135
|
+
future = self.array.read()
|
|
136
|
+
return type(self)(self.array, future)
|
|
137
|
+
|
|
138
|
+
def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: # type: ignore
|
|
139
|
+
future = self.array.read() if self.future is None else self.future
|
|
140
|
+
return np.asarray(future.result(), dtype=dtype)
|
|
141
|
+
|
|
142
|
+
def get_duck_array(self) -> np.ndarray:
|
|
143
|
+
# special method for xarray to return an in-memory (computed) representation
|
|
144
|
+
return np.asarray(self)
|
|
145
|
+
|
|
146
|
+
# Work around the missing __copy__ and __deepcopy__ methods from TensorStore,
|
|
147
|
+
# which are needed for Xarray:
|
|
148
|
+
# https://github.com/google/tensorstore/issues/109
|
|
149
|
+
# TensorStore objects are immutable, so there's no need to actually copy them.
|
|
150
|
+
|
|
151
|
+
def __copy__(self) -> _TensorStoreAdapter:
|
|
152
|
+
return type(self)(self.array, self.future)
|
|
153
|
+
|
|
154
|
+
def __deepcopy__(self, memo) -> _TensorStoreAdapter: # noqa ANN001
|
|
155
|
+
return self.__copy__()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _read_tensorstore(
|
|
159
|
+
array: indexing.ExplicitlyIndexed,
|
|
160
|
+
) -> indexing.ExplicitlyIndexed:
|
|
161
|
+
"""Starts async reading on a TensorStore array."""
|
|
162
|
+
return array.read() if isinstance(array, _TensorStoreAdapter) else array
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def read(xarraydata: XarrayData, /) -> XarrayData:
|
|
166
|
+
"""Starts async reads on all TensorStore arrays."""
|
|
167
|
+
# pylint: disable=protected-access
|
|
168
|
+
if isinstance(xarraydata, xarray.Dataset):
|
|
169
|
+
data = {
|
|
170
|
+
name: _read_tensorstore(var.variable._data)
|
|
171
|
+
for name, var in xarraydata.data_vars.items()
|
|
172
|
+
}
|
|
173
|
+
elif isinstance(xarraydata, xarray.DataArray):
|
|
174
|
+
data = _read_tensorstore(xarraydata.variable._data)
|
|
175
|
+
else:
|
|
176
|
+
raise TypeError(f"argument is not a DataArray or Dataset: {xarraydata}")
|
|
177
|
+
# pylint: enable=protected-access
|
|
178
|
+
return xarraydata.copy(data=data)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
_DEFAULT_STORAGE_DRIVER = "file"
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
|
|
185
|
+
if re.match(r"\w+\://", path): # path is a URI
|
|
186
|
+
kv_store = path
|
|
187
|
+
else:
|
|
188
|
+
kv_store = {"driver": _DEFAULT_STORAGE_DRIVER, "path": path}
|
|
189
|
+
|
|
190
|
+
if zarr_format == 2:
|
|
191
|
+
return {"driver": "zarr2", "kvstore": kv_store}
|
|
192
|
+
else:
|
|
193
|
+
return {"driver": "zarr3", "kvstore": kv_store}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset) -> None:
|
|
197
|
+
"""Check a dataset for data variables that would need masking or scaling."""
|
|
198
|
+
advice = (
|
|
199
|
+
"Consider re-opening with xarray_tensorstore.open_zarr(..., "
|
|
200
|
+
"mask_and_scale=False), or falling back to use xarray.open_zarr()."
|
|
201
|
+
)
|
|
202
|
+
for k in ds:
|
|
203
|
+
encoding = ds[k].encoding
|
|
204
|
+
for attr in ["_FillValue", "missing_value"]:
|
|
205
|
+
fill_value = encoding.get(attr, np.nan)
|
|
206
|
+
if fill_value == fill_value: # pylint: disable=comparison-with-itself
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"variable {k} has non-NaN fill value, which is not supported by"
|
|
209
|
+
f" xarray-tensorstore: {fill_value}. {advice}",
|
|
210
|
+
)
|
|
211
|
+
for attr in ["scale_factor", "add_offset"]:
|
|
212
|
+
if attr in encoding:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"variable {k} uses scale/offset encoding, which is not supported"
|
|
215
|
+
f" by xarray-tensorstore: {encoding}. {advice}",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def open_zarr(
|
|
220
|
+
path: str,
|
|
221
|
+
*,
|
|
222
|
+
context: tensorstore.Context | None = None,
|
|
223
|
+
mask_and_scale: bool = True,
|
|
224
|
+
write: bool = False,
|
|
225
|
+
) -> xarray.Dataset:
|
|
226
|
+
"""Open an xarray.Dataset from Zarr using TensorStore.
|
|
227
|
+
|
|
228
|
+
For best performance, explicitly call `read()` to asynchronously load data
|
|
229
|
+
in parallel. Otherwise, xarray's `.compute()` method will load each variable's
|
|
230
|
+
data in sequence.
|
|
231
|
+
|
|
232
|
+
Example usage:
|
|
233
|
+
|
|
234
|
+
import xarray_tensorstore
|
|
235
|
+
|
|
236
|
+
ds = xarray_tensorstore.open_zarr(path)
|
|
237
|
+
|
|
238
|
+
# indexing & transposing is lazy
|
|
239
|
+
example = ds.sel(time='2020-01-01').transpose('longitude', 'latitude', ...)
|
|
240
|
+
|
|
241
|
+
# start reading data asynchronously
|
|
242
|
+
read_example = xarray_tensorstore.read(example)
|
|
243
|
+
|
|
244
|
+
# blocking conversion of the data into NumPy arrays
|
|
245
|
+
numpy_example = read_example.compute()
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
path: path or URI to Zarr group to open.
|
|
249
|
+
context: TensorStore configuration options to use when opening arrays.
|
|
250
|
+
mask_and_scale: if True (default), attempt to apply masking and scaling like
|
|
251
|
+
xarray.open_zarr(). This is only supported for coordinate variables and
|
|
252
|
+
otherwise will raise an error.
|
|
253
|
+
write: Allow write access. Defaults to False.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Dataset with all data variables opened via TensorStore.
|
|
257
|
+
"""
|
|
258
|
+
# We use xarray.open_zarr (which uses Zarr Python internally) to open the
|
|
259
|
+
# initial version of the dataset for a few reasons:
|
|
260
|
+
# 1. TensorStore does not support Zarr groups or array attributes, which we
|
|
261
|
+
# need to open in the xarray.Dataset. We use Zarr Python instead of
|
|
262
|
+
# parsing the raw Zarr metadata files ourselves.
|
|
263
|
+
# 2. TensorStore doesn't support non-standard Zarr dtypes like UTF-8 strings.
|
|
264
|
+
# 3. Xarray's open_zarr machinery does some pre-processing (e.g., from numeric
|
|
265
|
+
# to datetime64 dtypes) that we would otherwise need to invoke explicitly
|
|
266
|
+
# via xarray.decode_cf().
|
|
267
|
+
#
|
|
268
|
+
# Fortunately (2) and (3) are most commonly encountered on small coordinate
|
|
269
|
+
# arrays, for which the performance advantages of TensorStore are irrelevant.
|
|
270
|
+
|
|
271
|
+
if context is None:
|
|
272
|
+
context = tensorstore.Context()
|
|
273
|
+
|
|
274
|
+
# chunks=None means avoid using dask
|
|
275
|
+
ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
|
|
276
|
+
|
|
277
|
+
# find out if its 2 or 3
|
|
278
|
+
try:
|
|
279
|
+
# this should work with zarr>=3 - https://github.com/zarr-developers/zarr-python
|
|
280
|
+
zarr_format = zarr.open(path).metadata.zarr_format
|
|
281
|
+
except: # noqa E722
|
|
282
|
+
# try to open it, but if it fails, assume zarr_format 2
|
|
283
|
+
zarr_format = 2
|
|
284
|
+
|
|
285
|
+
if mask_and_scale:
|
|
286
|
+
# Data variables get replaced below with _TensorStoreAdapter arrays, which
|
|
287
|
+
# don't get masked or scaled. Raising an error avoids surprising users with
|
|
288
|
+
# incorrect data values.
|
|
289
|
+
_raise_if_mask_and_scale_used_for_data_vars(ds)
|
|
290
|
+
|
|
291
|
+
specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds}
|
|
292
|
+
array_futures = {
|
|
293
|
+
k: tensorstore.open(spec, read=True, write=write, context=context)
|
|
294
|
+
for k, spec in specs.items()
|
|
295
|
+
}
|
|
296
|
+
arrays = {k: v.result() for k, v in array_futures.items()}
|
|
297
|
+
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
|
|
298
|
+
|
|
299
|
+
return ds.copy(data=new_data)
|
|
@@ -21,7 +21,7 @@ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
|
21
21
|
from ocf_data_sampler.select import Location, fill_time_periods
|
|
22
22
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
23
23
|
from ocf_data_sampler.torch_datasets.utils import (
|
|
24
|
-
|
|
24
|
+
config_normalization_values_to_dicts,
|
|
25
25
|
find_valid_time_periods,
|
|
26
26
|
slice_datasets_by_space,
|
|
27
27
|
slice_datasets_by_time,
|
|
@@ -110,11 +110,14 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
110
110
|
self.config = config
|
|
111
111
|
self.datasets_dict = datasets_dict
|
|
112
112
|
|
|
113
|
+
# Extract the normalisation values from the config for faster access
|
|
114
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
115
|
+
self.means_dict = means_dict
|
|
116
|
+
self.stds_dict = stds_dict
|
|
113
117
|
|
|
114
|
-
@staticmethod
|
|
115
118
|
def process_and_combine_datasets(
|
|
119
|
+
self,
|
|
116
120
|
dataset_dict: dict,
|
|
117
|
-
config: Configuration,
|
|
118
121
|
t0: pd.Timestamp,
|
|
119
122
|
location: Location,
|
|
120
123
|
) -> NumpySample:
|
|
@@ -122,7 +125,6 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
122
125
|
|
|
123
126
|
Args:
|
|
124
127
|
dataset_dict: Dictionary of xarray datasets
|
|
125
|
-
config: Configuration object
|
|
126
128
|
t0: init-time for sample
|
|
127
129
|
location: location of the sample
|
|
128
130
|
"""
|
|
@@ -134,13 +136,8 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
134
136
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
135
137
|
|
|
136
138
|
# Standardise and convert to NumpyBatch
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
config.input_data.nwp[nwp_key].channel_means,
|
|
140
|
-
)
|
|
141
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
142
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
143
|
-
)
|
|
139
|
+
da_channel_means = self.means_dict["nwp"][nwp_key]
|
|
140
|
+
da_channel_stds = self.stds_dict["nwp"][nwp_key]
|
|
144
141
|
|
|
145
142
|
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
146
143
|
|
|
@@ -153,15 +150,15 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
153
150
|
da_sat = dataset_dict["sat"]
|
|
154
151
|
|
|
155
152
|
# Standardise and convert to NumpyBatch
|
|
156
|
-
da_channel_means =
|
|
157
|
-
da_channel_stds =
|
|
153
|
+
da_channel_means = self.means_dict["sat"]
|
|
154
|
+
da_channel_stds = self.stds_dict["sat"]
|
|
158
155
|
|
|
159
156
|
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
160
157
|
|
|
161
158
|
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
162
159
|
|
|
163
160
|
if "gsp" in dataset_dict:
|
|
164
|
-
gsp_config = config.input_data.gsp
|
|
161
|
+
gsp_config = self.config.input_data.gsp
|
|
165
162
|
da_gsp = dataset_dict["gsp"]
|
|
166
163
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
167
164
|
|
|
@@ -183,13 +180,8 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
183
180
|
)
|
|
184
181
|
|
|
185
182
|
# Only add solar position if explicitly configured
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
config.input_data.solar_position is not None
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
if has_solar_config:
|
|
192
|
-
solar_config = config.input_data.solar_position
|
|
183
|
+
if self.config.input_data.solar_position is not None:
|
|
184
|
+
solar_config = self.config.input_data.solar_position
|
|
193
185
|
|
|
194
186
|
# Create datetime range for solar position calculation
|
|
195
187
|
datetimes = pd.date_range(
|
|
@@ -264,7 +256,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
264
256
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
265
257
|
sample_dict = compute(sample_dict)
|
|
266
258
|
|
|
267
|
-
return self.process_and_combine_datasets(sample_dict,
|
|
259
|
+
return self.process_and_combine_datasets(sample_dict, t0, location)
|
|
268
260
|
|
|
269
261
|
@override
|
|
270
262
|
def __getitem__(self, idx: int) -> NumpySample:
|
|
@@ -330,7 +322,6 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
|
|
|
330
322
|
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
331
323
|
gsp_numpy_sample = self.process_and_combine_datasets(
|
|
332
324
|
gsp_sample_dict,
|
|
333
|
-
self.config,
|
|
334
325
|
t0,
|
|
335
326
|
location,
|
|
336
327
|
)
|
|
@@ -25,7 +25,7 @@ from ocf_data_sampler.select import (
|
|
|
25
25
|
intersection_of_multiple_dataframes_of_periods,
|
|
26
26
|
)
|
|
27
27
|
from ocf_data_sampler.torch_datasets.utils import (
|
|
28
|
-
|
|
28
|
+
config_normalization_values_to_dicts,
|
|
29
29
|
find_valid_time_periods,
|
|
30
30
|
slice_datasets_by_space,
|
|
31
31
|
slice_datasets_by_time,
|
|
@@ -62,6 +62,8 @@ def process_and_combine_datasets(
|
|
|
62
62
|
dataset_dict: dict,
|
|
63
63
|
config: Configuration,
|
|
64
64
|
t0: pd.Timestamp,
|
|
65
|
+
means_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
|
|
66
|
+
stds_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
|
|
65
67
|
) -> NumpySample:
|
|
66
68
|
"""Normalise and convert data to numpy arrays.
|
|
67
69
|
|
|
@@ -69,6 +71,8 @@ def process_and_combine_datasets(
|
|
|
69
71
|
dataset_dict: Dictionary of xarray datasets
|
|
70
72
|
config: Configuration object
|
|
71
73
|
t0: init-time for sample
|
|
74
|
+
means_dict: Nested dictionary of mean values for the input data sources
|
|
75
|
+
stds_dict: Nested dictionary of std values for the input data sources
|
|
72
76
|
"""
|
|
73
77
|
numpy_modalities = []
|
|
74
78
|
|
|
@@ -79,12 +83,8 @@ def process_and_combine_datasets(
|
|
|
79
83
|
|
|
80
84
|
# Standardise and convert to NumpyBatch
|
|
81
85
|
|
|
82
|
-
da_channel_means =
|
|
83
|
-
|
|
84
|
-
)
|
|
85
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
86
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
87
|
-
)
|
|
86
|
+
da_channel_means = means_dict["nwp"][nwp_key]
|
|
87
|
+
da_channel_stds = stds_dict["nwp"][nwp_key]
|
|
88
88
|
|
|
89
89
|
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
90
90
|
|
|
@@ -97,8 +97,8 @@ def process_and_combine_datasets(
|
|
|
97
97
|
da_sat = dataset_dict["sat"]
|
|
98
98
|
|
|
99
99
|
# Standardise and convert to NumpyBatch
|
|
100
|
-
da_channel_means =
|
|
101
|
-
da_channel_stds =
|
|
100
|
+
da_channel_means = means_dict["sat"]
|
|
101
|
+
da_channel_stds = stds_dict["sat"]
|
|
102
102
|
|
|
103
103
|
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
104
104
|
|
|
@@ -109,11 +109,7 @@ def process_and_combine_datasets(
|
|
|
109
109
|
da_sites = da_sites / da_sites.capacity_kwp
|
|
110
110
|
|
|
111
111
|
# Convert to NumpyBatch
|
|
112
|
-
numpy_modalities.append(
|
|
113
|
-
convert_site_to_numpy_sample(
|
|
114
|
-
da_sites,
|
|
115
|
-
),
|
|
116
|
-
)
|
|
112
|
+
numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
|
|
117
113
|
|
|
118
114
|
# add datetime features
|
|
119
115
|
datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
|
|
@@ -193,6 +189,11 @@ class SitesDataset(Dataset):
|
|
|
193
189
|
# Assign coords and indices to self
|
|
194
190
|
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
195
191
|
|
|
192
|
+
# Extract the normalisation values from the config for faster access
|
|
193
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
194
|
+
self.means_dict = means_dict
|
|
195
|
+
self.stds_dict = stds_dict
|
|
196
|
+
|
|
196
197
|
def find_valid_t0_and_site_ids(
|
|
197
198
|
self,
|
|
198
199
|
datasets_dict: dict,
|
|
@@ -273,7 +274,13 @@ class SitesDataset(Dataset):
|
|
|
273
274
|
|
|
274
275
|
sample_dict = compute(sample_dict)
|
|
275
276
|
|
|
276
|
-
return process_and_combine_datasets(
|
|
277
|
+
return process_and_combine_datasets(
|
|
278
|
+
sample_dict,
|
|
279
|
+
self.config,
|
|
280
|
+
t0,
|
|
281
|
+
self.means_dict,
|
|
282
|
+
self.stds_dict,
|
|
283
|
+
)
|
|
277
284
|
|
|
278
285
|
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
279
286
|
"""Generate a sample for a given site id and t0.
|
|
@@ -332,6 +339,11 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
332
339
|
# Assign coords and indices to self
|
|
333
340
|
self.valid_t0s = valid_t0s
|
|
334
341
|
|
|
342
|
+
# Extract the normalisation values from the config for faster access
|
|
343
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
344
|
+
self.means_dict = means_dict
|
|
345
|
+
self.stds_dict = stds_dict
|
|
346
|
+
|
|
335
347
|
def find_valid_t0s(
|
|
336
348
|
self,
|
|
337
349
|
datasets_dict: dict,
|
|
@@ -406,6 +418,8 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
406
418
|
site_sample_dict,
|
|
407
419
|
self.config,
|
|
408
420
|
t0,
|
|
421
|
+
self.means_dict,
|
|
422
|
+
self.stds_dict,
|
|
409
423
|
)
|
|
410
424
|
site_samples.append(site_numpy_sample)
|
|
411
425
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
|
|
2
2
|
from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
|
|
3
3
|
from .valid_time_periods import find_valid_time_periods
|
|
4
4
|
from .spatial_slice_for_dataset import slice_datasets_by_space
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Utility function for converting channel dictionaries to xarray DataArrays."""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.config import Configuration
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
|
|
9
|
+
"""Converts a dictionary of channel values to a DataArray.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
channel_dict: Dictionary mapping channel names (str) to their values (float).
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
xr.DataArray: A 1D DataArray with channels as coordinates.
|
|
16
|
+
"""
|
|
17
|
+
return xr.DataArray(
|
|
18
|
+
list(channel_dict.values()),
|
|
19
|
+
coords={"channel": list(channel_dict.keys())},
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def config_normalization_values_to_dicts(
|
|
23
|
+
config: Configuration,
|
|
24
|
+
) -> tuple[dict[str, xr.DataArray | dict[str, xr.DataArray]]]:
|
|
25
|
+
"""Construct DataArrays of mean and std values from the config normalisation constants.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config: Data configuration.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Means dict
|
|
32
|
+
Stds dict
|
|
33
|
+
"""
|
|
34
|
+
means_dict = {}
|
|
35
|
+
stds_dict = {}
|
|
36
|
+
|
|
37
|
+
if config.input_data.nwp is not None:
|
|
38
|
+
|
|
39
|
+
means_dict["nwp"] = {}
|
|
40
|
+
stds_dict["nwp"] = {}
|
|
41
|
+
|
|
42
|
+
for nwp_key in config.input_data.nwp:
|
|
43
|
+
# Standardise and convert to NumpyBatch
|
|
44
|
+
|
|
45
|
+
means_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
|
|
46
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
47
|
+
)
|
|
48
|
+
stds_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
|
|
49
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if config.input_data.satellite is not None:
|
|
53
|
+
|
|
54
|
+
means_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
55
|
+
stds_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
56
|
+
|
|
57
|
+
return means_dict, stds_dict
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ocf-data-sampler
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Author: James Fulton, Peter Dudfield
|
|
5
5
|
Author-email: Open Climate Fix team <info@openclimatefix.org>
|
|
6
6
|
License: MIT License
|
|
@@ -28,14 +28,14 @@ License: MIT License
|
|
|
28
28
|
Project-URL: repository, https://github.com/openclimatefix/ocf-data-sampler
|
|
29
29
|
Classifier: Programming Language :: Python :: 3
|
|
30
30
|
Classifier: License :: OSI Approved :: MIT License
|
|
31
|
-
Requires-Python: >=3.
|
|
31
|
+
Requires-Python: >=3.11
|
|
32
32
|
Description-Content-Type: text/markdown
|
|
33
33
|
Requires-Dist: torch
|
|
34
34
|
Requires-Dist: numpy
|
|
35
35
|
Requires-Dist: pandas
|
|
36
36
|
Requires-Dist: xarray
|
|
37
37
|
Requires-Dist: zarr
|
|
38
|
-
Requires-Dist: numcodecs
|
|
38
|
+
Requires-Dist: numcodecs
|
|
39
39
|
Requires-Dist: dask
|
|
40
40
|
Requires-Dist: matplotlib
|
|
41
41
|
Requires-Dist: pvlib
|
|
@@ -44,7 +44,8 @@ Requires-Dist: pyproj
|
|
|
44
44
|
Requires-Dist: pyaml_env
|
|
45
45
|
Requires-Dist: pyresample
|
|
46
46
|
Requires-Dist: h5netcdf
|
|
47
|
-
Requires-Dist:
|
|
47
|
+
Requires-Dist: tensorstore
|
|
48
|
+
Requires-Dist: zarr>=3
|
|
48
49
|
|
|
49
50
|
# ocf-data-sampler
|
|
50
51
|
|
|
@@ -9,10 +9,11 @@ ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKF
|
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=d30jQWnwFaLj6rKNMHdz1qD8fzF8q--RNnEXT7bGiX0,2981
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=K8rWykjII-3g127If7WRRFivzHNx3SshCvZj4uQlf28,2089
|
|
12
|
-
ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=
|
|
13
|
-
ocf_data_sampler/load/satellite.py,sha256=
|
|
12
|
+
ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=ElXmW7GhYDpsHZr7KjM-KIDNJMc4lmgzVIBwHx5Wl0Q,2748
|
|
13
|
+
ocf_data_sampler/load/satellite.py,sha256=X5ZqFfMgab_WDwI7w1ZmdyMeh3GwV1g7mBd8tFgr8dM,1862
|
|
14
14
|
ocf_data_sampler/load/site.py,sha256=WtOy20VMHJIY0IwEemCdcecSDUGcVaLUown-4ixJw90,2147
|
|
15
15
|
ocf_data_sampler/load/utils.py,sha256=AGL0aOOQPrgqNBTjlBtR7Qg1PyQov3DFJo-y198u8pY,2044
|
|
16
|
+
ocf_data_sampler/load/xarray_tensorstore.py,sha256=DSZl364Hn3QjcVxxPmBKU9rsc5BlJBdzL_SMrv-9os0,10997
|
|
16
17
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
17
18
|
ocf_data_sampler/load/nwp/nwp.py,sha256=0E9shei3Mq1N7F-fBlEKY5Hm0_kI7ysY_rffnWIshvk,3612
|
|
18
19
|
ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -21,7 +22,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=P7JqfssmQq8eHKKXaBexsxts325A
|
|
|
21
22
|
ocf_data_sampler/load/nwp/providers/gfs.py,sha256=h6vm-Rfz1JGOE4P_fP1_XQJ3bugNbeNAIyt56N8B1Dc,1066
|
|
22
23
|
ocf_data_sampler/load/nwp/providers/icon.py,sha256=iVZwLKRr_D74_kAu5MHir6pRKEfbTmIxFRZAxzmiYdI,1257
|
|
23
24
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=2i32VM9gnmWUpbL0qBSp_AKzuyKucXZPS8yklbcGlbc,1039
|
|
24
|
-
ocf_data_sampler/load/nwp/providers/utils.py,sha256=
|
|
25
|
+
ocf_data_sampler/load/nwp/providers/utils.py,sha256=5LrLmy74AVY5uLwL2qEhy-yPqSYLoxOgN8W1v8FmaQA,2355
|
|
25
26
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=5bdpzM8hMAEe0XRSZ9AZFQdqEeBsEPhaF79Y8bDx3GQ,407
|
|
26
27
|
ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lHUgaKgHIoo80,2039
|
|
27
28
|
ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
|
|
@@ -40,14 +41,14 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
|
|
|
40
41
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
|
|
41
42
|
ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
|
|
42
43
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
|
|
43
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=876oLukvb1nLtZQ8HBN3PWfN7urKH2xa45tVar7XrbM,12010
|
|
45
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nn6N8daGxllYwCCiFKbCJANTl84NrDRl-nbNGcfXc3U,15429
|
|
45
46
|
ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
|
|
46
47
|
ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
|
|
47
48
|
ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
|
|
48
49
|
ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
|
|
49
|
-
ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=
|
|
50
|
-
ocf_data_sampler/torch_datasets/utils/
|
|
50
|
+
ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=_UHLL_yRzhLJVHi6ROSaSe8TGw80CAhU325uCZj7XkY,331
|
|
51
|
+
ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=jS3DkAwOF1W3AQnvsdkBJ1C8Unm93kQbS8hgTCtFv2A,1743
|
|
51
52
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
|
|
52
53
|
ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
|
|
53
54
|
ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=8E4a5v9dqr-sZOyBruuO-tjLPBbjtpYtdFY5z23aqnU,4365
|
|
@@ -56,7 +57,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
|
|
|
56
57
|
scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
|
|
57
58
|
scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
|
|
58
59
|
utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
|
|
59
|
-
ocf_data_sampler-0.5.
|
|
60
|
-
ocf_data_sampler-0.5.
|
|
61
|
-
ocf_data_sampler-0.5.
|
|
62
|
-
ocf_data_sampler-0.5.
|
|
60
|
+
ocf_data_sampler-0.5.5.dist-info/METADATA,sha256=R9MPrxfVGCnkBbUehSjd3taDZxeREDo_YaIv5ccqnyg,12581
|
|
61
|
+
ocf_data_sampler-0.5.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
62
|
+
ocf_data_sampler-0.5.5.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
|
|
63
|
+
ocf_data_sampler-0.5.5.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
"""Utility function for converting channel dictionaries to xarray DataArrays."""
|
|
2
|
-
|
|
3
|
-
import xarray as xr
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
|
|
7
|
-
"""Converts a dictionary of channel values to a DataArray.
|
|
8
|
-
|
|
9
|
-
Args:
|
|
10
|
-
channel_dict: Dictionary mapping channel names (str) to their values (float).
|
|
11
|
-
|
|
12
|
-
Returns:
|
|
13
|
-
xr.DataArray: A 1D DataArray with channels as coordinates.
|
|
14
|
-
"""
|
|
15
|
-
return xr.DataArray(
|
|
16
|
-
list(channel_dict.values()),
|
|
17
|
-
coords={"channel": list(channel_dict.keys())},
|
|
18
|
-
)
|
|
File without changes
|
|
File without changes
|