scdataloader 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +4 -0
- scdataloader/__main__.py +209 -0
- scdataloader/collator.py +307 -0
- scdataloader/config.py +106 -0
- scdataloader/data.py +181 -218
- scdataloader/datamodule.py +375 -0
- scdataloader/mapped.py +46 -32
- scdataloader/preprocess.py +524 -208
- scdataloader/utils.py +189 -123
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/METADATA +77 -7
- scdataloader-0.0.4.dist-info/RECORD +16 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/WHEEL +1 -1
- scdataloader-0.0.2.dist-info/RECORD +0 -12
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
|
|
5
|
+
from torch.utils.data.sampler import (
|
|
6
|
+
WeightedRandomSampler,
|
|
7
|
+
SubsetRandomSampler,
|
|
8
|
+
SequentialSampler,
|
|
9
|
+
)
|
|
10
|
+
import torch
|
|
11
|
+
from torch.utils.data import DataLoader, Sampler
|
|
12
|
+
import lightning as L
|
|
13
|
+
|
|
14
|
+
from typing import Optional, Union, Sequence
|
|
15
|
+
|
|
16
|
+
from .data import Dataset
|
|
17
|
+
from .collator import Collator
|
|
18
|
+
from .utils import getBiomartTable
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DataModule(L.LightningDataModule):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
collection_name: str,
|
|
25
|
+
label_to_weight: list = ["organism_ontology_term_id"],
|
|
26
|
+
organisms: list = ["NCBITaxon:9606"],
|
|
27
|
+
weight_scaler: int = 10,
|
|
28
|
+
train_oversampling_per_epoch: float = 0.1,
|
|
29
|
+
validation_split: float = 0.2,
|
|
30
|
+
test_split: float = 0,
|
|
31
|
+
gene_embeddings: str = "",
|
|
32
|
+
use_default_col: bool = True,
|
|
33
|
+
gene_position_tolerance: int = 10_000,
|
|
34
|
+
# this is for the mappedCollection
|
|
35
|
+
label_to_pred: list = ["organism_ontology_term_id"],
|
|
36
|
+
all_labels: list = ["organism_ontology_term_id"],
|
|
37
|
+
hierarchical_labels: list = [],
|
|
38
|
+
# this is for the collator
|
|
39
|
+
how: str = "random expr",
|
|
40
|
+
organism_name: str = "organism_ontology_term_id",
|
|
41
|
+
max_len: int = 1000,
|
|
42
|
+
add_zero_genes: int = 100,
|
|
43
|
+
do_gene_pos: Union[bool, str] = True,
|
|
44
|
+
tp_name: Optional[str] = None, # "heat_diff"
|
|
45
|
+
assays_to_drop: list = [
|
|
46
|
+
"EFO:0008853",
|
|
47
|
+
"EFO:0010961",
|
|
48
|
+
"EFO:0030007",
|
|
49
|
+
"EFO:0030062",
|
|
50
|
+
],
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
DataModule a pytorch lighting datamodule directly from a lamin Collection.
|
|
55
|
+
it can work with bare pytorch too
|
|
56
|
+
|
|
57
|
+
It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
|
|
58
|
+
This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
collection_name (str): The lamindb collection to be used.
|
|
62
|
+
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
63
|
+
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
64
|
+
any genes within this distance of each other will be considered at the same position.
|
|
65
|
+
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
|
|
66
|
+
the file must have ensembl_gene_id as index.
|
|
67
|
+
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
68
|
+
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
69
|
+
label_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
70
|
+
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
71
|
+
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
72
|
+
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
73
|
+
**other args: see @file data.py and @file collator.py for more details
|
|
74
|
+
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
75
|
+
"""
|
|
76
|
+
if collection_name is not None:
|
|
77
|
+
mdataset = Dataset(
|
|
78
|
+
ln.Collection.filter(name=collection_name).first(),
|
|
79
|
+
organisms=organisms,
|
|
80
|
+
obs=all_labels,
|
|
81
|
+
clss_to_pred=label_to_pred,
|
|
82
|
+
hierarchical_clss=hierarchical_labels,
|
|
83
|
+
)
|
|
84
|
+
print(mdataset)
|
|
85
|
+
# and location
|
|
86
|
+
if do_gene_pos:
|
|
87
|
+
if type(do_gene_pos) is str:
|
|
88
|
+
print("seeing a string: loading gene positions as biomart parquet file")
|
|
89
|
+
biomart = pd.read_parquet(do_gene_pos)
|
|
90
|
+
else:
|
|
91
|
+
# and annotations
|
|
92
|
+
biomart = getBiomartTable(
|
|
93
|
+
attributes=["start_position", "chromosome_name"]
|
|
94
|
+
).set_index("ensembl_gene_id")
|
|
95
|
+
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
96
|
+
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
97
|
+
c = []
|
|
98
|
+
i = 0
|
|
99
|
+
prev_position = -100000
|
|
100
|
+
prev_chromosome = None
|
|
101
|
+
for _, r in biomart.iterrows():
|
|
102
|
+
if (
|
|
103
|
+
r["chromosome_name"] != prev_chromosome
|
|
104
|
+
or r["start_position"] - prev_position > gene_position_tolerance
|
|
105
|
+
):
|
|
106
|
+
i += 1
|
|
107
|
+
c.append(i)
|
|
108
|
+
prev_position = r["start_position"]
|
|
109
|
+
prev_chromosome = r["chromosome_name"]
|
|
110
|
+
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
111
|
+
biomart["pos"] = c
|
|
112
|
+
mdataset.genedf = biomart.loc[mdataset.genedf.index]
|
|
113
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
114
|
+
|
|
115
|
+
if gene_embeddings != "":
|
|
116
|
+
mdataset.genedf = mdataset.genedf.join(
|
|
117
|
+
pd.read_parquet(gene_embeddings), how="inner"
|
|
118
|
+
)
|
|
119
|
+
if do_gene_pos:
|
|
120
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
121
|
+
self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
122
|
+
# we might want not to order the genes by expression (or do it?)
|
|
123
|
+
# we might want to not introduce zeros and
|
|
124
|
+
if use_default_col:
|
|
125
|
+
kwargs["collate_fn"] = Collator(
|
|
126
|
+
organisms=organisms,
|
|
127
|
+
how=how,
|
|
128
|
+
valid_genes=mdataset.genedf.index.tolist(),
|
|
129
|
+
max_len=max_len,
|
|
130
|
+
add_zero_genes=add_zero_genes,
|
|
131
|
+
org_to_id=mdataset.encoder[organism_name],
|
|
132
|
+
tp_name=tp_name,
|
|
133
|
+
organism_name=organism_name,
|
|
134
|
+
class_names=label_to_weight,
|
|
135
|
+
)
|
|
136
|
+
self.validation_split = validation_split
|
|
137
|
+
self.test_split = test_split
|
|
138
|
+
self.dataset = mdataset
|
|
139
|
+
self.kwargs = kwargs
|
|
140
|
+
self.assays_to_drop = assays_to_drop
|
|
141
|
+
self.n_samples = len(mdataset)
|
|
142
|
+
self.weight_scaler = weight_scaler
|
|
143
|
+
self.train_oversampling_per_epoch = train_oversampling_per_epoch
|
|
144
|
+
self.label_to_weight = label_to_weight
|
|
145
|
+
self.train_weights = None
|
|
146
|
+
self.train_labels = None
|
|
147
|
+
super().__init__()
|
|
148
|
+
|
|
149
|
+
def __repr__(self):
|
|
150
|
+
return (
|
|
151
|
+
f"DataLoader(\n"
|
|
152
|
+
f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
|
|
153
|
+
f"\tvalidation_split={self.validation_split},\n"
|
|
154
|
+
f"\ttest_split={self.test_split},\n"
|
|
155
|
+
f"\tn_samples={self.n_samples},\n"
|
|
156
|
+
f"\tweight_scaler={self.weight_scaler},\n"
|
|
157
|
+
f"\train_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
|
|
158
|
+
f"\tlabel_to_weight={self.label_to_weight}\n"
|
|
159
|
+
+ (
|
|
160
|
+
"\twith train_dataset size of=("
|
|
161
|
+
+ str((self.train_weights != 0).sum())
|
|
162
|
+
+ ")\n)"
|
|
163
|
+
)
|
|
164
|
+
if self.train_weights is not None
|
|
165
|
+
else ")"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def decoders(self):
|
|
170
|
+
"""
|
|
171
|
+
decoders the decoders for any labels that would have been encoded
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
dict[str, dict[int, str]]
|
|
175
|
+
"""
|
|
176
|
+
decoders = {}
|
|
177
|
+
for k, v in self.dataset.encoder.items():
|
|
178
|
+
decoders[k] = {va: ke for ke, va in v.items()}
|
|
179
|
+
return decoders
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def cls_hierarchy(self):
|
|
183
|
+
"""
|
|
184
|
+
cls_hierarchy the hierarchy of labels for any cls that would have a hierarchy
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
dict[str, dict[str, str]]
|
|
188
|
+
"""
|
|
189
|
+
cls_hierarchy = {}
|
|
190
|
+
for k, dic in self.dataset.class_groupings.items():
|
|
191
|
+
rdic = {}
|
|
192
|
+
for sk, v in dic.items():
|
|
193
|
+
rdic[self.dataset.encoder[k][sk]] = [
|
|
194
|
+
self.dataset.encoder[k][i] for i in list(v)
|
|
195
|
+
]
|
|
196
|
+
cls_hierarchy[k] = rdic
|
|
197
|
+
return cls_hierarchy
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def genes(self):
|
|
201
|
+
"""
|
|
202
|
+
genes the genes used in this datamodule
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
list
|
|
206
|
+
"""
|
|
207
|
+
return self.dataset.genedf.index.tolist()
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def num_datasets(self):
|
|
211
|
+
return len(self.dataset.mapped_dataset.storages)
|
|
212
|
+
|
|
213
|
+
def setup(self, stage=None):
|
|
214
|
+
"""
|
|
215
|
+
setup method is used to prepare the data for the training, validation, and test sets.
|
|
216
|
+
It shuffles the data, calculates weights for each set, and creates samplers for each set.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
stage (str, optional): The stage of the model training process.
|
|
220
|
+
It can be either 'fit' or 'test'. Defaults to None.
|
|
221
|
+
"""
|
|
222
|
+
if len(self.label_to_weight) > 0:
|
|
223
|
+
weights, labels = self.dataset.get_label_weights(
|
|
224
|
+
self.label_to_weight, scaler=self.weight_scaler
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
weights = np.ones(1)
|
|
228
|
+
labels = np.zeros(self.n_samples)
|
|
229
|
+
if isinstance(self.validation_split, int):
|
|
230
|
+
len_valid = self.validation_split
|
|
231
|
+
else:
|
|
232
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
233
|
+
if isinstance(self.test_split, int):
|
|
234
|
+
len_test = self.test_split
|
|
235
|
+
else:
|
|
236
|
+
len_test = int(self.n_samples * self.test_split)
|
|
237
|
+
assert (
|
|
238
|
+
len_test + len_valid < self.n_samples
|
|
239
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
240
|
+
|
|
241
|
+
idx_full = []
|
|
242
|
+
if len(self.assays_to_drop) > 0:
|
|
243
|
+
for i, a in enumerate(
|
|
244
|
+
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
|
|
245
|
+
):
|
|
246
|
+
if a not in self.assays_to_drop:
|
|
247
|
+
idx_full.append(i)
|
|
248
|
+
idx_full = np.array(idx_full)
|
|
249
|
+
else:
|
|
250
|
+
idx_full = np.arange(self.n_samples)
|
|
251
|
+
test_datasets = []
|
|
252
|
+
if len_test > 0:
|
|
253
|
+
# this way we work on some never seen datasets
|
|
254
|
+
# keeping at least one
|
|
255
|
+
len_test = (
|
|
256
|
+
len_test
|
|
257
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
258
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
259
|
+
)
|
|
260
|
+
cs = 0
|
|
261
|
+
print("these files will be considered test datasets:")
|
|
262
|
+
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
263
|
+
if cs + c > len_test:
|
|
264
|
+
break
|
|
265
|
+
else:
|
|
266
|
+
print(" " + self.dataset.mapped_dataset.path_list[i].path)
|
|
267
|
+
test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
|
|
268
|
+
cs += c
|
|
269
|
+
|
|
270
|
+
len_test = cs
|
|
271
|
+
print("perc test: ", len_test / self.n_samples)
|
|
272
|
+
self.test_idx = idx_full[:len_test]
|
|
273
|
+
idx_full = idx_full[len_test:]
|
|
274
|
+
else:
|
|
275
|
+
self.test_idx = None
|
|
276
|
+
|
|
277
|
+
np.random.shuffle(idx_full)
|
|
278
|
+
if len_valid > 0:
|
|
279
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
280
|
+
idx_full = idx_full[len_valid:]
|
|
281
|
+
else:
|
|
282
|
+
self.valid_idx = None
|
|
283
|
+
weights = np.concatenate([weights, np.zeros(1)])
|
|
284
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
285
|
+
|
|
286
|
+
self.train_weights = weights
|
|
287
|
+
self.train_labels = labels
|
|
288
|
+
self.idx_full = idx_full
|
|
289
|
+
|
|
290
|
+
return test_datasets
|
|
291
|
+
|
|
292
|
+
def train_dataloader(self, **kwargs):
|
|
293
|
+
# train_sampler = WeightedRandomSampler(
|
|
294
|
+
# self.train_weights[self.train_labels],
|
|
295
|
+
# int(self.n_samples*self.train_oversampling_per_epoch),
|
|
296
|
+
# replacement=True,
|
|
297
|
+
# )
|
|
298
|
+
train_sampler = LabelWeightedSampler(
|
|
299
|
+
self.train_weights,
|
|
300
|
+
self.train_labels,
|
|
301
|
+
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
302
|
+
# replacement=True,
|
|
303
|
+
)
|
|
304
|
+
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
305
|
+
|
|
306
|
+
def val_dataloader(self):
|
|
307
|
+
return (
|
|
308
|
+
DataLoader(
|
|
309
|
+
self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
|
|
310
|
+
)
|
|
311
|
+
if self.valid_idx is not None
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def test_dataloader(self):
|
|
316
|
+
return (
|
|
317
|
+
DataLoader(
|
|
318
|
+
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
319
|
+
)
|
|
320
|
+
if self.test_idx is not None
|
|
321
|
+
else None
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# def teardown(self):
|
|
325
|
+
# clean up state after the trainer stops, delete files...
|
|
326
|
+
# called on every process in DDP
|
|
327
|
+
# pass
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class LabelWeightedSampler(Sampler[int]):
|
|
331
|
+
label_weights: Sequence[float]
|
|
332
|
+
klass_indices: Sequence[Sequence[int]]
|
|
333
|
+
num_samples: int
|
|
334
|
+
|
|
335
|
+
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
336
|
+
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
337
|
+
# NOTE: here we use replacement=True, you can change it if you don't upsample a class.
|
|
338
|
+
def __init__(
|
|
339
|
+
self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
|
|
340
|
+
) -> None:
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
:param label_weights: list(len=num_classes)[float], weights for each class.
|
|
344
|
+
:param labels: list(len=dataset_len)[int], labels of a dataset.
|
|
345
|
+
:param num_samples: number of samples.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
super(LabelWeightedSampler, self).__init__(None)
|
|
349
|
+
# reweight labels from counter otherwsie same weight to labels that have many elements vs a few
|
|
350
|
+
label_weights = np.array(label_weights) * np.bincount(labels)
|
|
351
|
+
|
|
352
|
+
self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
|
|
353
|
+
self.labels = torch.as_tensor(labels, dtype=torch.int)
|
|
354
|
+
self.num_samples = num_samples
|
|
355
|
+
# list of tensor.
|
|
356
|
+
self.klass_indices = [
|
|
357
|
+
(self.labels == i_klass).nonzero().squeeze(1)
|
|
358
|
+
for i_klass in range(len(label_weights))
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
def __iter__(self):
|
|
362
|
+
sample_labels = torch.multinomial(
|
|
363
|
+
self.label_weights, num_samples=self.num_samples, replacement=True
|
|
364
|
+
)
|
|
365
|
+
sample_indices = torch.empty_like(sample_labels)
|
|
366
|
+
for i_klass, klass_index in enumerate(self.klass_indices):
|
|
367
|
+
if klass_index.numel() == 0:
|
|
368
|
+
continue
|
|
369
|
+
left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
|
|
370
|
+
right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
|
|
371
|
+
sample_indices[left_inds] = klass_index[right_inds]
|
|
372
|
+
yield from iter(sample_indices.tolist())
|
|
373
|
+
|
|
374
|
+
def __len__(self):
|
|
375
|
+
return self.num_samples
|
scdataloader/mapped.py
CHANGED
|
@@ -80,38 +80,51 @@ class MappedDataset:
|
|
|
80
80
|
join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
|
|
81
81
|
encode_labels: Optional[Union[bool, List[str]]] = False,
|
|
82
82
|
parallel: bool = False,
|
|
83
|
+
unknown_class: str = "unknown",
|
|
83
84
|
):
|
|
84
85
|
self.storages = []
|
|
85
86
|
self.conns = []
|
|
86
87
|
self.parallel = parallel
|
|
88
|
+
self.unknown_class = unknown_class
|
|
89
|
+
self.path_list = path_list
|
|
87
90
|
self._make_connections(path_list, parallel)
|
|
88
91
|
|
|
89
92
|
self.n_obs_list = []
|
|
90
93
|
for storage in self.storages:
|
|
91
94
|
with _Connect(storage) as store:
|
|
92
95
|
X = store["X"]
|
|
96
|
+
index = (
|
|
97
|
+
store["var"]["ensembl_gene_id"]
|
|
98
|
+
if "ensembl_gene_id" in store["var"]
|
|
99
|
+
else store["var"]["_index"]
|
|
100
|
+
)
|
|
101
|
+
if join_vars == "None":
|
|
102
|
+
if not all(
|
|
103
|
+
[
|
|
104
|
+
i <= j
|
|
105
|
+
for i, j in zip(
|
|
106
|
+
index[:99],
|
|
107
|
+
index[1:100],
|
|
108
|
+
)
|
|
109
|
+
]
|
|
110
|
+
):
|
|
111
|
+
raise ValueError("The variables are not sorted.")
|
|
93
112
|
if isinstance(X, ArrayTypes): # type: ignore
|
|
94
113
|
self.n_obs_list.append(X.shape[0])
|
|
95
114
|
else:
|
|
96
115
|
self.n_obs_list.append(X.attrs["shape"][0])
|
|
97
116
|
self.n_obs = sum(self.n_obs_list)
|
|
98
117
|
|
|
99
|
-
self.indices = np.hstack(
|
|
100
|
-
|
|
101
|
-
)
|
|
102
|
-
self.storage_idx = np.repeat(
|
|
103
|
-
np.arange(len(self.storages)), self.n_obs_list
|
|
104
|
-
)
|
|
118
|
+
self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
|
|
119
|
+
self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
|
|
105
120
|
|
|
106
121
|
self.join_vars = join_vars if len(path_list) > 1 else None
|
|
107
122
|
self.var_indices = None
|
|
108
|
-
if self.join_vars
|
|
123
|
+
if self.join_vars != "None":
|
|
109
124
|
self._make_join_vars()
|
|
110
125
|
|
|
111
126
|
self.encode_labels = encode_labels
|
|
112
|
-
self.label_keys = (
|
|
113
|
-
[label_keys] if isinstance(label_keys, str) else label_keys
|
|
114
|
-
)
|
|
127
|
+
self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
|
|
115
128
|
if isinstance(encode_labels, bool):
|
|
116
129
|
if encode_labels:
|
|
117
130
|
encode_labels = label_keys
|
|
@@ -122,6 +135,8 @@ class MappedDataset:
|
|
|
122
135
|
for label in encode_labels:
|
|
123
136
|
cats = self.get_merged_categories(label)
|
|
124
137
|
self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
|
|
138
|
+
if unknown_class in self.encoders[label]:
|
|
139
|
+
self.encoders[label][unknown_class] = -1
|
|
125
140
|
else:
|
|
126
141
|
self.encoders = {}
|
|
127
142
|
self._closed = False
|
|
@@ -157,9 +172,15 @@ class MappedDataset:
|
|
|
157
172
|
raise ValueError(
|
|
158
173
|
"The provided AnnData objects don't have shared varibales."
|
|
159
174
|
)
|
|
160
|
-
self.var_indices = [
|
|
161
|
-
|
|
162
|
-
|
|
175
|
+
self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
|
|
176
|
+
|
|
177
|
+
def _check_aligned_vars(self, vars: list):
|
|
178
|
+
i = 0
|
|
179
|
+
for storage in self.storages:
|
|
180
|
+
with _Connect(storage) as store:
|
|
181
|
+
if vars == _safer_read_index(store["var"]).tolist():
|
|
182
|
+
i += 1
|
|
183
|
+
print("{}% are aligned".format(i * 100 / len(self.storages)))
|
|
163
184
|
|
|
164
185
|
def __len__(self):
|
|
165
186
|
return self.n_obs
|
|
@@ -172,20 +193,17 @@ class MappedDataset:
|
|
|
172
193
|
else:
|
|
173
194
|
var_idxs = None
|
|
174
195
|
with _Connect(self.storages[storage_idx]) as store:
|
|
175
|
-
out =
|
|
196
|
+
out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
|
|
176
197
|
if self.label_keys is not None:
|
|
177
|
-
for
|
|
198
|
+
for _, label in enumerate(self.label_keys):
|
|
178
199
|
label_idx = self.get_label_idx(store, obs_idx, label)
|
|
179
200
|
if label in self.encoders:
|
|
180
|
-
out.
|
|
201
|
+
out.update({label: self.encoders[label][label_idx]})
|
|
181
202
|
else:
|
|
182
|
-
out.
|
|
203
|
+
out.update({label: label_idx})
|
|
204
|
+
out.update({"dataset": storage_idx})
|
|
183
205
|
return out
|
|
184
206
|
|
|
185
|
-
def uns(self, idx, key):
|
|
186
|
-
storage = self.storages[self.storage_idx[idx]]
|
|
187
|
-
return storage["uns"][key]
|
|
188
|
-
|
|
189
207
|
def get_data_idx(
|
|
190
208
|
self,
|
|
191
209
|
storage: StorageType,
|
|
@@ -240,13 +258,13 @@ class MappedDataset:
|
|
|
240
258
|
if i == 0:
|
|
241
259
|
labels = self.get_merged_labels(val)
|
|
242
260
|
else:
|
|
243
|
-
labels += "_" + self.get_merged_labels(val).astype(str).astype(
|
|
244
|
-
"O"
|
|
245
|
-
)
|
|
261
|
+
labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
|
|
246
262
|
counter = Counter(labels) # type: ignore
|
|
247
|
-
|
|
263
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
264
|
+
labels = np.array([rn[label] for label in labels])
|
|
265
|
+
counter = np.array(list(counter.values()))
|
|
248
266
|
weights = scaler / (counter + scaler)
|
|
249
|
-
return weights
|
|
267
|
+
return weights, labels
|
|
250
268
|
|
|
251
269
|
def get_merged_labels(self, label_key: str):
|
|
252
270
|
"""Get merged labels."""
|
|
@@ -255,9 +273,7 @@ class MappedDataset:
|
|
|
255
273
|
for storage in self.storages:
|
|
256
274
|
with _Connect(storage) as store:
|
|
257
275
|
codes = self.get_codes(store, label_key)
|
|
258
|
-
labels = (
|
|
259
|
-
decode(codes) if isinstance(codes[0], bytes) else codes
|
|
260
|
-
)
|
|
276
|
+
labels = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
261
277
|
cats = self.get_categories(store, label_key)
|
|
262
278
|
if cats is not None:
|
|
263
279
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
@@ -277,9 +293,7 @@ class MappedDataset:
|
|
|
277
293
|
cats_merge.update(cats)
|
|
278
294
|
else:
|
|
279
295
|
codes = self.get_codes(store, label_key)
|
|
280
|
-
codes = (
|
|
281
|
-
decode(codes) if isinstance(codes[0], bytes) else codes
|
|
282
|
-
)
|
|
296
|
+
codes = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
283
297
|
cats_merge.update(codes)
|
|
284
298
|
return cats_merge
|
|
285
299
|
|