anndata 0.12.1__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anndata/__init__.py CHANGED
@@ -12,7 +12,6 @@ from ._core.extensions import register_anndata_namespace
12
12
  from ._core.merge import concat
13
13
  from ._core.raw import Raw
14
14
  from ._settings import settings
15
- from ._version import __version__
16
15
  from ._warnings import (
17
16
  ExperimentalFeatureWarning,
18
17
  ImplicitModificationWarning,
@@ -28,22 +27,6 @@ from . import abc, experimental, typing, io, types # isort: skip
28
27
  # We use these in tests by attribute access
29
28
  from . import logging # noqa: F401 # isort: skip
30
29
 
31
- _DEPRECATED_IO = (
32
- "read_loom",
33
- "read_hdf",
34
- "read_excel",
35
- "read_umi_tools",
36
- "read_csv",
37
- "read_text",
38
- "read_mtx",
39
- )
40
- _DEPRECATED = {method: f"io.{method}" for method in _DEPRECATED_IO}
41
-
42
-
43
- def __getattr__(attr_name: str) -> Any:
44
- return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
45
-
46
-
47
30
  __all__ = [
48
31
  "AnnData",
49
32
  "ExperimentalFeatureWarning",
@@ -51,7 +34,6 @@ __all__ = [
51
34
  "OldFormatWarning",
52
35
  "Raw",
53
36
  "WriteWarning",
54
- "__version__",
55
37
  "abc",
56
38
  "concat",
57
39
  "experimental",
@@ -63,3 +45,26 @@ __all__ = [
63
45
  "types",
64
46
  "typing",
65
47
  ]
48
+
49
+ _DEPRECATED_IO = (
50
+ "read_loom",
51
+ "read_hdf",
52
+ "read_excel",
53
+ "read_umi_tools",
54
+ "read_csv",
55
+ "read_text",
56
+ "read_mtx",
57
+ )
58
+ _DEPRECATED = {method: f"io.{method}" for method in _DEPRECATED_IO}
59
+
60
+
61
+ def __getattr__(attr_name: str) -> Any:
62
+ if attr_name == "__version__":
63
+ import warnings
64
+ from importlib.metadata import version
65
+
66
+ msg = "`__version__` is deprecated, use `importlib.metadata.version('anndata')` instead."
67
+ warnings.warn(msg, FutureWarning, stacklevel=2)
68
+ return version("anndata")
69
+
70
+ return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
anndata/_core/anndata.py CHANGED
@@ -42,11 +42,7 @@ from .index import _normalize_indices, _subset, get_vector
42
42
  from .raw import Raw
43
43
  from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset
44
44
  from .storage import coerce_array
45
- from .views import (
46
- DictView,
47
- _resolve_idxs,
48
- as_view,
49
- )
45
+ from .views import DictView, _resolve_idxs, as_view
50
46
  from .xarray import Dataset2D
51
47
 
52
48
  if TYPE_CHECKING:
@@ -56,7 +52,7 @@ if TYPE_CHECKING:
56
52
 
57
53
  from zarr.storage import StoreLike
58
54
 
59
- from ..compat import Index1D, XDataset
55
+ from ..compat import Index1D, Index1DNorm, XDataset
60
56
  from ..typing import XDataType
61
57
  from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
62
58
  from .index import Index
@@ -197,6 +193,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
197
193
 
198
194
  _accessors: ClassVar[set[str]] = set()
199
195
 
196
+ # view attributes
197
+ _adata_ref: AnnData | None
198
+ _oidx: Index1DNorm | None
199
+ _vidx: Index1DNorm | None
200
+
200
201
  @old_positionals(
201
202
  "obsm",
202
203
  "varm",
@@ -226,8 +227,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
226
227
  asview: bool = False,
227
228
  obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
228
229
  varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
229
- oidx: Index1D | None = None,
230
- vidx: Index1D | None = None,
230
+ oidx: Index1DNorm | int | np.integer | None = None,
231
+ vidx: Index1DNorm | int | np.integer | None = None,
231
232
  ):
232
233
  # check for any multi-indices that aren’t later checked in coerce_array
233
234
  for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
@@ -237,6 +238,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
237
238
  if not isinstance(X, AnnData):
238
239
  msg = "`X` has to be an AnnData object."
239
240
  raise ValueError(msg)
241
+ assert oidx is not None
242
+ assert vidx is not None
240
243
  self._init_as_view(X, oidx, vidx)
241
244
  else:
242
245
  self._init_as_actual(
@@ -256,7 +259,12 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
256
259
  filemode=filemode,
257
260
  )
258
261
 
259
- def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
262
+ def _init_as_view(
263
+ self,
264
+ adata_ref: AnnData,
265
+ oidx: Index1DNorm | int | np.integer,
266
+ vidx: Index1DNorm | int | np.integer,
267
+ ):
260
268
  if adata_ref.isbacked and adata_ref.is_view:
261
269
  msg = (
262
270
  "Currently, you cannot index repeatedly into a backed AnnData, "
@@ -277,6 +285,9 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
277
285
  vidx += adata_ref.n_vars * (vidx < 0)
278
286
  vidx = slice(vidx, vidx + 1, 1)
279
287
  if adata_ref.is_view:
288
+ assert adata_ref._adata_ref is not None
289
+ assert adata_ref._oidx is not None
290
+ assert adata_ref._vidx is not None
280
291
  prev_oidx, prev_vidx = adata_ref._oidx, adata_ref._vidx
281
292
  adata_ref = adata_ref._adata_ref
282
293
  oidx, vidx = _resolve_idxs((prev_oidx, prev_vidx), (oidx, vidx), adata_ref)
@@ -925,22 +936,27 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
925
936
  Is sliced with `data` and `var` but behaves otherwise like a :term:`mapping`.
926
937
  """
927
938
 
939
+ @deprecated("obs (e.g. `k in adata.obs` or `str(adata.obs.columns.tolist())`)")
928
940
  def obs_keys(self) -> list[str]:
929
941
  """List keys of observation annotation :attr:`obs`."""
930
942
  return self._obs.keys().tolist()
931
943
 
944
+ @deprecated("var (e.g. `k in adata.var` or `str(adata.var.columns.tolist())`)")
932
945
  def var_keys(self) -> list[str]:
933
946
  """List keys of variable annotation :attr:`var`."""
934
947
  return self._var.keys().tolist()
935
948
 
949
+ @deprecated("obsm (e.g. `k in adata.obsm` or `adata.obsm.keys() | {'u'}`)")
936
950
  def obsm_keys(self) -> list[str]:
937
951
  """List keys of observation annotation :attr:`obsm`."""
938
952
  return list(self.obsm.keys())
939
953
 
954
+ @deprecated("varm (e.g. `k in adata.varm` or `adata.varm.keys() | {'u'}`)")
940
955
  def varm_keys(self) -> list[str]:
941
956
  """List keys of variable annotation :attr:`varm`."""
942
957
  return list(self.varm.keys())
943
958
 
959
+ @deprecated("uns (e.g. `k in adata.uns` or `sorted(adata.uns)`)")
944
960
  def uns_keys(self) -> list[str]:
945
961
  """List keys of unstructured annotation."""
946
962
  return sorted(self._uns.keys())
@@ -1004,7 +1020,9 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
1004
1020
 
1005
1021
  write_attribute(self.file._file, attr, value)
1006
1022
 
1007
- def _normalize_indices(self, index: Index | None) -> tuple[slice, slice]:
1023
+ def _normalize_indices(
1024
+ self, index: Index | None
1025
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
1008
1026
  return _normalize_indices(index, self.obs_names, self.var_names)
1009
1027
 
1010
1028
  # TODO: this is not quite complete...
@@ -1890,8 +1908,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
1890
1908
  compression_opts=compression_opts,
1891
1909
  as_dense=as_dense,
1892
1910
  )
1893
-
1894
- if self.isbacked:
1911
+ # Only reset the filename if the AnnData object now points to a complete new copy
1912
+ if self.isbacked and not self.is_view:
1895
1913
  self.file.filename = filename
1896
1914
 
1897
1915
  write = write_h5ad # a shortcut and backwards compat
anndata/_core/index.py CHANGED
@@ -14,18 +14,18 @@ from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
14
14
  from .xarray import Dataset2D
15
15
 
16
16
  if TYPE_CHECKING:
17
- from ..compat import Index, Index1D
17
+ from ..compat import Index, Index1D, Index1DNorm
18
18
 
19
19
 
20
20
  def _normalize_indices(
21
21
  index: Index | None, names0: pd.Index, names1: pd.Index
22
- ) -> tuple[slice, slice]:
22
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
23
23
  # deal with tuples of length 1
24
24
  if isinstance(index, tuple) and len(index) == 1:
25
25
  index = index[0]
26
26
  # deal with pd.Series
27
27
  if isinstance(index, pd.Series):
28
- index: Index = index.values
28
+ index = index.values
29
29
  if isinstance(index, tuple):
30
30
  # TODO: The series should probably be aligned first
31
31
  index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
@@ -36,15 +36,8 @@ def _normalize_indices(
36
36
 
37
37
 
38
38
  def _normalize_index( # noqa: PLR0911, PLR0912
39
- indexer: slice
40
- | np.integer
41
- | int
42
- | str
43
- | Sequence[bool | int | np.integer]
44
- | np.ndarray
45
- | pd.Index,
46
- index: pd.Index,
47
- ) -> slice | int | np.ndarray: # ndarray of int or bool
39
+ indexer: Index1D, index: pd.Index
40
+ ) -> Index1DNorm | int | np.integer:
48
41
  # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
49
42
  if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
50
43
  msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
@@ -212,7 +205,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index):
212
205
 
213
206
  # Registration for SparseDataset occurs in sparse_dataset.py
214
207
  @_subset.register(h5py.Dataset)
215
- def _subset_dataset(d, subset_idx):
208
+ def _subset_dataset(d: h5py.Dataset, subset_idx: Index):
216
209
  if not isinstance(subset_idx, tuple):
217
210
  subset_idx = (subset_idx,)
218
211
  ordered = list(subset_idx)
anndata/_core/merge.py CHANGED
@@ -14,9 +14,7 @@ from warnings import warn
14
14
 
15
15
  import numpy as np
16
16
  import pandas as pd
17
- import scipy
18
17
  from natsort import natsorted
19
- from packaging.version import Version
20
18
  from scipy import sparse
21
19
 
22
20
  from anndata._core.file_backing import to_memory
@@ -30,7 +28,6 @@ from ..compat import (
30
28
  CupyCSRMatrix,
31
29
  CupySparseMatrix,
32
30
  DaskArray,
33
- _map_cat_to_str,
34
31
  )
35
32
  from ..utils import asarray, axis_len, warn_once
36
33
  from .anndata import AnnData
@@ -146,11 +143,16 @@ def equal_dask_array(a, b) -> bool:
146
143
  return False
147
144
  if isinstance(b, DaskArray) and tokenize(a) == tokenize(b):
148
145
  return True
149
- if isinstance(a._meta, CSMatrix):
146
+ if isinstance(a._meta, np.ndarray):
147
+ return da.equal(a, b, where=~(da.isnan(a) & da.isnan(b))).all().compute()
148
+ if a.chunksize == b.chunksize and isinstance(
149
+ a._meta, CupySparseMatrix | CSMatrix | CSArray
150
+ ):
150
151
  # TODO: Maybe also do this in the other case?
151
152
  return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
152
- else:
153
- return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all()
153
+ msg = "Misaligned chunks detected when checking for merge equality of dask arrays. Reading full arrays into memory."
154
+ warn(msg, UserWarning, stacklevel=3)
155
+ return equal(a.compute(), b.compute())
154
156
 
155
157
 
156
158
  @equal.register(np.ndarray)
@@ -185,15 +187,6 @@ def equal_sparse(a, b) -> bool:
185
187
  # Comparison broken for CSC matrices
186
188
  # https://github.com/cupy/cupy/issues/7757
187
189
  a, b = CupyCSRMatrix(a), CupyCSRMatrix(b)
188
- if Version(scipy.__version__) >= Version("1.16.0rc1"):
189
- # TODO: https://github.com/scipy/scipy/issues/23068
190
- return bool(
191
- a.format == b.format
192
- and (a.shape == b.shape)
193
- and np.all(a.indptr == b.indptr)
194
- and np.all(a.indices == b.indices)
195
- and np.all((a.data == b.data) | (np.isnan(a.data) & np.isnan(b.data)))
196
- )
197
190
  comp = a != b
198
191
  if isinstance(comp, bool):
199
192
  return not comp
@@ -617,6 +610,9 @@ class Reindexer:
617
610
  sub_el = _subset(el, make_slice(indexer, axis, len(shape)))
618
611
 
619
612
  if any(indexer == -1):
613
+ # TODO: Remove this condition once https://github.com/dask/dask/pull/12078 is released
614
+ if isinstance(sub_el._meta, CSArray | CSMatrix) and np.isscalar(fill_value):
615
+ fill_value = np.array([[fill_value]])
620
616
  sub_el[make_slice(indexer == -1, axis, len(shape))] = fill_value
621
617
 
622
618
  return sub_el
@@ -1643,7 +1639,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
1643
1639
  )
1644
1640
  if index_unique is not None:
1645
1641
  concat_indices = concat_indices.str.cat(
1646
- _map_cat_to_str(label_col), sep=index_unique
1642
+ label_col.map(str, na_action="ignore"), sep=index_unique
1647
1643
  )
1648
1644
  concat_indices = pd.Index(concat_indices)
1649
1645
 
@@ -1748,15 +1744,10 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
1748
1744
  for r, a in zip(reindexers, adatas, strict=True)
1749
1745
  ],
1750
1746
  )
1751
- alt_pairwise = merge(
1752
- [
1753
- {
1754
- k: r(r(v, axis=0), axis=1)
1755
- for k, v in getattr(a, f"{alt_axis_name}p").items()
1756
- }
1757
- for r, a in zip(reindexers, adatas, strict=True)
1758
- ]
1759
- )
1747
+ alt_pairwise = merge([
1748
+ {k: r(r(v, axis=0), axis=1) for k, v in getattr(a, f"{alt_axis_name}p").items()}
1749
+ for r, a in zip(reindexers, adatas, strict=True)
1750
+ ])
1760
1751
  uns = uns_merge([a.uns for a in adatas])
1761
1752
 
1762
1753
  raw = None
@@ -1785,17 +1776,15 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
1785
1776
  "not concatenating `.raw` attributes."
1786
1777
  )
1787
1778
  warn(msg, UserWarning, stacklevel=2)
1788
- return AnnData(
1789
- **{
1790
- "X": X,
1791
- "layers": layers,
1792
- axis_name: concat_annot,
1793
- alt_axis_name: alt_annot,
1794
- f"{axis_name}m": concat_mapping,
1795
- f"{alt_axis_name}m": alt_mapping,
1796
- f"{axis_name}p": concat_pairwise,
1797
- f"{alt_axis_name}p": alt_pairwise,
1798
- "uns": uns,
1799
- "raw": raw,
1800
- }
1801
- )
1779
+ return AnnData(**{
1780
+ "X": X,
1781
+ "layers": layers,
1782
+ axis_name: concat_annot,
1783
+ alt_axis_name: alt_annot,
1784
+ f"{axis_name}m": concat_mapping,
1785
+ f"{alt_axis_name}m": alt_mapping,
1786
+ f"{axis_name}p": concat_pairwise,
1787
+ f"{alt_axis_name}p": alt_pairwise,
1788
+ "uns": uns,
1789
+ "raw": raw,
1790
+ })
anndata/_core/raw.py CHANGED
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  from collections.abc import Mapping, Sequence
18
18
  from typing import ClassVar
19
19
 
20
- from ..compat import CSMatrix
20
+ from ..compat import CSMatrix, Index, Index1DNorm
21
21
  from .aligned_mapping import AxisArraysView
22
22
  from .anndata import AnnData
23
23
  from .sparse_dataset import BaseCompressedSparseDataset
@@ -121,7 +121,7 @@ class Raw:
121
121
  def obs_names(self) -> pd.Index[str]:
122
122
  return self._adata.obs_names
123
123
 
124
- def __getitem__(self, index):
124
+ def __getitem__(self, index: Index) -> Raw:
125
125
  oidx, vidx = self._normalize_indices(index)
126
126
 
127
127
  # To preserve two dimensional shape
@@ -169,7 +169,9 @@ class Raw:
169
169
  uns=self._adata.uns.copy(),
170
170
  )
171
171
 
172
- def _normalize_indices(self, packed_index):
172
+ def _normalize_indices(
173
+ self, packed_index: Index
174
+ ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
173
175
  # deal with slicing with pd.Series
174
176
  if isinstance(packed_index, pd.Series):
175
177
  packed_index = packed_index.values
@@ -16,6 +16,7 @@ import warnings
16
16
  from abc import ABC
17
17
  from collections.abc import Iterable
18
18
  from functools import cached_property
19
+ from importlib.metadata import version
19
20
  from itertools import accumulate, chain, pairwise
20
21
  from math import floor
21
22
  from pathlib import Path
@@ -23,7 +24,6 @@ from typing import TYPE_CHECKING, NamedTuple
23
24
 
24
25
  import h5py
25
26
  import numpy as np
26
- import scipy
27
27
  import scipy.sparse as ss
28
28
  from packaging.version import Version
29
29
  from scipy.sparse import _sparsetools
@@ -54,7 +54,7 @@ else:
54
54
  from scipy.sparse import spmatrix as _cs_matrix
55
55
 
56
56
 
57
- SCIPY_1_15 = Version(scipy.__version__) >= Version("1.15rc0")
57
+ SCIPY_1_15 = Version(version("scipy")) >= Version("1.15rc0")
58
58
 
59
59
 
60
60
  class BackedFormat(NamedTuple):
@@ -278,9 +278,9 @@ def get_compressed_vectors(
278
278
  indptr_slices = [slice(*(x.indptr[i : i + 2])) for i in row_idxs]
279
279
  # HDF5 cannot handle out-of-order integer indexing
280
280
  if isinstance(x.data, ZarrArray):
281
- as_np_indptr = np.concatenate(
282
- [np.arange(s.start, s.stop) for s in indptr_slices]
283
- )
281
+ as_np_indptr = np.concatenate([
282
+ np.arange(s.start, s.stop) for s in indptr_slices
283
+ ])
284
284
  data = x.data[as_np_indptr]
285
285
  indices = x.indices[as_np_indptr]
286
286
  else:
@@ -309,9 +309,9 @@ def get_compressed_vectors_for_slices(
309
309
  start_indptr = indptr_indices[0] - next(offsets)
310
310
  if len(slices) < 2: # there is only one slice so no need to concatenate
311
311
  return data, indices, start_indptr
312
- end_indptr = np.concatenate(
313
- [s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)]
314
- )
312
+ end_indptr = np.concatenate([
313
+ s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)
314
+ ])
315
315
  indptr = np.concatenate([start_indptr, end_indptr])
316
316
  return data, indices, indptr
317
317
 
anndata/_core/views.py CHANGED
@@ -29,8 +29,12 @@ if TYPE_CHECKING:
29
29
  from collections.abc import Callable, Iterable, KeysView, Sequence
30
30
  from typing import Any, ClassVar
31
31
 
32
+ from numpy.typing import NDArray
33
+
32
34
  from anndata import AnnData
33
35
 
36
+ from ..compat import Index1DNorm
37
+
34
38
 
35
39
  @contextmanager
36
40
  def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):
@@ -96,7 +100,7 @@ class _ViewMixin(_SetItemMixin):
96
100
 
97
101
  # TODO: This makes `deepcopy(obj)` return `obj._view_args.parent._adata_ref`, fix it
98
102
  def __deepcopy__(self, memo):
99
- parent, attrname, keys = self._view_args
103
+ parent, attrname, _keys = self._view_args
100
104
  return deepcopy(getattr(parent._adata_ref, attrname))
101
105
 
102
106
 
@@ -433,18 +437,24 @@ except ImportError:
433
437
  pass
434
438
 
435
439
 
436
- def _resolve_idxs(old, new, adata):
437
- t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
438
- return t
440
+ def _resolve_idxs(
441
+ old: tuple[Index1DNorm, Index1DNorm],
442
+ new: tuple[Index1DNorm, Index1DNorm],
443
+ adata: AnnData,
444
+ ) -> tuple[Index1DNorm, Index1DNorm]:
445
+ o, v = (_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
446
+ return o, v
439
447
 
440
448
 
441
449
  @singledispatch
442
- def _resolve_idx(old, new, l):
443
- return old[new]
450
+ def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm:
451
+ raise NotImplementedError
444
452
 
445
453
 
446
454
  @_resolve_idx.register(np.ndarray)
447
- def _resolve_idx_ndarray(old, new, l):
455
+ def _resolve_idx_ndarray(
456
+ old: NDArray[np.bool_] | NDArray[np.integer], new: Index1DNorm, l: Literal[0, 1]
457
+ ) -> NDArray[np.bool_] | NDArray[np.integer]:
448
458
  if is_bool_dtype(old) and is_bool_dtype(new):
449
459
  mask_new = np.zeros_like(old)
450
460
  mask_new[np.flatnonzero(old)[new]] = True
@@ -454,21 +464,17 @@ def _resolve_idx_ndarray(old, new, l):
454
464
  return old[new]
455
465
 
456
466
 
457
- @_resolve_idx.register(np.integer)
458
- @_resolve_idx.register(int)
459
- def _resolve_idx_scalar(old, new, l):
460
- return np.array([old])[new]
461
-
462
-
463
467
  @_resolve_idx.register(slice)
464
- def _resolve_idx_slice(old, new, l):
468
+ def _resolve_idx_slice(
469
+ old: slice, new: Index1DNorm, l: Literal[0, 1]
470
+ ) -> slice | NDArray[np.integer]:
465
471
  if isinstance(new, slice):
466
472
  return _resolve_idx_slice_slice(old, new, l)
467
473
  else:
468
474
  return np.arange(*old.indices(l))[new]
469
475
 
470
476
 
471
- def _resolve_idx_slice_slice(old, new, l):
477
+ def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice:
472
478
  r = range(*old.indices(l))[new]
473
479
  # Convert back to slice
474
480
  start, stop, step = r.start, r.stop, r.step
anndata/_core/xarray.py CHANGED
@@ -184,18 +184,6 @@ class Dataset2D:
184
184
  Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`.
185
185
  """
186
186
 
187
- @dataclass(frozen=True)
188
- class IlocGetter:
189
- _ds: XDataset
190
- _coord: str
191
-
192
- def __getitem__(self, idx) -> Dataset2D:
193
- # xarray seems to have some code looking for a second entry in tuples,
194
- # so we unpack the tuple
195
- if isinstance(idx, tuple) and len(idx) == 1:
196
- idx = idx[0]
197
- return Dataset2D(self._ds.isel(**{self._coord: idx}))
198
-
199
187
  return IlocGetter(self.ds, self.index_dim)
200
188
 
201
189
  # See https://github.com/pydata/xarray/blob/568f3c1638d2d34373408ce2869028faa3949446/xarray/core/dataset.py#L1239-L1248
@@ -402,3 +390,16 @@ class Dataset2D:
402
390
  def _items(self):
403
391
  for col in self:
404
392
  yield col, self[col]
393
+
394
+
395
+ @dataclass(frozen=True)
396
+ class IlocGetter:
397
+ _ds: XDataset
398
+ _coord: str
399
+
400
+ def __getitem__(self, idx) -> Dataset2D:
401
+ # xarray seems to have some code looking for a second entry in tuples,
402
+ # so we unpack the tuple
403
+ if isinstance(idx, tuple) and len(idx) == 1:
404
+ idx = idx[0]
405
+ return Dataset2D(self._ds.isel(**{self._coord: idx}))