scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +66 -42
- scdataloader/collator.py +136 -67
- scdataloader/config.py +112 -0
- scdataloader/data.py +160 -169
- scdataloader/datamodule.py +403 -0
- scdataloader/mapped.py +285 -109
- scdataloader/preprocess.py +240 -109
- scdataloader/utils.py +162 -70
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/METADATA +87 -18
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
scdataloader/mapped.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from collections import Counter
|
|
2
4
|
from functools import reduce
|
|
3
|
-
from
|
|
4
|
-
from typing import List, Literal, Optional, Union
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
import pandas as pd
|
|
8
10
|
from lamin_utils import logger
|
|
9
|
-
from
|
|
10
|
-
from lamindb.
|
|
11
|
+
from lamindb_setup.core.upath import UPath
|
|
12
|
+
from lamindb.core._data import _track_run_input
|
|
13
|
+
|
|
14
|
+
from lamindb.core.storage._backed_access import (
|
|
11
15
|
ArrayTypes,
|
|
12
16
|
GroupTypes,
|
|
13
17
|
StorageType,
|
|
14
18
|
_safer_read_index,
|
|
15
19
|
registry,
|
|
16
20
|
)
|
|
17
|
-
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from lamindb_setup.core.types import UPathStr
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
class _Connect:
|
|
@@ -43,56 +49,140 @@ class _Connect:
|
|
|
43
49
|
|
|
44
50
|
def mapped(
|
|
45
51
|
dataset,
|
|
52
|
+
label_keys: str | list[str] | None = None,
|
|
53
|
+
join: Literal["inner", "outer"] | None = "inner",
|
|
54
|
+
encode_labels: bool | list[str] = True,
|
|
55
|
+
unknown_label: str | dict[str, str] | None = None,
|
|
56
|
+
cache_categories: bool = True,
|
|
57
|
+
parallel: bool = False,
|
|
58
|
+
dtype: str | None = None,
|
|
46
59
|
stream: bool = False,
|
|
47
|
-
is_run_input:
|
|
48
|
-
|
|
49
|
-
) -> "MappedDataset":
|
|
50
|
-
_track_run_input(dataset, is_run_input)
|
|
60
|
+
is_run_input: bool | None = None,
|
|
61
|
+
) -> MappedCollection:
|
|
51
62
|
path_list = []
|
|
52
|
-
for
|
|
53
|
-
if
|
|
54
|
-
logger.warning(f"Ignoring
|
|
63
|
+
for artifact in dataset.artifacts.all():
|
|
64
|
+
if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
|
|
65
|
+
logger.warning(f"Ignoring artifact with suffix {artifact.suffix}")
|
|
55
66
|
continue
|
|
56
|
-
elif not stream
|
|
57
|
-
path_list.append(
|
|
67
|
+
elif not stream:
|
|
68
|
+
path_list.append(artifact.stage())
|
|
58
69
|
else:
|
|
59
|
-
path_list.append(
|
|
60
|
-
|
|
70
|
+
path_list.append(artifact.path)
|
|
71
|
+
ds = MappedCollection(
|
|
72
|
+
path_list,
|
|
73
|
+
label_keys,
|
|
74
|
+
join,
|
|
75
|
+
encode_labels,
|
|
76
|
+
unknown_label,
|
|
77
|
+
cache_categories,
|
|
78
|
+
parallel,
|
|
79
|
+
dtype,
|
|
80
|
+
)
|
|
81
|
+
# track only if successful
|
|
82
|
+
_track_run_input(dataset, is_run_input)
|
|
83
|
+
return ds
|
|
61
84
|
|
|
62
85
|
|
|
63
|
-
class
|
|
64
|
-
"""Map-style
|
|
86
|
+
class MappedCollection:
|
|
87
|
+
"""Map-style collection for use in data loaders.
|
|
65
88
|
|
|
66
|
-
This
|
|
89
|
+
This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
|
|
90
|
+
<https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
|
|
67
91
|
|
|
68
|
-
|
|
92
|
+
If your `AnnData` collection is in the cloud, move them into a local cache
|
|
93
|
+
first for faster access.
|
|
69
94
|
|
|
70
95
|
.. note::
|
|
71
96
|
|
|
72
|
-
|
|
73
|
-
|
|
97
|
+
For a guide, see :doc:`docs:scrna5`.
|
|
98
|
+
|
|
99
|
+
For more convenient use within :class:`~lamindb.core.MappedCollection`,
|
|
100
|
+
see :meth:`~lamindb.Collection.mapped`.
|
|
101
|
+
|
|
102
|
+
This currently only works for collections of `AnnData` objects.
|
|
103
|
+
|
|
104
|
+
The implementation was influenced by the `SCimilarity
|
|
105
|
+
<https://github.com/Genentech/scimilarity>`__ data loader.
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
|
|
110
|
+
label_keys: Columns of the ``.obs`` slot that store labels.
|
|
111
|
+
join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
|
|
112
|
+
does not join.
|
|
113
|
+
encode_labels: Encode labels into integers.
|
|
114
|
+
Can be a list with elements from ``label_keys```.
|
|
115
|
+
unknown_label: Encode this label to -1.
|
|
116
|
+
Can be a dictionary with keys from ``label_keys`` if ``encode_labels=True```
|
|
117
|
+
or from ``encode_labels`` if it is a list.
|
|
118
|
+
cache_categories: Enable caching categories of ``label_keys`` for faster access.
|
|
119
|
+
parallel: Enable sampling with multiple processes.
|
|
120
|
+
dtype: Convert numpy arrays from ``.X`` to this dtype on selection.
|
|
74
121
|
"""
|
|
75
122
|
|
|
76
123
|
def __init__(
|
|
77
124
|
self,
|
|
78
|
-
path_list:
|
|
79
|
-
label_keys:
|
|
80
|
-
|
|
81
|
-
encode_labels:
|
|
125
|
+
path_list: list[UPathStr],
|
|
126
|
+
label_keys: str | list[str] | None = None,
|
|
127
|
+
join: Literal["inner", "outer", "auto"] | None = "inner",
|
|
128
|
+
encode_labels: bool | list[str] = True,
|
|
129
|
+
unknown_label: str | dict[str, str] | None = None,
|
|
130
|
+
cache_categories: bool = True,
|
|
82
131
|
parallel: bool = False,
|
|
83
|
-
|
|
132
|
+
dtype: str | None = None,
|
|
84
133
|
):
|
|
85
|
-
|
|
86
|
-
|
|
134
|
+
assert join in {None, "inner", "outer", "auto"}
|
|
135
|
+
|
|
136
|
+
label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
|
|
137
|
+
self.label_keys = label_keys
|
|
138
|
+
|
|
139
|
+
if isinstance(encode_labels, list):
|
|
140
|
+
if len(encode_labels) == 0:
|
|
141
|
+
encode_labels = False
|
|
142
|
+
elif label_keys is None or not all(
|
|
143
|
+
enc_label in label_keys for enc_label in encode_labels
|
|
144
|
+
):
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"All elements of `encode_labels` should be in `label_keys`."
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
if encode_labels:
|
|
150
|
+
encode_labels = label_keys if label_keys is not None else False
|
|
151
|
+
self.encode_labels = encode_labels
|
|
152
|
+
|
|
153
|
+
if encode_labels and isinstance(unknown_label, dict):
|
|
154
|
+
if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"All keys of `unknown_label` should be in `encode_labels` and `label_keys`."
|
|
157
|
+
)
|
|
158
|
+
self.unknown_label = unknown_label
|
|
159
|
+
|
|
160
|
+
self.storages = [] # type: ignore
|
|
161
|
+
self.conns = [] # type: ignore
|
|
87
162
|
self.parallel = parallel
|
|
88
|
-
self.
|
|
89
|
-
self.path_list = path_list
|
|
163
|
+
self._path_list = path_list
|
|
90
164
|
self._make_connections(path_list, parallel)
|
|
91
165
|
|
|
92
166
|
self.n_obs_list = []
|
|
93
167
|
for storage in self.storages:
|
|
94
168
|
with _Connect(storage) as store:
|
|
95
169
|
X = store["X"]
|
|
170
|
+
index = (
|
|
171
|
+
store["var"]["ensembl_gene_id"]
|
|
172
|
+
if "ensembl_gene_id" in store["var"]
|
|
173
|
+
else store["var"]["_index"]
|
|
174
|
+
)
|
|
175
|
+
if join is None:
|
|
176
|
+
if not all(
|
|
177
|
+
[
|
|
178
|
+
i <= j
|
|
179
|
+
for i, j in zip(
|
|
180
|
+
index[:99],
|
|
181
|
+
index[1:100],
|
|
182
|
+
)
|
|
183
|
+
]
|
|
184
|
+
):
|
|
185
|
+
raise ValueError("The variables are not sorted.")
|
|
96
186
|
if isinstance(X, ArrayTypes): # type: ignore
|
|
97
187
|
self.n_obs_list.append(X.shape[0])
|
|
98
188
|
else:
|
|
@@ -102,27 +192,21 @@ class MappedDataset:
|
|
|
102
192
|
self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
|
|
103
193
|
self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
|
|
104
194
|
|
|
105
|
-
self.join_vars =
|
|
195
|
+
self.join_vars = join
|
|
106
196
|
self.var_indices = None
|
|
107
|
-
if self.join_vars
|
|
197
|
+
if self.join_vars is not None:
|
|
108
198
|
self._make_join_vars()
|
|
109
199
|
|
|
110
|
-
self.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if encode_labels:
|
|
114
|
-
encode_labels = label_keys
|
|
200
|
+
if self.label_keys is not None:
|
|
201
|
+
if cache_categories:
|
|
202
|
+
self._cache_categories(self.label_keys)
|
|
115
203
|
else:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
self.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if unknown_class in self.encoders[label]:
|
|
123
|
-
self.encoders[label][unknown_class] = -1
|
|
124
|
-
else:
|
|
125
|
-
self.encoders = {}
|
|
204
|
+
self._cache_cats: dict = {}
|
|
205
|
+
self.encoders: dict = {}
|
|
206
|
+
if self.encode_labels:
|
|
207
|
+
self._make_encoders(self.encode_labels) # type: ignore
|
|
208
|
+
|
|
209
|
+
self._dtype = dtype
|
|
126
210
|
self._closed = False
|
|
127
211
|
|
|
128
212
|
def _make_connections(self, path_list: list, parallel: bool):
|
|
@@ -138,31 +222,63 @@ class MappedDataset:
|
|
|
138
222
|
self.conns.append(conn)
|
|
139
223
|
self.storages.append(storage)
|
|
140
224
|
|
|
225
|
+
def _cache_categories(self, label_keys: list):
|
|
226
|
+
self._cache_cats = {}
|
|
227
|
+
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
228
|
+
for label in label_keys:
|
|
229
|
+
self._cache_cats[label] = []
|
|
230
|
+
for storage in self.storages:
|
|
231
|
+
with _Connect(storage) as store:
|
|
232
|
+
cats = self._get_categories(store, label)
|
|
233
|
+
if cats is not None:
|
|
234
|
+
cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
235
|
+
self._cache_cats[label].append(cats)
|
|
236
|
+
|
|
237
|
+
def _make_encoders(self, encode_labels: list):
|
|
238
|
+
for label in encode_labels:
|
|
239
|
+
cats = self.get_merged_categories(label)
|
|
240
|
+
encoder = {}
|
|
241
|
+
if isinstance(self.unknown_label, dict):
|
|
242
|
+
unknown_label = self.unknown_label.get(label, None)
|
|
243
|
+
else:
|
|
244
|
+
unknown_label = self.unknown_label
|
|
245
|
+
if unknown_label is not None and unknown_label in cats:
|
|
246
|
+
cats.remove(unknown_label)
|
|
247
|
+
encoder[unknown_label] = -1
|
|
248
|
+
cats = list(cats)
|
|
249
|
+
cats.sort()
|
|
250
|
+
encoder.update({cat: i for i, cat in enumerate(cats)})
|
|
251
|
+
self.encoders[label] = encoder
|
|
252
|
+
|
|
141
253
|
def _make_join_vars(self):
|
|
142
254
|
var_list = []
|
|
143
255
|
for storage in self.storages:
|
|
144
256
|
with _Connect(storage) as store:
|
|
145
257
|
var_list.append(_safer_read_index(store["var"]))
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
258
|
+
|
|
259
|
+
self.var_joint = None
|
|
260
|
+
vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
|
|
261
|
+
if vars_eq:
|
|
262
|
+
self.join_vars = None
|
|
263
|
+
self.var_joint = var_list[0]
|
|
264
|
+
return
|
|
153
265
|
if self.join_vars == "inner":
|
|
154
266
|
self.var_joint = reduce(pd.Index.intersection, var_list)
|
|
155
267
|
if len(self.var_joint) == 0:
|
|
156
268
|
raise ValueError(
|
|
157
|
-
"The provided AnnData objects don't have shared varibales
|
|
269
|
+
"The provided AnnData objects don't have shared varibales.\n"
|
|
270
|
+
"Use join='outer'."
|
|
158
271
|
)
|
|
159
272
|
self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
|
|
273
|
+
elif self.join_vars == "outer":
|
|
274
|
+
self.var_joint = reduce(pd.Index.union, var_list)
|
|
275
|
+
self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
|
|
160
276
|
|
|
161
277
|
def _check_aligned_vars(self, vars: list):
|
|
162
278
|
i = 0
|
|
163
279
|
for storage in self.storages:
|
|
164
280
|
with _Connect(storage) as store:
|
|
165
|
-
if
|
|
281
|
+
if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
|
|
166
282
|
i += 1
|
|
167
283
|
print("{}% are aligned".format(i * 100 / len(self.storages)))
|
|
168
284
|
|
|
@@ -173,49 +289,75 @@ class MappedDataset:
|
|
|
173
289
|
obs_idx = self.indices[idx]
|
|
174
290
|
storage_idx = self.storage_idx[idx]
|
|
175
291
|
if self.var_indices is not None:
|
|
176
|
-
|
|
292
|
+
var_idxs_join = self.var_indices[storage_idx]
|
|
177
293
|
else:
|
|
178
|
-
|
|
294
|
+
var_idxs_join = None
|
|
295
|
+
|
|
179
296
|
with _Connect(self.storages[storage_idx]) as store:
|
|
180
|
-
out = {"x": self.
|
|
297
|
+
out = {"x": self._get_data_idx(store, obs_idx, var_idxs_join)}
|
|
298
|
+
out["_storage_idx"] = storage_idx
|
|
181
299
|
if self.label_keys is not None:
|
|
182
|
-
for
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
300
|
+
for label in self.label_keys:
|
|
301
|
+
if label in self._cache_cats:
|
|
302
|
+
cats = self._cache_cats[label][storage_idx]
|
|
303
|
+
if cats is None:
|
|
304
|
+
cats = []
|
|
186
305
|
else:
|
|
187
|
-
|
|
306
|
+
cats = None
|
|
307
|
+
label_idx = self._get_label_idx(store, obs_idx, label, cats)
|
|
308
|
+
if label in self.encoders:
|
|
309
|
+
label_idx = self.encoders[label][label_idx]
|
|
310
|
+
out[label] = label_idx
|
|
188
311
|
return out
|
|
189
312
|
|
|
190
|
-
def
|
|
191
|
-
storage = self.storages[self.storage_idx[idx]]
|
|
192
|
-
return storage["uns"][key]
|
|
193
|
-
|
|
194
|
-
def get_data_idx(
|
|
313
|
+
def _get_data_idx(
|
|
195
314
|
self,
|
|
196
|
-
storage: StorageType,
|
|
315
|
+
storage: StorageType, # type: ignore
|
|
197
316
|
idx: int,
|
|
198
|
-
|
|
199
|
-
layer_key:
|
|
317
|
+
var_idxs_join: list | None = None,
|
|
318
|
+
layer_key: str | None = None,
|
|
200
319
|
):
|
|
201
320
|
"""Get the index for the data."""
|
|
202
|
-
layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore
|
|
321
|
+
layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore
|
|
203
322
|
if isinstance(layer, ArrayTypes): # type: ignore
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
323
|
+
layer_idx = layer[idx]
|
|
324
|
+
if self.join_vars is None:
|
|
325
|
+
result = layer_idx
|
|
326
|
+
if self._dtype is not None:
|
|
327
|
+
result = result.astype(self._dtype, copy=False)
|
|
328
|
+
elif self.join_vars == "outer":
|
|
329
|
+
dtype = layer_idx.dtype if self._dtype is None else self._dtype
|
|
330
|
+
result = np.zeros(len(self.var_joint), dtype=dtype)
|
|
331
|
+
result[var_idxs_join] = layer_idx
|
|
332
|
+
else: # inner join
|
|
333
|
+
result = layer_idx[var_idxs_join]
|
|
334
|
+
if self._dtype is not None:
|
|
335
|
+
result = result.astype(self._dtype, copy=False)
|
|
336
|
+
return result
|
|
207
337
|
else: # assume csr_matrix here
|
|
208
338
|
data = layer["data"]
|
|
209
339
|
indices = layer["indices"]
|
|
210
340
|
indptr = layer["indptr"]
|
|
211
341
|
s = slice(*(indptr[idx : idx + 2]))
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
342
|
+
data_s = data[s]
|
|
343
|
+
dtype = data_s.dtype if self._dtype is None else self._dtype
|
|
344
|
+
if self.join_vars == "outer":
|
|
345
|
+
layer_idx = np.zeros(len(self.var_joint), dtype=dtype)
|
|
346
|
+
layer_idx[var_idxs_join[indices[s]]] = data_s
|
|
347
|
+
else:
|
|
348
|
+
layer_idx = np.zeros(layer.attrs["shape"][1], dtype=dtype)
|
|
349
|
+
layer_idx[indices[s]] = data_s
|
|
350
|
+
if self.join_vars == "inner":
|
|
351
|
+
layer_idx = layer_idx[var_idxs_join]
|
|
352
|
+
return layer_idx
|
|
217
353
|
|
|
218
|
-
def
|
|
354
|
+
def _get_label_idx(
|
|
355
|
+
self,
|
|
356
|
+
storage: StorageType,
|
|
357
|
+
idx: int,
|
|
358
|
+
label_key: str,
|
|
359
|
+
categories: list | None = None,
|
|
360
|
+
):
|
|
219
361
|
"""Get the index for the label by key."""
|
|
220
362
|
obs = storage["obs"] # type: ignore
|
|
221
363
|
# how backwards compatible do we want to be here actually?
|
|
@@ -227,39 +369,48 @@ class MappedDataset:
|
|
|
227
369
|
label = labels[idx]
|
|
228
370
|
else:
|
|
229
371
|
label = labels["codes"][idx]
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
372
|
+
if categories is not None:
|
|
373
|
+
cats = categories
|
|
374
|
+
else:
|
|
375
|
+
cats = self._get_categories(storage, label_key)
|
|
376
|
+
if cats is not None and len(cats) > 0:
|
|
233
377
|
label = cats[label]
|
|
234
378
|
if isinstance(label, bytes):
|
|
235
379
|
label = label.decode("utf-8")
|
|
236
380
|
return label
|
|
237
381
|
|
|
238
|
-
def get_label_weights(self, label_keys:
|
|
239
|
-
"""Get all weights for
|
|
240
|
-
if
|
|
382
|
+
def get_label_weights(self, label_keys: str | list[str], scaler: int = 10):
|
|
383
|
+
"""Get all weights for the given label keys."""
|
|
384
|
+
if isinstance(label_keys, str):
|
|
241
385
|
label_keys = [label_keys]
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
386
|
+
labels_list = []
|
|
387
|
+
for label_key in label_keys:
|
|
388
|
+
labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
|
|
389
|
+
labels_list.append(labels_to_str)
|
|
390
|
+
if len(labels_list) > 1:
|
|
391
|
+
labels = reduce(lambda a, b: a + b, labels_list)
|
|
392
|
+
else:
|
|
393
|
+
labels = labels_list[0]
|
|
394
|
+
|
|
249
395
|
counter = Counter(labels) # type: ignore
|
|
250
|
-
|
|
396
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
397
|
+
labels = np.array([rn[label] for label in labels])
|
|
398
|
+
counter = np.array(list(counter.values()))
|
|
251
399
|
weights = scaler / (counter + scaler)
|
|
252
|
-
return weights
|
|
400
|
+
return weights, labels
|
|
253
401
|
|
|
254
402
|
def get_merged_labels(self, label_key: str):
|
|
255
|
-
"""Get merged labels
|
|
403
|
+
"""Get merged labels for `label_key` from all `.obs`."""
|
|
256
404
|
labels_merge = []
|
|
257
405
|
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
258
|
-
for storage in self.storages:
|
|
406
|
+
for i, storage in enumerate(self.storages):
|
|
259
407
|
with _Connect(storage) as store:
|
|
260
|
-
codes = self.
|
|
408
|
+
codes = self._get_codes(store, label_key)
|
|
261
409
|
labels = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
262
|
-
|
|
410
|
+
if label_key in self._cache_cats:
|
|
411
|
+
cats = self._cache_cats[label_key][i]
|
|
412
|
+
else:
|
|
413
|
+
cats = self._get_categories(store, label_key)
|
|
263
414
|
if cats is not None:
|
|
264
415
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
265
416
|
labels = cats[labels]
|
|
@@ -267,22 +418,25 @@ class MappedDataset:
|
|
|
267
418
|
return np.hstack(labels_merge)
|
|
268
419
|
|
|
269
420
|
def get_merged_categories(self, label_key: str):
|
|
270
|
-
"""Get merged categories
|
|
421
|
+
"""Get merged categories for `label_key` from all `.obs`."""
|
|
271
422
|
cats_merge = set()
|
|
272
423
|
decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
|
273
|
-
for storage in self.storages:
|
|
424
|
+
for i, storage in enumerate(self.storages):
|
|
274
425
|
with _Connect(storage) as store:
|
|
275
|
-
|
|
426
|
+
if label_key in self._cache_cats:
|
|
427
|
+
cats = self._cache_cats[label_key][i]
|
|
428
|
+
else:
|
|
429
|
+
cats = self._get_categories(store, label_key)
|
|
276
430
|
if cats is not None:
|
|
277
431
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
278
432
|
cats_merge.update(cats)
|
|
279
433
|
else:
|
|
280
|
-
codes = self.
|
|
434
|
+
codes = self._get_codes(store, label_key)
|
|
281
435
|
codes = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
282
436
|
cats_merge.update(codes)
|
|
283
437
|
return cats_merge
|
|
284
438
|
|
|
285
|
-
def
|
|
439
|
+
def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
|
|
286
440
|
"""Get categories."""
|
|
287
441
|
obs = storage["obs"] # type: ignore
|
|
288
442
|
if isinstance(obs, ArrayTypes): # type: ignore
|
|
@@ -309,8 +463,9 @@ class MappedDataset:
|
|
|
309
463
|
return labels.attrs["categories"]
|
|
310
464
|
else:
|
|
311
465
|
return None
|
|
466
|
+
return None
|
|
312
467
|
|
|
313
|
-
def
|
|
468
|
+
def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
|
|
314
469
|
"""Get codes."""
|
|
315
470
|
obs = storage["obs"] # type: ignore
|
|
316
471
|
if isinstance(obs, ArrayTypes): # type: ignore
|
|
@@ -323,7 +478,10 @@ class MappedDataset:
|
|
|
323
478
|
return label["codes"][...]
|
|
324
479
|
|
|
325
480
|
def close(self):
|
|
326
|
-
"""Close
|
|
481
|
+
"""Close connections to array streaming backend.
|
|
482
|
+
|
|
483
|
+
No effect if `parallel=True`.
|
|
484
|
+
"""
|
|
327
485
|
for storage in self.storages:
|
|
328
486
|
if hasattr(storage, "close"):
|
|
329
487
|
storage.close()
|
|
@@ -334,6 +492,10 @@ class MappedDataset:
|
|
|
334
492
|
|
|
335
493
|
@property
|
|
336
494
|
def closed(self):
|
|
495
|
+
"""Check if connections to array streaming backend are closed.
|
|
496
|
+
|
|
497
|
+
Does not matter if `parallel=True`.
|
|
498
|
+
"""
|
|
337
499
|
return self._closed
|
|
338
500
|
|
|
339
501
|
def __enter__(self):
|
|
@@ -341,3 +503,17 @@ class MappedDataset:
|
|
|
341
503
|
|
|
342
504
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
343
505
|
self.close()
|
|
506
|
+
|
|
507
|
+
@staticmethod
|
|
508
|
+
def torch_worker_init_fn(worker_id):
|
|
509
|
+
"""`worker_init_fn` for `torch.utils.data.DataLoader`.
|
|
510
|
+
|
|
511
|
+
Improves performance for `num_workers > 1`.
|
|
512
|
+
"""
|
|
513
|
+
from torch.utils.data import get_worker_info
|
|
514
|
+
|
|
515
|
+
mapped = get_worker_info().dataset
|
|
516
|
+
mapped.parallel = False
|
|
517
|
+
mapped.storages = []
|
|
518
|
+
mapped.conns = []
|
|
519
|
+
mapped._make_connections(mapped._path_list, parallel=False)
|