ocf-data-sampler 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -4,8 +4,7 @@ from glob import glob
4
4
 
5
5
  import xarray as xr
6
6
 
7
- from ocf_data_sampler.load.open_tensorstore_zarrs import open_zarrs
8
- from ocf_data_sampler.load.xarray_tensorstore import open_zarr
7
+ from ocf_data_sampler.load.open_xarray_tensorstore import open_zarr, open_zarrs
9
8
 
10
9
 
11
10
  def open_zarr_paths(
@@ -0,0 +1,167 @@
1
+ """Utilities for loading TensorStore data into Xarray.
2
+
3
+ This module uses and adapts internal functions from the Google xarray-tensorstore project [1],
4
+ licensed under the Apache License, Version 2.0. See [2] for details.
5
+
6
+ Modifications copyright 2025 Open climate Fix. Licensed under the MIT License.
7
+
8
+ Modifications from the original include:
9
+ - Adding support for opening multiple zarr files as a single xarray object
10
+ - Support for zarr 3 -> https://github.com/google/xarray-tensorstore/pull/22
11
+
12
+ References:
13
+ [1] https://github.com/google-research/tensorstore/blob/main/tensorstore/xarray.py
14
+ [2] https://www.apache.org/licenses/LICENSE-2.0
15
+ """
16
+
17
+ import os.path
18
+ import re
19
+
20
+ import tensorstore as ts
21
+ import xarray as xr
22
+ import zarr
23
+ from xarray_tensorstore import (
24
+ _DEFAULT_STORAGE_DRIVER,
25
+ _raise_if_mask_and_scale_used_for_data_vars,
26
+ _TensorStoreAdapter,
27
+ )
28
+
29
+
30
+ def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
31
+ if re.match(r"\w+\://", path): # path is a URI
32
+ kv_store = path
33
+ else:
34
+ kv_store = {"driver": _DEFAULT_STORAGE_DRIVER, "path": path}
35
+ return {"driver": f"zarr{zarr_format}", "kvstore": kv_store}
36
+
37
+
38
+ def _get_data_variable_array_futures(
39
+ path: str,
40
+ context: ts.Context | None,
41
+ variables: list[str],
42
+ ) -> dict[ts.Future]:
43
+ """Open all data variables in a zarr group and return futures.
44
+
45
+ Args:
46
+ path: path or URI to zarr group to open.
47
+ context: TensorStore configuration options to use when opening arrays.
48
+ variables: The variables in the zarr groupto open.
49
+ """
50
+ zarr_format = zarr.open(path).metadata.zarr_format
51
+ specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in variables}
52
+ return {k: ts.open(spec, read=True, write=False, context=context) for k, spec in specs.items()}
53
+
54
+
55
+ def _tensorstore_open_zarrs(
56
+ paths: list[str],
57
+ data_vars: list[str],
58
+ concat_axes: list[int],
59
+ context: ts.Context,
60
+ ) -> dict[str, ts.TensorStore]:
61
+ """Open multiple zarrs with TensorStore.
62
+
63
+ Args:
64
+ paths: List of paths to zarr stores.
65
+ data_vars: List of data variable names to open.
66
+ concat_axes: List of axes along which to concatenate the data variables.
67
+ context: TensorStore context.
68
+ """
69
+ # Open all the variables from all the datasets - returned as futures
70
+ arrays_list: list[dict[str, ts.Future]] = []
71
+ for path in paths:
72
+ arrays_list.append(_get_data_variable_array_futures(path, context, data_vars))
73
+
74
+ # Wait for the async open operations
75
+ arrays_list = [{k: v.result() for k, v in arrays.items()} for arrays in arrays_list]
76
+
77
+ # Concatenate each of the variables along the required axis
78
+ arrays = {}
79
+ for k, axis in zip(data_vars, concat_axes, strict=True):
80
+ variable_arrays = [d[k] for d in arrays_list]
81
+ arrays[k] = ts.concat(variable_arrays, axis=axis)
82
+
83
+ return arrays
84
+
85
+
86
+ def open_zarr(
87
+ path: str,
88
+ context: ts.Context | None = None,
89
+ mask_and_scale: bool = True,
90
+ ) -> xr.Dataset:
91
+ """Open an xarray.Dataset from zarr using TensorStore.
92
+
93
+ Args:
94
+ path: path or URI to zarr group to open.
95
+ context: TensorStore configuration options to use when opening arrays.
96
+ mask_and_scale: if True (default), attempt to apply masking and scaling like
97
+ xarray.open_zarr(). This is only supported for coordinate variables and
98
+ otherwise will raise an error.
99
+
100
+ Returns:
101
+ Dataset with all data variables opened via TensorStore.
102
+ """
103
+ if context is None:
104
+ context = ts.Context()
105
+
106
+ # Avoid using dask by settung `chunks=None`
107
+ ds = xr.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
108
+
109
+ if mask_and_scale:
110
+ _raise_if_mask_and_scale_used_for_data_vars(ds)
111
+
112
+ # Open all data variables using tensorstore - returned as futures
113
+ data_vars = list(ds.data_vars)
114
+ arrays = _get_data_variable_array_futures(path, context, data_vars)
115
+
116
+ # Wait for the async open operations
117
+ arrays = {k: v.result() for k, v in arrays.items()}
118
+
119
+ # Adapt the tensorstore arrays and plug them into the xarray object
120
+ new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
121
+
122
+ return ds.copy(data=new_data)
123
+
124
+
125
+ def open_zarrs(
126
+ paths: list[str],
127
+ concat_dim: str,
128
+ context: ts.Context | None = None,
129
+ mask_and_scale: bool = True,
130
+ ) -> xr.Dataset:
131
+ """Open multiple zarrs with TensorStore.
132
+
133
+ Args:
134
+ paths: List of paths to zarr stores.
135
+ concat_dim: Dimension along which to concatenate the data variables.
136
+ context: TensorStore context.
137
+ mask_and_scale: Whether to mask and scale the data.
138
+
139
+ Returns:
140
+ Concatenated Dataset with all data variables opened via TensorStore.
141
+ """
142
+ if context is None:
143
+ context = ts.Context()
144
+
145
+ ds_list = [xr.open_zarr(p, mask_and_scale=mask_and_scale, decode_timedelta=True) for p in paths]
146
+ ds = xr.concat(
147
+ ds_list,
148
+ dim=concat_dim,
149
+ data_vars="minimal",
150
+ compat="equals",
151
+ combine_attrs="drop_conflicts",
152
+ )
153
+
154
+ if mask_and_scale:
155
+ _raise_if_mask_and_scale_used_for_data_vars(ds)
156
+
157
+ # Find the axis along which each data array must be concatenated
158
+ data_vars = list(ds.data_vars)
159
+ concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
160
+
161
+ # Open and concat all zarrs so each variables is a single TensorStore array
162
+ arrays = _tensorstore_open_zarrs(paths, data_vars, concat_axes, context)
163
+
164
+ # Plug the arrays into the xarray object
165
+ new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
166
+
167
+ return ds.copy(data=new_data)
@@ -2,14 +2,12 @@
2
2
  import numpy as np
3
3
  import xarray as xr
4
4
 
5
+ from ocf_data_sampler.load.open_xarray_tensorstore import open_zarr, open_zarrs
5
6
  from ocf_data_sampler.load.utils import (
6
7
  check_time_unique_increasing,
7
8
  get_xr_data_array_from_xr_dataset,
8
9
  make_spatial_coords_increasing,
9
10
  )
10
- from ocf_data_sampler.load.xarray_tensorstore import open_zarr
11
-
12
- from .open_tensorstore_zarrs import open_zarrs
13
11
 
14
12
 
15
13
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.5
3
+ Version: 0.5.7
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -44,7 +44,7 @@ Requires-Dist: pyproj
44
44
  Requires-Dist: pyaml_env
45
45
  Requires-Dist: pyresample
46
46
  Requires-Dist: h5netcdf
47
- Requires-Dist: tensorstore
47
+ Requires-Dist: xarray-tensorstore==0.1.5
48
48
  Requires-Dist: zarr>=3
49
49
 
50
50
  # ocf-data-sampler
@@ -63,6 +63,12 @@ We are currently migrating to this repo from [ocf_datapipes](https://github.com/
63
63
  > [!Note]
64
64
  > This repository is still in early development development and large changes to the user facing functions may still occur.
65
65
 
66
+ ## Licence
67
+
68
+ This project is primarily licensed under the MIT License (see LICENSE).
69
+
70
+ It includes and adapts internal functions from the Google xarray-tensorstore project, licensed under the Apache License, Version 2.0.
71
+
66
72
  ## Documentation
67
73
 
68
74
  **ocf-data-sampler** doesn't have external documentation _yet_; you can read a bit about how our torch datasets work in the README [here](ocf_data_sampler/torch_datasets/README.md).
@@ -9,11 +9,10 @@ ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKF
9
9
  ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
10
10
  ocf_data_sampler/load/gsp.py,sha256=d30jQWnwFaLj6rKNMHdz1qD8fzF8q--RNnEXT7bGiX0,2981
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=K8rWykjII-3g127If7WRRFivzHNx3SshCvZj4uQlf28,2089
12
- ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=ElXmW7GhYDpsHZr7KjM-KIDNJMc4lmgzVIBwHx5Wl0Q,2748
13
- ocf_data_sampler/load/satellite.py,sha256=X5ZqFfMgab_WDwI7w1ZmdyMeh3GwV1g7mBd8tFgr8dM,1862
12
+ ocf_data_sampler/load/open_xarray_tensorstore.py,sha256=kAqlIavGe1dcCPkzAtoZo2dFS-tW36E-wRE_3w1HMfg,5620
13
+ ocf_data_sampler/load/satellite.py,sha256=B-m0_Py_D0GwzwX5o-ixyeXntV5Z4k4MbmMBHZLUWMM,1831
14
14
  ocf_data_sampler/load/site.py,sha256=WtOy20VMHJIY0IwEemCdcecSDUGcVaLUown-4ixJw90,2147
15
15
  ocf_data_sampler/load/utils.py,sha256=AGL0aOOQPrgqNBTjlBtR7Qg1PyQov3DFJo-y198u8pY,2044
16
- ocf_data_sampler/load/xarray_tensorstore.py,sha256=DSZl364Hn3QjcVxxPmBKU9rsc5BlJBdzL_SMrv-9os0,10997
17
16
  ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
18
17
  ocf_data_sampler/load/nwp/nwp.py,sha256=0E9shei3Mq1N7F-fBlEKY5Hm0_kI7ysY_rffnWIshvk,3612
19
18
  ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -22,7 +21,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=P7JqfssmQq8eHKKXaBexsxts325A
22
21
  ocf_data_sampler/load/nwp/providers/gfs.py,sha256=h6vm-Rfz1JGOE4P_fP1_XQJ3bugNbeNAIyt56N8B1Dc,1066
23
22
  ocf_data_sampler/load/nwp/providers/icon.py,sha256=iVZwLKRr_D74_kAu5MHir6pRKEfbTmIxFRZAxzmiYdI,1257
24
23
  ocf_data_sampler/load/nwp/providers/ukv.py,sha256=2i32VM9gnmWUpbL0qBSp_AKzuyKucXZPS8yklbcGlbc,1039
25
- ocf_data_sampler/load/nwp/providers/utils.py,sha256=5LrLmy74AVY5uLwL2qEhy-yPqSYLoxOgN8W1v8FmaQA,2355
24
+ ocf_data_sampler/load/nwp/providers/utils.py,sha256=IjJ3w7zDgXNFaVa4TMk8yVCvdzfrIRu5tn1OaaQ7Zso,2304
26
25
  ocf_data_sampler/numpy_sample/__init__.py,sha256=5bdpzM8hMAEe0XRSZ9AZFQdqEeBsEPhaF79Y8bDx3GQ,407
27
26
  ocf_data_sampler/numpy_sample/collate.py,sha256=hoxIc5SoHoIs3Nx37aRZzWChpswjy9lHUgaKgHIoo80,2039
28
27
  ocf_data_sampler/numpy_sample/common_types.py,sha256=9CjYHkUTx0ObduWh43fhsybZCTXvexql7qC2ptMDoek,377
@@ -57,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
57
56
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
58
57
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
59
58
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
60
- ocf_data_sampler-0.5.5.dist-info/METADATA,sha256=R9MPrxfVGCnkBbUehSjd3taDZxeREDo_YaIv5ccqnyg,12581
61
- ocf_data_sampler-0.5.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
62
- ocf_data_sampler-0.5.5.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
63
- ocf_data_sampler-0.5.5.dist-info/RECORD,,
59
+ ocf_data_sampler-0.5.7.dist-info/METADATA,sha256=Nu2RLYiLYyU6nkLu8g__Q8EPFIgYMLu5cZLcLXAckXs,12816
60
+ ocf_data_sampler-0.5.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
+ ocf_data_sampler-0.5.7.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
+ ocf_data_sampler-0.5.7.dist-info/RECORD,,
@@ -1,93 +0,0 @@
1
- """Open multiple zarrs with TensorStore.
2
-
3
- This extendds the functionality of xarray_tensorstore to open multiple zarr stores
4
- """
5
-
6
- import os
7
-
8
- import tensorstore as ts
9
- import xarray as xr
10
-
11
- from ocf_data_sampler.load.xarray_tensorstore import (
12
- _raise_if_mask_and_scale_used_for_data_vars,
13
- _TensorStoreAdapter,
14
- _zarr_spec_from_path,
15
- )
16
-
17
-
18
- def tensorstore_open_multi_zarrs(
19
- paths: list[str],
20
- data_vars: list[str],
21
- concat_axes: list[int],
22
- context: ts.Context,
23
- write: bool,
24
- ) -> dict[str, ts.TensorStore]:
25
- """Open multiple zarrs with TensorStore.
26
-
27
- Args:
28
- paths: List of paths to zarr stores.
29
- data_vars: List of data variable names to open.
30
- concat_axes: List of axes along which to concatenate the data variables.
31
- context: TensorStore context.
32
- write: Whether to open the stores for writing.
33
- """
34
- arrays_list = []
35
- for path in paths:
36
- specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in data_vars}
37
- array_futures = {
38
- k: ts.open(spec, read=True, write=write, context=context)
39
- for k, spec in specs.items()
40
- }
41
- arrays_list.append({k: v.result() for k, v in array_futures.items()})
42
-
43
- arrays = {}
44
- for k, axis in zip(data_vars, concat_axes, strict=False):
45
- datasets = [d[k] for d in arrays_list]
46
- arrays[k] = ts.concat(datasets, axis=axis)
47
-
48
- return arrays
49
-
50
-
51
- def open_zarrs(
52
- paths: list[str],
53
- concat_dim: str,
54
- *,
55
- context: ts.Context | None = None,
56
- mask_and_scale: bool = True,
57
- write: bool = False,
58
- ) -> xr.Dataset:
59
- """Open multiple zarrs with TensorStore.
60
-
61
- Args:
62
- paths: List of paths to zarr stores.
63
- concat_dim: Dimension along which to concatenate the data variables.
64
- context: TensorStore context.
65
- mask_and_scale: Whether to mask and scale the data.
66
- write: Whether to open the stores for writing.
67
- """
68
- if context is None:
69
- context = ts.Context()
70
-
71
- ds = xr.open_mfdataset(
72
- paths,
73
- concat_dim=concat_dim,
74
- combine="nested",
75
- mask_and_scale=mask_and_scale,
76
- decode_timedelta=True,
77
- )
78
-
79
- if mask_and_scale:
80
- # Data variables get replaced below with _TensorStoreAdapter arrays, which
81
- # don't get masked or scaled. Raising an error avoids surprising users with
82
- # incorrect data values.
83
- _raise_if_mask_and_scale_used_for_data_vars(ds)
84
-
85
- data_vars = list(ds.data_vars)
86
-
87
- concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
88
-
89
- arrays = tensorstore_open_multi_zarrs(paths, data_vars, concat_axes, context, write)
90
-
91
- new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
92
-
93
- return ds.copy(data=new_data)
@@ -1,299 +0,0 @@
1
- # Copyright 2023 Google LLC
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """Utilities for loading TensorStore data into Xarray.
15
-
16
- Copied from https://github.com/google-research/tensorstore/blob/main/tensorstore/xarray.py
17
- But we added small changes so that it works for zarr3
18
- https://github.com/google/xarray-tensorstore/pull/22
19
- """
20
- from __future__ import annotations
21
-
22
- import dataclasses
23
- import math
24
- import os.path
25
- import re
26
- from typing import TypeVar
27
-
28
- import numpy as np
29
- import tensorstore
30
- import xarray
31
- import zarr
32
- from xarray.core import indexing
33
-
34
- __version__ = "0.1.5" # keep in sync with setup.py
35
-
36
-
37
- Index = TypeVar("Index", int, slice, np.ndarray, None)
38
- XarrayData = TypeVar("XarrayData", xarray.Dataset, xarray.DataArray)
39
-
40
-
41
- def _numpy_to_tensorstore_index(index: Index, size: int) -> Index:
42
- """Switch from NumPy to TensorStore indexing conventions."""
43
- # https://google.github.io/tensorstore/python/indexing.html#differences-compared-to-numpy-indexing
44
- if index is None:
45
- return None
46
- elif isinstance(index, int):
47
- # Negative integers do not count from the end in TensorStore
48
- return index + size if index < 0 else index
49
- elif isinstance(index, slice):
50
- start = _numpy_to_tensorstore_index(index.start, size)
51
- stop = _numpy_to_tensorstore_index(index.stop, size)
52
- if stop is not None:
53
- # TensorStore does not allow out of bounds slicing
54
- stop = min(stop, size)
55
- return slice(start, stop, index.step)
56
- else:
57
- assert isinstance(index, np.ndarray) # noqa S101
58
- return np.where(index < 0, index + size, index)
59
-
60
-
61
- @dataclasses.dataclass(frozen=True)
62
- class _TensorStoreAdapter(indexing.ExplicitlyIndexed):
63
- """TensorStore array that can be wrapped by xarray.Variable.
64
-
65
- We use Xarray's semi-internal ExplicitlyIndexed API so that Xarray will not
66
- attempt to load our array into memory as a NumPy array. In the future, this
67
- should be supported by public Xarray APIs, as part of the refactor discussed
68
- in: https://github.com/pydata/xarray/issues/3981
69
- """
70
-
71
- array: tensorstore.TensorStore
72
- future: tensorstore.Future | None = None
73
-
74
- @property
75
- def shape(self) -> tuple[int, ...]:
76
- return self.array.shape
77
-
78
- @property
79
- def dtype(self) -> np.dtype:
80
- return self.array.dtype.numpy_dtype
81
-
82
- @property
83
- def ndim(self) -> int:
84
- return len(self.shape)
85
-
86
- @property
87
- def size(self) -> int:
88
- return math.prod(self.shape)
89
-
90
- def __getitem__(self, key: indexing.ExplicitIndexer) -> _TensorStoreAdapter:
91
- index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
92
- if isinstance(key, indexing.OuterIndexer):
93
- # TODO(shoyer): fix this for newer versions of Xarray.
94
- # We get the error message:
95
- # AttributeError: '_TensorStoreAdapter' object has no attribute 'oindex'
96
- indexed = self.array.oindex[index_tuple]
97
- elif isinstance(key, indexing.VectorizedIndexer):
98
- indexed = self.array.vindex[index_tuple]
99
- else:
100
- assert isinstance(key, indexing.BasicIndexer) # noqa S101
101
- indexed = self.array[index_tuple]
102
- # Translate to the origin so repeated indexing is relative to the new bounds
103
- # like NumPy, not absolute like TensorStore
104
- translated = indexed[tensorstore.d[:].translate_to[0]]
105
- return type(self)(translated)
106
-
107
- def __setitem__(self, key: indexing.ExplicitIndexer, value) -> None: # noqa ANN001
108
- index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
109
- if isinstance(key, indexing.OuterIndexer):
110
- self.array.oindex[index_tuple] = value
111
- elif isinstance(key, indexing.VectorizedIndexer):
112
- self.array.vindex[index_tuple] = value
113
- else:
114
- assert isinstance(key, indexing.BasicIndexer) # noqa S101
115
- self.array[index_tuple] = value
116
- # Invalidate the future so that the next read will pick up the new value
117
- object.__setattr__(self, "future", None)
118
-
119
- # xarray>2024.02.0 uses oindex and vindex properties, which are expected to
120
- # return objects whose __getitem__ method supports the appropriate form of
121
- # indexing.
122
- @property
123
- def oindex(self) -> _TensorStoreAdapter:
124
- return self
125
-
126
- @property
127
- def vindex(self) -> _TensorStoreAdapter:
128
- return self
129
-
130
- def transpose(self, order: tuple[int, ...]) -> _TensorStoreAdapter:
131
- transposed = self.array[tensorstore.d[order].transpose[:]]
132
- return type(self)(transposed)
133
-
134
- def read(self) -> _TensorStoreAdapter:
135
- future = self.array.read()
136
- return type(self)(self.array, future)
137
-
138
- def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: # type: ignore
139
- future = self.array.read() if self.future is None else self.future
140
- return np.asarray(future.result(), dtype=dtype)
141
-
142
- def get_duck_array(self) -> np.ndarray:
143
- # special method for xarray to return an in-memory (computed) representation
144
- return np.asarray(self)
145
-
146
- # Work around the missing __copy__ and __deepcopy__ methods from TensorStore,
147
- # which are needed for Xarray:
148
- # https://github.com/google/tensorstore/issues/109
149
- # TensorStore objects are immutable, so there's no need to actually copy them.
150
-
151
- def __copy__(self) -> _TensorStoreAdapter:
152
- return type(self)(self.array, self.future)
153
-
154
- def __deepcopy__(self, memo) -> _TensorStoreAdapter: # noqa ANN001
155
- return self.__copy__()
156
-
157
-
158
- def _read_tensorstore(
159
- array: indexing.ExplicitlyIndexed,
160
- ) -> indexing.ExplicitlyIndexed:
161
- """Starts async reading on a TensorStore array."""
162
- return array.read() if isinstance(array, _TensorStoreAdapter) else array
163
-
164
-
165
- def read(xarraydata: XarrayData, /) -> XarrayData:
166
- """Starts async reads on all TensorStore arrays."""
167
- # pylint: disable=protected-access
168
- if isinstance(xarraydata, xarray.Dataset):
169
- data = {
170
- name: _read_tensorstore(var.variable._data)
171
- for name, var in xarraydata.data_vars.items()
172
- }
173
- elif isinstance(xarraydata, xarray.DataArray):
174
- data = _read_tensorstore(xarraydata.variable._data)
175
- else:
176
- raise TypeError(f"argument is not a DataArray or Dataset: {xarraydata}")
177
- # pylint: enable=protected-access
178
- return xarraydata.copy(data=data)
179
-
180
-
181
- _DEFAULT_STORAGE_DRIVER = "file"
182
-
183
-
184
- def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
185
- if re.match(r"\w+\://", path): # path is a URI
186
- kv_store = path
187
- else:
188
- kv_store = {"driver": _DEFAULT_STORAGE_DRIVER, "path": path}
189
-
190
- if zarr_format == 2:
191
- return {"driver": "zarr2", "kvstore": kv_store}
192
- else:
193
- return {"driver": "zarr3", "kvstore": kv_store}
194
-
195
-
196
- def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset) -> None:
197
- """Check a dataset for data variables that would need masking or scaling."""
198
- advice = (
199
- "Consider re-opening with xarray_tensorstore.open_zarr(..., "
200
- "mask_and_scale=False), or falling back to use xarray.open_zarr()."
201
- )
202
- for k in ds:
203
- encoding = ds[k].encoding
204
- for attr in ["_FillValue", "missing_value"]:
205
- fill_value = encoding.get(attr, np.nan)
206
- if fill_value == fill_value: # pylint: disable=comparison-with-itself
207
- raise ValueError(
208
- f"variable {k} has non-NaN fill value, which is not supported by"
209
- f" xarray-tensorstore: {fill_value}. {advice}",
210
- )
211
- for attr in ["scale_factor", "add_offset"]:
212
- if attr in encoding:
213
- raise ValueError(
214
- f"variable {k} uses scale/offset encoding, which is not supported"
215
- f" by xarray-tensorstore: {encoding}. {advice}",
216
- )
217
-
218
-
219
- def open_zarr(
220
- path: str,
221
- *,
222
- context: tensorstore.Context | None = None,
223
- mask_and_scale: bool = True,
224
- write: bool = False,
225
- ) -> xarray.Dataset:
226
- """Open an xarray.Dataset from Zarr using TensorStore.
227
-
228
- For best performance, explicitly call `read()` to asynchronously load data
229
- in parallel. Otherwise, xarray's `.compute()` method will load each variable's
230
- data in sequence.
231
-
232
- Example usage:
233
-
234
- import xarray_tensorstore
235
-
236
- ds = xarray_tensorstore.open_zarr(path)
237
-
238
- # indexing & transposing is lazy
239
- example = ds.sel(time='2020-01-01').transpose('longitude', 'latitude', ...)
240
-
241
- # start reading data asynchronously
242
- read_example = xarray_tensorstore.read(example)
243
-
244
- # blocking conversion of the data into NumPy arrays
245
- numpy_example = read_example.compute()
246
-
247
- Args:
248
- path: path or URI to Zarr group to open.
249
- context: TensorStore configuration options to use when opening arrays.
250
- mask_and_scale: if True (default), attempt to apply masking and scaling like
251
- xarray.open_zarr(). This is only supported for coordinate variables and
252
- otherwise will raise an error.
253
- write: Allow write access. Defaults to False.
254
-
255
- Returns:
256
- Dataset with all data variables opened via TensorStore.
257
- """
258
- # We use xarray.open_zarr (which uses Zarr Python internally) to open the
259
- # initial version of the dataset for a few reasons:
260
- # 1. TensorStore does not support Zarr groups or array attributes, which we
261
- # need to open in the xarray.Dataset. We use Zarr Python instead of
262
- # parsing the raw Zarr metadata files ourselves.
263
- # 2. TensorStore doesn't support non-standard Zarr dtypes like UTF-8 strings.
264
- # 3. Xarray's open_zarr machinery does some pre-processing (e.g., from numeric
265
- # to datetime64 dtypes) that we would otherwise need to invoke explicitly
266
- # via xarray.decode_cf().
267
- #
268
- # Fortunately (2) and (3) are most commonly encountered on small coordinate
269
- # arrays, for which the performance advantages of TensorStore are irrelevant.
270
-
271
- if context is None:
272
- context = tensorstore.Context()
273
-
274
- # chunks=None means avoid using dask
275
- ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
276
-
277
- # find out if its 2 or 3
278
- try:
279
- # this should work with zarr>=3 - https://github.com/zarr-developers/zarr-python
280
- zarr_format = zarr.open(path).metadata.zarr_format
281
- except: # noqa E722
282
- # try to open it, but if it fails, assume zarr_format 2
283
- zarr_format = 2
284
-
285
- if mask_and_scale:
286
- # Data variables get replaced below with _TensorStoreAdapter arrays, which
287
- # don't get masked or scaled. Raising an error avoids surprising users with
288
- # incorrect data values.
289
- _raise_if_mask_and_scale_used_for_data_vars(ds)
290
-
291
- specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds}
292
- array_futures = {
293
- k: tensorstore.open(spec, read=True, write=write, context=context)
294
- for k, spec in specs.items()
295
- }
296
- arrays = {k: v.result() for k, v in array_futures.items()}
297
- new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
298
-
299
- return ds.copy(data=new_data)