lamindb 0.76.8__py3-none-any.whl → 0.76.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lamindb/__init__.py +113 -113
- lamindb/_artifact.py +1205 -1205
- lamindb/_can_validate.py +579 -579
- lamindb/_collection.py +389 -387
- lamindb/_curate.py +1601 -1601
- lamindb/_feature.py +155 -155
- lamindb/_feature_set.py +242 -242
- lamindb/_filter.py +23 -23
- lamindb/_finish.py +256 -256
- lamindb/_from_values.py +382 -382
- lamindb/_is_versioned.py +40 -40
- lamindb/_parents.py +476 -476
- lamindb/_query_manager.py +125 -125
- lamindb/_query_set.py +362 -362
- lamindb/_record.py +649 -649
- lamindb/_run.py +57 -57
- lamindb/_save.py +308 -308
- lamindb/_storage.py +14 -14
- lamindb/_transform.py +127 -127
- lamindb/_ulabel.py +56 -56
- lamindb/_utils.py +9 -9
- lamindb/_view.py +72 -72
- lamindb/core/__init__.py +94 -94
- lamindb/core/_context.py +574 -574
- lamindb/core/_data.py +438 -438
- lamindb/core/_feature_manager.py +867 -867
- lamindb/core/_label_manager.py +253 -253
- lamindb/core/_mapped_collection.py +631 -597
- lamindb/core/_settings.py +187 -187
- lamindb/core/_sync_git.py +138 -138
- lamindb/core/_track_environment.py +27 -27
- lamindb/core/datasets/__init__.py +59 -59
- lamindb/core/datasets/_core.py +581 -571
- lamindb/core/datasets/_fake.py +36 -36
- lamindb/core/exceptions.py +90 -90
- lamindb/core/fields.py +12 -12
- lamindb/core/loaders.py +164 -164
- lamindb/core/schema.py +56 -56
- lamindb/core/storage/__init__.py +25 -25
- lamindb/core/storage/_anndata_accessor.py +740 -740
- lamindb/core/storage/_anndata_sizes.py +41 -41
- lamindb/core/storage/_backed_access.py +98 -98
- lamindb/core/storage/_tiledbsoma.py +204 -204
- lamindb/core/storage/_valid_suffixes.py +21 -21
- lamindb/core/storage/_zarr.py +110 -110
- lamindb/core/storage/objects.py +62 -62
- lamindb/core/storage/paths.py +172 -172
- lamindb/core/subsettings/__init__.py +12 -12
- lamindb/core/subsettings/_creation_settings.py +38 -38
- lamindb/core/subsettings/_transform_settings.py +21 -21
- lamindb/core/types.py +19 -19
- lamindb/core/versioning.py +158 -158
- lamindb/integrations/__init__.py +12 -12
- lamindb/integrations/_vitessce.py +107 -107
- lamindb/setup/__init__.py +14 -14
- lamindb/setup/core/__init__.py +4 -4
- {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/LICENSE +201 -201
- {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/METADATA +4 -4
- lamindb-0.76.9.dist-info/RECORD +60 -0
- {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/WHEEL +1 -1
- lamindb-0.76.8.dist-info/RECORD +0 -60
@@ -1,597 +1,631 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from collections import Counter
|
4
|
-
from functools import reduce
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
import pandas as pd
|
10
|
-
from lamin_utils import logger
|
11
|
-
from lamindb_setup.core.upath import UPath
|
12
|
-
|
13
|
-
from .storage._anndata_accessor import (
|
14
|
-
ArrayType,
|
15
|
-
ArrayTypes,
|
16
|
-
GroupType,
|
17
|
-
GroupTypes,
|
18
|
-
StorageType,
|
19
|
-
_safer_read_index,
|
20
|
-
get_spec,
|
21
|
-
registry,
|
22
|
-
)
|
23
|
-
|
24
|
-
if TYPE_CHECKING:
|
25
|
-
from lamindb_setup.core.types import UPathStr
|
26
|
-
|
27
|
-
|
28
|
-
class _Connect:
|
29
|
-
def __init__(self, storage):
|
30
|
-
if isinstance(storage, UPath):
|
31
|
-
self.conn, self.store = registry.open("h5py", storage)
|
32
|
-
self.to_close = True
|
33
|
-
else:
|
34
|
-
self.conn, self.store = None, storage
|
35
|
-
self.to_close = False
|
36
|
-
|
37
|
-
def __enter__(self):
|
38
|
-
return self.store
|
39
|
-
|
40
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
41
|
-
self.close()
|
42
|
-
|
43
|
-
def close(self):
|
44
|
-
if not self.to_close:
|
45
|
-
return
|
46
|
-
if hasattr(self.store, "close"):
|
47
|
-
self.store.close()
|
48
|
-
if hasattr(self.conn, "close"):
|
49
|
-
self.conn.close()
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
`
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
see :
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
)
|
122
|
-
|
123
|
-
|
124
|
-
self.
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
self.
|
142
|
-
|
143
|
-
if
|
144
|
-
if
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
self.
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
self.
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
if
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
self.
|
246
|
-
|
247
|
-
def
|
248
|
-
self.
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
self.
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
def
|
319
|
-
"""
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
if
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
else
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
):
|
442
|
-
"""Get
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
if
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
if isinstance(
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
self.
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections import Counter
|
4
|
+
from functools import reduce
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from lamin_utils import logger
|
11
|
+
from lamindb_setup.core.upath import UPath
|
12
|
+
|
13
|
+
from .storage._anndata_accessor import (
|
14
|
+
ArrayType,
|
15
|
+
ArrayTypes,
|
16
|
+
GroupType,
|
17
|
+
GroupTypes,
|
18
|
+
StorageType,
|
19
|
+
_safer_read_index,
|
20
|
+
get_spec,
|
21
|
+
registry,
|
22
|
+
)
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from lamindb_setup.core.types import UPathStr
|
26
|
+
|
27
|
+
|
28
|
+
class _Connect:
|
29
|
+
def __init__(self, storage):
|
30
|
+
if isinstance(storage, UPath):
|
31
|
+
self.conn, self.store = registry.open("h5py", storage)
|
32
|
+
self.to_close = True
|
33
|
+
else:
|
34
|
+
self.conn, self.store = None, storage
|
35
|
+
self.to_close = False
|
36
|
+
|
37
|
+
def __enter__(self):
|
38
|
+
return self.store
|
39
|
+
|
40
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
41
|
+
self.close()
|
42
|
+
|
43
|
+
def close(self):
|
44
|
+
if not self.to_close:
|
45
|
+
return
|
46
|
+
if hasattr(self.store, "close"):
|
47
|
+
self.store.close()
|
48
|
+
if hasattr(self.conn, "close"):
|
49
|
+
self.conn.close()
|
50
|
+
|
51
|
+
|
52
|
+
_decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
|
53
|
+
|
54
|
+
|
55
|
+
class MappedCollection:
|
56
|
+
"""Map-style collection for use in data loaders.
|
57
|
+
|
58
|
+
This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
|
59
|
+
<https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
|
60
|
+
|
61
|
+
If your `AnnData` collection is in the cloud, move them into a local cache
|
62
|
+
first for faster access.
|
63
|
+
|
64
|
+
`__getitem__` of the `MappedCollection` object takes a single integer index
|
65
|
+
and returns a dictionary with the observation data sample for this index from
|
66
|
+
the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
|
67
|
+
(`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
|
68
|
+
for the index of the `AnnData` object containing this observation sample.
|
69
|
+
|
70
|
+
.. note::
|
71
|
+
|
72
|
+
For a guide, see :doc:`docs:scrna5`.
|
73
|
+
|
74
|
+
For more convenient use within :class:`~lamindb.core.MappedCollection`,
|
75
|
+
see :meth:`~lamindb.Collection.mapped`.
|
76
|
+
|
77
|
+
This currently only works for collections of `AnnData` objects.
|
78
|
+
|
79
|
+
The implementation was influenced by the `SCimilarity
|
80
|
+
<https://github.com/Genentech/scimilarity>`__ data loader.
|
81
|
+
|
82
|
+
|
83
|
+
Args:
|
84
|
+
path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
|
85
|
+
layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
|
86
|
+
retrieves ``.X``.
|
87
|
+
obsm_keys: Keys from the ``.obsm`` slots.
|
88
|
+
obs_keys: Keys from the ``.obs`` slots.
|
89
|
+
obs_filter: Select only observations with these values for the given obs column.
|
90
|
+
Should be a tuple with an obs column name as the first element
|
91
|
+
and filtering values (a string or a tuple of strings) as the second element.
|
92
|
+
join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
|
93
|
+
does not join.
|
94
|
+
encode_labels: Encode labels into integers.
|
95
|
+
Can be a list with elements from ``obs_keys``.
|
96
|
+
unknown_label: Encode this label to -1.
|
97
|
+
Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
|
98
|
+
or from ``encode_labels`` if it is a list.
|
99
|
+
cache_categories: Enable caching categories of ``obs_keys`` for faster access.
|
100
|
+
parallel: Enable sampling with multiple processes.
|
101
|
+
dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
path_list: list[UPathStr],
|
107
|
+
layers_keys: str | list[str] | None = None,
|
108
|
+
obs_keys: str | list[str] | None = None,
|
109
|
+
obsm_keys: str | list[str] | None = None,
|
110
|
+
obs_filter: tuple[str, str | tuple[str, ...]] | None = None,
|
111
|
+
join: Literal["inner", "outer"] | None = "inner",
|
112
|
+
encode_labels: bool | list[str] = True,
|
113
|
+
unknown_label: str | dict[str, str] | None = None,
|
114
|
+
cache_categories: bool = True,
|
115
|
+
parallel: bool = False,
|
116
|
+
dtype: str | None = None,
|
117
|
+
):
|
118
|
+
if join not in {None, "inner", "outer"}: # pragma: nocover
|
119
|
+
raise ValueError(
|
120
|
+
f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
|
121
|
+
)
|
122
|
+
|
123
|
+
self.filtered = obs_filter is not None
|
124
|
+
if self.filtered and len(obs_filter) != 2:
|
125
|
+
raise ValueError(
|
126
|
+
"obs_filter should be a tuple with obs column name "
|
127
|
+
"as the first element and filtering values as the second element"
|
128
|
+
)
|
129
|
+
|
130
|
+
if layers_keys is None:
|
131
|
+
self.layers_keys = ["X"]
|
132
|
+
else:
|
133
|
+
self.layers_keys = (
|
134
|
+
[layers_keys] if isinstance(layers_keys, str) else layers_keys
|
135
|
+
)
|
136
|
+
|
137
|
+
obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
|
138
|
+
self.obsm_keys = obsm_keys
|
139
|
+
|
140
|
+
obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
|
141
|
+
self.obs_keys = obs_keys
|
142
|
+
|
143
|
+
if isinstance(encode_labels, list):
|
144
|
+
if len(encode_labels) == 0:
|
145
|
+
encode_labels = False
|
146
|
+
elif obs_keys is None or not all(
|
147
|
+
enc_label in obs_keys for enc_label in encode_labels
|
148
|
+
):
|
149
|
+
raise ValueError(
|
150
|
+
"All elements of `encode_labels` should be in `obs_keys`."
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
if encode_labels:
|
154
|
+
encode_labels = obs_keys if obs_keys is not None else False
|
155
|
+
self.encode_labels = encode_labels
|
156
|
+
|
157
|
+
if encode_labels and isinstance(unknown_label, dict):
|
158
|
+
if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
|
159
|
+
raise ValueError(
|
160
|
+
"All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
|
161
|
+
)
|
162
|
+
self.unknown_label = unknown_label
|
163
|
+
|
164
|
+
self.storages = [] # type: ignore
|
165
|
+
self.conns = [] # type: ignore
|
166
|
+
self.parallel = parallel
|
167
|
+
self.path_list = path_list
|
168
|
+
self._make_connections(path_list, parallel)
|
169
|
+
|
170
|
+
self._cache_cats: dict = {}
|
171
|
+
if self.obs_keys is not None:
|
172
|
+
if cache_categories:
|
173
|
+
self._cache_categories(self.obs_keys)
|
174
|
+
self.encoders: dict = {}
|
175
|
+
if self.encode_labels:
|
176
|
+
self._make_encoders(self.encode_labels) # type: ignore
|
177
|
+
|
178
|
+
self.n_obs_list = []
|
179
|
+
self.indices_list = []
|
180
|
+
for i, storage in enumerate(self.storages):
|
181
|
+
with _Connect(storage) as store:
|
182
|
+
X = store["X"]
|
183
|
+
store_path = self.path_list[i]
|
184
|
+
self._check_csc_raise_error(X, "X", store_path)
|
185
|
+
if self.filtered:
|
186
|
+
obs_filter_key, obs_filter_values = obs_filter
|
187
|
+
indices_storage = np.where(
|
188
|
+
np.isin(
|
189
|
+
self._get_labels(store, obs_filter_key), obs_filter_values
|
190
|
+
)
|
191
|
+
)[0]
|
192
|
+
n_obs_storage = len(indices_storage)
|
193
|
+
else:
|
194
|
+
if isinstance(X, ArrayTypes): # type: ignore
|
195
|
+
n_obs_storage = X.shape[0]
|
196
|
+
else:
|
197
|
+
n_obs_storage = X.attrs["shape"][0]
|
198
|
+
indices_storage = np.arange(n_obs_storage)
|
199
|
+
self.n_obs_list.append(n_obs_storage)
|
200
|
+
self.indices_list.append(indices_storage)
|
201
|
+
for layer_key in self.layers_keys:
|
202
|
+
if layer_key == "X":
|
203
|
+
continue
|
204
|
+
self._check_csc_raise_error(
|
205
|
+
store["layers"][layer_key],
|
206
|
+
f"layers/{layer_key}",
|
207
|
+
store_path,
|
208
|
+
)
|
209
|
+
if self.obsm_keys is not None:
|
210
|
+
for obsm_key in self.obsm_keys:
|
211
|
+
self._check_csc_raise_error(
|
212
|
+
store["obsm"][obsm_key],
|
213
|
+
f"obsm/{obsm_key}",
|
214
|
+
store_path,
|
215
|
+
)
|
216
|
+
self.n_obs = sum(self.n_obs_list)
|
217
|
+
|
218
|
+
self.indices = np.hstack(self.indices_list)
|
219
|
+
self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
|
220
|
+
|
221
|
+
self.join_vars: Literal["inner", "outer"] | None = join
|
222
|
+
self.var_indices: list | None = None
|
223
|
+
self.var_joint: pd.Index | None = None
|
224
|
+
self.n_vars_list: list | None = None
|
225
|
+
self.var_list: list | None = None
|
226
|
+
self.n_vars: int | None = None
|
227
|
+
if self.join_vars is not None:
|
228
|
+
self._make_join_vars()
|
229
|
+
self.n_vars = len(self.var_joint)
|
230
|
+
|
231
|
+
self._dtype = dtype
|
232
|
+
self._closed = False
|
233
|
+
|
234
|
+
def _make_connections(self, path_list: list, parallel: bool):
|
235
|
+
for path in path_list:
|
236
|
+
path = UPath(path)
|
237
|
+
if path.exists() and path.is_file(): # type: ignore
|
238
|
+
if parallel:
|
239
|
+
conn, storage = None, path
|
240
|
+
else:
|
241
|
+
conn, storage = registry.open("h5py", path)
|
242
|
+
else:
|
243
|
+
conn, storage = registry.open("zarr", path)
|
244
|
+
self.conns.append(conn)
|
245
|
+
self.storages.append(storage)
|
246
|
+
|
247
|
+
def _cache_categories(self, obs_keys: list):
|
248
|
+
self._cache_cats = {}
|
249
|
+
for label in obs_keys:
|
250
|
+
self._cache_cats[label] = []
|
251
|
+
for storage in self.storages:
|
252
|
+
with _Connect(storage) as store:
|
253
|
+
cats = self._get_categories(store, label)
|
254
|
+
if cats is not None:
|
255
|
+
cats = (
|
256
|
+
_decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
257
|
+
)
|
258
|
+
self._cache_cats[label].append(cats)
|
259
|
+
|
260
|
+
def _make_encoders(self, encode_labels: list):
|
261
|
+
for label in encode_labels:
|
262
|
+
cats = self.get_merged_categories(label)
|
263
|
+
encoder = {}
|
264
|
+
if isinstance(self.unknown_label, dict):
|
265
|
+
unknown_label = self.unknown_label.get(label, None)
|
266
|
+
else:
|
267
|
+
unknown_label = self.unknown_label
|
268
|
+
if unknown_label is not None and unknown_label in cats:
|
269
|
+
cats.remove(unknown_label)
|
270
|
+
encoder[unknown_label] = -1
|
271
|
+
encoder.update({cat: i for i, cat in enumerate(cats)})
|
272
|
+
self.encoders[label] = encoder
|
273
|
+
|
274
|
+
def _read_vars(self):
|
275
|
+
self.var_list = []
|
276
|
+
self.n_vars_list = []
|
277
|
+
for storage in self.storages:
|
278
|
+
with _Connect(storage) as store:
|
279
|
+
vars = _safer_read_index(store["var"])
|
280
|
+
self.var_list.append(vars)
|
281
|
+
self.n_vars_list.append(len(vars))
|
282
|
+
|
283
|
+
def _make_join_vars(self):
|
284
|
+
if self.var_list is None:
|
285
|
+
self._read_vars()
|
286
|
+
vars_eq = all(self.var_list[0].equals(vrs) for vrs in self.var_list[1:])
|
287
|
+
if vars_eq:
|
288
|
+
self.join_vars = None
|
289
|
+
self.var_joint = self.var_list[0]
|
290
|
+
return
|
291
|
+
|
292
|
+
if self.join_vars == "inner":
|
293
|
+
self.var_joint = reduce(pd.Index.intersection, self.var_list)
|
294
|
+
if len(self.var_joint) == 0:
|
295
|
+
raise ValueError(
|
296
|
+
"The provided AnnData objects don't have shared varibales.\n"
|
297
|
+
"Use join='outer'."
|
298
|
+
)
|
299
|
+
self.var_indices = [
|
300
|
+
vrs.get_indexer(self.var_joint) for vrs in self.var_list
|
301
|
+
]
|
302
|
+
elif self.join_vars == "outer":
|
303
|
+
self.var_joint = reduce(pd.Index.union, self.var_list)
|
304
|
+
self.var_indices = [
|
305
|
+
self.var_joint.get_indexer(vrs) for vrs in self.var_list
|
306
|
+
]
|
307
|
+
|
308
|
+
def check_vars_sorted(self, ascending: bool = True) -> bool:
|
309
|
+
"""Returns `True` if all variables are sorted in all objects."""
|
310
|
+
if self.var_list is None:
|
311
|
+
self._read_vars()
|
312
|
+
if ascending:
|
313
|
+
vrs_sort_status = (vrs.is_monotonic_increasing for vrs in self.var_list)
|
314
|
+
else:
|
315
|
+
vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
|
316
|
+
return all(vrs_sort_status)
|
317
|
+
|
318
|
+
def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
|
319
|
+
"""Returns indices of objects with non-aligned variables.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
vars: Check alignment against these variables.
|
323
|
+
"""
|
324
|
+
if self.var_list is None:
|
325
|
+
self._read_vars()
|
326
|
+
vars = pd.Index(vars)
|
327
|
+
return [i for i, vrs in enumerate(self.var_list) if not vrs.equals(vars)]
|
328
|
+
|
329
|
+
def _check_csc_raise_error(
|
330
|
+
self, elem: GroupType | ArrayType, key: str, path: UPathStr
|
331
|
+
):
|
332
|
+
if isinstance(elem, ArrayTypes): # type: ignore
|
333
|
+
return
|
334
|
+
if get_spec(elem).encoding_type == "csc_matrix":
|
335
|
+
if not self.parallel:
|
336
|
+
self.close()
|
337
|
+
raise ValueError(
|
338
|
+
f"{key} in {path} is a csc matrix, `MappedCollection` doesn't support this format yet."
|
339
|
+
)
|
340
|
+
|
341
|
+
def __len__(self):
|
342
|
+
return self.n_obs
|
343
|
+
|
344
|
+
@property
|
345
|
+
def shape(self) -> tuple[int, int]:
|
346
|
+
"""Shape of the (virtually aligned) dataset."""
|
347
|
+
return (self.n_obs, self.n_vars)
|
348
|
+
|
349
|
+
@property
|
350
|
+
def original_shapes(self) -> list[tuple[int, int]]:
|
351
|
+
"""Shapes of the underlying AnnData objects."""
|
352
|
+
if self.n_vars_list is None:
|
353
|
+
n_vars_list = [None] * len(self.n_obs_list)
|
354
|
+
else:
|
355
|
+
n_vars_list = self.n_vars_list
|
356
|
+
return list(zip(self.n_obs_list, n_vars_list))
|
357
|
+
|
358
|
+
def __getitem__(self, idx: int):
|
359
|
+
obs_idx = self.indices[idx]
|
360
|
+
storage_idx = self.storage_idx[idx]
|
361
|
+
if self.var_indices is not None:
|
362
|
+
var_idxs_join = self.var_indices[storage_idx]
|
363
|
+
else:
|
364
|
+
var_idxs_join = None
|
365
|
+
|
366
|
+
with _Connect(self.storages[storage_idx]) as store:
|
367
|
+
out = {}
|
368
|
+
for layers_key in self.layers_keys:
|
369
|
+
lazy_data = (
|
370
|
+
store["X"] if layers_key == "X" else store["layers"][layers_key]
|
371
|
+
)
|
372
|
+
out[layers_key] = self._get_data_idx(
|
373
|
+
lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
|
374
|
+
)
|
375
|
+
if self.obsm_keys is not None:
|
376
|
+
for obsm_key in self.obsm_keys:
|
377
|
+
lazy_data = store["obsm"][obsm_key]
|
378
|
+
out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
|
379
|
+
out["_store_idx"] = storage_idx
|
380
|
+
if self.obs_keys is not None:
|
381
|
+
for label in self.obs_keys:
|
382
|
+
if label in self._cache_cats:
|
383
|
+
cats = self._cache_cats[label][storage_idx]
|
384
|
+
if cats is None:
|
385
|
+
cats = []
|
386
|
+
else:
|
387
|
+
cats = None
|
388
|
+
label_idx = self._get_obs_idx(store, obs_idx, label, cats)
|
389
|
+
if label in self.encoders:
|
390
|
+
label_idx = self.encoders[label][label_idx]
|
391
|
+
out[label] = label_idx
|
392
|
+
return out
|
393
|
+
|
394
|
+
def _get_data_idx(
|
395
|
+
self,
|
396
|
+
lazy_data: ArrayType | GroupType,
|
397
|
+
idx: int,
|
398
|
+
join_vars: Literal["inner", "outer"] | None = None,
|
399
|
+
var_idxs_join: list | None = None,
|
400
|
+
n_vars_out: int | None = None,
|
401
|
+
):
|
402
|
+
"""Get the index for the data."""
|
403
|
+
if isinstance(lazy_data, ArrayTypes): # type: ignore
|
404
|
+
lazy_data_idx = lazy_data[idx] # type: ignore
|
405
|
+
if join_vars is None:
|
406
|
+
result = lazy_data_idx
|
407
|
+
if self._dtype is not None:
|
408
|
+
result = result.astype(self._dtype, copy=False)
|
409
|
+
elif join_vars == "outer":
|
410
|
+
dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
|
411
|
+
result = np.zeros(n_vars_out, dtype=dtype)
|
412
|
+
result[var_idxs_join] = lazy_data_idx
|
413
|
+
else: # inner join
|
414
|
+
result = lazy_data_idx[var_idxs_join]
|
415
|
+
if self._dtype is not None:
|
416
|
+
result = result.astype(self._dtype, copy=False)
|
417
|
+
return result
|
418
|
+
else: # assume csr_matrix here
|
419
|
+
data = lazy_data["data"] # type: ignore
|
420
|
+
indices = lazy_data["indices"] # type: ignore
|
421
|
+
indptr = lazy_data["indptr"] # type: ignore
|
422
|
+
s = slice(*(indptr[idx : idx + 2]))
|
423
|
+
data_s = data[s]
|
424
|
+
dtype = data_s.dtype if self._dtype is None else self._dtype
|
425
|
+
if join_vars == "outer":
|
426
|
+
lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
|
427
|
+
lazy_data_idx[var_idxs_join[indices[s]]] = data_s
|
428
|
+
else:
|
429
|
+
lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
|
430
|
+
lazy_data_idx[indices[s]] = data_s
|
431
|
+
if join_vars == "inner":
|
432
|
+
lazy_data_idx = lazy_data_idx[var_idxs_join]
|
433
|
+
return lazy_data_idx
|
434
|
+
|
435
|
+
def _get_obs_idx(
|
436
|
+
self,
|
437
|
+
storage: StorageType,
|
438
|
+
idx: int,
|
439
|
+
label_key: str,
|
440
|
+
categories: list | None = None,
|
441
|
+
):
|
442
|
+
"""Get the index for the label by key."""
|
443
|
+
obs = storage["obs"] # type: ignore
|
444
|
+
# how backwards compatible do we want to be here actually?
|
445
|
+
if isinstance(obs, ArrayTypes): # type: ignore
|
446
|
+
label = obs[idx][obs.dtype.names.index(label_key)]
|
447
|
+
else:
|
448
|
+
labels = obs[label_key]
|
449
|
+
if isinstance(labels, ArrayTypes): # type: ignore
|
450
|
+
label = labels[idx]
|
451
|
+
else:
|
452
|
+
label = labels["codes"][idx]
|
453
|
+
if categories is not None:
|
454
|
+
cats = categories
|
455
|
+
else:
|
456
|
+
cats = self._get_categories(storage, label_key)
|
457
|
+
if cats is not None and len(cats) > 0:
|
458
|
+
label = cats[label]
|
459
|
+
if isinstance(label, bytes):
|
460
|
+
label = label.decode("utf-8")
|
461
|
+
return label
|
462
|
+
|
463
|
+
def get_label_weights(
|
464
|
+
self,
|
465
|
+
obs_keys: str | list[str],
|
466
|
+
scaler: float | None = None,
|
467
|
+
return_categories: bool = False,
|
468
|
+
):
|
469
|
+
"""Get all weights for the given label keys.
|
470
|
+
|
471
|
+
This counts the number of labels for each label and returns
|
472
|
+
weights for each obs label accoding to the formula `1 / num of this label in the data`.
|
473
|
+
If `scaler` is provided, then `scaler / (scaler + num of this label in the data)`.
|
474
|
+
|
475
|
+
Args:
|
476
|
+
obs_keys: A key in the ``.obs`` slots or a list of keys. If a list is provided,
|
477
|
+
the labels from the obs keys will be concatenated with ``"__"`` delimeter
|
478
|
+
scaler: Use this number to scale the provided weights.
|
479
|
+
return_categories: If `False`, returns weights for each observation,
|
480
|
+
can be directly passed to a sampler. If `True`, returns a dictionary with
|
481
|
+
unique categories for labels (concatenated if `obs_keys` is a list)
|
482
|
+
and their weights.
|
483
|
+
"""
|
484
|
+
if isinstance(obs_keys, str):
|
485
|
+
obs_keys = [obs_keys]
|
486
|
+
labels_list = []
|
487
|
+
for label_key in obs_keys:
|
488
|
+
labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
|
489
|
+
labels_list.append(labels_to_str)
|
490
|
+
if len(labels_list) > 1:
|
491
|
+
labels = ["__".join(labels_obs) for labels_obs in zip(*labels_list)]
|
492
|
+
else:
|
493
|
+
labels = labels_list[0]
|
494
|
+
counter = Counter(labels)
|
495
|
+
if return_categories:
|
496
|
+
return {
|
497
|
+
k: 1.0 / v if scaler is None else scaler / (v + scaler)
|
498
|
+
for k, v in counter.items()
|
499
|
+
}
|
500
|
+
counts = np.array([counter[label] for label in labels])
|
501
|
+
if scaler is None:
|
502
|
+
weights = 1.0 / counts
|
503
|
+
else:
|
504
|
+
weights = scaler / (counts + scaler)
|
505
|
+
return weights
|
506
|
+
|
507
|
+
def get_merged_labels(self, label_key: str):
|
508
|
+
"""Get merged labels for `label_key` from all `.obs`."""
|
509
|
+
labels_merge = []
|
510
|
+
for i, storage in enumerate(self.storages):
|
511
|
+
with _Connect(storage) as store:
|
512
|
+
labels = self._get_labels(store, label_key, storage_idx=i)
|
513
|
+
if self.filtered:
|
514
|
+
labels = labels[self.indices_list[i]]
|
515
|
+
labels_merge.append(labels)
|
516
|
+
return np.hstack(labels_merge)
|
517
|
+
|
518
|
+
def get_merged_categories(self, label_key: str):
|
519
|
+
"""Get merged categories for `label_key` from all `.obs`."""
|
520
|
+
cats_merge = set()
|
521
|
+
for i, storage in enumerate(self.storages):
|
522
|
+
with _Connect(storage) as store:
|
523
|
+
if label_key in self._cache_cats:
|
524
|
+
cats = self._cache_cats[label_key][i]
|
525
|
+
else:
|
526
|
+
cats = self._get_categories(store, label_key)
|
527
|
+
if cats is not None:
|
528
|
+
cats = _decode(cats) if isinstance(cats[0], bytes) else cats
|
529
|
+
cats_merge.update(cats)
|
530
|
+
else:
|
531
|
+
codes = self._get_codes(store, label_key)
|
532
|
+
codes = _decode(codes) if isinstance(codes[0], bytes) else codes
|
533
|
+
cats_merge.update(codes)
|
534
|
+
return sorted(cats_merge)
|
535
|
+
|
536
|
+
def _get_categories(self, storage: StorageType, label_key: str):
|
537
|
+
"""Get categories."""
|
538
|
+
obs = storage["obs"] # type: ignore
|
539
|
+
if isinstance(obs, ArrayTypes): # type: ignore
|
540
|
+
cat_key_uns = f"{label_key}_categories"
|
541
|
+
if cat_key_uns in storage["uns"]: # type: ignore
|
542
|
+
return storage["uns"][cat_key_uns] # type: ignore
|
543
|
+
else:
|
544
|
+
return None
|
545
|
+
else:
|
546
|
+
if "__categories" in obs:
|
547
|
+
cats = obs["__categories"]
|
548
|
+
if label_key in cats:
|
549
|
+
return cats[label_key]
|
550
|
+
else:
|
551
|
+
return None
|
552
|
+
labels = obs[label_key]
|
553
|
+
if isinstance(labels, GroupTypes): # type: ignore
|
554
|
+
if "categories" in labels:
|
555
|
+
return labels["categories"]
|
556
|
+
else:
|
557
|
+
return None
|
558
|
+
else:
|
559
|
+
if "categories" in labels.attrs:
|
560
|
+
return labels.attrs["categories"]
|
561
|
+
else:
|
562
|
+
return None
|
563
|
+
return None
|
564
|
+
|
565
|
+
def _get_codes(self, storage: StorageType, label_key: str):
|
566
|
+
"""Get codes."""
|
567
|
+
obs = storage["obs"] # type: ignore
|
568
|
+
if isinstance(obs, ArrayTypes): # type: ignore
|
569
|
+
label = obs[label_key]
|
570
|
+
else:
|
571
|
+
label = obs[label_key]
|
572
|
+
if isinstance(label, ArrayTypes): # type: ignore
|
573
|
+
return label[...]
|
574
|
+
else:
|
575
|
+
return label["codes"][...]
|
576
|
+
|
577
|
+
def _get_labels(
|
578
|
+
self, storage: StorageType, label_key: str, storage_idx: int | None = None
|
579
|
+
):
|
580
|
+
"""Get labels."""
|
581
|
+
codes = self._get_codes(storage, label_key)
|
582
|
+
labels = _decode(codes) if isinstance(codes[0], bytes) else codes
|
583
|
+
if storage_idx is not None and label_key in self._cache_cats:
|
584
|
+
cats = self._cache_cats[label_key][storage_idx]
|
585
|
+
else:
|
586
|
+
cats = self._get_categories(storage, label_key)
|
587
|
+
if cats is not None:
|
588
|
+
cats = _decode(cats) if isinstance(cats[0], bytes) else cats
|
589
|
+
labels = cats[labels]
|
590
|
+
return labels
|
591
|
+
|
592
|
+
def close(self):
|
593
|
+
"""Close connections to array streaming backend.
|
594
|
+
|
595
|
+
No effect if `parallel=True`.
|
596
|
+
"""
|
597
|
+
for storage in self.storages:
|
598
|
+
if hasattr(storage, "close"):
|
599
|
+
storage.close()
|
600
|
+
for conn in self.conns:
|
601
|
+
if hasattr(conn, "close"):
|
602
|
+
conn.close()
|
603
|
+
self._closed = True
|
604
|
+
|
605
|
+
@property
|
606
|
+
def closed(self) -> bool:
|
607
|
+
"""Check if connections to array streaming backend are closed.
|
608
|
+
|
609
|
+
Does not matter if `parallel=True`.
|
610
|
+
"""
|
611
|
+
return self._closed
|
612
|
+
|
613
|
+
def __enter__(self):
|
614
|
+
return self
|
615
|
+
|
616
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
617
|
+
self.close()
|
618
|
+
|
619
|
+
@staticmethod
|
620
|
+
def torch_worker_init_fn(worker_id):
|
621
|
+
"""`worker_init_fn` for `torch.utils.data.DataLoader`.
|
622
|
+
|
623
|
+
Improves performance for `num_workers > 1`.
|
624
|
+
"""
|
625
|
+
from torch.utils.data import get_worker_info
|
626
|
+
|
627
|
+
mapped = get_worker_info().dataset
|
628
|
+
mapped.parallel = False
|
629
|
+
mapped.storages = []
|
630
|
+
mapped.conns = []
|
631
|
+
mapped._make_connections(mapped.path_list, parallel=False)
|