scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,403 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import lamindb as ln
4
+
5
+ from torch.utils.data.sampler import (
6
+ WeightedRandomSampler,
7
+ SubsetRandomSampler,
8
+ SequentialSampler,
9
+ RandomSampler,
10
+ )
11
+ import torch
12
+ from torch.utils.data import DataLoader, Sampler
13
+ import lightning as L
14
+
15
+ from typing import Optional, Union, Sequence
16
+
17
+ from .data import Dataset
18
+ from .collator import Collator
19
+ from .utils import getBiomartTable
20
+
21
+
22
+ class DataModule(L.LightningDataModule):
23
+ def __init__(
24
+ self,
25
+ collection_name: str,
26
+ clss_to_weight: list = ["organism_ontology_term_id"],
27
+ organisms: list = ["NCBITaxon:9606"],
28
+ weight_scaler: int = 10,
29
+ train_oversampling_per_epoch: float = 0.1,
30
+ validation_split: float = 0.2,
31
+ test_split: float = 0,
32
+ gene_embeddings: str = "",
33
+ use_default_col: bool = True,
34
+ gene_position_tolerance: int = 10_000,
35
+ # this is for the mappedCollection
36
+ clss_to_pred: list = ["organism_ontology_term_id"],
37
+ all_clss: list = ["organism_ontology_term_id"],
38
+ hierarchical_clss: list = [],
39
+ # this is for the collator
40
+ how: str = "random expr",
41
+ organism_name: str = "organism_ontology_term_id",
42
+ max_len: int = 1000,
43
+ add_zero_genes: int = 100,
44
+ do_gene_pos: Union[bool, str] = True,
45
+ tp_name: Optional[str] = None, # "heat_diff"
46
+ assays_to_drop: list = [
47
+ "EFO:0008853",
48
+ "EFO:0010961",
49
+ "EFO:0030007",
50
+ "EFO:0030062",
51
+ ],
52
+ **kwargs,
53
+ ):
54
+ """
55
+ DataModule a pytorch lighting datamodule directly from a lamin Collection.
56
+ it can work with bare pytorch too
57
+
58
+ It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
59
+ This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
60
+
61
+ Args:
62
+ collection_name (str): The lamindb collection to be used.
63
+ clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
64
+ organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
65
+ weight_scaler (int, optional): how much more you will see the most present vs less present category.
66
+ train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
67
+ validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
68
+ test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
69
+ it will use a full dataset and will round to the nearest dataset's cell count.
70
+ gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
71
+ the file must have ensembl_gene_id as index.
72
+ This is used to subset the available genes further to the ones that have embeddings in your model.
73
+ use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
74
+ gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
75
+ any genes within this distance of each other will be considered at the same position.
76
+ clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
77
+ assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
78
+ do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
79
+ max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
80
+ add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
81
+ how (str, optional): The method to use for the collator. Defaults to "random expr".
82
+ organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
83
+ tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
84
+ hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
85
+ all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
86
+ clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
87
+ **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
88
+
89
+ see @file data.py and @file collator.py for more details about some of the parameters
90
+ """
91
+ if collection_name is not None:
92
+ mdataset = Dataset(
93
+ ln.Collection.filter(name=collection_name).first(),
94
+ organisms=organisms,
95
+ obs=all_clss,
96
+ clss_to_pred=clss_to_pred,
97
+ hierarchical_clss=hierarchical_clss,
98
+ )
99
+ # print(mdataset)
100
+ # and location
101
+ self.gene_pos = None
102
+ if do_gene_pos:
103
+ if type(do_gene_pos) is str:
104
+ print("seeing a string: loading gene positions as biomart parquet file")
105
+ biomart = pd.read_parquet(do_gene_pos)
106
+ else:
107
+ # and annotations
108
+ if organisms != ["NCBITaxon:9606"]:
109
+ raise ValueError(
110
+ "need to provide your own table as this automated function only works for humans for now"
111
+ )
112
+ biomart = getBiomartTable(
113
+ attributes=["start_position", "chromosome_name"]
114
+ ).set_index("ensembl_gene_id")
115
+ biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
116
+ biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
117
+ c = []
118
+ i = 0
119
+ prev_position = -100000
120
+ prev_chromosome = None
121
+ for _, r in biomart.iterrows():
122
+ if (
123
+ r["chromosome_name"] != prev_chromosome
124
+ or r["start_position"] - prev_position > gene_position_tolerance
125
+ ):
126
+ i += 1
127
+ c.append(i)
128
+ prev_position = r["start_position"]
129
+ prev_chromosome = r["chromosome_name"]
130
+ print(f"reduced the size to {len(set(c))/len(biomart)}")
131
+ biomart["pos"] = c
132
+ mdataset.genedf = biomart.loc[mdataset.genedf.index]
133
+ self.gene_pos = mdataset.genedf["pos"].tolist()
134
+
135
+ if gene_embeddings != "":
136
+ mdataset.genedf = mdataset.genedf.join(
137
+ pd.read_parquet(gene_embeddings), how="inner"
138
+ )
139
+ if do_gene_pos:
140
+ self.gene_pos = mdataset.genedf["pos"].tolist()
141
+ self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
142
+ # we might want not to order the genes by expression (or do it?)
143
+ # we might want to not introduce zeros and
144
+ if use_default_col:
145
+ kwargs["collate_fn"] = Collator(
146
+ organisms=organisms,
147
+ how=how,
148
+ valid_genes=mdataset.genedf.index.tolist(),
149
+ max_len=max_len,
150
+ add_zero_genes=add_zero_genes,
151
+ org_to_id=mdataset.encoder[organism_name],
152
+ tp_name=tp_name,
153
+ organism_name=organism_name,
154
+ class_names=clss_to_weight,
155
+ )
156
+ self.validation_split = validation_split
157
+ self.test_split = test_split
158
+ self.dataset = mdataset
159
+ self.kwargs = kwargs
160
+ if "sampler" in self.kwargs:
161
+ self.kwargs.pop("sampler")
162
+ self.assays_to_drop = assays_to_drop
163
+ self.n_samples = len(mdataset)
164
+ self.weight_scaler = weight_scaler
165
+ self.train_oversampling_per_epoch = train_oversampling_per_epoch
166
+ self.clss_to_weight = clss_to_weight
167
+ self.train_weights = None
168
+ self.train_labels = None
169
+ self.test_datasets = []
170
+ self.test_idx = []
171
+ super().__init__()
172
+
173
+ def __repr__(self):
174
+ return (
175
+ f"DataLoader(\n"
176
+ f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
177
+ f"\tvalidation_split={self.validation_split},\n"
178
+ f"\ttest_split={self.test_split},\n"
179
+ f"\tn_samples={self.n_samples},\n"
180
+ f"\tweight_scaler={self.weight_scaler},\n"
181
+ f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
182
+ f"\tassays_to_drop={self.assays_to_drop},\n"
183
+ f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
184
+ f"\ttest datasets={str(self.test_datasets)},\n"
185
+ f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
186
+ f"\tclss_to_weight={self.clss_to_weight}\n"
187
+ + (
188
+ "\twith train_dataset size of=("
189
+ + str((self.train_weights != 0).sum())
190
+ + ")\n)"
191
+ )
192
+ if self.train_weights is not None
193
+ else ")"
194
+ )
195
+
196
+ @property
197
+ def decoders(self):
198
+ """
199
+ decoders the decoders for any labels that would have been encoded
200
+
201
+ Returns:
202
+ dict[str, dict[int, str]]
203
+ """
204
+ decoders = {}
205
+ for k, v in self.dataset.encoder.items():
206
+ decoders[k] = {va: ke for ke, va in v.items()}
207
+ return decoders
208
+
209
+ @property
210
+ def labels_hierarchy(self):
211
+ """
212
+ labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy
213
+
214
+ Returns:
215
+ dict[str, dict[str, str]]
216
+ """
217
+ labels_hierarchy = {}
218
+ for k, dic in self.dataset.labels_groupings.items():
219
+ rdic = {}
220
+ for sk, v in dic.items():
221
+ rdic[self.dataset.encoder[k][sk]] = [
222
+ self.dataset.encoder[k][i] for i in list(v)
223
+ ]
224
+ labels_hierarchy[k] = rdic
225
+ return labels_hierarchy
226
+
227
+ @property
228
+ def genes(self):
229
+ """
230
+ genes the genes used in this datamodule
231
+
232
+ Returns:
233
+ list
234
+ """
235
+ return self.dataset.genedf.index.tolist()
236
+
237
+ @property
238
+ def num_datasets(self):
239
+ return len(self.dataset.mapped_dataset.storages)
240
+
241
+ def setup(self, stage=None):
242
+ """
243
+ setup method is used to prepare the data for the training, validation, and test sets.
244
+ It shuffles the data, calculates weights for each set, and creates samplers for each set.
245
+
246
+ Args:
247
+ stage (str, optional): The stage of the model training process.
248
+ It can be either 'fit' or 'test'. Defaults to None.
249
+ """
250
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
251
+ weights, labels = self.dataset.get_label_weights(
252
+ self.clss_to_weight, scaler=self.weight_scaler
253
+ )
254
+ else:
255
+ weights = np.ones(1)
256
+ labels = np.zeros(self.n_samples)
257
+ if isinstance(self.validation_split, int):
258
+ len_valid = self.validation_split
259
+ else:
260
+ len_valid = int(self.n_samples * self.validation_split)
261
+ if isinstance(self.test_split, int):
262
+ len_test = self.test_split
263
+ else:
264
+ len_test = int(self.n_samples * self.test_split)
265
+ assert (
266
+ len_test + len_valid < self.n_samples
267
+ ), "test set + valid set size is configured to be larger than entire dataset."
268
+
269
+ idx_full = []
270
+ if len(self.assays_to_drop) > 0:
271
+ for i, a in enumerate(
272
+ self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
273
+ ):
274
+ if a not in self.assays_to_drop:
275
+ idx_full.append(i)
276
+ idx_full = np.array(idx_full)
277
+ else:
278
+ idx_full = np.arange(self.n_samples)
279
+ if len_test > 0:
280
+ # this way we work on some never seen datasets
281
+ # keeping at least one
282
+ len_test = (
283
+ len_test
284
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
285
+ else self.dataset.mapped_dataset.n_obs_list[0]
286
+ )
287
+ cs = 0
288
+ for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
289
+ if cs + c > len_test:
290
+ break
291
+ else:
292
+ self.test_datasets.append(
293
+ self.dataset.mapped_dataset._path_list[i].path
294
+ )
295
+ cs += c
296
+ len_test = cs
297
+ self.test_idx = idx_full[:len_test]
298
+ idx_full = idx_full[len_test:]
299
+ else:
300
+ self.test_idx = None
301
+
302
+ np.random.shuffle(idx_full)
303
+ if len_valid > 0:
304
+ self.valid_idx = idx_full[:len_valid].copy()
305
+ idx_full = idx_full[len_valid:]
306
+ else:
307
+ self.valid_idx = None
308
+ weights = np.concatenate([weights, np.zeros(1)])
309
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
310
+
311
+ self.train_weights = weights
312
+ self.train_labels = labels
313
+ self.idx_full = idx_full
314
+ return self.test_datasets
315
+
316
+ def train_dataloader(self, **kwargs):
317
+ # train_sampler = WeightedRandomSampler(
318
+ # self.train_weights[self.train_labels],
319
+ # int(self.n_samples*self.train_oversampling_per_epoch),
320
+ # replacement=True,
321
+ # )
322
+ train_sampler = LabelWeightedSampler(
323
+ self.train_weights,
324
+ self.train_labels,
325
+ num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
326
+ )
327
+ return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
328
+
329
+ def val_dataloader(self):
330
+ return (
331
+ DataLoader(
332
+ self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
333
+ )
334
+ if self.valid_idx is not None
335
+ else None
336
+ )
337
+
338
+ def test_dataloader(self):
339
+ return (
340
+ DataLoader(
341
+ self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
342
+ )
343
+ if self.test_idx is not None
344
+ else None
345
+ )
346
+
347
+ def predict_dataloader(self):
348
+ return DataLoader(
349
+ self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
350
+ )
351
+
352
+ # def teardown(self):
353
+ # clean up state after the trainer stops, delete files...
354
+ # called on every process in DDP
355
+ # pass
356
+
357
+
358
+ class LabelWeightedSampler(Sampler[int]):
359
+ label_weights: Sequence[float]
360
+ klass_indices: Sequence[Sequence[int]]
361
+ num_samples: int
362
+
363
+ # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
364
+ # this will result a class-balanced sampling, no matter how imbalance the labels are.
365
+ # NOTE: here we use replacement=True, you can change it if you don't upsample a class.
366
+ def __init__(
367
+ self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
368
+ ) -> None:
369
+ """
370
+
371
+ :param label_weights: list(len=num_classes)[float], weights for each class.
372
+ :param labels: list(len=dataset_len)[int], labels of a dataset.
373
+ :param num_samples: number of samples.
374
+ """
375
+
376
+ super(LabelWeightedSampler, self).__init__(None)
377
+ # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
378
+ label_weights = np.array(label_weights) * np.bincount(labels)
379
+
380
+ self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
381
+ self.labels = torch.as_tensor(labels, dtype=torch.int)
382
+ self.num_samples = num_samples
383
+ # list of tensor.
384
+ self.klass_indices = [
385
+ (self.labels == i_klass).nonzero().squeeze(1)
386
+ for i_klass in range(len(label_weights))
387
+ ]
388
+
389
+ def __iter__(self):
390
+ sample_labels = torch.multinomial(
391
+ self.label_weights, num_samples=self.num_samples, replacement=True
392
+ )
393
+ sample_indices = torch.empty_like(sample_labels)
394
+ for i_klass, klass_index in enumerate(self.klass_indices):
395
+ if klass_index.numel() == 0:
396
+ continue
397
+ left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
398
+ right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
399
+ sample_indices[left_inds] = klass_index[right_inds]
400
+ yield from iter(sample_indices.tolist())
401
+
402
+ def __len__(self):
403
+ return self.num_samples