anndata 0.12.0rc1__py3-none-any.whl → 0.12.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. anndata/__init__.py +14 -19
  2. anndata/_core/access.py +1 -1
  3. anndata/_core/aligned_df.py +31 -4
  4. anndata/_core/aligned_mapping.py +9 -2
  5. anndata/_core/anndata.py +105 -103
  6. anndata/_core/file_backing.py +6 -0
  7. anndata/_core/index.py +14 -8
  8. anndata/_core/merge.py +229 -98
  9. anndata/_core/raw.py +1 -4
  10. anndata/_core/sparse_dataset.py +16 -17
  11. anndata/_core/storage.py +5 -10
  12. anndata/_core/views.py +13 -7
  13. anndata/_core/xarray.py +145 -0
  14. anndata/_io/__init__.py +3 -3
  15. anndata/_io/h5ad.py +6 -9
  16. anndata/_io/read.py +36 -24
  17. anndata/_io/specs/__init__.py +6 -6
  18. anndata/_io/specs/lazy_methods.py +15 -14
  19. anndata/_io/specs/methods.py +10 -16
  20. anndata/_io/specs/registry.py +9 -8
  21. anndata/_io/utils.py +10 -14
  22. anndata/_io/write.py +11 -14
  23. anndata/_io/zarr.py +15 -16
  24. anndata/_settings.py +6 -2
  25. anndata/_types.py +28 -24
  26. anndata/_version.py +32 -7
  27. anndata/_warnings.py +0 -6
  28. anndata/abc.py +1 -1
  29. anndata/compat/__init__.py +39 -70
  30. anndata/experimental/__init__.py +13 -8
  31. anndata/experimental/backed/_compat.py +1 -26
  32. anndata/experimental/backed/_io.py +13 -11
  33. anndata/experimental/backed/_lazy_arrays.py +11 -10
  34. anndata/experimental/merge.py +11 -11
  35. anndata/experimental/multi_files/_anncollection.py +24 -46
  36. anndata/experimental/pytorch/_annloader.py +1 -5
  37. anndata/io.py +3 -3
  38. anndata/tests/helpers.py +43 -72
  39. anndata/typing.py +3 -2
  40. anndata/utils.py +30 -21
  41. {anndata-0.12.0rc1.dist-info → anndata-0.12.0rc3.dist-info}/METADATA +37 -33
  42. anndata-0.12.0rc3.dist-info/RECORD +57 -0
  43. testing/anndata/_pytest.py +1 -1
  44. anndata/experimental/backed/_xarray.py +0 -146
  45. anndata-0.12.0rc1.dist-info/RECORD +0 -57
  46. {anndata-0.12.0rc1.dist-info → anndata-0.12.0rc3.dist-info}/WHEEL +0 -0
  47. {anndata-0.12.0rc1.dist-info → anndata-0.12.0rc3.dist-info}/licenses/LICENSE +0 -0
anndata/__init__.py CHANGED
@@ -23,10 +23,10 @@ from .io import read_h5ad, read_zarr
23
23
  from .utils import module_get_attr_redirect
24
24
 
25
25
  # Submodules need to be imported last
26
- from . import abc, experimental, typing, io, types # noqa: E402 isort: skip
26
+ from . import abc, experimental, typing, io, types # isort: skip
27
27
 
28
28
  # We use these in tests by attribute access
29
- from . import logging # noqa: F401, E402 isort: skip
29
+ from . import logging # noqa: F401 # isort: skip
30
30
 
31
31
  _DEPRECATED_IO = (
32
32
  "read_loom",
@@ -37,7 +37,7 @@ _DEPRECATED_IO = (
37
37
  "read_text",
38
38
  "read_mtx",
39
39
  )
40
- _DEPRECATED = dict((method, f"io.{method}") for method in _DEPRECATED_IO)
40
+ _DEPRECATED = {method: f"io.{method}" for method in _DEPRECATED_IO}
41
41
 
42
42
 
43
43
  def __getattr__(attr_name: str) -> Any:
@@ -45,26 +45,21 @@ def __getattr__(attr_name: str) -> Any:
45
45
 
46
46
 
47
47
  __all__ = [
48
- # Attributes
48
+ "AnnData",
49
+ "ExperimentalFeatureWarning",
50
+ "ImplicitModificationWarning",
51
+ "OldFormatWarning",
52
+ "Raw",
53
+ "WriteWarning",
49
54
  "__version__",
50
- "settings",
51
- # Submodules
52
55
  "abc",
56
+ "concat",
53
57
  "experimental",
54
- "typing",
55
- "types",
56
58
  "io",
57
- # Classes
58
- "AnnData",
59
- "Raw",
60
- # Functions
61
- "concat",
62
- "read_zarr",
63
59
  "read_h5ad",
60
+ "read_zarr",
64
61
  "register_anndata_namespace",
65
- # Warnings
66
- "OldFormatWarning",
67
- "WriteWarning",
68
- "ImplicitModificationWarning",
69
- "ExperimentalFeatureWarning",
62
+ "settings",
63
+ "types",
64
+ "typing",
70
65
  ]
anndata/_core/access.py CHANGED
@@ -13,7 +13,7 @@ class ElementRef(NamedTuple):
13
13
  keys: tuple[str, ...] = ()
14
14
 
15
15
  def __str__(self) -> str:
16
- return f".{self.attrname}" + "".join(map(lambda x: f"['{x}']", self.keys))
16
+ return f".{self.attrname}" + "".join(f"[{x!r}]" for x in self.keys)
17
17
 
18
18
  @property
19
19
  def _parent_el(self):
@@ -9,6 +9,8 @@ import pandas as pd
9
9
  from pandas.api.types import is_string_dtype
10
10
 
11
11
  from .._warnings import ImplicitModificationWarning
12
+ from ..compat import XDataset
13
+ from .xarray import Dataset2D
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from collections.abc import Iterable
@@ -50,7 +52,7 @@ def _gen_dataframe_mapping(
50
52
  df = pd.DataFrame(
51
53
  anno,
52
54
  index=anno[index_name],
53
- columns=[k for k in anno.keys() if k != index_name],
55
+ columns=[k for k in anno if k != index_name],
54
56
  )
55
57
  break
56
58
  else:
@@ -80,7 +82,8 @@ def _gen_dataframe_df(
80
82
  raise _mk_df_error(source, attr, length, len(anno))
81
83
  anno = anno.copy(deep=False)
82
84
  if not is_string_dtype(anno.index):
83
- warnings.warn("Transforming to str index.", ImplicitModificationWarning)
85
+ msg = "Transforming to str index."
86
+ warnings.warn(msg, ImplicitModificationWarning, stacklevel=2)
84
87
  anno.index = anno.index.astype(str)
85
88
  if not len(anno.columns):
86
89
  anno.columns = anno.columns.astype(str)
@@ -107,8 +110,8 @@ def _mk_df_error(
107
110
  expected: int,
108
111
  actual: int,
109
112
  ):
113
+ what = "row" if attr == "obs" else "column"
110
114
  if source == "X":
111
- what = "row" if attr == "obs" else "column"
112
115
  msg = (
113
116
  f"Observations annot. `{attr}` must have as many rows as `X` has {what}s "
114
117
  f"({expected}), but has {actual} rows."
@@ -116,6 +119,30 @@ def _mk_df_error(
116
119
  else:
117
120
  msg = (
118
121
  f"`shape` is inconsistent with `{attr}` "
119
- "({actual} {what}s instead of {expected})"
122
+ f"({actual} {what}s instead of {expected})"
120
123
  )
121
124
  return ValueError(msg)
125
+
126
+
127
+ @_gen_dataframe.register(Dataset2D)
128
+ def _gen_dataframe_xr(
129
+ anno: Dataset2D,
130
+ index_names: Iterable[str],
131
+ *,
132
+ source: Literal["X", "shape"],
133
+ attr: Literal["obs", "var"],
134
+ length: int | None = None,
135
+ ):
136
+ return anno
137
+
138
+
139
+ @_gen_dataframe.register(XDataset)
140
+ def _gen_dataframe_xdataset(
141
+ anno: XDataset,
142
+ index_names: Iterable[str],
143
+ *,
144
+ source: Literal["X", "shape"],
145
+ attr: Literal["obs", "var"],
146
+ length: int | None = None,
147
+ ):
148
+ return Dataset2D(anno)
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import pandas as pd
12
12
 
13
13
  from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
14
- from ..compat import AwkArray, CSArray, CSMatrix
14
+ from ..compat import AwkArray, CSArray, CSMatrix, CupyArray, XDataset
15
15
  from ..utils import (
16
16
  axis_len,
17
17
  convert_to_dict,
@@ -23,6 +23,7 @@ from .access import ElementRef
23
23
  from .index import _subset
24
24
  from .storage import coerce_array
25
25
  from .views import as_view, view_update
26
+ from .xarray import Dataset2D
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from collections.abc import Callable, Iterable, Iterator, Mapping
@@ -75,6 +76,10 @@ class AlignedMappingBase(MutableMapping[str, Value], ABC):
75
76
  ExperimentalFeatureWarning,
76
77
  # stacklevel=3,
77
78
  )
79
+ elif isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1:
80
+ val = val.reshape((val.shape[0], 1))
81
+ elif isinstance(val, XDataset):
82
+ val = Dataset2D(data_vars=val.data_vars, coords=val.coords, attrs=val.attrs)
78
83
  for i, axis in enumerate(self.axes):
79
84
  if self.parent.shape[axis] == axis_len(val, i):
80
85
  continue
@@ -94,7 +99,6 @@ class AlignedMappingBase(MutableMapping[str, Value], ABC):
94
99
  f"Value had shape {actual_shape} while it should have had {right_shape}."
95
100
  )
96
101
  raise ValueError(msg)
97
-
98
102
  name = f"{self.attrname.title().rstrip('s')} {key!r}"
99
103
  return coerce_array(val, name=name, allow_df=self._allow_df)
100
104
 
@@ -274,6 +278,9 @@ class AxisArraysBase(AlignedMappingBase):
274
278
  else:
275
279
  msg = "Index.equals and pd.testing.assert_index_equal disagree"
276
280
  raise AssertionError(msg)
281
+ val.index.name = (
282
+ self.dim_names.name
283
+ ) # this is consistent with AnnData.obsm.setter and AnnData.varm.setter
277
284
  return super()._validate_value(val, key)
278
285
 
279
286
  @property
anndata/_core/anndata.py CHANGED
@@ -8,7 +8,7 @@ import warnings
8
8
  from collections import OrderedDict
9
9
  from collections.abc import Mapping, MutableMapping, Sequence
10
10
  from copy import copy, deepcopy
11
- from functools import partial, singledispatch
11
+ from functools import partial, singledispatchmethod
12
12
  from pathlib import Path
13
13
  from textwrap import dedent
14
14
  from typing import TYPE_CHECKING, cast
@@ -47,6 +47,7 @@ from .views import (
47
47
  _resolve_idxs,
48
48
  as_view,
49
49
  )
50
+ from .xarray import Dataset2D
50
51
 
51
52
  if TYPE_CHECKING:
52
53
  from collections.abc import Iterable
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
55
56
 
56
57
  from zarr.storage import StoreLike
57
58
 
58
- from ..compat import Index1D
59
+ from ..compat import Index1D, XDataset
59
60
  from ..typing import XDataType
60
61
  from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
61
62
  from .index import Index
@@ -176,10 +177,10 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
176
177
  .. _scikit-learn: http://scikit-learn.org/
177
178
  """
178
179
 
179
- _BACKED_ATTRS = ["X", "raw.X"]
180
+ _BACKED_ATTRS: ClassVar[list[str]] = ["X", "raw.X"]
180
181
 
181
182
  # backwards compat
182
- _H5_ALIASES = dict(
183
+ _H5_ALIASES: ClassVar[dict[str, set[str]]] = dict(
183
184
  X={"X", "_X", "data", "_data"},
184
185
  obs={"obs", "_obs", "smp", "_smp"},
185
186
  var={"var", "_var"},
@@ -189,7 +190,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
189
190
  layers={"layers", "_layers"},
190
191
  )
191
192
 
192
- _H5_ALIASES_NAMES = dict(
193
+ _H5_ALIASES_NAMES: ClassVar[dict[str, set[str]]] = dict(
193
194
  obs={"obs_names", "smp_names", "row_names", "index"},
194
195
  var={"var_names", "col_names", "index"},
195
196
  )
@@ -207,7 +208,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
207
208
  "filemode",
208
209
  "asview",
209
210
  )
210
- def __init__(
211
+ def __init__( # noqa: PLR0913
211
212
  self,
212
213
  X: XDataType | pd.DataFrame | None = None,
213
214
  obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
@@ -310,9 +311,10 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
310
311
  else:
311
312
  self._raw = None
312
313
 
313
- def _init_as_actual(
314
+ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915
314
315
  self,
315
316
  X=None,
317
+ *,
316
318
  obs=None,
317
319
  var=None,
318
320
  uns=None,
@@ -390,10 +392,10 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
390
392
  _check_2d_shape(X)
391
393
  # if type doesn’t match, a copy is made, otherwise, use a view
392
394
  if dtype is not None:
393
- warnings.warn(
394
- "The dtype argument is deprecated and will be removed in late 2024.",
395
- FutureWarning,
395
+ msg = (
396
+ "The dtype argument is deprecated and will be removed in late 2024."
396
397
  )
398
+ warnings.warn(msg, FutureWarning, stacklevel=3)
397
399
  if issparse(X) or isinstance(X, ma.MaskedArray):
398
400
  # TODO: maybe use view on data attribute of sparse matrix
399
401
  # as in readwrite.read_10x_h5
@@ -412,7 +414,9 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
412
414
  n_obs, n_vars = (
413
415
  shape
414
416
  if shape is not None
415
- else _infer_shape(obs, var, obsm, varm, layers, obsp, varp)
417
+ else _infer_shape(
418
+ obs, var, obsm=obsm, varm=varm, layers=layers, obsp=obsp, varp=varp
419
+ )
416
420
  )
417
421
  source = "shape"
418
422
 
@@ -503,10 +507,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
503
507
  return sum(sizes.values())
504
508
 
505
509
  def _gen_repr(self, n_obs, n_vars) -> str:
506
- if self.isbacked:
507
- backed_at = f" backed at {str(self.filename)!r}"
508
- else:
509
- backed_at = ""
510
+ backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else ""
510
511
  descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}"
511
512
  for attr in [
512
513
  "obs",
@@ -574,7 +575,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
574
575
  # return X
575
576
 
576
577
  @X.setter
577
- def X(self, value: XDataType | None):
578
+ def X(self, value: XDataType | None): # noqa: PLR0912
578
579
  if value is None:
579
580
  if self.isbacked:
580
581
  msg = "Cannot currently remove data matrix from backed object."
@@ -627,34 +628,33 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
627
628
  X[oidx, vidx] = value
628
629
  else:
629
630
  self._set_backed("X", value)
630
- else:
631
- if self.is_view:
632
- if sparse.issparse(self._adata_ref._X) and isinstance(
633
- value, np.ndarray
634
- ):
635
- if isinstance(self._adata_ref.X, CSArray):
636
- memory_class = sparse.coo_array
637
- else:
638
- memory_class = sparse.coo_matrix
639
- value = memory_class(value)
640
- elif sparse.issparse(value) and isinstance(
641
- self._adata_ref._X, np.ndarray
642
- ):
643
- warnings.warn(
644
- "Trying to set a dense array with a sparse array on a view."
645
- "Densifying the sparse array."
646
- "This may incur excessive memory usage",
647
- stacklevel=2,
648
- )
649
- value = value.toarray()
631
+ elif self.is_view:
632
+ if sparse.issparse(self._adata_ref._X) and isinstance(
633
+ value, np.ndarray
634
+ ):
635
+ if isinstance(self._adata_ref.X, CSArray):
636
+ memory_class = sparse.coo_array
637
+ else:
638
+ memory_class = sparse.coo_matrix
639
+ value = memory_class(value)
640
+ elif sparse.issparse(value) and isinstance(
641
+ self._adata_ref._X, np.ndarray
642
+ ):
650
643
  warnings.warn(
651
- "Modifying `X` on a view results in data being overridden",
652
- ImplicitModificationWarning,
644
+ "Trying to set a dense array with a sparse array on a view."
645
+ "Densifying the sparse array."
646
+ "This may incur excessive memory usage",
653
647
  stacklevel=2,
654
648
  )
655
- self._adata_ref._X[oidx, vidx] = value
656
- else:
657
- self._X = value
649
+ value = value.toarray()
650
+ warnings.warn(
651
+ "Modifying `X` on a view results in data being overridden",
652
+ ImplicitModificationWarning,
653
+ stacklevel=2,
654
+ )
655
+ self._adata_ref._X[oidx, vidx] = value
656
+ else:
657
+ self._X = value
658
658
  else:
659
659
  msg = f"Data matrix has wrong shape {value.shape}, need to be {self.shape}."
660
660
  raise ValueError(msg)
@@ -747,10 +747,14 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
747
747
  """Number of variables/features."""
748
748
  return len(self.var_names)
749
749
 
750
- def _set_dim_df(self, value: pd.DataFrame, attr: Literal["obs", "var"]):
751
- if not isinstance(value, pd.DataFrame):
752
- msg = f"Can only assign pd.DataFrame to {attr}."
753
- raise ValueError(msg)
750
+ def _set_dim_df(self, value: pd.DataFrame | XDataset, attr: Literal["obs", "var"]):
751
+ value = _gen_dataframe(
752
+ value,
753
+ [f"{attr}_names", f"{'row' if attr == 'obs' else 'col'}_names"],
754
+ source="shape",
755
+ attr=attr,
756
+ length=self.n_obs if attr == "obs" else self.n_vars,
757
+ )
754
758
  raise_value_error_if_multiindex_columns(value, attr)
755
759
  value_idx = self._prep_dim_index(value.index, attr)
756
760
  if self.is_view:
@@ -805,12 +809,12 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
805
809
  v.index = value
806
810
 
807
811
  @property
808
- def obs(self) -> pd.DataFrame:
812
+ def obs(self) -> pd.DataFrame | Dataset2D:
809
813
  """One-dimensional annotation of observations (`pd.DataFrame`)."""
810
814
  return self._obs
811
815
 
812
816
  @obs.setter
813
- def obs(self, value: pd.DataFrame):
817
+ def obs(self, value: pd.DataFrame | XDataset):
814
818
  self._set_dim_df(value, "obs")
815
819
 
816
820
  @obs.deleter
@@ -828,12 +832,12 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
828
832
  self._set_dim_index(names, "obs")
829
833
 
830
834
  @property
831
- def var(self) -> pd.DataFrame:
835
+ def var(self) -> pd.DataFrame | Dataset2D:
832
836
  """One-dimensional annotation of variables/ features (`pd.DataFrame`)."""
833
837
  return self._var
834
838
 
835
839
  @var.setter
836
- def var(self, value: pd.DataFrame):
840
+ def var(self, value: pd.DataFrame | XDataset):
837
841
  self._set_dim_df(value, "var")
838
842
 
839
843
  @var.deleter
@@ -939,7 +943,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
939
943
 
940
944
  def uns_keys(self) -> list[str]:
941
945
  """List keys of unstructured annotation."""
942
- return sorted(list(self._uns.keys()))
946
+ return sorted(self._uns.keys())
943
947
 
944
948
  @property
945
949
  def isbacked(self) -> bool:
@@ -988,10 +992,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
988
992
  else:
989
993
  # change from memory to backing-mode
990
994
  # write the content of self to disk
991
- if self.raw is not None:
992
- as_dense = ("X", "raw/X")
993
- else:
994
- as_dense = ("X",)
995
+ as_dense = ("X", "raw/X") if self.raw is not None else ("X",)
995
996
  self.write(filename, as_dense=as_dense)
996
997
  # open new file for accessing
997
998
  self.file.open(filename, "r+")
@@ -1026,8 +1027,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1026
1027
  oidx, vidx = self._normalize_indices(index)
1027
1028
  return AnnData(self, oidx=oidx, vidx=vidx, asview=True)
1028
1029
 
1030
+ @singledispatchmethod
1029
1031
  @staticmethod
1030
- @singledispatch
1031
1032
  def _remove_unused_categories(
1032
1033
  df_full: pd.DataFrame, df_sub: pd.DataFrame, uns: dict[str, Any]
1033
1034
  ):
@@ -1129,6 +1130,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1129
1130
  dont_modify = True
1130
1131
  else:
1131
1132
  dfs = [df]
1133
+ del df
1134
+
1132
1135
  for df in dfs:
1133
1136
  string_cols = [
1134
1137
  key for key in df.columns if infer_dtype(df[key]) == "string"
@@ -1202,10 +1205,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1202
1205
  """
1203
1206
  from anndata.compat import _safe_transpose
1204
1207
 
1205
- if not self.isbacked:
1206
- X = self.X
1207
- else:
1208
- X = self.file["X"]
1208
+ X = self.X if not self.isbacked else self.file["X"]
1209
1209
  if self.is_view:
1210
1210
  msg = (
1211
1211
  "You’re trying to transpose a view of an `AnnData`, "
@@ -1305,11 +1305,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1305
1305
  if "X" in self.layers:
1306
1306
  pass
1307
1307
  else:
1308
- warnings.warn(
1308
+ msg = (
1309
1309
  "In a future version of AnnData, access to `.X` by passing"
1310
- " `layer='X'` will be removed. Instead pass `layer=None`.",
1311
- FutureWarning,
1310
+ " `layer='X'` will be removed. Instead pass `layer=None`."
1312
1311
  )
1312
+ warnings.warn(msg, FutureWarning, stacklevel=2)
1313
1313
  layer = None
1314
1314
  return get_vector(self, k, "obs", "var", layer=layer)
1315
1315
 
@@ -1337,11 +1337,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1337
1337
  if "X" in self.layers:
1338
1338
  pass
1339
1339
  else:
1340
- warnings.warn(
1340
+ msg = (
1341
1341
  "In a future version of AnnData, access to `.X` by passing "
1342
- "`layer='X'` will be removed. Instead pass `layer=None`.",
1343
- FutureWarning,
1342
+ "`layer='X'` will be removed. Instead pass `layer=None`."
1344
1343
  )
1344
+ warnings.warn(msg, FutureWarning, stacklevel=2)
1345
1345
  layer = None
1346
1346
  return get_vector(self, k, "var", "obs", layer=layer)
1347
1347
 
@@ -1369,13 +1369,14 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1369
1369
 
1370
1370
  def _mutated_copy(self, **kwargs):
1371
1371
  """Creating AnnData with attributes optionally specified via kwargs."""
1372
- if self.isbacked:
1373
- if "X" not in kwargs or (self.raw is not None and "raw" not in kwargs):
1374
- msg = (
1375
- "This function does not currently handle backed objects "
1376
- "internally, this should be dealt with before."
1377
- )
1378
- raise NotImplementedError(msg)
1372
+ if self.isbacked and (
1373
+ "X" not in kwargs or (self.raw is not None and "raw" not in kwargs)
1374
+ ):
1375
+ msg = (
1376
+ "This function does not currently handle backed objects "
1377
+ "internally, this should be dealt with before."
1378
+ )
1379
+ raise NotImplementedError(msg)
1379
1380
  new = {}
1380
1381
 
1381
1382
  for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers"]:
@@ -1481,7 +1482,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1481
1482
  *adatas: AnnData,
1482
1483
  join: str = "inner",
1483
1484
  batch_key: str = "batch",
1484
- batch_categories: Sequence[Any] = None,
1485
+ batch_categories: Sequence[Any] | None = None,
1485
1486
  uns_merge: str | None = None,
1486
1487
  index_unique: str | None = "-",
1487
1488
  fill_value=None,
@@ -1707,7 +1708,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1707
1708
  return self.copy()
1708
1709
  elif len(adatas) == 1 and not isinstance(adatas[0], AnnData):
1709
1710
  adatas = adatas[0] # backwards compatibility
1710
- all_adatas = (self,) + tuple(adatas)
1711
+ all_adatas = (self, *adatas)
1711
1712
 
1712
1713
  out = concat(
1713
1714
  all_adatas,
@@ -1779,30 +1780,25 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
1779
1780
  raise AttributeError(msg)
1780
1781
 
1781
1782
  def _check_dimensions(self, key=None):
1782
- if key is None:
1783
- key = {"obsm", "varm"}
1784
- else:
1785
- key = {key}
1786
- if "obsm" in key:
1787
- if (
1788
- not all([axis_len(o, 0) == self.n_obs for o in self.obsm.values()])
1789
- and len(self.obsm.dim_names) != self.n_obs
1790
- ):
1791
- msg = (
1792
- "Observations annot. `obsm` must have number of rows of `X`"
1793
- f" ({self.n_obs}), but has {len(self.obsm)} rows."
1794
- )
1795
- raise ValueError(msg)
1796
- if "varm" in key:
1797
- if (
1798
- not all([axis_len(v, 0) == self.n_vars for v in self.varm.values()])
1799
- and len(self.varm.dim_names) != self.n_vars
1800
- ):
1801
- msg = (
1802
- "Variables annot. `varm` must have number of columns of `X`"
1803
- f" ({self.n_vars}), but has {len(self.varm)} rows."
1804
- )
1805
- raise ValueError(msg)
1783
+ key = {"obsm", "varm"} if key is None else {key}
1784
+ if "obsm" in key and (
1785
+ not all(axis_len(o, 0) == self.n_obs for o in self.obsm.values())
1786
+ and len(self.obsm.dim_names) != self.n_obs
1787
+ ):
1788
+ msg = (
1789
+ "Observations annot. `obsm` must have number of rows of `X`"
1790
+ f" ({self.n_obs}), but has {len(self.obsm)} rows."
1791
+ )
1792
+ raise ValueError(msg)
1793
+ if "varm" in key and (
1794
+ not all(axis_len(v, 0) == self.n_vars for v in self.varm.values())
1795
+ and len(self.varm.dim_names) != self.n_vars
1796
+ ):
1797
+ msg = (
1798
+ "Variables annot. `varm` must have number of columns of `X`"
1799
+ f" ({self.n_vars}), but has {len(self.varm)} rows."
1800
+ )
1801
+ raise ValueError(msg)
1806
1802
 
1807
1803
  @old_positionals("compression", "compression_opts", "as_dense")
1808
1804
  def write_h5ad(
@@ -2082,15 +2078,20 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
2082
2078
  m_attr[key] = self._get_and_delete_multicol_field(axis, key)
2083
2079
 
2084
2080
  def _get_and_delete_multicol_field(self, a, key_multicol):
2085
- keys = []
2086
- for k in getattr(self, a).columns:
2087
- if k.startswith(key_multicol):
2088
- keys.append(k)
2081
+ keys = [k for k in getattr(self, a).columns if k.startswith(key_multicol)]
2089
2082
  values = getattr(self, a)[keys].values
2090
2083
  getattr(self, a).drop(keys, axis=1, inplace=True)
2091
2084
  return values
2092
2085
 
2093
2086
 
2087
+ @AnnData._remove_unused_categories.register(Dataset2D)
2088
+ @staticmethod
2089
+ def _remove_unused_categories_xr(
2090
+ df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any]
2091
+ ):
2092
+ pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access.
2093
+
2094
+
2094
2095
  def _check_2d_shape(X):
2095
2096
  """\
2096
2097
  Check shape of array or sparse matrix.
@@ -2112,7 +2113,7 @@ def _infer_shape_for_axis(
2112
2113
  for elem in [xxx, xxxm, xxxp]:
2113
2114
  if elem is not None and hasattr(elem, "shape"):
2114
2115
  return elem.shape[0]
2115
- for elem, id in zip([layers, xxxm, xxxp], ["layers", "xxxm", "xxxp"]):
2116
+ for elem, id in zip([layers, xxxm, xxxp], ["layers", "xxxm", "xxxp"], strict=True):
2116
2117
  if elem is not None:
2117
2118
  elem = cast("Mapping", elem)
2118
2119
  for sub_elem in elem.values():
@@ -2125,6 +2126,7 @@ def _infer_shape_for_axis(
2125
2126
  def _infer_shape(
2126
2127
  obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
2127
2128
  var: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
2129
+ *,
2128
2130
  obsm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
2129
2131
  varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
2130
2132
  layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
@@ -10,6 +10,7 @@ import h5py
10
10
 
11
11
  from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup
12
12
  from .sparse_dataset import BaseCompressedSparseDataset
13
+ from .xarray import Dataset2D
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from collections.abc import Iterator
@@ -162,6 +163,11 @@ def _(x: AwkArray, *, copy: bool = False):
162
163
  return x
163
164
 
164
165
 
166
+ @to_memory.register(Dataset2D)
167
+ def _(x: Dataset2D, *, copy: bool = False):
168
+ return x.to_memory(copy=copy)
169
+
170
+
165
171
  @singledispatch
166
172
  def filename(x):
167
173
  msg = f"Not implemented for {type(x)}"