anndata 0.12.0rc4__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,7 +59,7 @@ def _gen_dataframe_mapping(
59
59
  df = pd.DataFrame(
60
60
  anno,
61
61
  index=None if length is None else mk_index(length),
62
- columns=None if len(anno) else [],
62
+ columns=None if anno else [],
63
63
  )
64
64
 
65
65
  if length is None:
anndata/_core/anndata.py CHANGED
@@ -56,13 +56,13 @@ if TYPE_CHECKING:
56
56
 
57
57
  from zarr.storage import StoreLike
58
58
 
59
- from ..compat import Index1D, XDataset
59
+ from ..compat import Index1D, Index1DNorm, XDataset
60
60
  from ..typing import XDataType
61
61
  from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
62
62
  from .index import Index
63
63
 
64
64
 
65
- class AnnData(metaclass=utils.DeprecationMixinMeta):
65
+ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
66
66
  """\
67
67
  An annotated data matrix.
68
68
 
@@ -197,6 +197,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
197
197
 
198
198
  _accessors: ClassVar[set[str]] = set()
199
199
 
200
+ # view attributes
201
+ _adata_ref: AnnData | None
202
+ _oidx: Index1DNorm | None
203
+ _vidx: Index1DNorm | None
204
+
200
205
  @old_positionals(
201
206
  "obsm",
202
207
  "varm",
@@ -226,8 +231,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
226
231
  asview: bool = False,
227
232
  obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
228
233
  varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
229
- oidx: Index1D | None = None,
230
- vidx: Index1D | None = None,
234
+ oidx: Index1DNorm | int | np.integer | None = None,
235
+ vidx: Index1DNorm | int | np.integer | None = None,
231
236
  ):
232
237
  # check for any multi-indices that aren’t later checked in coerce_array
233
238
  for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
@@ -237,6 +242,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
237
242
  if not isinstance(X, AnnData):
238
243
  msg = "`X` has to be an AnnData object."
239
244
  raise ValueError(msg)
245
+ assert oidx is not None
246
+ assert vidx is not None
240
247
  self._init_as_view(X, oidx, vidx)
241
248
  else:
242
249
  self._init_as_actual(
@@ -256,7 +263,12 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
256
263
  filemode=filemode,
257
264
  )
258
265
 
259
- def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
266
+ def _init_as_view(
267
+ self,
268
+ adata_ref: AnnData,
269
+ oidx: Index1DNorm | int | np.integer,
270
+ vidx: Index1DNorm | int | np.integer,
271
+ ):
260
272
  if adata_ref.isbacked and adata_ref.is_view:
261
273
  msg = (
262
274
  "Currently, you cannot index repeatedly into a backed AnnData, "
@@ -277,6 +289,9 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
277
289
  vidx += adata_ref.n_vars * (vidx < 0)
278
290
  vidx = slice(vidx, vidx + 1, 1)
279
291
  if adata_ref.is_view:
292
+ assert adata_ref._adata_ref is not None
293
+ assert adata_ref._oidx is not None
294
+ assert adata_ref._vidx is not None
280
295
  prev_oidx, prev_vidx = adata_ref._oidx, adata_ref._vidx
281
296
  adata_ref = adata_ref._adata_ref
282
297
  oidx, vidx = _resolve_idxs((prev_oidx, prev_vidx), (oidx, vidx), adata_ref)
@@ -1004,7 +1019,9 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1004
1019
 
1005
1020
  write_attribute(self.file._file, attr, value)
1006
1021
 
1007
- def _normalize_indices(self, index: Index | None) -> tuple[slice, slice]:
1022
+ def _normalize_indices(
1023
+ self, index: Index | None
1024
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
1008
1025
  return _normalize_indices(index, self.obs_names, self.var_names)
1009
1026
 
1010
1027
  # TODO: this is not quite complete...
anndata/_core/index.py CHANGED
@@ -14,18 +14,18 @@ from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
14
14
  from .xarray import Dataset2D
15
15
 
16
16
  if TYPE_CHECKING:
17
- from ..compat import Index, Index1D
17
+ from ..compat import Index, Index1D, Index1DNorm
18
18
 
19
19
 
20
20
  def _normalize_indices(
21
21
  index: Index | None, names0: pd.Index, names1: pd.Index
22
- ) -> tuple[slice, slice]:
22
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
23
23
  # deal with tuples of length 1
24
24
  if isinstance(index, tuple) and len(index) == 1:
25
25
  index = index[0]
26
26
  # deal with pd.Series
27
27
  if isinstance(index, pd.Series):
28
- index: Index = index.values
28
+ index = index.values
29
29
  if isinstance(index, tuple):
30
30
  # TODO: The series should probably be aligned first
31
31
  index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
@@ -36,15 +36,8 @@ def _normalize_indices(
36
36
 
37
37
 
38
38
  def _normalize_index( # noqa: PLR0911, PLR0912
39
- indexer: slice
40
- | np.integer
41
- | int
42
- | str
43
- | Sequence[bool | int | np.integer]
44
- | np.ndarray
45
- | pd.Index,
46
- index: pd.Index,
47
- ) -> slice | int | np.ndarray: # ndarray of int or bool
39
+ indexer: Index1D, index: pd.Index
40
+ ) -> Index1DNorm | int | np.integer:
48
41
  # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
49
42
  if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
50
43
  msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
@@ -212,7 +205,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index):
212
205
 
213
206
  # Registration for SparseDataset occurs in sparse_dataset.py
214
207
  @_subset.register(h5py.Dataset)
215
- def _subset_dataset(d, subset_idx):
208
+ def _subset_dataset(d: h5py.Dataset, subset_idx: Index):
216
209
  if not isinstance(subset_idx, tuple):
217
210
  subset_idx = (subset_idx,)
218
211
  ordered = list(subset_idx)
anndata/_core/merge.py CHANGED
@@ -904,12 +904,6 @@ def concat_arrays( # noqa: PLR0911, PLR0912
904
904
  ],
905
905
  format="csr",
906
906
  )
907
- scipy_version = Version(scipy.__version__)
908
- # Bug where xstack produces a matrix not an array in 1.11.*
909
- if use_sparse_array and (scipy_version.major, scipy_version.minor) == (1, 11):
910
- if mat.format == "csc":
911
- return sparse.csc_array(mat)
912
- return sparse.csr_array(mat)
913
907
  return mat
914
908
  else:
915
909
  return np.concatenate(
anndata/_core/raw.py CHANGED
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  from collections.abc import Mapping, Sequence
18
18
  from typing import ClassVar
19
19
 
20
- from ..compat import CSMatrix
20
+ from ..compat import CSMatrix, Index, Index1DNorm
21
21
  from .aligned_mapping import AxisArraysView
22
22
  from .anndata import AnnData
23
23
  from .sparse_dataset import BaseCompressedSparseDataset
@@ -121,7 +121,7 @@ class Raw:
121
121
  def obs_names(self) -> pd.Index[str]:
122
122
  return self._adata.obs_names
123
123
 
124
- def __getitem__(self, index):
124
+ def __getitem__(self, index: Index) -> Raw:
125
125
  oidx, vidx = self._normalize_indices(index)
126
126
 
127
127
  # To preserve two dimensional shape
@@ -169,7 +169,9 @@ class Raw:
169
169
  uns=self._adata.uns.copy(),
170
170
  )
171
171
 
172
- def _normalize_indices(self, packed_index):
172
+ def _normalize_indices(
173
+ self, packed_index: Index
174
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
173
175
  # deal with slicing with pd.Series
174
176
  if isinstance(packed_index, pd.Series):
175
177
  packed_index = packed_index.values
@@ -165,7 +165,11 @@ class BackedSparseMatrix(_cs_matrix):
165
165
  def _get_contiguous_compressed_slice(
166
166
  self, s: slice
167
167
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
168
- new_indptr = self.indptr[s.start : s.stop + 1].copy()
168
+ new_indptr = self.indptr[s.start : s.stop + 1]
169
+ # If indptr is cached, we need to make a copy of the subset
170
+ # so as not to alter the underlying cached data.
171
+ if isinstance(self.indptr, np.ndarray):
172
+ new_indptr = new_indptr.copy()
169
173
 
170
174
  start = new_indptr[0]
171
175
  stop = new_indptr[-1]
anndata/_core/views.py CHANGED
@@ -29,8 +29,12 @@ if TYPE_CHECKING:
29
29
  from collections.abc import Callable, Iterable, KeysView, Sequence
30
30
  from typing import Any, ClassVar
31
31
 
32
+ from numpy.typing import NDArray
33
+
32
34
  from anndata import AnnData
33
35
 
36
+ from ..compat import Index1DNorm
37
+
34
38
 
35
39
  @contextmanager
36
40
  def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):
@@ -433,18 +437,24 @@ except ImportError:
433
437
  pass
434
438
 
435
439
 
436
- def _resolve_idxs(old, new, adata):
437
- t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
438
- return t
440
+ def _resolve_idxs(
441
+ old: tuple[Index1DNorm, Index1DNorm],
442
+ new: tuple[Index1DNorm, Index1DNorm],
443
+ adata: AnnData,
444
+ ) -> tuple[Index1DNorm, Index1DNorm]:
445
+ o, v = (_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
446
+ return o, v
439
447
 
440
448
 
441
449
  @singledispatch
442
- def _resolve_idx(old, new, l):
443
- return old[new]
450
+ def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm:
451
+ raise NotImplementedError
444
452
 
445
453
 
446
454
  @_resolve_idx.register(np.ndarray)
447
- def _resolve_idx_ndarray(old, new, l):
455
+ def _resolve_idx_ndarray(
456
+ old: NDArray[np.bool_] | NDArray[np.integer], new: Index1DNorm, l: Literal[0, 1]
457
+ ) -> NDArray[np.bool_] | NDArray[np.integer]:
448
458
  if is_bool_dtype(old) and is_bool_dtype(new):
449
459
  mask_new = np.zeros_like(old)
450
460
  mask_new[np.flatnonzero(old)[new]] = True
@@ -454,21 +464,17 @@ def _resolve_idx_ndarray(old, new, l):
454
464
  return old[new]
455
465
 
456
466
 
457
- @_resolve_idx.register(np.integer)
458
- @_resolve_idx.register(int)
459
- def _resolve_idx_scalar(old, new, l):
460
- return np.array([old])[new]
461
-
462
-
463
467
  @_resolve_idx.register(slice)
464
- def _resolve_idx_slice(old, new, l):
468
+ def _resolve_idx_slice(
469
+ old: slice, new: Index1DNorm, l: Literal[0, 1]
470
+ ) -> slice | NDArray[np.integer]:
465
471
  if isinstance(new, slice):
466
472
  return _resolve_idx_slice_slice(old, new, l)
467
473
  else:
468
474
  return np.arange(*old.indices(l))[new]
469
475
 
470
476
 
471
- def _resolve_idx_slice_slice(old, new, l):
477
+ def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice:
472
478
  r = range(*old.indices(l))[new]
473
479
  # Convert back to slice
474
480
  start, stop, step = r.start, r.stop, r.step
anndata/_core/xarray.py CHANGED
@@ -184,18 +184,6 @@ class Dataset2D:
184
184
  Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`.
185
185
  """
186
186
 
187
- @dataclass(frozen=True)
188
- class IlocGetter:
189
- _ds: XDataset
190
- _coord: str
191
-
192
- def __getitem__(self, idx) -> Dataset2D:
193
- # xarray seems to have some code looking for a second entry in tuples,
194
- # so we unpack the tuple
195
- if isinstance(idx, tuple) and len(idx) == 1:
196
- idx = idx[0]
197
- return Dataset2D(self._ds.isel(**{self._coord: idx}))
198
-
199
187
  return IlocGetter(self.ds, self.index_dim)
200
188
 
201
189
  # See https://github.com/pydata/xarray/blob/568f3c1638d2d34373408ce2869028faa3949446/xarray/core/dataset.py#L1239-L1248
@@ -245,7 +233,7 @@ class Dataset2D:
245
233
  if df.index.name != index_key and index_key is not None:
246
234
  df = df.set_index(index_key)
247
235
  for col in set(self.columns) - non_nullable_string_cols:
248
- df[col] = pd.array(self[col].data, dtype="string")
236
+ df[col] = df[col].astype(dtype="string")
249
237
  df.index.name = None # matches old AnnData object
250
238
  return df
251
239
 
@@ -389,9 +377,12 @@ class Dataset2D:
389
377
  }
390
378
  el = self.ds.drop_vars(extension_arrays.keys())
391
379
  el = el.reindex({index_dim: index}, method=None, fill_value=fill_value)
392
- for col in self.ds:
393
- el[col] = pd.Series(self.ds[col], index=self.index).reindex(
394
- index, fill_value=fill_value
380
+ for col, data in extension_arrays.items():
381
+ el[col] = XDataArray.from_series(
382
+ pd.Series(data.data, index=self.index).reindex(
383
+ index.rename(self.index.name) if index is not None else index,
384
+ fill_value=fill_value,
385
+ )
395
386
  )
396
387
  return Dataset2D(el)
397
388
 
@@ -399,3 +390,16 @@ class Dataset2D:
399
390
  def _items(self):
400
391
  for col in self:
401
392
  yield col, self[col]
393
+
394
+
395
+ @dataclass(frozen=True)
396
+ class IlocGetter:
397
+ _ds: XDataset
398
+ _coord: str
399
+
400
+ def __getitem__(self, idx) -> Dataset2D:
401
+ # xarray seems to have some code looking for a second entry in tuples,
402
+ # so we unpack the tuple
403
+ if isinstance(idx, tuple) and len(idx) == 1:
404
+ idx = idx[0]
405
+ return Dataset2D(self._ds.isel(**{self._coord: idx}))
anndata/_io/h5ad.py CHANGED
@@ -4,7 +4,7 @@ import re
4
4
  from functools import partial
5
5
  from pathlib import Path
6
6
  from types import MappingProxyType
7
- from typing import TYPE_CHECKING, TypeVar
7
+ from typing import TYPE_CHECKING, TypeVar, cast
8
8
  from warnings import warn
9
9
 
10
10
  import h5py
@@ -36,11 +36,12 @@ from .utils import (
36
36
  )
37
37
 
38
38
  if TYPE_CHECKING:
39
- from collections.abc import Callable, Collection, Mapping, Sequence
39
+ from collections.abc import Callable, Collection, Container, Mapping, Sequence
40
40
  from os import PathLike
41
41
  from typing import Any, Literal
42
42
 
43
43
  from .._core.file_backing import AnnDataFileManager
44
+ from .._core.raw import Raw
44
45
 
45
46
  T = TypeVar("T")
46
47
 
@@ -82,29 +83,18 @@ def write_h5ad(
82
83
  # TODO: Use spec writing system for this
83
84
  # Currently can't use write_dispatched here because this function is also called to do an
84
85
  # inplace update of a backed object, which would delete "/"
85
- f = f["/"]
86
+ f = cast("h5py.Group", f["/"])
86
87
  f.attrs.setdefault("encoding-type", "anndata")
87
88
  f.attrs.setdefault("encoding-version", "0.1.0")
88
89
 
89
- if "X" in as_dense and isinstance(
90
- adata.X, CSMatrix | BaseCompressedSparseDataset
91
- ):
92
- write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
93
- elif not (adata.isbacked and Path(adata.filename) == Path(filepath)):
94
- # If adata.isbacked, X should already be up to date
95
- write_elem(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
96
- if "raw/X" in as_dense and isinstance(
97
- adata.raw.X, CSMatrix | BaseCompressedSparseDataset
98
- ):
99
- write_sparse_as_dense(
100
- f, "raw/X", adata.raw.X, dataset_kwargs=dataset_kwargs
101
- )
102
- write_elem(f, "raw/var", adata.raw.var, dataset_kwargs=dataset_kwargs)
103
- write_elem(
104
- f, "raw/varm", dict(adata.raw.varm), dataset_kwargs=dataset_kwargs
105
- )
106
- elif adata.raw is not None:
107
- write_elem(f, "raw", adata.raw, dataset_kwargs=dataset_kwargs)
90
+ _write_x(
91
+ f,
92
+ adata, # accessing adata.X reopens adata.file if it’s backed
93
+ is_backed=adata.isbacked and adata.filename == filepath,
94
+ as_dense=as_dense,
95
+ dataset_kwargs=dataset_kwargs,
96
+ )
97
+ _write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs)
108
98
  write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs)
109
99
  write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs)
110
100
  write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs)
@@ -115,6 +105,41 @@ def write_h5ad(
115
105
  write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs)
116
106
 
117
107
 
108
+ def _write_x(
109
+ f: h5py.Group,
110
+ adata: AnnData,
111
+ *,
112
+ is_backed: bool,
113
+ as_dense: Container[str],
114
+ dataset_kwargs: Mapping[str, Any],
115
+ ) -> None:
116
+ if "X" in as_dense and isinstance(adata.X, CSMatrix | BaseCompressedSparseDataset):
117
+ write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
118
+ elif is_backed:
119
+ pass # If adata.isbacked, X should already be up to date
120
+ elif adata.X is None:
121
+ f.pop("X", None)
122
+ else:
123
+ write_elem(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
124
+
125
+
126
+ def _write_raw(
127
+ f: h5py.Group,
128
+ raw: Raw,
129
+ *,
130
+ as_dense: Container[str],
131
+ dataset_kwargs: Mapping[str, Any],
132
+ ) -> None:
133
+ if "raw/X" in as_dense and isinstance(
134
+ raw.X, CSMatrix | BaseCompressedSparseDataset
135
+ ):
136
+ write_sparse_as_dense(f, "raw/X", raw.X, dataset_kwargs=dataset_kwargs)
137
+ write_elem(f, "raw/var", raw.var, dataset_kwargs=dataset_kwargs)
138
+ write_elem(f, "raw/varm", dict(raw.varm), dataset_kwargs=dataset_kwargs)
139
+ elif raw is not None:
140
+ write_elem(f, "raw", raw, dataset_kwargs=dataset_kwargs)
141
+
142
+
118
143
  @report_write_key_on_error
119
144
  @write_spec(IOSpec("array", "0.2.0"))
120
145
  def write_sparse_as_dense(
@@ -176,7 +201,7 @@ def read_h5ad_backed(
176
201
 
177
202
  def read_h5ad(
178
203
  filename: PathLike[str] | str,
179
- backed: Literal["r", "r+"] | bool | None = None,
204
+ backed: Literal["r", "r+"] | bool | None = None, # noqa: FBT001
180
205
  *,
181
206
  as_sparse: Sequence[str] = (),
182
207
  as_sparse_fmt: type[CSMatrix] = sparse.csr_matrix,
anndata/_io/read.py CHANGED
@@ -22,9 +22,11 @@ if TYPE_CHECKING:
22
22
  from collections.abc import Generator, Iterable, Iterator, Mapping
23
23
 
24
24
 
25
+ @old_positionals("first_column_names", "dtype")
25
26
  def read_csv(
26
27
  filename: PathLike[str] | str | Iterator[str],
27
28
  delimiter: str | None = ",",
29
+ *,
28
30
  first_column_names: bool | None = None,
29
31
  dtype: str = "float32",
30
32
  ) -> AnnData:
@@ -46,7 +48,9 @@ def read_csv(
46
48
  dtype
47
49
  Numpy data type.
48
50
  """
49
- return read_text(filename, delimiter, first_column_names, dtype)
51
+ return read_text(
52
+ filename, delimiter, first_column_names=first_column_names, dtype=dtype
53
+ )
50
54
 
51
55
 
52
56
  def read_excel(
@@ -331,9 +335,11 @@ def read_mtx(filename: PathLike[str] | str, dtype: str = "float32") -> AnnData:
331
335
  return AnnData(X)
332
336
 
333
337
 
338
+ @old_positionals("first_column_names", "dtype")
334
339
  def read_text(
335
340
  filename: PathLike[str] | str | Iterator[str],
336
341
  delimiter: str | None = None,
342
+ *,
337
343
  first_column_names: bool | None = None,
338
344
  dtype: str = "float32",
339
345
  ) -> AnnData:
@@ -356,18 +362,26 @@ def read_text(
356
362
  Numpy data type.
357
363
  """
358
364
  if not isinstance(filename, PathLike | str | bytes):
359
- return _read_text(filename, delimiter, first_column_names, dtype)
365
+ return _read_text(
366
+ filename, delimiter, first_column_names=first_column_names, dtype=dtype
367
+ )
360
368
 
361
369
  filename = Path(filename)
362
370
  if filename.suffix == ".gz":
363
371
  with gzip.open(str(filename), mode="rt") as f:
364
- return _read_text(f, delimiter, first_column_names, dtype)
372
+ return _read_text(
373
+ f, delimiter, first_column_names=first_column_names, dtype=dtype
374
+ )
365
375
  elif filename.suffix == ".bz2":
366
376
  with bz2.open(str(filename), mode="rt") as f:
367
- return _read_text(f, delimiter, first_column_names, dtype)
377
+ return _read_text(
378
+ f, delimiter, first_column_names=first_column_names, dtype=dtype
379
+ )
368
380
  else:
369
381
  with filename.open() as f:
370
- return _read_text(f, delimiter, first_column_names, dtype)
382
+ return _read_text(
383
+ f, delimiter, first_column_names=first_column_names, dtype=dtype
384
+ )
371
385
 
372
386
 
373
387
  def _iter_lines(file_like: Iterable[str]) -> Generator[str, None, None]:
@@ -381,6 +395,7 @@ def _iter_lines(file_like: Iterable[str]) -> Generator[str, None, None]:
381
395
  def _read_text( # noqa: PLR0912, PLR0915
382
396
  f: Iterator[str],
383
397
  delimiter: str | None,
398
+ *,
384
399
  first_column_names: bool | None,
385
400
  dtype: str,
386
401
  ) -> AnnData:
@@ -132,7 +132,7 @@ def read_sparse_as_dask(
132
132
  path_or_sparse_dataset = (
133
133
  Path(filename(elem))
134
134
  if isinstance(elem, H5Group)
135
- else ad.io.sparse_dataset(elem)
135
+ else ad.io.sparse_dataset(elem, should_cache_indptr=False)
136
136
  )
137
137
  elem_name = get_elem_name(elem)
138
138
  shape: tuple[int, int] = tuple(elem.attrs["shape"])
@@ -177,21 +177,37 @@ def read_sparse_as_dask(
177
177
  return da_mtx
178
178
 
179
179
 
180
+ def resolve_chunks(
181
+ elem: H5Array | ZarrArray,
182
+ chunks_arg: tuple[int, ...] | None,
183
+ shape: tuple[int, ...],
184
+ ) -> tuple[int, ...]:
185
+ shape = tuple(elem.shape)
186
+ if chunks_arg is not None:
187
+ # None and -1 on a given axis indicate that one should use the shape
188
+ # in `dask`'s semantics.
189
+ return tuple(
190
+ c if c not in {None, -1} else s
191
+ for c, s in zip(chunks_arg, shape, strict=True)
192
+ )
193
+ elif elem.chunks is None: # h5 unchunked
194
+ return tuple(min(_DEFAULT_STRIDE, s) for s in shape)
195
+ return elem.chunks
196
+
197
+
180
198
  @_LAZY_REGISTRY.register_read(H5Array, IOSpec("string-array", "0.2.0"))
181
199
  def read_h5_string_array(
182
200
  elem: H5Array,
183
201
  *,
184
202
  _reader: LazyReader,
185
- chunks: tuple[int, int] | None = None,
203
+ chunks: tuple[int] | None = None,
186
204
  ) -> DaskArray:
187
205
  import dask.array as da
188
206
 
189
207
  from anndata._io.h5ad import read_dataset
190
208
 
191
- return da.from_array(
192
- read_dataset(elem),
193
- chunks=chunks if chunks is not None else (_DEFAULT_STRIDE,) * len(elem.shape),
194
- )
209
+ chunks = resolve_chunks(elem, chunks, tuple(elem.shape))
210
+ return da.from_array(read_dataset(elem), chunks=chunks)
195
211
 
196
212
 
197
213
  @_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0"))
@@ -204,13 +220,7 @@ def read_h5_array(
204
220
  elem_name: str = elem.name
205
221
  shape = tuple(elem.shape)
206
222
  dtype = elem.dtype
207
- chunks = (
208
- tuple(
209
- c if c not in {None, -1} else s for c, s in zip(chunks, shape, strict=True)
210
- )
211
- if chunks is not None
212
- else tuple(min(_DEFAULT_STRIDE, s) for s in shape)
213
- )
223
+ chunks = resolve_chunks(elem, chunks, shape)
214
224
 
215
225
  chunk_layout = tuple(
216
226
  compute_chunk_layout_for_axis_size(chunks[i], shape[i])
@@ -228,7 +238,6 @@ def read_h5_array(
228
238
  def read_zarr_array(
229
239
  elem: ZarrArray, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None
230
240
  ) -> DaskArray:
231
- chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks
232
241
  import dask.array as da
233
242
 
234
243
  return da.from_zarr(elem, chunks=chunks)
@@ -284,9 +293,10 @@ def read_dataframe(
284
293
  *,
285
294
  _reader: LazyReader,
286
295
  use_range_index: bool = False,
296
+ chunks: tuple[int] | None = None,
287
297
  ) -> Dataset2D:
288
298
  elem_dict = {
289
- k: _reader.read_elem(elem[k])
299
+ k: _reader.read_elem(elem[k], chunks=chunks)
290
300
  for k in [*elem.attrs["column-order"], elem.attrs["_index"]]
291
301
  }
292
302
  # If we use a range index, the coord axis needs to have the special dim name