lamindb 0.69.8__py3-none-any.whl → 0.69.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from itertools import compress
4
- from typing import TYPE_CHECKING, Iterable
4
+ from typing import TYPE_CHECKING, Iterable, Optional
5
5
 
6
6
  import anndata as ad
7
7
  from anndata import AnnData
@@ -23,7 +23,7 @@ from lamindb.core.storage import LocalPathClasses
23
23
  from ._settings import settings
24
24
 
25
25
  if TYPE_CHECKING:
26
- from lnschema_core.types import AnnDataLike, FieldAttr
26
+ from lnschema_core.types import FieldAttr
27
27
 
28
28
  from lamindb._query_set import QuerySet
29
29
 
@@ -132,10 +132,11 @@ def print_features(self: Data) -> str:
132
132
 
133
133
 
134
134
  def parse_feature_sets_from_anndata(
135
- adata: AnnDataLike,
136
- var_field: FieldAttr,
135
+ adata: AnnData,
136
+ var_field: FieldAttr | None = None,
137
137
  obs_field: FieldAttr = Feature.name,
138
- **kwargs,
138
+ mute: bool = False,
139
+ organism: str | Registry | None = None,
139
140
  ) -> dict:
140
141
  data_parse = adata
141
142
  if not isinstance(adata, AnnData): # is a path
@@ -149,29 +150,36 @@ def parse_feature_sets_from_anndata(
149
150
  data_parse = ad.read(filepath, backed="r")
150
151
  type = "float"
151
152
  else:
152
- type = convert_numpy_dtype_to_lamin_feature_type(adata.X.dtype)
153
+ type = (
154
+ "float"
155
+ if adata.X is None
156
+ else convert_numpy_dtype_to_lamin_feature_type(adata.X.dtype)
157
+ )
153
158
  feature_sets = {}
154
- logger.info("parsing feature names of X stored in slot 'var'")
155
- logger.indent = " "
156
- feature_set_var = FeatureSet.from_values(
157
- data_parse.var.index,
158
- var_field,
159
- type=type,
160
- **kwargs,
161
- )
162
- if feature_set_var is not None:
163
- feature_sets["var"] = feature_set_var
164
- logger.save(f"linked: {feature_set_var}")
165
- logger.indent = ""
166
- if feature_set_var is None:
167
- logger.warning("skip linking features to artifact in slot 'var'")
159
+ if var_field is not None:
160
+ logger.info("parsing feature names of X stored in slot 'var'")
161
+ logger.indent = " "
162
+ feature_set_var = FeatureSet.from_values(
163
+ data_parse.var.index,
164
+ var_field,
165
+ type=type,
166
+ mute=mute,
167
+ organism=organism,
168
+ )
169
+ if feature_set_var is not None:
170
+ feature_sets["var"] = feature_set_var
171
+ logger.save(f"linked: {feature_set_var}")
172
+ logger.indent = ""
173
+ if feature_set_var is None:
174
+ logger.warning("skip linking features to artifact in slot 'var'")
168
175
  if len(data_parse.obs.columns) > 0:
169
176
  logger.info("parsing feature names of slot 'obs'")
170
177
  logger.indent = " "
171
178
  feature_set_obs = FeatureSet.from_df(
172
179
  df=data_parse.obs,
173
180
  field=obs_field,
174
- **kwargs,
181
+ mute=mute,
182
+ organism=organism,
175
183
  )
176
184
  if feature_set_obs is not None:
177
185
  feature_sets["obs"] = feature_set_obs
@@ -224,7 +232,7 @@ class FeatureManager:
224
232
  slot = "columns" if slot is None else slot
225
233
  self._add_feature_set(feature_set=FeatureSet(features=features), slot=slot)
226
234
 
227
- def add_from_df(self, field: FieldAttr = Feature.name, **kwargs):
235
+ def add_from_df(self, field: FieldAttr = Feature.name, organism: str | None = None):
228
236
  """Add features from DataFrame."""
229
237
  if isinstance(self._host, Artifact):
230
238
  assert self._host.accessor == "DataFrame"
@@ -235,7 +243,7 @@ class FeatureManager:
235
243
  # parse and register features
236
244
  registry = field.field.model
237
245
  df = self._host.load()
238
- features = registry.from_values(df.columns, field=field, **kwargs)
246
+ features = registry.from_values(df.columns, field=field, organism=organism)
239
247
  if len(features) == 0:
240
248
  logger.error(
241
249
  "no validated features found in DataFrame! please register features first!"
@@ -252,7 +260,8 @@ class FeatureManager:
252
260
  self,
253
261
  var_field: FieldAttr,
254
262
  obs_field: FieldAttr | None = Feature.name,
255
- **kwargs,
263
+ mute: bool = False,
264
+ organism: str | Registry | None = None,
256
265
  ):
257
266
  """Add features from AnnData."""
258
267
  if isinstance(self._host, Artifact):
@@ -263,13 +272,53 @@ class FeatureManager:
263
272
  # parse and register features
264
273
  adata = self._host.load()
265
274
  feature_sets = parse_feature_sets_from_anndata(
266
- adata, var_field=var_field, obs_field=obs_field, **kwargs
275
+ adata,
276
+ var_field=var_field,
277
+ obs_field=obs_field,
278
+ mute=mute,
279
+ organism=organism,
267
280
  )
268
281
 
269
282
  # link feature sets
270
283
  self._host._feature_sets = feature_sets
271
284
  self._host.save()
272
285
 
286
+ def add_from_mudata(
287
+ self,
288
+ var_fields: dict[str, FieldAttr],
289
+ obs_fields: dict[str, FieldAttr] = None,
290
+ mute: bool = False,
291
+ organism: str | Registry | None = None,
292
+ ):
293
+ """Add features from MuData."""
294
+ if obs_fields is None:
295
+ obs_fields = {}
296
+ if isinstance(self._host, Artifact):
297
+ assert self._host.accessor == "MuData"
298
+ else:
299
+ raise NotImplementedError()
300
+
301
+ # parse and register features
302
+ mdata = self._host.load()
303
+ feature_sets = {}
304
+ obs_features = features = Feature.from_values(mdata.obs.columns)
305
+ if len(obs_features) > 0:
306
+ feature_sets["obs"] = FeatureSet(features=features)
307
+ for modality, field in var_fields.items():
308
+ modality_fs = parse_feature_sets_from_anndata(
309
+ mdata[modality],
310
+ var_field=field,
311
+ obs_field=obs_fields.get(modality, Feature.name),
312
+ mute=mute,
313
+ organism=organism,
314
+ )
315
+ for k, v in modality_fs.items():
316
+ feature_sets[f"['{modality}'].{k}"] = v
317
+
318
+ # link feature sets
319
+ self._host._feature_sets = feature_sets
320
+ self._host.save()
321
+
273
322
  def _add_feature_set(self, feature_set: FeatureSet, slot: str):
274
323
  """Add new feature set to a slot.
275
324
 
@@ -49,7 +49,7 @@ def print_labels(self: Data):
49
49
  n = labels.count()
50
50
  field = get_default_str_field(labels)
51
51
  print_values = _print_values(labels.list(field), n=10)
52
- labels_msg += f" 🏷️ {related_name} ({n}, {colors.italic(related_model)}): {print_values}\n"
52
+ labels_msg += f" 📎 {related_name} ({n}, {colors.italic(related_model)}): {print_values}\n"
53
53
  if len(labels_msg) > 0:
54
54
  return f"{colors.green('Labels')}:\n{labels_msg}"
55
55
  else:
@@ -11,7 +11,9 @@ from lamin_utils import logger
11
11
  from lamindb_setup.core.upath import UPath
12
12
 
13
13
  from .storage._backed_access import (
14
+ ArrayType,
14
15
  ArrayTypes,
16
+ GroupType,
15
17
  GroupTypes,
16
18
  StorageType,
17
19
  _safer_read_index,
@@ -55,6 +57,12 @@ class MappedCollection:
55
57
  If your `AnnData` collection is in the cloud, move them into a local cache
56
58
  first for faster access.
57
59
 
60
+ `__getitem__` of the `MappedCollection` object takes a single integer index
61
+ and returns a dictionary with the observation data sample for this index from
62
+ the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
63
+ (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
64
+ for the index of the `AnnData` object containing this observation sample.
65
+
58
66
  .. note::
59
67
 
60
68
  For a guide, see :doc:`docs:scrna5`.
@@ -70,23 +78,28 @@ class MappedCollection:
70
78
 
71
79
  Args:
72
80
  path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
73
- label_keys: Columns of the ``.obs`` slot that store labels.
81
+ layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
82
+ retrieves ``.X``.
83
+ obsm_keys: Keys from the ``.obsm`` slots.
84
+ obs_keys: Keys from the ``.obs`` slots.
74
85
  join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
75
86
  does not join.
76
87
  encode_labels: Encode labels into integers.
77
- Can be a list with elements from ``label_keys```.
88
+ Can be a list with elements from ``obs_keys``.
78
89
  unknown_label: Encode this label to -1.
79
- Can be a dictionary with keys from ``label_keys`` if ``encode_labels=True```
90
+ Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
80
91
  or from ``encode_labels`` if it is a list.
81
- cache_categories: Enable caching categories of ``label_keys`` for faster access.
92
+ cache_categories: Enable caching categories of ``obs_keys`` for faster access.
82
93
  parallel: Enable sampling with multiple processes.
83
- dtype: Convert numpy arrays from ``.X`` to this dtype on selection.
94
+ dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
84
95
  """
85
96
 
86
97
  def __init__(
87
98
  self,
88
99
  path_list: list[UPathStr],
89
- label_keys: str | list[str] | None = None,
100
+ layers_keys: str | list[str] | None = None,
101
+ obs_keys: str | list[str] | None = None,
102
+ obsm_keys: str | list[str] | None = None,
90
103
  join: Literal["inner", "outer"] | None = "inner",
91
104
  encode_labels: bool | list[str] = True,
92
105
  unknown_label: str | dict[str, str] | None = None,
@@ -96,27 +109,37 @@ class MappedCollection:
96
109
  ):
97
110
  assert join in {None, "inner", "outer"}
98
111
 
99
- label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
100
- self.label_keys = label_keys
112
+ if layers_keys is None:
113
+ self.layers_keys = ["X"]
114
+ else:
115
+ self.layers_keys = (
116
+ [layers_keys] if isinstance(layers_keys, str) else layers_keys
117
+ )
118
+
119
+ obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
120
+ self.obsm_keys = obsm_keys
121
+
122
+ obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
123
+ self.obs_keys = obs_keys
101
124
 
102
125
  if isinstance(encode_labels, list):
103
126
  if len(encode_labels) == 0:
104
127
  encode_labels = False
105
- elif label_keys is None or not all(
106
- enc_label in label_keys for enc_label in encode_labels
128
+ elif obs_keys is None or not all(
129
+ enc_label in obs_keys for enc_label in encode_labels
107
130
  ):
108
131
  raise ValueError(
109
- "All elements of `encode_labels` should be in `label_keys`."
132
+ "All elements of `encode_labels` should be in `obs_keys`."
110
133
  )
111
134
  else:
112
135
  if encode_labels:
113
- encode_labels = label_keys if label_keys is not None else False
136
+ encode_labels = obs_keys if obs_keys is not None else False
114
137
  self.encode_labels = encode_labels
115
138
 
116
139
  if encode_labels and isinstance(unknown_label, dict):
117
140
  if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
118
141
  raise ValueError(
119
- "All keys of `unknown_label` should be in `encode_labels` and `label_keys`."
142
+ "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
120
143
  )
121
144
  self.unknown_label = unknown_label
122
145
 
@@ -141,12 +164,16 @@ class MappedCollection:
141
164
 
142
165
  self.join_vars = join
143
166
  self.var_indices = None
167
+ self.var_joint = None
168
+ self.n_vars_list = None
169
+ self.n_vars = None
144
170
  if self.join_vars is not None:
145
171
  self._make_join_vars()
172
+ self.n_vars = len(self.var_joint)
146
173
 
147
- if self.label_keys is not None:
174
+ if self.obs_keys is not None:
148
175
  if cache_categories:
149
- self._cache_categories(self.label_keys)
176
+ self._cache_categories(self.obs_keys)
150
177
  else:
151
178
  self._cache_cats: dict = {}
152
179
  self.encoders: dict = {}
@@ -169,10 +196,10 @@ class MappedCollection:
169
196
  self.conns.append(conn)
170
197
  self.storages.append(storage)
171
198
 
172
- def _cache_categories(self, label_keys: list):
199
+ def _cache_categories(self, obs_keys: list):
173
200
  self._cache_cats = {}
174
201
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
175
- for label in label_keys:
202
+ for label in obs_keys:
176
203
  self._cache_cats[label] = []
177
204
  for storage in self.storages:
178
205
  with _Connect(storage) as store:
@@ -197,11 +224,13 @@ class MappedCollection:
197
224
 
198
225
  def _make_join_vars(self):
199
226
  var_list = []
227
+ self.n_vars_list = []
200
228
  for storage in self.storages:
201
229
  with _Connect(storage) as store:
202
- var_list.append(_safer_read_index(store["var"]))
230
+ vars = _safer_read_index(store["var"])
231
+ var_list.append(vars)
232
+ self.n_vars_list.append(len(vars))
203
233
 
204
- self.var_joint = None
205
234
  vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
206
235
  if vars_eq:
207
236
  self.join_vars = None
@@ -223,6 +252,20 @@ class MappedCollection:
223
252
  def __len__(self):
224
253
  return self.n_obs
225
254
 
255
+ @property
256
+ def shape(self):
257
+ """Shape of the (virtually aligned) dataset."""
258
+ return (self.n_obs, self.n_vars)
259
+
260
+ @property
261
+ def original_shapes(self):
262
+ """Shapes of the underlying AnnData objects."""
263
+ if self.n_vars_list is None:
264
+ n_vars_list = [None] * len(self.n_obs_list)
265
+ else:
266
+ n_vars_list = self.n_vars_list
267
+ return list(zip(self.n_obs_list, n_vars_list))
268
+
226
269
  def __getitem__(self, idx: int):
227
270
  obs_idx = self.indices[idx]
228
271
  storage_idx = self.storage_idx[idx]
@@ -232,17 +275,28 @@ class MappedCollection:
232
275
  var_idxs_join = None
233
276
 
234
277
  with _Connect(self.storages[storage_idx]) as store:
235
- out = {"x": self._get_data_idx(store, obs_idx, var_idxs_join)}
236
- out["_storage_idx"] = storage_idx
237
- if self.label_keys is not None:
238
- for label in self.label_keys:
278
+ out = {}
279
+ for layers_key in self.layers_keys:
280
+ lazy_data = (
281
+ store["X"] if layers_key == "X" else store["layers"][layers_key]
282
+ )
283
+ out[layers_key] = self._get_data_idx(
284
+ lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
285
+ )
286
+ if self.obsm_keys is not None:
287
+ for obsm_key in self.obsm_keys:
288
+ lazy_data = store["obsm"][obsm_key]
289
+ out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
290
+ out["_store_idx"] = storage_idx
291
+ if self.obs_keys is not None:
292
+ for label in self.obs_keys:
239
293
  if label in self._cache_cats:
240
294
  cats = self._cache_cats[label][storage_idx]
241
295
  if cats is None:
242
296
  cats = []
243
297
  else:
244
298
  cats = None
245
- label_idx = self._get_label_idx(store, obs_idx, label, cats)
299
+ label_idx = self._get_obs_idx(store, obs_idx, label, cats)
246
300
  if label in self.encoders:
247
301
  label_idx = self.encoders[label][label_idx]
248
302
  out[label] = label_idx
@@ -250,46 +304,46 @@ class MappedCollection:
250
304
 
251
305
  def _get_data_idx(
252
306
  self,
253
- storage: StorageType, # type: ignore
307
+ lazy_data: ArrayType | GroupType, # type: ignore
254
308
  idx: int,
309
+ join_vars: Literal["inner", "outer"] | None = None,
255
310
  var_idxs_join: list | None = None,
256
- layer_key: str | None = None,
311
+ n_vars_out: int | None = None,
257
312
  ):
258
313
  """Get the index for the data."""
259
- layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore
260
- if isinstance(layer, ArrayTypes): # type: ignore
261
- layer_idx = layer[idx]
262
- if self.join_vars is None:
263
- result = layer_idx
314
+ if isinstance(lazy_data, ArrayTypes): # type: ignore
315
+ lazy_data_idx = lazy_data[idx] # type: ignore
316
+ if join_vars is None:
317
+ result = lazy_data_idx
264
318
  if self._dtype is not None:
265
319
  result = result.astype(self._dtype, copy=False)
266
- elif self.join_vars == "outer":
267
- dtype = layer_idx.dtype if self._dtype is None else self._dtype
268
- result = np.zeros(len(self.var_joint), dtype=dtype)
269
- result[var_idxs_join] = layer_idx
320
+ elif join_vars == "outer":
321
+ dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
322
+ result = np.zeros(n_vars_out, dtype=dtype)
323
+ result[var_idxs_join] = lazy_data_idx
270
324
  else: # inner join
271
- result = layer_idx[var_idxs_join]
325
+ result = lazy_data_idx[var_idxs_join]
272
326
  if self._dtype is not None:
273
327
  result = result.astype(self._dtype, copy=False)
274
328
  return result
275
329
  else: # assume csr_matrix here
276
- data = layer["data"]
277
- indices = layer["indices"]
278
- indptr = layer["indptr"]
330
+ data = lazy_data["data"] # type: ignore
331
+ indices = lazy_data["indices"] # type: ignore
332
+ indptr = lazy_data["indptr"] # type: ignore
279
333
  s = slice(*(indptr[idx : idx + 2]))
280
334
  data_s = data[s]
281
335
  dtype = data_s.dtype if self._dtype is None else self._dtype
282
- if self.join_vars == "outer":
283
- layer_idx = np.zeros(len(self.var_joint), dtype=dtype)
284
- layer_idx[var_idxs_join[indices[s]]] = data_s
336
+ if join_vars == "outer":
337
+ lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
338
+ lazy_data_idx[var_idxs_join[indices[s]]] = data_s
285
339
  else:
286
- layer_idx = np.zeros(layer.attrs["shape"][1], dtype=dtype)
287
- layer_idx[indices[s]] = data_s
288
- if self.join_vars == "inner":
289
- layer_idx = layer_idx[var_idxs_join]
290
- return layer_idx
340
+ lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
341
+ lazy_data_idx[indices[s]] = data_s
342
+ if join_vars == "inner":
343
+ lazy_data_idx = lazy_data_idx[var_idxs_join]
344
+ return lazy_data_idx
291
345
 
292
- def _get_label_idx(
346
+ def _get_obs_idx(
293
347
  self,
294
348
  storage: StorageType,
295
349
  idx: int,
@@ -317,12 +371,12 @@ class MappedCollection:
317
371
  label = label.decode("utf-8")
318
372
  return label
319
373
 
320
- def get_label_weights(self, label_keys: str | list[str]):
374
+ def get_label_weights(self, obs_keys: str | list[str]):
321
375
  """Get all weights for the given label keys."""
322
- if isinstance(label_keys, str):
323
- label_keys = [label_keys]
376
+ if isinstance(obs_keys, str):
377
+ obs_keys = [obs_keys]
324
378
  labels_list = []
325
- for label_key in label_keys:
379
+ for label_key in obs_keys:
326
380
  labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
327
381
  labels_list.append(labels_to_str)
328
382
  if len(labels_list) > 1:
@@ -401,7 +401,47 @@ def mudata_papalexi21_subset(): # pragma: no cover
401
401
  "papalexi21_subset.h5mu",
402
402
  )
403
403
 
404
- return md.read_h5mu(filepath)
404
+ mdata = md.read_h5mu(filepath)
405
+ for mod in ["rna", "adt", "hto", "gdo"]:
406
+ mdata[mod].obs.drop(
407
+ mdata[mod].obs.columns, axis=1, inplace=True
408
+ ) # Drop all columns
409
+ for col in mdata.obs.columns:
410
+ for mod in ["rna", "adt", "hto", "gdo"]:
411
+ if col.endswith(f"_{mod.upper()}"):
412
+ new_col = col.replace(f"{mod}:", "")
413
+ if new_col != col:
414
+ mdata[mod].obs[new_col] = mdata.obs.pop(col)
415
+ else:
416
+ new_col = col.replace(f"{mod}:", "")
417
+ if new_col not in mdata.obs.columns and col in mdata.obs.columns:
418
+ mdata.obs[new_col] = mdata.obs.pop(col)
419
+
420
+ for col in mdata.obs.columns:
421
+ for mod in ["rna", "adt", "hto", "gdo"]:
422
+ if col.endswith(f"_{mod.upper()}"):
423
+ del mdata.obs[col]
424
+
425
+ for col in [
426
+ "orig.ident",
427
+ "MULTI_ID",
428
+ "NT",
429
+ "S.Score",
430
+ "G2M.Score",
431
+ "Phase",
432
+ "gene_target",
433
+ "guide_ID",
434
+ "HTO_classification",
435
+ ]:
436
+ del mdata.obs[col]
437
+ mdata.update()
438
+
439
+ mdata["rna"].obs["percent.mito"] = mdata.obs.pop("percent.mito")
440
+ mdata["hto"].obs["technique"] = "cell hashing"
441
+ mdata["hto"].obs["technique"] = mdata["hto"].obs["technique"].astype("category")
442
+ mdata.update()
443
+
444
+ return mdata
405
445
 
406
446
 
407
447
  def df_iris() -> pd.DataFrame:
@@ -100,7 +100,7 @@ def _records_to_df(obj):
100
100
  return obj
101
101
 
102
102
 
103
- class Registry:
103
+ class AccessRegistry:
104
104
  def __init__(self):
105
105
  self._registry = {}
106
106
  self._openers = {}
@@ -141,7 +141,7 @@ class Registry:
141
141
 
142
142
 
143
143
  # storage specific functions should be registered and called through the registry
144
- registry = Registry()
144
+ registry = AccessRegistry()
145
145
 
146
146
 
147
147
  @registry.register_open("h5py")
@@ -176,8 +176,12 @@ def safer_read_partial(elem, indices):
176
176
  indices_increasing = []
177
177
  indices_inverse = []
178
178
  for indices_dim in indices:
179
- if isinstance(indices_dim, np.ndarray) and not np.all(
180
- np.diff(indices_dim) > 0
179
+ # should be integer or bool
180
+ # ignore bool or increasing unique integers
181
+ if (
182
+ isinstance(indices_dim, np.ndarray)
183
+ and indices_dim.dtype != "bool"
184
+ and not np.all(np.diff(indices_dim) > 0)
181
185
  ):
182
186
  idx_unique, idx_inverse = np.unique(indices_dim, return_inverse=True)
183
187
  indices_increasing.append(idx_unique)
@@ -22,6 +22,7 @@ from lnschema_core.models import Artifact, Storage
22
22
  from lamindb.core._settings import settings
23
23
 
24
24
  if TYPE_CHECKING:
25
+ import mudata as md
25
26
  from lamindb_setup.core.types import UPathStr
26
27
 
27
28
  try:
@@ -136,6 +137,9 @@ def delete_storage_using_key(
136
137
 
137
138
  def delete_storage(storagepath: Path):
138
139
  """Delete arbitrary artifact."""
140
+ # TODO is_relative_to is not available in 3.8 and deprecated since 3.12
141
+ # replace with check_path_is_child_of_root but this needs to first be debugged
142
+ # if not check_path_is_child_of_root(storagepath, settings.storage):
139
143
  if not storagepath.is_relative_to(settings.storage): # type: ignore
140
144
  logger.warning("couldn't delete files outside of default storage")
141
145
  return "did-not-delete"
@@ -167,6 +171,13 @@ def read_tsv(path: UPathStr, **kwargs) -> pd.DataFrame:
167
171
  return pd.read_csv(path_sanitized, sep="\t", **kwargs)
168
172
 
169
173
 
174
+ def read_mdata_h5mu(filepath: UPathStr, **kwargs) -> md.MuData:
175
+ import mudata as md
176
+
177
+ path_sanitized = Path(filepath)
178
+ return md.read_h5mu(path_sanitized, **kwargs)
179
+
180
+
170
181
  def load_html(path: UPathStr):
171
182
  if is_run_from_ipython:
172
183
  with open(path, encoding="utf-8") as f:
@@ -221,6 +232,7 @@ def load_to_memory(filepath: UPathStr, stream: bool = False, **kwargs):
221
232
  ".zrad": read_adata_zarr,
222
233
  ".html": load_html,
223
234
  ".json": load_json,
235
+ ".h5mu": read_mdata_h5mu,
224
236
  }
225
237
 
226
238
  reader = READER_FUNCS.get(filepath.suffix)
@@ -9,6 +9,14 @@ if TYPE_CHECKING:
9
9
  from lamindb_setup.core.types import UPathStr
10
10
 
11
11
 
12
+ def _mudata_is_installed():
13
+ try:
14
+ import mudata
15
+ except ImportError:
16
+ return False
17
+ return True
18
+
19
+
12
20
  def infer_suffix(dmem, adata_format: str | None = None):
13
21
  """Infer LaminDB storage file suffix from a data object."""
14
22
  if isinstance(dmem, AnnData):
@@ -25,6 +33,11 @@ def infer_suffix(dmem, adata_format: str | None = None):
25
33
  elif isinstance(dmem, DataFrame):
26
34
  return ".parquet"
27
35
  else:
36
+ if _mudata_is_installed():
37
+ from mudata import MuData
38
+
39
+ if isinstance(dmem, MuData):
40
+ return ".h5mu"
28
41
  raise NotImplementedError
29
42
 
30
43
 
@@ -34,4 +47,10 @@ def write_to_file(dmem, filepath: UPathStr):
34
47
  elif isinstance(dmem, DataFrame):
35
48
  dmem.to_parquet(filepath)
36
49
  else:
50
+ if _mudata_is_installed():
51
+ from mudata import MuData
52
+
53
+ if isinstance(dmem, MuData):
54
+ dmem.write(filepath)
55
+ return
37
56
  raise NotImplementedError
lamindb/core/types.py CHANGED
@@ -4,14 +4,12 @@
4
4
  :toctree: .
5
5
 
6
6
  UPathStr
7
- DataLike
8
7
  StrField
9
8
  ListLike
10
9
  TransformType
11
10
  """
12
11
  from lamindb_setup.core.types import UPathStr
13
12
  from lnschema_core.types import (
14
- DataLike,
15
13
  ListLike,
16
14
  StrField,
17
15
  TransformType,
@@ -1,15 +1,4 @@
1
- """Core setup library.
1
+ import lamindb_setup as _lamindb_setup
2
+ from lamindb_setup.core import * # noqa: F403
2
3
 
3
- .. autosummary::
4
- :toctree:
5
-
6
- UserSettings
7
- InstanceSettings
8
- StorageSettings
9
-
10
- """
11
- from lamindb_setup.core import ( # pragma: no cover
12
- InstanceSettings,
13
- StorageSettings,
14
- UserSettings,
15
- )
4
+ __doc__ = _lamindb_setup.core.__doc__.replace("lamindb_setup", "lamindb.setup")