ocf-data-sampler 0.5.2__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -3,9 +3,9 @@
3
3
  from glob import glob
4
4
 
5
5
  import xarray as xr
6
- from xarray_tensorstore import open_zarr
7
6
 
8
7
  from ocf_data_sampler.load.open_tensorstore_zarrs import open_zarrs
8
+ from ocf_data_sampler.load.xarray_tensorstore import open_zarr
9
9
 
10
10
 
11
11
  def open_zarr_paths(
@@ -7,7 +7,8 @@ import os
7
7
 
8
8
  import tensorstore as ts
9
9
  import xarray as xr
10
- from xarray_tensorstore import (
10
+
11
+ from ocf_data_sampler.load.xarray_tensorstore import (
11
12
  _raise_if_mask_and_scale_used_for_data_vars,
12
13
  _TensorStoreAdapter,
13
14
  _zarr_spec_from_path,
@@ -1,13 +1,13 @@
1
1
  """Satellite loader."""
2
2
  import numpy as np
3
3
  import xarray as xr
4
- from xarray_tensorstore import open_zarr
5
4
 
6
5
  from ocf_data_sampler.load.utils import (
7
6
  check_time_unique_increasing,
8
7
  get_xr_data_array_from_xr_dataset,
9
8
  make_spatial_coords_increasing,
10
9
  )
10
+ from ocf_data_sampler.load.xarray_tensorstore import open_zarr
11
11
 
12
12
  from .open_tensorstore_zarrs import open_zarrs
13
13
 
@@ -0,0 +1,299 @@
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for loading TensorStore data into Xarray.
15
+
16
+ Copied from https://github.com/google-research/tensorstore/blob/main/tensorstore/xarray.py
17
+ But we added small changes so that it works for zarr3
18
+ https://github.com/google/xarray-tensorstore/pull/22
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import dataclasses
23
+ import math
24
+ import os.path
25
+ import re
26
+ from typing import TypeVar
27
+
28
+ import numpy as np
29
+ import tensorstore
30
+ import xarray
31
+ import zarr
32
+ from xarray.core import indexing
33
+
34
+ __version__ = "0.1.5" # keep in sync with setup.py
35
+
36
+
37
+ Index = TypeVar("Index", int, slice, np.ndarray, None)
38
+ XarrayData = TypeVar("XarrayData", xarray.Dataset, xarray.DataArray)
39
+
40
+
41
+ def _numpy_to_tensorstore_index(index: Index, size: int) -> Index:
42
+ """Switch from NumPy to TensorStore indexing conventions."""
43
+ # https://google.github.io/tensorstore/python/indexing.html#differences-compared-to-numpy-indexing
44
+ if index is None:
45
+ return None
46
+ elif isinstance(index, int):
47
+ # Negative integers do not count from the end in TensorStore
48
+ return index + size if index < 0 else index
49
+ elif isinstance(index, slice):
50
+ start = _numpy_to_tensorstore_index(index.start, size)
51
+ stop = _numpy_to_tensorstore_index(index.stop, size)
52
+ if stop is not None:
53
+ # TensorStore does not allow out of bounds slicing
54
+ stop = min(stop, size)
55
+ return slice(start, stop, index.step)
56
+ else:
57
+ assert isinstance(index, np.ndarray) # noqa S101
58
+ return np.where(index < 0, index + size, index)
59
+
60
+
61
+ @dataclasses.dataclass(frozen=True)
62
+ class _TensorStoreAdapter(indexing.ExplicitlyIndexed):
63
+ """TensorStore array that can be wrapped by xarray.Variable.
64
+
65
+ We use Xarray's semi-internal ExplicitlyIndexed API so that Xarray will not
66
+ attempt to load our array into memory as a NumPy array. In the future, this
67
+ should be supported by public Xarray APIs, as part of the refactor discussed
68
+ in: https://github.com/pydata/xarray/issues/3981
69
+ """
70
+
71
+ array: tensorstore.TensorStore
72
+ future: tensorstore.Future | None = None
73
+
74
+ @property
75
+ def shape(self) -> tuple[int, ...]:
76
+ return self.array.shape
77
+
78
+ @property
79
+ def dtype(self) -> np.dtype:
80
+ return self.array.dtype.numpy_dtype
81
+
82
+ @property
83
+ def ndim(self) -> int:
84
+ return len(self.shape)
85
+
86
+ @property
87
+ def size(self) -> int:
88
+ return math.prod(self.shape)
89
+
90
+ def __getitem__(self, key: indexing.ExplicitIndexer) -> _TensorStoreAdapter:
91
+ index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
92
+ if isinstance(key, indexing.OuterIndexer):
93
+ # TODO(shoyer): fix this for newer versions of Xarray.
94
+ # We get the error message:
95
+ # AttributeError: '_TensorStoreAdapter' object has no attribute 'oindex'
96
+ indexed = self.array.oindex[index_tuple]
97
+ elif isinstance(key, indexing.VectorizedIndexer):
98
+ indexed = self.array.vindex[index_tuple]
99
+ else:
100
+ assert isinstance(key, indexing.BasicIndexer) # noqa S101
101
+ indexed = self.array[index_tuple]
102
+ # Translate to the origin so repeated indexing is relative to the new bounds
103
+ # like NumPy, not absolute like TensorStore
104
+ translated = indexed[tensorstore.d[:].translate_to[0]]
105
+ return type(self)(translated)
106
+
107
+ def __setitem__(self, key: indexing.ExplicitIndexer, value) -> None: # noqa ANN001
108
+ index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
109
+ if isinstance(key, indexing.OuterIndexer):
110
+ self.array.oindex[index_tuple] = value
111
+ elif isinstance(key, indexing.VectorizedIndexer):
112
+ self.array.vindex[index_tuple] = value
113
+ else:
114
+ assert isinstance(key, indexing.BasicIndexer) # noqa S101
115
+ self.array[index_tuple] = value
116
+ # Invalidate the future so that the next read will pick up the new value
117
+ object.__setattr__(self, "future", None)
118
+
119
+ # xarray>2024.02.0 uses oindex and vindex properties, which are expected to
120
+ # return objects whose __getitem__ method supports the appropriate form of
121
+ # indexing.
122
+ @property
123
+ def oindex(self) -> _TensorStoreAdapter:
124
+ return self
125
+
126
+ @property
127
+ def vindex(self) -> _TensorStoreAdapter:
128
+ return self
129
+
130
+ def transpose(self, order: tuple[int, ...]) -> _TensorStoreAdapter:
131
+ transposed = self.array[tensorstore.d[order].transpose[:]]
132
+ return type(self)(transposed)
133
+
134
+ def read(self) -> _TensorStoreAdapter:
135
+ future = self.array.read()
136
+ return type(self)(self.array, future)
137
+
138
+ def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: # type: ignore
139
+ future = self.array.read() if self.future is None else self.future
140
+ return np.asarray(future.result(), dtype=dtype)
141
+
142
+ def get_duck_array(self) -> np.ndarray:
143
+ # special method for xarray to return an in-memory (computed) representation
144
+ return np.asarray(self)
145
+
146
+ # Work around the missing __copy__ and __deepcopy__ methods from TensorStore,
147
+ # which are needed for Xarray:
148
+ # https://github.com/google/tensorstore/issues/109
149
+ # TensorStore objects are immutable, so there's no need to actually copy them.
150
+
151
+ def __copy__(self) -> _TensorStoreAdapter:
152
+ return type(self)(self.array, self.future)
153
+
154
+ def __deepcopy__(self, memo) -> _TensorStoreAdapter: # noqa ANN001
155
+ return self.__copy__()
156
+
157
+
158
+ def _read_tensorstore(
159
+ array: indexing.ExplicitlyIndexed,
160
+ ) -> indexing.ExplicitlyIndexed:
161
+ """Starts async reading on a TensorStore array."""
162
+ return array.read() if isinstance(array, _TensorStoreAdapter) else array
163
+
164
+
165
+ def read(xarraydata: XarrayData, /) -> XarrayData:
166
+ """Starts async reads on all TensorStore arrays."""
167
+ # pylint: disable=protected-access
168
+ if isinstance(xarraydata, xarray.Dataset):
169
+ data = {
170
+ name: _read_tensorstore(var.variable._data)
171
+ for name, var in xarraydata.data_vars.items()
172
+ }
173
+ elif isinstance(xarraydata, xarray.DataArray):
174
+ data = _read_tensorstore(xarraydata.variable._data)
175
+ else:
176
+ raise TypeError(f"argument is not a DataArray or Dataset: {xarraydata}")
177
+ # pylint: enable=protected-access
178
+ return xarraydata.copy(data=data)
179
+
180
+
181
+ _DEFAULT_STORAGE_DRIVER = "file"
182
+
183
+
184
+ def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
185
+ if re.match(r"\w+\://", path): # path is a URI
186
+ kv_store = path
187
+ else:
188
+ kv_store = {"driver": _DEFAULT_STORAGE_DRIVER, "path": path}
189
+
190
+ if zarr_format == 2:
191
+ return {"driver": "zarr2", "kvstore": kv_store}
192
+ else:
193
+ return {"driver": "zarr3", "kvstore": kv_store}
194
+
195
+
196
+ def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset) -> None:
197
+ """Check a dataset for data variables that would need masking or scaling."""
198
+ advice = (
199
+ "Consider re-opening with xarray_tensorstore.open_zarr(..., "
200
+ "mask_and_scale=False), or falling back to use xarray.open_zarr()."
201
+ )
202
+ for k in ds:
203
+ encoding = ds[k].encoding
204
+ for attr in ["_FillValue", "missing_value"]:
205
+ fill_value = encoding.get(attr, np.nan)
206
+ if fill_value == fill_value: # pylint: disable=comparison-with-itself
207
+ raise ValueError(
208
+ f"variable {k} has non-NaN fill value, which is not supported by"
209
+ f" xarray-tensorstore: {fill_value}. {advice}",
210
+ )
211
+ for attr in ["scale_factor", "add_offset"]:
212
+ if attr in encoding:
213
+ raise ValueError(
214
+ f"variable {k} uses scale/offset encoding, which is not supported"
215
+ f" by xarray-tensorstore: {encoding}. {advice}",
216
+ )
217
+
218
+
219
+ def open_zarr(
220
+ path: str,
221
+ *,
222
+ context: tensorstore.Context | None = None,
223
+ mask_and_scale: bool = True,
224
+ write: bool = False,
225
+ ) -> xarray.Dataset:
226
+ """Open an xarray.Dataset from Zarr using TensorStore.
227
+
228
+ For best performance, explicitly call `read()` to asynchronously load data
229
+ in parallel. Otherwise, xarray's `.compute()` method will load each variable's
230
+ data in sequence.
231
+
232
+ Example usage:
233
+
234
+ import xarray_tensorstore
235
+
236
+ ds = xarray_tensorstore.open_zarr(path)
237
+
238
+ # indexing & transposing is lazy
239
+ example = ds.sel(time='2020-01-01').transpose('longitude', 'latitude', ...)
240
+
241
+ # start reading data asynchronously
242
+ read_example = xarray_tensorstore.read(example)
243
+
244
+ # blocking conversion of the data into NumPy arrays
245
+ numpy_example = read_example.compute()
246
+
247
+ Args:
248
+ path: path or URI to Zarr group to open.
249
+ context: TensorStore configuration options to use when opening arrays.
250
+ mask_and_scale: if True (default), attempt to apply masking and scaling like
251
+ xarray.open_zarr(). This is only supported for coordinate variables and
252
+ otherwise will raise an error.
253
+ write: Allow write access. Defaults to False.
254
+
255
+ Returns:
256
+ Dataset with all data variables opened via TensorStore.
257
+ """
258
+ # We use xarray.open_zarr (which uses Zarr Python internally) to open the
259
+ # initial version of the dataset for a few reasons:
260
+ # 1. TensorStore does not support Zarr groups or array attributes, which we
261
+ # need to open in the xarray.Dataset. We use Zarr Python instead of
262
+ # parsing the raw Zarr metadata files ourselves.
263
+ # 2. TensorStore doesn't support non-standard Zarr dtypes like UTF-8 strings.
264
+ # 3. Xarray's open_zarr machinery does some pre-processing (e.g., from numeric
265
+ # to datetime64 dtypes) that we would otherwise need to invoke explicitly
266
+ # via xarray.decode_cf().
267
+ #
268
+ # Fortunately (2) and (3) are most commonly encountered on small coordinate
269
+ # arrays, for which the performance advantages of TensorStore are irrelevant.
270
+
271
+ if context is None:
272
+ context = tensorstore.Context()
273
+
274
+ # chunks=None means avoid using dask
275
+ ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
276
+
277
+ # find out if its 2 or 3
278
+ try:
279
+ # this should work with zarr>=3 - https://github.com/zarr-developers/zarr-python
280
+ zarr_format = zarr.open(path).metadata.zarr_format
281
+ except: # noqa E722
282
+ # try to open it, but if it fails, assume zarr_format 2
283
+ zarr_format = 2
284
+
285
+ if mask_and_scale:
286
+ # Data variables get replaced below with _TensorStoreAdapter arrays, which
287
+ # don't get masked or scaled. Raising an error avoids surprising users with
288
+ # incorrect data values.
289
+ _raise_if_mask_and_scale_used_for_data_vars(ds)
290
+
291
+ specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds}
292
+ array_futures = {
293
+ k: tensorstore.open(spec, read=True, write=write, context=context)
294
+ for k, spec in specs.items()
295
+ }
296
+ arrays = {k: v.result() for k, v in array_futures.items()}
297
+ new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
298
+
299
+ return ds.copy(data=new_data)
@@ -21,7 +21,7 @@ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
21
21
  from ocf_data_sampler.select import Location, fill_time_periods
22
22
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
23
23
  from ocf_data_sampler.torch_datasets.utils import (
24
- channel_dict_to_dataarray,
24
+ config_normalization_values_to_dicts,
25
25
  find_valid_time_periods,
26
26
  slice_datasets_by_space,
27
27
  slice_datasets_by_time,
@@ -110,11 +110,14 @@ class AbstractPVNetUKDataset(Dataset):
110
110
  self.config = config
111
111
  self.datasets_dict = datasets_dict
112
112
 
113
+ # Extract the normalisation values from the config for faster access
114
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
115
+ self.means_dict = means_dict
116
+ self.stds_dict = stds_dict
113
117
 
114
- @staticmethod
115
118
  def process_and_combine_datasets(
119
+ self,
116
120
  dataset_dict: dict,
117
- config: Configuration,
118
121
  t0: pd.Timestamp,
119
122
  location: Location,
120
123
  ) -> NumpySample:
@@ -122,7 +125,6 @@ class AbstractPVNetUKDataset(Dataset):
122
125
 
123
126
  Args:
124
127
  dataset_dict: Dictionary of xarray datasets
125
- config: Configuration object
126
128
  t0: init-time for sample
127
129
  location: location of the sample
128
130
  """
@@ -134,13 +136,8 @@ class AbstractPVNetUKDataset(Dataset):
134
136
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
135
137
 
136
138
  # Standardise and convert to NumpyBatch
137
-
138
- da_channel_means = channel_dict_to_dataarray(
139
- config.input_data.nwp[nwp_key].channel_means,
140
- )
141
- da_channel_stds = channel_dict_to_dataarray(
142
- config.input_data.nwp[nwp_key].channel_stds,
143
- )
139
+ da_channel_means = self.means_dict["nwp"][nwp_key]
140
+ da_channel_stds = self.stds_dict["nwp"][nwp_key]
144
141
 
145
142
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
146
143
 
@@ -153,15 +150,15 @@ class AbstractPVNetUKDataset(Dataset):
153
150
  da_sat = dataset_dict["sat"]
154
151
 
155
152
  # Standardise and convert to NumpyBatch
156
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
157
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
153
+ da_channel_means = self.means_dict["sat"]
154
+ da_channel_stds = self.stds_dict["sat"]
158
155
 
159
156
  da_sat = (da_sat - da_channel_means) / da_channel_stds
160
157
 
161
158
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
162
159
 
163
160
  if "gsp" in dataset_dict:
164
- gsp_config = config.input_data.gsp
161
+ gsp_config = self.config.input_data.gsp
165
162
  da_gsp = dataset_dict["gsp"]
166
163
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
167
164
 
@@ -183,13 +180,8 @@ class AbstractPVNetUKDataset(Dataset):
183
180
  )
184
181
 
185
182
  # Only add solar position if explicitly configured
186
- has_solar_config = (
187
- hasattr(config.input_data, "solar_position") and
188
- config.input_data.solar_position is not None
189
- )
190
-
191
- if has_solar_config:
192
- solar_config = config.input_data.solar_position
183
+ if self.config.input_data.solar_position is not None:
184
+ solar_config = self.config.input_data.solar_position
193
185
 
194
186
  # Create datetime range for solar position calculation
195
187
  datetimes = pd.date_range(
@@ -264,7 +256,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
264
256
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
265
257
  sample_dict = compute(sample_dict)
266
258
 
267
- return self.process_and_combine_datasets(sample_dict, self.config, t0, location)
259
+ return self.process_and_combine_datasets(sample_dict, t0, location)
268
260
 
269
261
  @override
270
262
  def __getitem__(self, idx: int) -> NumpySample:
@@ -330,7 +322,6 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
330
322
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
331
323
  gsp_numpy_sample = self.process_and_combine_datasets(
332
324
  gsp_sample_dict,
333
- self.config,
334
325
  t0,
335
326
  location,
336
327
  )
@@ -25,7 +25,7 @@ from ocf_data_sampler.select import (
25
25
  intersection_of_multiple_dataframes_of_periods,
26
26
  )
27
27
  from ocf_data_sampler.torch_datasets.utils import (
28
- channel_dict_to_dataarray,
28
+ config_normalization_values_to_dicts,
29
29
  find_valid_time_periods,
30
30
  slice_datasets_by_space,
31
31
  slice_datasets_by_time,
@@ -62,6 +62,8 @@ def process_and_combine_datasets(
62
62
  dataset_dict: dict,
63
63
  config: Configuration,
64
64
  t0: pd.Timestamp,
65
+ means_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
66
+ stds_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
65
67
  ) -> NumpySample:
66
68
  """Normalise and convert data to numpy arrays.
67
69
 
@@ -69,6 +71,8 @@ def process_and_combine_datasets(
69
71
  dataset_dict: Dictionary of xarray datasets
70
72
  config: Configuration object
71
73
  t0: init-time for sample
74
+ means_dict: Nested dictionary of mean values for the input data sources
75
+ stds_dict: Nested dictionary of std values for the input data sources
72
76
  """
73
77
  numpy_modalities = []
74
78
 
@@ -79,12 +83,8 @@ def process_and_combine_datasets(
79
83
 
80
84
  # Standardise and convert to NumpyBatch
81
85
 
82
- da_channel_means = channel_dict_to_dataarray(
83
- config.input_data.nwp[nwp_key].channel_means,
84
- )
85
- da_channel_stds = channel_dict_to_dataarray(
86
- config.input_data.nwp[nwp_key].channel_stds,
87
- )
86
+ da_channel_means = means_dict["nwp"][nwp_key]
87
+ da_channel_stds = stds_dict["nwp"][nwp_key]
88
88
 
89
89
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
90
90
 
@@ -97,8 +97,8 @@ def process_and_combine_datasets(
97
97
  da_sat = dataset_dict["sat"]
98
98
 
99
99
  # Standardise and convert to NumpyBatch
100
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
101
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
100
+ da_channel_means = means_dict["sat"]
101
+ da_channel_stds = stds_dict["sat"]
102
102
 
103
103
  da_sat = (da_sat - da_channel_means) / da_channel_stds
104
104
 
@@ -109,11 +109,7 @@ def process_and_combine_datasets(
109
109
  da_sites = da_sites / da_sites.capacity_kwp
110
110
 
111
111
  # Convert to NumpyBatch
112
- numpy_modalities.append(
113
- convert_site_to_numpy_sample(
114
- da_sites,
115
- ),
116
- )
112
+ numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
117
113
 
118
114
  # add datetime features
119
115
  datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
@@ -193,6 +189,11 @@ class SitesDataset(Dataset):
193
189
  # Assign coords and indices to self
194
190
  self.valid_t0_and_site_ids = valid_t0_and_site_ids
195
191
 
192
+ # Extract the normalisation values from the config for faster access
193
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
194
+ self.means_dict = means_dict
195
+ self.stds_dict = stds_dict
196
+
196
197
  def find_valid_t0_and_site_ids(
197
198
  self,
198
199
  datasets_dict: dict,
@@ -273,7 +274,13 @@ class SitesDataset(Dataset):
273
274
 
274
275
  sample_dict = compute(sample_dict)
275
276
 
276
- return process_and_combine_datasets(sample_dict, self.config, t0)
277
+ return process_and_combine_datasets(
278
+ sample_dict,
279
+ self.config,
280
+ t0,
281
+ self.means_dict,
282
+ self.stds_dict,
283
+ )
277
284
 
278
285
  def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
279
286
  """Generate a sample for a given site id and t0.
@@ -332,6 +339,11 @@ class SitesDatasetConcurrent(Dataset):
332
339
  # Assign coords and indices to self
333
340
  self.valid_t0s = valid_t0s
334
341
 
342
+ # Extract the normalisation values from the config for faster access
343
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
344
+ self.means_dict = means_dict
345
+ self.stds_dict = stds_dict
346
+
335
347
  def find_valid_t0s(
336
348
  self,
337
349
  datasets_dict: dict,
@@ -406,6 +418,8 @@ class SitesDatasetConcurrent(Dataset):
406
418
  site_sample_dict,
407
419
  self.config,
408
420
  t0,
421
+ self.means_dict,
422
+ self.stds_dict,
409
423
  )
410
424
  site_samples.append(site_numpy_sample)
411
425
 
@@ -1,4 +1,4 @@
1
- from .channel_dict_to_dataarray import channel_dict_to_dataarray
1
+ from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
2
2
  from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
3
  from .valid_time_periods import find_valid_time_periods
4
4
  from .spatial_slice_for_dataset import slice_datasets_by_space
@@ -0,0 +1,57 @@
1
+ """Utility function for converting channel dictionaries to xarray DataArrays."""
2
+
3
+ import xarray as xr
4
+
5
+ from ocf_data_sampler.config import Configuration
6
+
7
+
8
+ def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
9
+ """Converts a dictionary of channel values to a DataArray.
10
+
11
+ Args:
12
+ channel_dict: Dictionary mapping channel names (str) to their values (float).
13
+
14
+ Returns:
15
+ xr.DataArray: A 1D DataArray with channels as coordinates.
16
+ """
17
+ return xr.DataArray(
18
+ list(channel_dict.values()),
19
+ coords={"channel": list(channel_dict.keys())},
20
+ )
21
+
22
+ def config_normalization_values_to_dicts(
23
+ config: Configuration,
24
+ ) -> tuple[dict[str, xr.DataArray | dict[str, xr.DataArray]]]:
25
+ """Construct DataArrays of mean and std values from the config normalisation constants.
26
+
27
+ Args:
28
+ config: Data configuration.
29
+
30
+ Returns:
31
+ Means dict
32
+ Stds dict
33
+ """
34
+ means_dict = {}
35
+ stds_dict = {}
36
+
37
+ if config.input_data.nwp is not None:
38
+
39
+ means_dict["nwp"] = {}
40
+ stds_dict["nwp"] = {}
41
+
42
+ for nwp_key in config.input_data.nwp:
43
+ # Standardise and convert to NumpyBatch
44
+
45
+ means_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
46
+ config.input_data.nwp[nwp_key].channel_means,
47
+ )
48
+ stds_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
49
+ config.input_data.nwp[nwp_key].channel_stds,
50
+ )
51
+
52
+ if config.input_data.satellite is not None:
53
+
54
+ means_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
55
+ stds_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
56
+
57
+ return means_dict, stds_dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.2
3
+ Version: 0.5.5
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -28,14 +28,14 @@ License: MIT License
28
28
  Project-URL: repository, https://github.com/openclimatefix/ocf-data-sampler
29
29
  Classifier: Programming Language :: Python :: 3
30
30
  Classifier: License :: OSI Approved :: MIT License
31
- Requires-Python: >=3.10
31
+ Requires-Python: >=3.11
32
32
  Description-Content-Type: text/markdown
33
33
  Requires-Dist: torch
34
34
  Requires-Dist: numpy
35
35
  Requires-Dist: pandas
36
36
  Requires-Dist: xarray
37
37
  Requires-Dist: zarr
38
- Requires-Dist: numcodecs==0.13.1
38
+ Requires-Dist: numcodecs
39
39
  Requires-Dist: dask
40
40
  Requires-Dist: matplotlib
41
41
  Requires-Dist: pvlib
@@ -44,7 +44,8 @@ Requires-Dist: pyproj
44
44
  Requires-Dist: pyaml_env
45
45
  Requires-Dist: pyresample
46
46
  Requires-Dist: h5netcdf
47
- Requires-Dist: xarray-tensorstore==0.1.5
47
+ Requires-Dist: tensorstore
48
+ Requires-Dist: zarr>=3
48
49
 
49
50
  # ocf-data-sampler
50
51
 
@@ -9,10 +9,11 @@ ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKF
9
9
  ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
10
10
  ocf_data_sampler/load/gsp.py,sha256=d30jQWnwFaLj6rKNMHdz1qD8fzF8q--RNnEXT7bGiX0,2981
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=K8rWykjII-3g127If7WRRFivzHNx3SshCvZj4uQlf28,2089
12
- ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=_RHWe0GmrBSA9s1TH5I9VCMPpeZEsuRuhDt5Vyyx5Fo,2725
13
- ocf_data_sampler/load/satellite.py,sha256=RylkJz8avxdM5pK_liaTlD1DTboyPMgykXJ4_Ek9WBA,1840
12
+ ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=ElXmW7GhYDpsHZr7KjM-KIDNJMc4lmgzVIBwHx5Wl0Q,2748
13
+ ocf_data_sampler/load/satellite.py,sha256=X5ZqFfMgab_WDwI7w1ZmdyMeh3GwV1g7mBd8tFgr8dM,1862
14
14
  ocf_data_sampler/load/site.py,sha256=WtOy20VMHJIY0IwEemCdcecSDUGcVaLUown-4ixJw90,2147
15
15
  ocf_data_sampler/load/utils.py,sha256=AGL0aOOQPrgqNBTjlBtR7Qg1PyQov3DFJo-y198u8pY,2044
16
+ ocf_data_sampler/load/xarray_tensorstore.py,sha256=DSZl364Hn3QjcVxxPmBKU9rsc5BlJBdzL_SMrv-9os0,10997
16
17
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
17
18
  ocf_data_sampler/load/nwp/nwp.py,sha256=0E9shei3Mq1N7F-fBlEKY5Hm0_kI7ysY_rffnWIshvk,3612
18
19
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -21,7 +22,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=P7JqfssmQq8eHKKXaBexsxts325A
21
22
  ocf_data_sampler/load/nwp/providers/gfs.py,sha256=h6vm-Rfz1JGOE4P_fP1_XQJ3bugNbeNAIyt56N8B1Dc,1066
22
23
  ocf_data_sampler/load/nwp/providers/icon.py,sha256=iVZwLKRr_D74_kAu5MHir6pRKEfbTmIxFRZAxzmiYdI,1257
23
24
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=2i32VM9gnmWUpbL0qBSp_AKzuyKucXZPS8yklbcGlbc,1039
24
- ocf_data_sampler/load/nwp/providers/utils.py,sha256=cVwCiC8FqNpkZFSUGv1CRqIQlKdjx1sIsb2SIUlvWV8,2333
25
+ ocf_data_sampler/load/nwp/providers/utils.py,sha256=5LrLmy74AVY5uLwL2qEhy-yPqSYLoxOgN8W1v8FmaQA,2355
25
26
  ocf_data_sampler/numpy_sample/__init__.py,sha256=5bdpzM8hMAEe0XRSZ9AZFQdqEeBsEPhaF79Y8bDx3GQ,407
26
27
  ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lHUgaKgHIoo80,2039
27
28
  ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
@@ -40,14 +41,14 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
40
41
  ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
41
42
  ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
42
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
43
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=v63goKEMI6UgBPnQCnIbxhFFdwuP_sxgcPYY6iNfGkc,12257
44
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_0A2kRq8B5WL5zWjKxNY9snAl_GwptohUt7c6DDa2AA,14812
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=876oLukvb1nLtZQ8HBN3PWfN7urKH2xa45tVar7XrbM,12010
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nn6N8daGxllYwCCiFKbCJANTl84NrDRl-nbNGcfXc3U,15429
45
46
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
46
47
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
47
48
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
48
49
  ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
49
- ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_QuiYOuvIyulPrn6xEA4TY,309
50
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
50
+ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=_UHLL_yRzhLJVHi6ROSaSe8TGw80CAhU325uCZj7XkY,331
51
+ ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=jS3DkAwOF1W3AQnvsdkBJ1C8Unm93kQbS8hgTCtFv2A,1743
51
52
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
52
53
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
53
54
  ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=8E4a5v9dqr-sZOyBruuO-tjLPBbjtpYtdFY5z23aqnU,4365
@@ -56,7 +57,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
56
57
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
57
58
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
58
59
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
59
- ocf_data_sampler-0.5.2.dist-info/METADATA,sha256=mYEZX1GRP6sJoaRs3B5DY5SAFUl1r4OqkgJYXemLzOM,12580
60
- ocf_data_sampler-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- ocf_data_sampler-0.5.2.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
- ocf_data_sampler-0.5.2.dist-info/RECORD,,
60
+ ocf_data_sampler-0.5.5.dist-info/METADATA,sha256=R9MPrxfVGCnkBbUehSjd3taDZxeREDo_YaIv5ccqnyg,12581
61
+ ocf_data_sampler-0.5.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
62
+ ocf_data_sampler-0.5.5.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
63
+ ocf_data_sampler-0.5.5.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- """Utility function for converting channel dictionaries to xarray DataArrays."""
2
-
3
- import xarray as xr
4
-
5
-
6
- def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
- """Converts a dictionary of channel values to a DataArray.
8
-
9
- Args:
10
- channel_dict: Dictionary mapping channel names (str) to their values (float).
11
-
12
- Returns:
13
- xr.DataArray: A 1D DataArray with channels as coordinates.
14
- """
15
- return xr.DataArray(
16
- list(channel_dict.values()),
17
- coords={"channel": list(channel_dict.keys())},
18
- )