scdataloader 1.9.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__init__.py +2 -1
- scdataloader/collator.py +30 -42
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +116 -43
- scdataloader/datamodule.py +555 -225
- scdataloader/mapped.py +84 -18
- scdataloader/preprocess.py +108 -94
- scdataloader/utils.py +39 -33
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/METADATA +13 -5
- scdataloader-2.0.0.dist-info/RECORD +16 -0
- scdataloader-2.0.0.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.1.dist-info/RECORD +0 -16
- scdataloader-1.9.1.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/WHEEL +0 -0
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -16,7 +16,6 @@ from torch.utils.data import Dataset as torchDataset
|
|
|
16
16
|
|
|
17
17
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
18
18
|
|
|
19
|
-
from .config import LABELS_TOADD
|
|
20
19
|
from .mapped import MappedCollection, _Connect
|
|
21
20
|
|
|
22
21
|
|
|
@@ -39,19 +38,18 @@ class Dataset(torchDataset):
|
|
|
39
38
|
----
|
|
40
39
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
41
40
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
42
|
-
organisms (list[str]): list of organisms to load
|
|
43
|
-
(for now only validates the the genes map to this organism)
|
|
44
41
|
obs (list[str]): list of observations to load from the Collection
|
|
45
42
|
clss_to_predict (list[str]): list of observations to encode
|
|
46
43
|
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
47
44
|
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
45
|
+
metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
|
|
46
|
+
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
|
|
47
|
+
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
48
|
+
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
48
49
|
"""
|
|
49
50
|
|
|
50
51
|
lamin_dataset: ln.Collection
|
|
51
52
|
genedf: Optional[pd.DataFrame] = None
|
|
52
|
-
organisms: Optional[Union[list[str], str]] = field(
|
|
53
|
-
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
54
|
-
)
|
|
55
53
|
# set of obs to prepare for prediction (encode)
|
|
56
54
|
clss_to_predict: Optional[list[str]] = field(default_factory=list)
|
|
57
55
|
# set of obs that need to be hierarchically prepared
|
|
@@ -59,6 +57,8 @@ class Dataset(torchDataset):
|
|
|
59
57
|
join_vars: Literal["inner", "outer"] | None = None
|
|
60
58
|
metacell_mode: float = 0.0
|
|
61
59
|
get_knn_cells: bool = False
|
|
60
|
+
store_location: str | None = None
|
|
61
|
+
force_recompute_indices: bool = False
|
|
62
62
|
|
|
63
63
|
def __post_init__(self):
|
|
64
64
|
self.mapped_dataset = mapped(
|
|
@@ -71,6 +71,8 @@ class Dataset(torchDataset):
|
|
|
71
71
|
parallel=True,
|
|
72
72
|
metacell_mode=self.metacell_mode,
|
|
73
73
|
get_knn_cells=self.get_knn_cells,
|
|
74
|
+
store_location=self.store_location,
|
|
75
|
+
force_recompute_indices=self.force_recompute_indices,
|
|
74
76
|
)
|
|
75
77
|
print(
|
|
76
78
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
@@ -85,7 +87,7 @@ class Dataset(torchDataset):
|
|
|
85
87
|
if clss not in self.hierarchical_clss:
|
|
86
88
|
# otherwise it's already been done
|
|
87
89
|
self.class_topred[clss] = set(
|
|
88
|
-
self.mapped_dataset.
|
|
90
|
+
self.mapped_dataset.encoders[clss].keys()
|
|
89
91
|
)
|
|
90
92
|
if (
|
|
91
93
|
self.mapped_dataset.unknown_label
|
|
@@ -94,12 +96,19 @@ class Dataset(torchDataset):
|
|
|
94
96
|
self.class_topred[clss] -= set(
|
|
95
97
|
[self.mapped_dataset.unknown_label]
|
|
96
98
|
)
|
|
97
|
-
|
|
98
99
|
if self.genedf is None:
|
|
100
|
+
if "organism_ontology_term_id" not in self.clss_to_predict:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
|
|
103
|
+
)
|
|
104
|
+
self.organisms = list(self.class_topred["organism_ontology_term_id"])
|
|
105
|
+
self.organisms.sort()
|
|
99
106
|
self.genedf = load_genes(self.organisms)
|
|
107
|
+
else:
|
|
108
|
+
self.organisms = None
|
|
100
109
|
|
|
101
110
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
102
|
-
self.check_aligned_vars()
|
|
111
|
+
# self.check_aligned_vars()
|
|
103
112
|
|
|
104
113
|
def check_aligned_vars(self):
|
|
105
114
|
vars = self.genedf.index.tolist()
|
|
@@ -117,6 +126,10 @@ class Dataset(torchDataset):
|
|
|
117
126
|
def encoder(self):
|
|
118
127
|
return self.mapped_dataset.encoders
|
|
119
128
|
|
|
129
|
+
@encoder.setter
|
|
130
|
+
def encoder(self, encoder):
|
|
131
|
+
self.mapped_dataset.encoders = encoder
|
|
132
|
+
|
|
120
133
|
def __getitem__(self, *args, **kwargs):
|
|
121
134
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
122
135
|
return item
|
|
@@ -132,7 +145,11 @@ class Dataset(torchDataset):
|
|
|
132
145
|
+ " {} genes\n".format(self.genedf.shape[0])
|
|
133
146
|
+ " {} clss_to_predict\n".format(len(self.clss_to_predict))
|
|
134
147
|
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
135
|
-
+
|
|
148
|
+
+ (
|
|
149
|
+
" {} organisms\n".format(len(self.organisms))
|
|
150
|
+
if self.organisms is not None
|
|
151
|
+
else ""
|
|
152
|
+
)
|
|
136
153
|
+ (
|
|
137
154
|
"dataset contains {} classes to predict\n".format(
|
|
138
155
|
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
@@ -148,31 +165,24 @@ class Dataset(torchDataset):
|
|
|
148
165
|
obs_keys: str | list[str],
|
|
149
166
|
scaler: int = 10,
|
|
150
167
|
return_categories=False,
|
|
151
|
-
bypass_label=["neuron"],
|
|
152
168
|
):
|
|
153
169
|
"""Get all weights for the given label keys."""
|
|
154
170
|
if isinstance(obs_keys, str):
|
|
155
171
|
obs_keys = [obs_keys]
|
|
156
|
-
|
|
172
|
+
labels = None
|
|
157
173
|
for label_key in obs_keys:
|
|
158
|
-
labels_to_str = (
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
else:
|
|
165
|
-
labels = labels_list[0]
|
|
166
|
-
|
|
167
|
-
counter = Counter(labels) # type: ignore
|
|
174
|
+
labels_to_str = self.mapped_dataset.get_merged_labels(label_key)
|
|
175
|
+
if labels is None:
|
|
176
|
+
labels = labels_to_str
|
|
177
|
+
else:
|
|
178
|
+
labels = concat_categorical_codes([labels, labels_to_str])
|
|
179
|
+
counter = Counter(labels.codes) # type: ignore
|
|
168
180
|
if return_categories:
|
|
169
|
-
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
170
|
-
labels = np.array([rn[label] for label in labels])
|
|
171
181
|
counter = np.array(list(counter.values()))
|
|
172
182
|
weights = scaler / (counter + scaler)
|
|
173
|
-
return weights, labels
|
|
183
|
+
return weights, np.array(labels.codes)
|
|
174
184
|
else:
|
|
175
|
-
counts = np.array([counter[label] for label in labels])
|
|
185
|
+
counts = np.array([counter[label] for label in labels.codes])
|
|
176
186
|
if scaler is None:
|
|
177
187
|
weights = 1.0 / counts
|
|
178
188
|
else:
|
|
@@ -267,12 +277,14 @@ class Dataset(torchDataset):
|
|
|
267
277
|
clss
|
|
268
278
|
)
|
|
269
279
|
)
|
|
270
|
-
cats = set(self.mapped_dataset.
|
|
271
|
-
addition = set(LABELS_TOADD.get(clss, {}).values())
|
|
272
|
-
cats |= addition
|
|
280
|
+
cats = set(self.mapped_dataset.encoders[clss].keys())
|
|
273
281
|
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
|
274
282
|
for i, j in groupings.items():
|
|
275
283
|
if len(j) == 0:
|
|
284
|
+
# that should not happen
|
|
285
|
+
import pdb
|
|
286
|
+
|
|
287
|
+
pdb.set_trace()
|
|
276
288
|
groupings.pop(i)
|
|
277
289
|
self.labels_groupings[clss] = groupings
|
|
278
290
|
if clss in self.clss_to_predict:
|
|
@@ -287,11 +299,12 @@ class Dataset(torchDataset):
|
|
|
287
299
|
)
|
|
288
300
|
|
|
289
301
|
for i, v in enumerate(
|
|
290
|
-
|
|
302
|
+
set(groupings.keys())
|
|
303
|
+
- set(self.mapped_dataset.encoders[clss].keys())
|
|
291
304
|
):
|
|
292
305
|
self.mapped_dataset.encoders[clss].update({v: mlength + i})
|
|
293
|
-
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
294
306
|
|
|
307
|
+
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
295
308
|
self.class_topred[clss] = leaf_labels
|
|
296
309
|
c = 0
|
|
297
310
|
update = {}
|
|
@@ -320,6 +333,7 @@ class SimpleAnnDataset(torchDataset):
|
|
|
320
333
|
adata: AnnData,
|
|
321
334
|
obs_to_output: Optional[list[str]] = [],
|
|
322
335
|
layer: Optional[str] = None,
|
|
336
|
+
get_knn_cells: bool = False,
|
|
323
337
|
):
|
|
324
338
|
"""
|
|
325
339
|
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
@@ -330,31 +344,48 @@ class SimpleAnnDataset(torchDataset):
|
|
|
330
344
|
adata (anndata.AnnData): anndata object to use
|
|
331
345
|
obs_to_output (list[str]): list of observations to output from anndata.obs
|
|
332
346
|
layer (str): layer of the anndata to use
|
|
347
|
+
get_knn_cells (bool): whether to get the knn cells
|
|
333
348
|
"""
|
|
334
349
|
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
335
350
|
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
351
|
+
|
|
336
352
|
self.obs_to_output = adata.obs[obs_to_output]
|
|
353
|
+
self.get_knn_cells = get_knn_cells
|
|
354
|
+
if get_knn_cells and "connectivities" not in adata.obsp:
|
|
355
|
+
raise ValueError("neighbors key not found in adata.obsm")
|
|
356
|
+
if get_knn_cells:
|
|
357
|
+
self.distances = adata.obsp["distances"]
|
|
337
358
|
|
|
338
359
|
def __len__(self):
|
|
339
360
|
return self.adataX.shape[0]
|
|
340
361
|
|
|
341
362
|
def __iter__(self):
|
|
342
|
-
for idx
|
|
343
|
-
with warnings.catch_warnings():
|
|
344
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
345
|
-
out = {"X": self.adataX[idx].reshape(-1)}
|
|
346
|
-
out.update(
|
|
347
|
-
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
348
|
-
)
|
|
349
|
-
yield out
|
|
350
|
-
|
|
351
|
-
def __getitem__(self, idx):
|
|
352
|
-
with warnings.catch_warnings():
|
|
353
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
363
|
+
for idx in range(self.adataX.shape[0]):
|
|
354
364
|
out = {"X": self.adataX[idx].reshape(-1)}
|
|
355
365
|
out.update(
|
|
356
366
|
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
357
367
|
)
|
|
368
|
+
if self.get_knn_cells:
|
|
369
|
+
distances = self.distances[idx].toarray()[0]
|
|
370
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
371
|
+
out["knn_cells"] = np.array(
|
|
372
|
+
[self.adataX[i].reshape(-1) for i in nn_idx],
|
|
373
|
+
dtype=int,
|
|
374
|
+
)
|
|
375
|
+
out["distances"] = distances[nn_idx]
|
|
376
|
+
yield out
|
|
377
|
+
|
|
378
|
+
def __getitem__(self, idx):
|
|
379
|
+
out = {"X": self.adataX[idx].reshape(-1)}
|
|
380
|
+
out.update({name: val for name, val in self.obs_to_output.iloc[idx].items()})
|
|
381
|
+
if self.get_knn_cells:
|
|
382
|
+
distances = self.distances[idx].toarray()[0]
|
|
383
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
384
|
+
out["knn_cells"] = np.array(
|
|
385
|
+
[self.adataX[i].reshape(-1) for i in nn_idx],
|
|
386
|
+
dtype=int,
|
|
387
|
+
)
|
|
388
|
+
out["distances"] = distances[nn_idx]
|
|
358
389
|
return out
|
|
359
390
|
|
|
360
391
|
|
|
@@ -374,6 +405,8 @@ def mapped(
|
|
|
374
405
|
metacell_mode: bool = False,
|
|
375
406
|
meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
|
|
376
407
|
get_knn_cells: bool = False,
|
|
408
|
+
store_location: str | None = None,
|
|
409
|
+
force_recompute_indices: bool = False,
|
|
377
410
|
) -> MappedCollection:
|
|
378
411
|
path_list = []
|
|
379
412
|
for artifact in dataset.artifacts.all():
|
|
@@ -401,5 +434,45 @@ def mapped(
|
|
|
401
434
|
meta_assays=meta_assays,
|
|
402
435
|
metacell_mode=metacell_mode,
|
|
403
436
|
get_knn_cells=get_knn_cells,
|
|
437
|
+
store_location=store_location,
|
|
438
|
+
force_recompute_indices=force_recompute_indices,
|
|
404
439
|
)
|
|
405
440
|
return ds
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def concat_categorical_codes(series_list: list[pd.Categorical]) -> pd.Categorical:
|
|
444
|
+
"""Efficiently combine multiple categorical data using their codes,
|
|
445
|
+
only creating categories for combinations that exist in the data.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
series_list: List of pandas Categorical data
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Combined Categorical with only existing combinations
|
|
452
|
+
"""
|
|
453
|
+
# Get the codes for each categorical
|
|
454
|
+
codes_list = [s.codes.astype(np.int32) for s in series_list]
|
|
455
|
+
n_cats = [len(s.categories) for s in series_list]
|
|
456
|
+
|
|
457
|
+
# Calculate combined codes
|
|
458
|
+
combined_codes = codes_list[0]
|
|
459
|
+
multiplier = n_cats[0]
|
|
460
|
+
for codes, n_cat in zip(codes_list[1:], n_cats[1:]):
|
|
461
|
+
combined_codes = (combined_codes * n_cat) + codes
|
|
462
|
+
multiplier *= n_cat
|
|
463
|
+
|
|
464
|
+
# Find unique combinations that actually exist in the data
|
|
465
|
+
unique_existing_codes = np.unique(combined_codes)
|
|
466
|
+
|
|
467
|
+
# Create a mapping from old codes to new compressed codes
|
|
468
|
+
code_mapping = {old: new for new, old in enumerate(unique_existing_codes)}
|
|
469
|
+
|
|
470
|
+
# Map the combined codes to their new compressed values
|
|
471
|
+
combined_codes = np.array([code_mapping[code] for code in combined_codes])
|
|
472
|
+
|
|
473
|
+
# Create final categorical with only existing combinations
|
|
474
|
+
return pd.Categorical.from_codes(
|
|
475
|
+
codes=combined_codes,
|
|
476
|
+
categories=np.arange(len(unique_existing_codes)),
|
|
477
|
+
ordered=False,
|
|
478
|
+
)
|