scdataloader 0.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -2
- scdataloader/__main__.py +3 -0
- scdataloader/collator.py +61 -96
- scdataloader/config.py +6 -0
- scdataloader/data.py +138 -90
- scdataloader/datamodule.py +67 -39
- scdataloader/mapped.py +302 -120
- scdataloader/preprocess.py +4 -213
- scdataloader/utils.py +128 -92
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/METADATA +82 -26
- scdataloader-1.0.5.dist-info/RECORD +16 -0
- scdataloader-0.0.4.dist-info/RECORD +0 -16
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/entry_points.txt +0 -0
scdataloader/mapped.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from collections import Counter
|
|
2
4
|
from functools import reduce
|
|
3
|
-
from
|
|
4
|
-
from typing import List, Literal, Optional, Union
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
import pandas as pd
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
from lamindb.
|
|
10
|
+
from lamindb_setup.core.upath import UPath
|
|
11
|
+
|
|
12
|
+
from lamindb.core.storage._anndata_accessor import (
|
|
13
|
+
ArrayType,
|
|
11
14
|
ArrayTypes,
|
|
15
|
+
GroupType,
|
|
12
16
|
GroupTypes,
|
|
13
17
|
StorageType,
|
|
14
18
|
_safer_read_index,
|
|
15
19
|
registry,
|
|
16
20
|
)
|
|
17
|
-
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from lamindb_setup.core.types import UPathStr
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
class _Connect:
|
|
@@ -41,52 +47,108 @@ class _Connect:
|
|
|
41
47
|
self.conn.close()
|
|
42
48
|
|
|
43
49
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
stream: bool = False,
|
|
47
|
-
is_run_input: Optional[bool] = None,
|
|
48
|
-
**kwargs,
|
|
49
|
-
) -> "MappedDataset":
|
|
50
|
-
_track_run_input(dataset, is_run_input)
|
|
51
|
-
path_list = []
|
|
52
|
-
for file in dataset.artifacts.all():
|
|
53
|
-
if file.suffix not in {".h5ad", ".zrad", ".zarr"}:
|
|
54
|
-
logger.warning(f"Ignoring file with suffix {file.suffix}")
|
|
55
|
-
continue
|
|
56
|
-
elif not stream and file.suffix == ".h5ad":
|
|
57
|
-
path_list.append(file.stage())
|
|
58
|
-
else:
|
|
59
|
-
path_list.append(file.path)
|
|
60
|
-
return MappedDataset(path_list, **kwargs)
|
|
61
|
-
|
|
50
|
+
class MappedCollection:
|
|
51
|
+
"""Map-style collection for use in data loaders.
|
|
62
52
|
|
|
63
|
-
class
|
|
64
|
-
|
|
53
|
+
This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
|
|
54
|
+
<https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
|
|
65
55
|
|
|
66
|
-
|
|
56
|
+
If your `AnnData` collection is in the cloud, move them into a local cache
|
|
57
|
+
first for faster access.
|
|
67
58
|
|
|
68
|
-
|
|
59
|
+
`__getitem__` of the `MappedCollection` object takes a single integer index
|
|
60
|
+
and returns a dictionary with the observation data sample for this index from
|
|
61
|
+
the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
|
|
62
|
+
(`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
|
|
63
|
+
for the index of the `AnnData` object containing this observation sample.
|
|
69
64
|
|
|
70
65
|
.. note::
|
|
71
66
|
|
|
72
|
-
|
|
73
|
-
|
|
67
|
+
For a guide, see :doc:`docs:scrna5`.
|
|
68
|
+
|
|
69
|
+
For more convenient use within :class:`~lamindb.core.MappedCollection`,
|
|
70
|
+
see :meth:`~lamindb.Collection.mapped`.
|
|
71
|
+
|
|
72
|
+
This currently only works for collections of `AnnData` objects.
|
|
73
|
+
|
|
74
|
+
The implementation was influenced by the `SCimilarity
|
|
75
|
+
<https://github.com/Genentech/scimilarity>`__ data loader.
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
|
|
80
|
+
layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
|
|
81
|
+
retrieves ``.X``.
|
|
82
|
+
obsm_keys: Keys from the ``.obsm`` slots.
|
|
83
|
+
obs_keys: Keys from the ``.obs`` slots.
|
|
84
|
+
join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
|
|
85
|
+
does not join.
|
|
86
|
+
encode_labels: Encode labels into integers.
|
|
87
|
+
Can be a list with elements from ``obs_keys``.
|
|
88
|
+
unknown_label: Encode this label to -1.
|
|
89
|
+
Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
|
|
90
|
+
or from ``encode_labels`` if it is a list.
|
|
91
|
+
cache_categories: Enable caching categories of ``obs_keys`` for faster access.
|
|
92
|
+
parallel: Enable sampling with multiple processes.
|
|
93
|
+
dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
|
|
74
94
|
"""
|
|
75
95
|
|
|
76
96
|
def __init__(
|
|
77
97
|
self,
|
|
78
|
-
path_list:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
98
|
+
path_list: list[UPathStr],
|
|
99
|
+
layers_keys: str | list[str] | None = None,
|
|
100
|
+
obs_keys: str | list[str] | None = None,
|
|
101
|
+
obsm_keys: str | list[str] | None = None,
|
|
102
|
+
join: Literal["inner", "outer"] | None = "inner",
|
|
103
|
+
encode_labels: bool | list[str] = True,
|
|
104
|
+
unknown_label: str | dict[str, str] | None = None,
|
|
105
|
+
cache_categories: bool = True,
|
|
82
106
|
parallel: bool = False,
|
|
83
|
-
|
|
107
|
+
dtype: str | None = None,
|
|
84
108
|
):
|
|
85
|
-
|
|
86
|
-
|
|
109
|
+
if join not in {None, "inner", "outer"}: # pragma: nocover
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if layers_keys is None:
|
|
115
|
+
self.layers_keys = ["X"]
|
|
116
|
+
else:
|
|
117
|
+
self.layers_keys = (
|
|
118
|
+
[layers_keys] if isinstance(layers_keys, str) else layers_keys
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
|
|
122
|
+
self.obsm_keys = obsm_keys
|
|
123
|
+
|
|
124
|
+
obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
|
|
125
|
+
self.obs_keys = obs_keys
|
|
126
|
+
|
|
127
|
+
if isinstance(encode_labels, list):
|
|
128
|
+
if len(encode_labels) == 0:
|
|
129
|
+
encode_labels = False
|
|
130
|
+
elif obs_keys is None or not all(
|
|
131
|
+
enc_label in obs_keys for enc_label in encode_labels
|
|
132
|
+
):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"All elements of `encode_labels` should be in `obs_keys`."
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
if encode_labels:
|
|
138
|
+
encode_labels = obs_keys if obs_keys is not None else False
|
|
139
|
+
self.encode_labels = encode_labels
|
|
140
|
+
|
|
141
|
+
if encode_labels and isinstance(unknown_label, dict):
|
|
142
|
+
if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
|
|
145
|
+
)
|
|
146
|
+
self.unknown_label = unknown_label
|
|
147
|
+
|
|
148
|
+
self.storages = [] # type: ignore
|
|
149
|
+
self.conns = [] # type: ignore
|
|
87
150
|
self.parallel = parallel
|
|
88
|
-
self.
|
|
89
|
-
self.path_list = path_list
|
|
151
|
+
self._path_list = path_list
|
|
90
152
|
self._make_connections(path_list, parallel)
|
|
91
153
|
|
|
92
154
|
self.n_obs_list = []
|
|
@@ -98,7 +160,7 @@ class MappedDataset:
|
|
|
98
160
|
if "ensembl_gene_id" in store["var"]
|
|
99
161
|
else store["var"]["_index"]
|
|
100
162
|
)
|
|
101
|
-
if
|
|
163
|
+
if join is None:
|
|
102
164
|
if not all(
|
|
103
165
|
[
|
|
104
166
|
i <= j
|
|
@@ -118,27 +180,25 @@ class MappedDataset:
|
|
|
118
180
|
self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
|
|
119
181
|
self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
|
|
120
182
|
|
|
121
|
-
self.join_vars =
|
|
183
|
+
self.join_vars = join
|
|
122
184
|
self.var_indices = None
|
|
123
|
-
|
|
185
|
+
self.var_joint = None
|
|
186
|
+
self.n_vars_list = None
|
|
187
|
+
self.n_vars = None
|
|
188
|
+
if self.join_vars is not None:
|
|
124
189
|
self._make_join_vars()
|
|
190
|
+
self.n_vars = len(self.var_joint)
|
|
125
191
|
|
|
126
|
-
self.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
if encode_labels:
|
|
130
|
-
encode_labels = label_keys
|
|
192
|
+
if self.obs_keys is not None:
|
|
193
|
+
if cache_categories:
|
|
194
|
+
self._cache_categories(self.obs_keys)
|
|
131
195
|
else:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
self.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if unknown_class in self.encoders[label]:
|
|
139
|
-
self.encoders[label][unknown_class] = -1
|
|
140
|
-
else:
|
|
141
|
-
self.encoders = {}
|
|
196
|
+
self._cache_cats: dict = {}
|
|
197
|
+
self.encoders: dict = {}
|
|
198
|
+
if self.encode_labels:
|
|
199
|
+
self._make_encoders(self.encode_labels) # type: ignore
|
|
200
|
+
|
|
201
|
+
self._dtype = dtype
|
|
142
202
|
self._closed = False
|
|
143
203
|
|
|
144
204
|
def _make_connections(self, path_list: list, parallel: bool):
|
|
@@ -154,81 +214,171 @@ class MappedDataset:
|
|
|
154
214
|
self.conns.append(conn)
|
|
155
215
|
self.storages.append(storage)
|
|
156
216
|
|
|
217
|
+
def _cache_categories(self, obs_keys: list):
|
|
218
|
+
self._cache_cats = {}
|
|
219
|
+
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
220
|
+
for label in obs_keys:
|
|
221
|
+
self._cache_cats[label] = []
|
|
222
|
+
for storage in self.storages:
|
|
223
|
+
with _Connect(storage) as store:
|
|
224
|
+
cats = self._get_categories(store, label)
|
|
225
|
+
if cats is not None:
|
|
226
|
+
cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
227
|
+
self._cache_cats[label].append(cats)
|
|
228
|
+
|
|
229
|
+
def _make_encoders(self, encode_labels: list):
|
|
230
|
+
for label in encode_labels:
|
|
231
|
+
cats = self.get_merged_categories(label)
|
|
232
|
+
encoder = {}
|
|
233
|
+
if isinstance(self.unknown_label, dict):
|
|
234
|
+
unknown_label = self.unknown_label.get(label, None)
|
|
235
|
+
else:
|
|
236
|
+
unknown_label = self.unknown_label
|
|
237
|
+
if unknown_label is not None and unknown_label in cats:
|
|
238
|
+
cats.remove(unknown_label)
|
|
239
|
+
encoder[unknown_label] = -1
|
|
240
|
+
cats = list(cats)
|
|
241
|
+
cats.sort()
|
|
242
|
+
encoder.update({cat: i for i, cat in enumerate(cats)})
|
|
243
|
+
self.encoders[label] = encoder
|
|
244
|
+
|
|
157
245
|
def _make_join_vars(self):
|
|
158
246
|
var_list = []
|
|
247
|
+
self.n_vars_list = []
|
|
159
248
|
for storage in self.storages:
|
|
160
249
|
with _Connect(storage) as store:
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
250
|
+
vars = _safer_read_index(store["var"])
|
|
251
|
+
var_list.append(vars)
|
|
252
|
+
self.n_vars_list.append(len(vars))
|
|
253
|
+
|
|
254
|
+
self.var_joint = None
|
|
255
|
+
vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
|
|
256
|
+
if vars_eq:
|
|
257
|
+
self.join_vars = None
|
|
258
|
+
self.var_joint = var_list[0]
|
|
259
|
+
return
|
|
260
|
+
|
|
169
261
|
if self.join_vars == "inner":
|
|
170
262
|
self.var_joint = reduce(pd.Index.intersection, var_list)
|
|
171
263
|
if len(self.var_joint) == 0:
|
|
172
264
|
raise ValueError(
|
|
173
|
-
"The provided AnnData objects don't have shared varibales
|
|
265
|
+
"The provided AnnData objects don't have shared varibales.\n"
|
|
266
|
+
"Use join='outer'."
|
|
174
267
|
)
|
|
175
268
|
self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
|
|
269
|
+
elif self.join_vars == "outer":
|
|
270
|
+
self.var_joint = reduce(pd.Index.union, var_list)
|
|
271
|
+
self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
|
|
176
272
|
|
|
177
273
|
def _check_aligned_vars(self, vars: list):
|
|
178
274
|
i = 0
|
|
179
275
|
for storage in self.storages:
|
|
180
276
|
with _Connect(storage) as store:
|
|
181
|
-
if
|
|
277
|
+
if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
|
|
182
278
|
i += 1
|
|
183
279
|
print("{}% are aligned".format(i * 100 / len(self.storages)))
|
|
184
280
|
|
|
185
281
|
def __len__(self):
|
|
186
282
|
return self.n_obs
|
|
187
283
|
|
|
284
|
+
@property
|
|
285
|
+
def shape(self):
|
|
286
|
+
"""Shape of the (virtually aligned) dataset."""
|
|
287
|
+
return (self.n_obs, self.n_vars)
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def original_shapes(self):
|
|
291
|
+
"""Shapes of the underlying AnnData objects."""
|
|
292
|
+
if self.n_vars_list is None:
|
|
293
|
+
n_vars_list = [None] * len(self.n_obs_list)
|
|
294
|
+
else:
|
|
295
|
+
n_vars_list = self.n_vars_list
|
|
296
|
+
return list(zip(self.n_obs_list, n_vars_list))
|
|
297
|
+
|
|
188
298
|
def __getitem__(self, idx: int):
|
|
189
299
|
obs_idx = self.indices[idx]
|
|
190
300
|
storage_idx = self.storage_idx[idx]
|
|
191
301
|
if self.var_indices is not None:
|
|
192
|
-
|
|
302
|
+
var_idxs_join = self.var_indices[storage_idx]
|
|
193
303
|
else:
|
|
194
|
-
|
|
304
|
+
var_idxs_join = None
|
|
305
|
+
|
|
195
306
|
with _Connect(self.storages[storage_idx]) as store:
|
|
196
|
-
out = {
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
307
|
+
out = {}
|
|
308
|
+
for layers_key in self.layers_keys:
|
|
309
|
+
lazy_data = (
|
|
310
|
+
store["X"] if layers_key == "X" else store["layers"][layers_key]
|
|
311
|
+
)
|
|
312
|
+
out[layers_key] = self._get_data_idx(
|
|
313
|
+
lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
|
|
314
|
+
)
|
|
315
|
+
if self.obsm_keys is not None:
|
|
316
|
+
for obsm_key in self.obsm_keys:
|
|
317
|
+
lazy_data = store["obsm"][obsm_key]
|
|
318
|
+
out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
|
|
319
|
+
out["_store_idx"] = storage_idx
|
|
320
|
+
if self.obs_keys is not None:
|
|
321
|
+
for label in self.obs_keys:
|
|
322
|
+
if label in self._cache_cats:
|
|
323
|
+
cats = self._cache_cats[label][storage_idx]
|
|
324
|
+
if cats is None:
|
|
325
|
+
cats = []
|
|
202
326
|
else:
|
|
203
|
-
|
|
204
|
-
|
|
327
|
+
cats = None
|
|
328
|
+
label_idx = self._get_obs_idx(store, obs_idx, label, cats)
|
|
329
|
+
if label in self.encoders:
|
|
330
|
+
label_idx = self.encoders[label][label_idx]
|
|
331
|
+
out[label] = label_idx
|
|
205
332
|
return out
|
|
206
333
|
|
|
207
|
-
def
|
|
334
|
+
def _get_data_idx(
|
|
208
335
|
self,
|
|
209
|
-
|
|
336
|
+
lazy_data: ArrayType | GroupType, # type: ignore
|
|
210
337
|
idx: int,
|
|
211
|
-
|
|
212
|
-
|
|
338
|
+
join_vars: Literal["inner", "outer"] | None = None,
|
|
339
|
+
var_idxs_join: list | None = None,
|
|
340
|
+
n_vars_out: int | None = None,
|
|
213
341
|
):
|
|
214
342
|
"""Get the index for the data."""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
343
|
+
if isinstance(lazy_data, ArrayTypes): # type: ignore
|
|
344
|
+
lazy_data_idx = lazy_data[idx] # type: ignore
|
|
345
|
+
if join_vars is None:
|
|
346
|
+
result = lazy_data_idx
|
|
347
|
+
if self._dtype is not None:
|
|
348
|
+
result = result.astype(self._dtype, copy=False)
|
|
349
|
+
elif join_vars == "outer":
|
|
350
|
+
dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
|
|
351
|
+
result = np.zeros(n_vars_out, dtype=dtype)
|
|
352
|
+
result[var_idxs_join] = lazy_data_idx
|
|
353
|
+
else: # inner join
|
|
354
|
+
result = lazy_data_idx[var_idxs_join]
|
|
355
|
+
if self._dtype is not None:
|
|
356
|
+
result = result.astype(self._dtype, copy=False)
|
|
357
|
+
return result
|
|
220
358
|
else: # assume csr_matrix here
|
|
221
|
-
data =
|
|
222
|
-
indices =
|
|
223
|
-
indptr =
|
|
359
|
+
data = lazy_data["data"] # type: ignore
|
|
360
|
+
indices = lazy_data["indices"] # type: ignore
|
|
361
|
+
indptr = lazy_data["indptr"] # type: ignore
|
|
224
362
|
s = slice(*(indptr[idx : idx + 2]))
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
363
|
+
data_s = data[s]
|
|
364
|
+
dtype = data_s.dtype if self._dtype is None else self._dtype
|
|
365
|
+
if join_vars == "outer":
|
|
366
|
+
lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
|
|
367
|
+
lazy_data_idx[var_idxs_join[indices[s]]] = data_s
|
|
368
|
+
else:
|
|
369
|
+
lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
|
|
370
|
+
lazy_data_idx[indices[s]] = data_s
|
|
371
|
+
if join_vars == "inner":
|
|
372
|
+
lazy_data_idx = lazy_data_idx[var_idxs_join]
|
|
373
|
+
return lazy_data_idx
|
|
230
374
|
|
|
231
|
-
def
|
|
375
|
+
def _get_obs_idx(
|
|
376
|
+
self,
|
|
377
|
+
storage: StorageType,
|
|
378
|
+
idx: int,
|
|
379
|
+
label_key: str,
|
|
380
|
+
categories: list | None = None,
|
|
381
|
+
):
|
|
232
382
|
"""Get the index for the label by key."""
|
|
233
383
|
obs = storage["obs"] # type: ignore
|
|
234
384
|
# how backwards compatible do we want to be here actually?
|
|
@@ -240,25 +390,29 @@ class MappedDataset:
|
|
|
240
390
|
label = labels[idx]
|
|
241
391
|
else:
|
|
242
392
|
label = labels["codes"][idx]
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
393
|
+
if categories is not None:
|
|
394
|
+
cats = categories
|
|
395
|
+
else:
|
|
396
|
+
cats = self._get_categories(storage, label_key)
|
|
397
|
+
if cats is not None and len(cats) > 0:
|
|
246
398
|
label = cats[label]
|
|
247
399
|
if isinstance(label, bytes):
|
|
248
400
|
label = label.decode("utf-8")
|
|
249
401
|
return label
|
|
250
402
|
|
|
251
|
-
def get_label_weights(self,
|
|
252
|
-
"""Get all weights for
|
|
253
|
-
if
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
403
|
+
def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
|
|
404
|
+
"""Get all weights for the given label keys."""
|
|
405
|
+
if isinstance(obs_keys, str):
|
|
406
|
+
obs_keys = [obs_keys]
|
|
407
|
+
labels_list = []
|
|
408
|
+
for label_key in obs_keys:
|
|
409
|
+
labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
|
|
410
|
+
labels_list.append(labels_to_str)
|
|
411
|
+
if len(labels_list) > 1:
|
|
412
|
+
labels = reduce(lambda a, b: a + b, labels_list)
|
|
413
|
+
else:
|
|
414
|
+
labels = labels_list[0]
|
|
415
|
+
|
|
262
416
|
counter = Counter(labels) # type: ignore
|
|
263
417
|
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
264
418
|
labels = np.array([rn[label] for label in labels])
|
|
@@ -267,14 +421,17 @@ class MappedDataset:
|
|
|
267
421
|
return weights, labels
|
|
268
422
|
|
|
269
423
|
def get_merged_labels(self, label_key: str):
|
|
270
|
-
"""Get merged labels
|
|
424
|
+
"""Get merged labels for `label_key` from all `.obs`."""
|
|
271
425
|
labels_merge = []
|
|
272
426
|
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
273
|
-
for storage in self.storages:
|
|
427
|
+
for i, storage in enumerate(self.storages):
|
|
274
428
|
with _Connect(storage) as store:
|
|
275
|
-
codes = self.
|
|
429
|
+
codes = self._get_codes(store, label_key)
|
|
276
430
|
labels = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
277
|
-
|
|
431
|
+
if label_key in self._cache_cats:
|
|
432
|
+
cats = self._cache_cats[label_key][i]
|
|
433
|
+
else:
|
|
434
|
+
cats = self._get_categories(store, label_key)
|
|
278
435
|
if cats is not None:
|
|
279
436
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
280
437
|
labels = cats[labels]
|
|
@@ -282,22 +439,25 @@ class MappedDataset:
|
|
|
282
439
|
return np.hstack(labels_merge)
|
|
283
440
|
|
|
284
441
|
def get_merged_categories(self, label_key: str):
|
|
285
|
-
"""Get merged categories
|
|
442
|
+
"""Get merged categories for `label_key` from all `.obs`."""
|
|
286
443
|
cats_merge = set()
|
|
287
444
|
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
288
|
-
for storage in self.storages:
|
|
445
|
+
for i, storage in enumerate(self.storages):
|
|
289
446
|
with _Connect(storage) as store:
|
|
290
|
-
|
|
447
|
+
if label_key in self._cache_cats:
|
|
448
|
+
cats = self._cache_cats[label_key][i]
|
|
449
|
+
else:
|
|
450
|
+
cats = self._get_categories(store, label_key)
|
|
291
451
|
if cats is not None:
|
|
292
452
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
293
453
|
cats_merge.update(cats)
|
|
294
454
|
else:
|
|
295
|
-
codes = self.
|
|
455
|
+
codes = self._get_codes(store, label_key)
|
|
296
456
|
codes = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
297
457
|
cats_merge.update(codes)
|
|
298
458
|
return cats_merge
|
|
299
459
|
|
|
300
|
-
def
|
|
460
|
+
def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
|
|
301
461
|
"""Get categories."""
|
|
302
462
|
obs = storage["obs"] # type: ignore
|
|
303
463
|
if isinstance(obs, ArrayTypes): # type: ignore
|
|
@@ -324,8 +484,9 @@ class MappedDataset:
|
|
|
324
484
|
return labels.attrs["categories"]
|
|
325
485
|
else:
|
|
326
486
|
return None
|
|
487
|
+
return None
|
|
327
488
|
|
|
328
|
-
def
|
|
489
|
+
def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
|
|
329
490
|
"""Get codes."""
|
|
330
491
|
obs = storage["obs"] # type: ignore
|
|
331
492
|
if isinstance(obs, ArrayTypes): # type: ignore
|
|
@@ -338,7 +499,10 @@ class MappedDataset:
|
|
|
338
499
|
return label["codes"][...]
|
|
339
500
|
|
|
340
501
|
def close(self):
|
|
341
|
-
"""Close
|
|
502
|
+
"""Close connections to array streaming backend.
|
|
503
|
+
|
|
504
|
+
No effect if `parallel=True`.
|
|
505
|
+
"""
|
|
342
506
|
for storage in self.storages:
|
|
343
507
|
if hasattr(storage, "close"):
|
|
344
508
|
storage.close()
|
|
@@ -349,6 +513,10 @@ class MappedDataset:
|
|
|
349
513
|
|
|
350
514
|
@property
|
|
351
515
|
def closed(self):
|
|
516
|
+
"""Check if connections to array streaming backend are closed.
|
|
517
|
+
|
|
518
|
+
Does not matter if `parallel=True`.
|
|
519
|
+
"""
|
|
352
520
|
return self._closed
|
|
353
521
|
|
|
354
522
|
def __enter__(self):
|
|
@@ -356,3 +524,17 @@ class MappedDataset:
|
|
|
356
524
|
|
|
357
525
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
358
526
|
self.close()
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def torch_worker_init_fn(worker_id):
|
|
530
|
+
"""`worker_init_fn` for `torch.utils.data.DataLoader`.
|
|
531
|
+
|
|
532
|
+
Improves performance for `num_workers > 1`.
|
|
533
|
+
"""
|
|
534
|
+
from torch.utils.data import get_worker_info
|
|
535
|
+
|
|
536
|
+
mapped = get_worker_info().dataset
|
|
537
|
+
mapped.parallel = False
|
|
538
|
+
mapped.storages = []
|
|
539
|
+
mapped.conns = []
|
|
540
|
+
mapped._make_connections(mapped._path_list, parallel=False)
|