scdataloader 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,375 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import lamindb as ln
4
+
5
+ from torch.utils.data.sampler import (
6
+ WeightedRandomSampler,
7
+ SubsetRandomSampler,
8
+ SequentialSampler,
9
+ )
10
+ import torch
11
+ from torch.utils.data import DataLoader, Sampler
12
+ import lightning as L
13
+
14
+ from typing import Optional, Union, Sequence
15
+
16
+ from .data import Dataset
17
+ from .collator import Collator
18
+ from .utils import getBiomartTable
19
+
20
+
21
+ class DataModule(L.LightningDataModule):
22
+ def __init__(
23
+ self,
24
+ collection_name: str,
25
+ label_to_weight: list = ["organism_ontology_term_id"],
26
+ organisms: list = ["NCBITaxon:9606"],
27
+ weight_scaler: int = 10,
28
+ train_oversampling_per_epoch: float = 0.1,
29
+ validation_split: float = 0.2,
30
+ test_split: float = 0,
31
+ gene_embeddings: str = "",
32
+ use_default_col: bool = True,
33
+ gene_position_tolerance: int = 10_000,
34
+ # this is for the mappedCollection
35
+ label_to_pred: list = ["organism_ontology_term_id"],
36
+ all_labels: list = ["organism_ontology_term_id"],
37
+ hierarchical_labels: list = [],
38
+ # this is for the collator
39
+ how: str = "random expr",
40
+ organism_name: str = "organism_ontology_term_id",
41
+ max_len: int = 1000,
42
+ add_zero_genes: int = 100,
43
+ do_gene_pos: Union[bool, str] = True,
44
+ tp_name: Optional[str] = None, # "heat_diff"
45
+ assays_to_drop: list = [
46
+ "EFO:0008853",
47
+ "EFO:0010961",
48
+ "EFO:0030007",
49
+ "EFO:0030062",
50
+ ],
51
+ **kwargs,
52
+ ):
53
+ """
54
+ DataModule a pytorch lighting datamodule directly from a lamin Collection.
55
+ it can work with bare pytorch too
56
+
57
+ It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
58
+ This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
59
+
60
+ Args:
61
+ collection_name (str): The lamindb collection to be used.
62
+ weight_scaler (int, optional): how much more you will see the most present vs less present category.
63
+ gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
64
+ any genes within this distance of each other will be considered at the same position.
65
+ gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
66
+ the file must have ensembl_gene_id as index.
67
+ This is used to subset the available genes further to the ones that have embeddings in your model.
68
+ organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
69
+ label_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
70
+ validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
71
+ test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
72
+ it will use a full dataset and will round to the nearest dataset's cell count.
73
+ **other args: see @file data.py and @file collator.py for more details
74
+ **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
75
+ """
76
+ if collection_name is not None:
77
+ mdataset = Dataset(
78
+ ln.Collection.filter(name=collection_name).first(),
79
+ organisms=organisms,
80
+ obs=all_labels,
81
+ clss_to_pred=label_to_pred,
82
+ hierarchical_clss=hierarchical_labels,
83
+ )
84
+ print(mdataset)
85
+ # and location
86
+ if do_gene_pos:
87
+ if type(do_gene_pos) is str:
88
+ print("seeing a string: loading gene positions as biomart parquet file")
89
+ biomart = pd.read_parquet(do_gene_pos)
90
+ else:
91
+ # and annotations
92
+ biomart = getBiomartTable(
93
+ attributes=["start_position", "chromosome_name"]
94
+ ).set_index("ensembl_gene_id")
95
+ biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
96
+ biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
97
+ c = []
98
+ i = 0
99
+ prev_position = -100000
100
+ prev_chromosome = None
101
+ for _, r in biomart.iterrows():
102
+ if (
103
+ r["chromosome_name"] != prev_chromosome
104
+ or r["start_position"] - prev_position > gene_position_tolerance
105
+ ):
106
+ i += 1
107
+ c.append(i)
108
+ prev_position = r["start_position"]
109
+ prev_chromosome = r["chromosome_name"]
110
+ print(f"reduced the size to {len(set(c))/len(biomart)}")
111
+ biomart["pos"] = c
112
+ mdataset.genedf = biomart.loc[mdataset.genedf.index]
113
+ self.gene_pos = mdataset.genedf["pos"].tolist()
114
+
115
+ if gene_embeddings != "":
116
+ mdataset.genedf = mdataset.genedf.join(
117
+ pd.read_parquet(gene_embeddings), how="inner"
118
+ )
119
+ if do_gene_pos:
120
+ self.gene_pos = mdataset.genedf["pos"].tolist()
121
+ self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
122
+ # we might want not to order the genes by expression (or do it?)
123
+ # we might want to not introduce zeros and
124
+ if use_default_col:
125
+ kwargs["collate_fn"] = Collator(
126
+ organisms=organisms,
127
+ how=how,
128
+ valid_genes=mdataset.genedf.index.tolist(),
129
+ max_len=max_len,
130
+ add_zero_genes=add_zero_genes,
131
+ org_to_id=mdataset.encoder[organism_name],
132
+ tp_name=tp_name,
133
+ organism_name=organism_name,
134
+ class_names=label_to_weight,
135
+ )
136
+ self.validation_split = validation_split
137
+ self.test_split = test_split
138
+ self.dataset = mdataset
139
+ self.kwargs = kwargs
140
+ self.assays_to_drop = assays_to_drop
141
+ self.n_samples = len(mdataset)
142
+ self.weight_scaler = weight_scaler
143
+ self.train_oversampling_per_epoch = train_oversampling_per_epoch
144
+ self.label_to_weight = label_to_weight
145
+ self.train_weights = None
146
+ self.train_labels = None
147
+ super().__init__()
148
+
149
+ def __repr__(self):
150
+ return (
151
+ f"DataLoader(\n"
152
+ f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
153
+ f"\tvalidation_split={self.validation_split},\n"
154
+ f"\ttest_split={self.test_split},\n"
155
+ f"\tn_samples={self.n_samples},\n"
156
+ f"\tweight_scaler={self.weight_scaler},\n"
157
+ f"\train_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
158
+ f"\tlabel_to_weight={self.label_to_weight}\n"
159
+ + (
160
+ "\twith train_dataset size of=("
161
+ + str((self.train_weights != 0).sum())
162
+ + ")\n)"
163
+ )
164
+ if self.train_weights is not None
165
+ else ")"
166
+ )
167
+
168
+ @property
169
+ def decoders(self):
170
+ """
171
+ decoders the decoders for any labels that would have been encoded
172
+
173
+ Returns:
174
+ dict[str, dict[int, str]]
175
+ """
176
+ decoders = {}
177
+ for k, v in self.dataset.encoder.items():
178
+ decoders[k] = {va: ke for ke, va in v.items()}
179
+ return decoders
180
+
181
+ @property
182
+ def cls_hierarchy(self):
183
+ """
184
+ cls_hierarchy the hierarchy of labels for any cls that would have a hierarchy
185
+
186
+ Returns:
187
+ dict[str, dict[str, str]]
188
+ """
189
+ cls_hierarchy = {}
190
+ for k, dic in self.dataset.class_groupings.items():
191
+ rdic = {}
192
+ for sk, v in dic.items():
193
+ rdic[self.dataset.encoder[k][sk]] = [
194
+ self.dataset.encoder[k][i] for i in list(v)
195
+ ]
196
+ cls_hierarchy[k] = rdic
197
+ return cls_hierarchy
198
+
199
+ @property
200
+ def genes(self):
201
+ """
202
+ genes the genes used in this datamodule
203
+
204
+ Returns:
205
+ list
206
+ """
207
+ return self.dataset.genedf.index.tolist()
208
+
209
+ @property
210
+ def num_datasets(self):
211
+ return len(self.dataset.mapped_dataset.storages)
212
+
213
+ def setup(self, stage=None):
214
+ """
215
+ setup method is used to prepare the data for the training, validation, and test sets.
216
+ It shuffles the data, calculates weights for each set, and creates samplers for each set.
217
+
218
+ Args:
219
+ stage (str, optional): The stage of the model training process.
220
+ It can be either 'fit' or 'test'. Defaults to None.
221
+ """
222
+ if len(self.label_to_weight) > 0:
223
+ weights, labels = self.dataset.get_label_weights(
224
+ self.label_to_weight, scaler=self.weight_scaler
225
+ )
226
+ else:
227
+ weights = np.ones(1)
228
+ labels = np.zeros(self.n_samples)
229
+ if isinstance(self.validation_split, int):
230
+ len_valid = self.validation_split
231
+ else:
232
+ len_valid = int(self.n_samples * self.validation_split)
233
+ if isinstance(self.test_split, int):
234
+ len_test = self.test_split
235
+ else:
236
+ len_test = int(self.n_samples * self.test_split)
237
+ assert (
238
+ len_test + len_valid < self.n_samples
239
+ ), "test set + valid set size is configured to be larger than entire dataset."
240
+
241
+ idx_full = []
242
+ if len(self.assays_to_drop) > 0:
243
+ for i, a in enumerate(
244
+ self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
245
+ ):
246
+ if a not in self.assays_to_drop:
247
+ idx_full.append(i)
248
+ idx_full = np.array(idx_full)
249
+ else:
250
+ idx_full = np.arange(self.n_samples)
251
+ test_datasets = []
252
+ if len_test > 0:
253
+ # this way we work on some never seen datasets
254
+ # keeping at least one
255
+ len_test = (
256
+ len_test
257
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
258
+ else self.dataset.mapped_dataset.n_obs_list[0]
259
+ )
260
+ cs = 0
261
+ print("these files will be considered test datasets:")
262
+ for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
263
+ if cs + c > len_test:
264
+ break
265
+ else:
266
+ print(" " + self.dataset.mapped_dataset.path_list[i].path)
267
+ test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
268
+ cs += c
269
+
270
+ len_test = cs
271
+ print("perc test: ", len_test / self.n_samples)
272
+ self.test_idx = idx_full[:len_test]
273
+ idx_full = idx_full[len_test:]
274
+ else:
275
+ self.test_idx = None
276
+
277
+ np.random.shuffle(idx_full)
278
+ if len_valid > 0:
279
+ self.valid_idx = idx_full[:len_valid].copy()
280
+ idx_full = idx_full[len_valid:]
281
+ else:
282
+ self.valid_idx = None
283
+ weights = np.concatenate([weights, np.zeros(1)])
284
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
285
+
286
+ self.train_weights = weights
287
+ self.train_labels = labels
288
+ self.idx_full = idx_full
289
+
290
+ return test_datasets
291
+
292
+ def train_dataloader(self, **kwargs):
293
+ # train_sampler = WeightedRandomSampler(
294
+ # self.train_weights[self.train_labels],
295
+ # int(self.n_samples*self.train_oversampling_per_epoch),
296
+ # replacement=True,
297
+ # )
298
+ train_sampler = LabelWeightedSampler(
299
+ self.train_weights,
300
+ self.train_labels,
301
+ num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
302
+ # replacement=True,
303
+ )
304
+ return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
305
+
306
+ def val_dataloader(self):
307
+ return (
308
+ DataLoader(
309
+ self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
310
+ )
311
+ if self.valid_idx is not None
312
+ else None
313
+ )
314
+
315
+ def test_dataloader(self):
316
+ return (
317
+ DataLoader(
318
+ self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
319
+ )
320
+ if self.test_idx is not None
321
+ else None
322
+ )
323
+
324
+ # def teardown(self):
325
+ # clean up state after the trainer stops, delete files...
326
+ # called on every process in DDP
327
+ # pass
328
+
329
+
330
+ class LabelWeightedSampler(Sampler[int]):
331
+ label_weights: Sequence[float]
332
+ klass_indices: Sequence[Sequence[int]]
333
+ num_samples: int
334
+
335
+ # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
336
+ # this will result a class-balanced sampling, no matter how imbalance the labels are.
337
+ # NOTE: here we use replacement=True, you can change it if you don't upsample a class.
338
+ def __init__(
339
+ self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
340
+ ) -> None:
341
+ """
342
+
343
+ :param label_weights: list(len=num_classes)[float], weights for each class.
344
+ :param labels: list(len=dataset_len)[int], labels of a dataset.
345
+ :param num_samples: number of samples.
346
+ """
347
+
348
+ super(LabelWeightedSampler, self).__init__(None)
349
+ # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
350
+ label_weights = np.array(label_weights) * np.bincount(labels)
351
+
352
+ self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
353
+ self.labels = torch.as_tensor(labels, dtype=torch.int)
354
+ self.num_samples = num_samples
355
+ # list of tensor.
356
+ self.klass_indices = [
357
+ (self.labels == i_klass).nonzero().squeeze(1)
358
+ for i_klass in range(len(label_weights))
359
+ ]
360
+
361
+ def __iter__(self):
362
+ sample_labels = torch.multinomial(
363
+ self.label_weights, num_samples=self.num_samples, replacement=True
364
+ )
365
+ sample_indices = torch.empty_like(sample_labels)
366
+ for i_klass, klass_index in enumerate(self.klass_indices):
367
+ if klass_index.numel() == 0:
368
+ continue
369
+ left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
370
+ right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
371
+ sample_indices[left_inds] = klass_index[right_inds]
372
+ yield from iter(sample_indices.tolist())
373
+
374
+ def __len__(self):
375
+ return self.num_samples
scdataloader/mapped.py CHANGED
@@ -80,38 +80,51 @@ class MappedDataset:
80
80
  join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
81
81
  encode_labels: Optional[Union[bool, List[str]]] = False,
82
82
  parallel: bool = False,
83
+ unknown_class: str = "unknown",
83
84
  ):
84
85
  self.storages = []
85
86
  self.conns = []
86
87
  self.parallel = parallel
88
+ self.unknown_class = unknown_class
89
+ self.path_list = path_list
87
90
  self._make_connections(path_list, parallel)
88
91
 
89
92
  self.n_obs_list = []
90
93
  for storage in self.storages:
91
94
  with _Connect(storage) as store:
92
95
  X = store["X"]
96
+ index = (
97
+ store["var"]["ensembl_gene_id"]
98
+ if "ensembl_gene_id" in store["var"]
99
+ else store["var"]["_index"]
100
+ )
101
+ if join_vars == "None":
102
+ if not all(
103
+ [
104
+ i <= j
105
+ for i, j in zip(
106
+ index[:99],
107
+ index[1:100],
108
+ )
109
+ ]
110
+ ):
111
+ raise ValueError("The variables are not sorted.")
93
112
  if isinstance(X, ArrayTypes): # type: ignore
94
113
  self.n_obs_list.append(X.shape[0])
95
114
  else:
96
115
  self.n_obs_list.append(X.attrs["shape"][0])
97
116
  self.n_obs = sum(self.n_obs_list)
98
117
 
99
- self.indices = np.hstack(
100
- [np.arange(n_obs) for n_obs in self.n_obs_list]
101
- )
102
- self.storage_idx = np.repeat(
103
- np.arange(len(self.storages)), self.n_obs_list
104
- )
118
+ self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
119
+ self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
105
120
 
106
121
  self.join_vars = join_vars if len(path_list) > 1 else None
107
122
  self.var_indices = None
108
- if self.join_vars is not None:
123
+ if self.join_vars != "None":
109
124
  self._make_join_vars()
110
125
 
111
126
  self.encode_labels = encode_labels
112
- self.label_keys = (
113
- [label_keys] if isinstance(label_keys, str) else label_keys
114
- )
127
+ self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
115
128
  if isinstance(encode_labels, bool):
116
129
  if encode_labels:
117
130
  encode_labels = label_keys
@@ -122,6 +135,8 @@ class MappedDataset:
122
135
  for label in encode_labels:
123
136
  cats = self.get_merged_categories(label)
124
137
  self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
138
+ if unknown_class in self.encoders[label]:
139
+ self.encoders[label][unknown_class] = -1
125
140
  else:
126
141
  self.encoders = {}
127
142
  self._closed = False
@@ -157,9 +172,15 @@ class MappedDataset:
157
172
  raise ValueError(
158
173
  "The provided AnnData objects don't have shared varibales."
159
174
  )
160
- self.var_indices = [
161
- vrs.get_indexer(self.var_joint) for vrs in var_list
162
- ]
175
+ self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
176
+
177
+ def _check_aligned_vars(self, vars: list):
178
+ i = 0
179
+ for storage in self.storages:
180
+ with _Connect(storage) as store:
181
+ if vars == _safer_read_index(store["var"]).tolist():
182
+ i += 1
183
+ print("{}% are aligned".format(i * 100 / len(self.storages)))
163
184
 
164
185
  def __len__(self):
165
186
  return self.n_obs
@@ -172,20 +193,17 @@ class MappedDataset:
172
193
  else:
173
194
  var_idxs = None
174
195
  with _Connect(self.storages[storage_idx]) as store:
175
- out = [self.get_data_idx(store, obs_idx, var_idxs)]
196
+ out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
176
197
  if self.label_keys is not None:
177
- for i, label in enumerate(self.label_keys):
198
+ for _, label in enumerate(self.label_keys):
178
199
  label_idx = self.get_label_idx(store, obs_idx, label)
179
200
  if label in self.encoders:
180
- out.append(self.encoders[label][label_idx])
201
+ out.update({label: self.encoders[label][label_idx]})
181
202
  else:
182
- out.append(label_idx)
203
+ out.update({label: label_idx})
204
+ out.update({"dataset": storage_idx})
183
205
  return out
184
206
 
185
- def uns(self, idx, key):
186
- storage = self.storages[self.storage_idx[idx]]
187
- return storage["uns"][key]
188
-
189
207
  def get_data_idx(
190
208
  self,
191
209
  storage: StorageType,
@@ -240,13 +258,13 @@ class MappedDataset:
240
258
  if i == 0:
241
259
  labels = self.get_merged_labels(val)
242
260
  else:
243
- labels += "_" + self.get_merged_labels(val).astype(str).astype(
244
- "O"
245
- )
261
+ labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
246
262
  counter = Counter(labels) # type: ignore
247
- counter = np.array([counter[label] for label in labels])
263
+ rn = {n: i for i, n in enumerate(counter.keys())}
264
+ labels = np.array([rn[label] for label in labels])
265
+ counter = np.array(list(counter.values()))
248
266
  weights = scaler / (counter + scaler)
249
- return weights
267
+ return weights, labels
250
268
 
251
269
  def get_merged_labels(self, label_key: str):
252
270
  """Get merged labels."""
@@ -255,9 +273,7 @@ class MappedDataset:
255
273
  for storage in self.storages:
256
274
  with _Connect(storage) as store:
257
275
  codes = self.get_codes(store, label_key)
258
- labels = (
259
- decode(codes) if isinstance(codes[0], bytes) else codes
260
- )
276
+ labels = decode(codes) if isinstance(codes[0], bytes) else codes
261
277
  cats = self.get_categories(store, label_key)
262
278
  if cats is not None:
263
279
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
@@ -277,9 +293,7 @@ class MappedDataset:
277
293
  cats_merge.update(cats)
278
294
  else:
279
295
  codes = self.get_codes(store, label_key)
280
- codes = (
281
- decode(codes) if isinstance(codes[0], bytes) else codes
282
- )
296
+ codes = decode(codes) if isinstance(codes[0], bytes) else codes
283
297
  cats_merge.update(codes)
284
298
  return cats_merge
285
299