scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,318 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import lamindb as ln
4
+
5
+ from torch.utils.data.sampler import (
6
+ WeightedRandomSampler,
7
+ SubsetRandomSampler,
8
+ SequentialSampler,
9
+ )
10
+ from torch.utils.data import DataLoader
11
+ import lightning as L
12
+
13
+ from typing import Optional
14
+
15
+ from .data import Dataset
16
+ from .collator import Collator
17
+ from .mapped import MappedDataset
18
+ from .utils import getBiomartTable
19
+
20
+ # TODO: put in config
21
+ COARSE_TISSUE = {
22
+ "adipose tissue": "",
23
+ "bladder organ": "",
24
+ "blood": "",
25
+ "bone marrow": "",
26
+ "brain": "",
27
+ "breast": "",
28
+ "esophagus": "",
29
+ "eye": "",
30
+ "embryo": "",
31
+ "fallopian tube": "",
32
+ "gall bladder": "",
33
+ "heart": "",
34
+ "intestine": "",
35
+ "kidney": "",
36
+ "liver": "",
37
+ "lung": "",
38
+ "lymph node": "",
39
+ "musculature of body": "",
40
+ "nose": "",
41
+ "ovary": "",
42
+ "pancreas": "",
43
+ "placenta": "",
44
+ "skin of body": "",
45
+ "spinal cord": "",
46
+ "spleen": "",
47
+ "stomach": "",
48
+ "thymus": "",
49
+ "thyroid gland": "",
50
+ "tongue": "",
51
+ "uterus": "",
52
+ }
53
+
54
+ COARSE_ANCESTRY = {
55
+ "African": "",
56
+ "Chinese": "",
57
+ "East Asian": "",
58
+ "Eskimo": "",
59
+ "European": "",
60
+ "Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
61
+ "Hispanic or Latin American": "",
62
+ "Native American": "",
63
+ "Oceanian": "",
64
+ "South Asian": "",
65
+ }
66
+
67
+ COARSE_DEVELOPMENT_STAGE = {
68
+ "Embryonic human": "",
69
+ "Fetal": "",
70
+ "Immature": "",
71
+ "Mature": "",
72
+ }
73
+
74
+ COARSE_ASSAY = {
75
+ "10x 3'": "",
76
+ "10x 5'": "",
77
+ "10x multiome": "",
78
+ "CEL-seq2": "",
79
+ "Drop-seq": "",
80
+ "GEXSCOPE technology": "",
81
+ "inDrop": "",
82
+ "microwell-seq": "",
83
+ "sci-Plex": "",
84
+ "sci-RNA-seq": "",
85
+ "Seq-Well": "",
86
+ "Slide-seq": "",
87
+ "Smart-seq": "",
88
+ "SPLiT-seq": "",
89
+ "TruDrop": "",
90
+ "Visium Spatial Gene Expression": "",
91
+ }
92
+
93
+
94
+ class DataModule(L.LightningDataModule):
95
+ """
96
+ Base class for all data loaders
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ mdataset: Optional[MappedDataset] = None,
102
+ collection_name=None,
103
+ organisms: list = ["NCBITaxon:9606"],
104
+ weight_scaler: int = 30,
105
+ train_oversampling=1,
106
+ label_to_weight: list = [],
107
+ label_to_pred: list = [],
108
+ validation_split: float = 0.2,
109
+ test_split: float = 0,
110
+ use_default_col=True,
111
+ all_labels=[],
112
+ hierarchical_labels=[],
113
+ how="most expr",
114
+ organism_name="organism_ontology_term_id",
115
+ max_len=1000,
116
+ add_zero_genes=100,
117
+ do_gene_pos=True,
118
+ gene_embeddings="",
119
+ gene_position_tolerance=10_000,
120
+ **kwargs,
121
+ ):
122
+ """
123
+ Initializes the DataModule.
124
+
125
+ Args:
126
+ dataset (MappedDataset): The dataset to be used.
127
+ weight_scaler (int, optional): The weight scaler for weighted random sampling. Defaults to 30.
128
+ label_to_weight (list, optional): List of labels to weight. Defaults to [].
129
+ validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
130
+ test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
131
+ **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
132
+ """
133
+ if collection_name is not None:
134
+ mdataset = Dataset(
135
+ ln.Collection.filter(name=collection_name).first(),
136
+ organisms=organisms,
137
+ obs=all_labels,
138
+ clss_to_pred=label_to_pred,
139
+ hierarchical_clss=hierarchical_labels,
140
+ )
141
+ print(mdataset)
142
+ # and location
143
+ if do_gene_pos:
144
+ # and annotations
145
+ biomart = getBiomartTable(
146
+ attributes=["start_position", "chromosome_name"]
147
+ ).set_index("ensembl_gene_id")
148
+ biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
149
+ biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
150
+ c = []
151
+ i = 0
152
+ prev_position = -100000
153
+ prev_chromosome = None
154
+ for _, r in biomart.iterrows():
155
+ if (
156
+ r["chromosome_name"] != prev_chromosome
157
+ or r["start_position"] - prev_position > gene_position_tolerance
158
+ ):
159
+ i += 1
160
+ c.append(i)
161
+ prev_position = r["start_position"]
162
+ prev_chromosome = r["chromosome_name"]
163
+ print(f"reduced the size to {len(set(c))/len(biomart)}")
164
+ biomart["pos"] = c
165
+ mdataset.genedf = biomart.loc[
166
+ mdataset.genedf[mdataset.genedf.index.isin(biomart.index)].index
167
+ ]
168
+ self.gene_pos = mdataset.genedf["pos"].tolist()
169
+
170
+ if gene_embeddings != "":
171
+ mdataset.genedf = mdataset.genedf.join(
172
+ pd.read_parquet(gene_embeddings), how="inner"
173
+ )
174
+ if do_gene_pos:
175
+ self.gene_pos = mdataset.genedf["pos"].tolist()
176
+ self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
177
+ # we might want not to order the genes by expression (or do it?)
178
+ # we might want to not introduce zeros and
179
+ if use_default_col:
180
+ kwargs["collate_fn"] = Collator(
181
+ organisms=organisms,
182
+ how=how,
183
+ valid_genes=mdataset.genedf.index.tolist(),
184
+ max_len=max_len,
185
+ add_zero_genes=add_zero_genes,
186
+ org_to_id=mdataset.encoder[organism_name],
187
+ tp_name="heat_diff",
188
+ organism_name=organism_name,
189
+ class_names=label_to_weight,
190
+ )
191
+ self.validation_split = validation_split
192
+ self.test_split = test_split
193
+ self.dataset = mdataset
194
+ self.kwargs = kwargs
195
+ self.n_samples = len(mdataset)
196
+ self.weight_scaler = weight_scaler
197
+ self.train_oversampling = train_oversampling
198
+ self.label_to_weight = label_to_weight
199
+ super().__init__()
200
+
201
+ @property
202
+ def decoders(self):
203
+ decoders = {}
204
+ for k, v in self.dataset.encoder.items():
205
+ decoders[k] = {va: ke for ke, va in v.items()}
206
+ return decoders
207
+
208
+ @property
209
+ def cls_hierarchy(self):
210
+ cls_hierarchy = {}
211
+ for k, dic in self.dataset.class_groupings.items():
212
+ rdic = {}
213
+ for sk, v in dic.items():
214
+ rdic[self.dataset.encoder[k][sk]] = [
215
+ self.dataset.encoder[k][i] for i in list(v)
216
+ ]
217
+ cls_hierarchy[k] = rdic
218
+ return cls_hierarchy
219
+
220
+ @property
221
+ def genes(self):
222
+ return self.dataset.genedf.index.tolist()
223
+
224
+ def setup(self, stage=None):
225
+ """
226
+ setup method is used to prepare the data for the training, validation, and test sets.
227
+ It shuffles the data, calculates weights for each set, and creates samplers for each set.
228
+
229
+ Args:
230
+ stage (str, optional): The stage of the model training process.
231
+ It can be either 'fit' or 'test'. Defaults to None.
232
+ """
233
+
234
+ if len(self.label_to_weight) > 0:
235
+ weights = self.dataset.get_label_weights(
236
+ self.label_to_weight, scaler=self.weight_scaler
237
+ )
238
+ else:
239
+ weights = np.ones(self.n_samples)
240
+ if isinstance(self.validation_split, int):
241
+ len_valid = self.validation_split
242
+ else:
243
+ len_valid = int(self.n_samples * self.validation_split)
244
+ if isinstance(self.test_split, int):
245
+ len_test = self.test_split
246
+ else:
247
+ len_test = int(self.n_samples * self.test_split)
248
+ assert (
249
+ len_test + len_valid < self.n_samples
250
+ ), "test set + valid set size is configured to be larger than entire dataset."
251
+ idx_full = np.arange(self.n_samples)
252
+ if len_test > 0:
253
+ # this way we work on some never seen datasets
254
+ # keeping at least one
255
+ len_test = (
256
+ len_test
257
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
258
+ else self.dataset.mapped_dataset.n_obs_list[0]
259
+ )
260
+ cs = 0
261
+ test_datasets = []
262
+ print("these files will be considered test datasets:")
263
+ for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
264
+ if cs + c > len_test:
265
+ break
266
+ else:
267
+ print(" " + self.dataset.mapped_dataset.path_list[i].path)
268
+ test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
269
+ cs += c
270
+
271
+ len_test = cs
272
+ print("perc test: ", len_test / self.n_samples)
273
+ test_idx = idx_full[:len_test]
274
+ idx_full = idx_full[len_test:]
275
+ self.test_sampler = SequentialSampler(test_idx)
276
+ else:
277
+ self.test_sampler = None
278
+ test_datasets = None
279
+
280
+ np.random.shuffle(idx_full)
281
+ if len_valid > 0:
282
+ valid_idx = idx_full[:len_valid]
283
+ idx_full = idx_full[len_valid:]
284
+ self.valid_sampler = SubsetRandomSampler(valid_idx)
285
+ else:
286
+ self.valid_sampler = None
287
+
288
+ weights[~idx_full] = 0
289
+ self.train_sampler = WeightedRandomSampler(
290
+ weights,
291
+ int(len(idx_full) * self.train_oversampling),
292
+ replacement=True,
293
+ )
294
+ return test_datasets
295
+
296
+ def train_dataloader(self, **kwargs):
297
+ return DataLoader(
298
+ self.dataset, sampler=self.train_sampler, **self.kwargs, **kwargs
299
+ )
300
+
301
+ def val_dataloader(self):
302
+ return (
303
+ DataLoader(self.dataset, sampler=self.valid_sampler, **self.kwargs)
304
+ if self.valid_sampler is not None
305
+ else None
306
+ )
307
+
308
+ def test_dataloader(self):
309
+ return (
310
+ DataLoader(self.dataset, sampler=self.test_sampler, **self.kwargs)
311
+ if self.test_sampler is not None
312
+ else None
313
+ )
314
+
315
+ # def teardown(self):
316
+ # clean up state after the trainer stops, delete files...
317
+ # called on every process in DDP
318
+ # pass
scdataloader/mapped.py CHANGED
@@ -80,10 +80,13 @@ class MappedDataset:
80
80
  join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
81
81
  encode_labels: Optional[Union[bool, List[str]]] = False,
82
82
  parallel: bool = False,
83
+ unknown_class: str = "unknown",
83
84
  ):
84
85
  self.storages = []
85
86
  self.conns = []
86
87
  self.parallel = parallel
88
+ self.unknown_class = unknown_class
89
+ self.path_list = path_list
87
90
  self._make_connections(path_list, parallel)
88
91
 
89
92
  self.n_obs_list = []
@@ -96,22 +99,16 @@ class MappedDataset:
96
99
  self.n_obs_list.append(X.attrs["shape"][0])
97
100
  self.n_obs = sum(self.n_obs_list)
98
101
 
99
- self.indices = np.hstack(
100
- [np.arange(n_obs) for n_obs in self.n_obs_list]
101
- )
102
- self.storage_idx = np.repeat(
103
- np.arange(len(self.storages)), self.n_obs_list
104
- )
102
+ self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
103
+ self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
105
104
 
106
105
  self.join_vars = join_vars if len(path_list) > 1 else None
107
106
  self.var_indices = None
108
- if self.join_vars is not None:
107
+ if self.join_vars != "None":
109
108
  self._make_join_vars()
110
109
 
111
110
  self.encode_labels = encode_labels
112
- self.label_keys = (
113
- [label_keys] if isinstance(label_keys, str) else label_keys
114
- )
111
+ self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
115
112
  if isinstance(encode_labels, bool):
116
113
  if encode_labels:
117
114
  encode_labels = label_keys
@@ -122,6 +119,8 @@ class MappedDataset:
122
119
  for label in encode_labels:
123
120
  cats = self.get_merged_categories(label)
124
121
  self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
122
+ if unknown_class in self.encoders[label]:
123
+ self.encoders[label][unknown_class] = -1
125
124
  else:
126
125
  self.encoders = {}
127
126
  self._closed = False
@@ -157,9 +156,15 @@ class MappedDataset:
157
156
  raise ValueError(
158
157
  "The provided AnnData objects don't have shared varibales."
159
158
  )
160
- self.var_indices = [
161
- vrs.get_indexer(self.var_joint) for vrs in var_list
162
- ]
159
+ self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
160
+
161
+ def _check_aligned_vars(self, vars: list):
162
+ i = 0
163
+ for storage in self.storages:
164
+ with _Connect(storage) as store:
165
+ if vars == _safer_read_index(store["var"]).tolist():
166
+ i += 1
167
+ print("{}% are aligned".format(i * 100 / len(self.storages)))
163
168
 
164
169
  def __len__(self):
165
170
  return self.n_obs
@@ -172,14 +177,14 @@ class MappedDataset:
172
177
  else:
173
178
  var_idxs = None
174
179
  with _Connect(self.storages[storage_idx]) as store:
175
- out = [self.get_data_idx(store, obs_idx, var_idxs)]
180
+ out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
176
181
  if self.label_keys is not None:
177
182
  for i, label in enumerate(self.label_keys):
178
183
  label_idx = self.get_label_idx(store, obs_idx, label)
179
184
  if label in self.encoders:
180
- out.append(self.encoders[label][label_idx])
185
+ out.update({label: self.encoders[label][label_idx]})
181
186
  else:
182
- out.append(label_idx)
187
+ out.update({label: label_idx})
183
188
  return out
184
189
 
185
190
  def uns(self, idx, key):
@@ -240,9 +245,7 @@ class MappedDataset:
240
245
  if i == 0:
241
246
  labels = self.get_merged_labels(val)
242
247
  else:
243
- labels += "_" + self.get_merged_labels(val).astype(str).astype(
244
- "O"
245
- )
248
+ labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
246
249
  counter = Counter(labels) # type: ignore
247
250
  counter = np.array([counter[label] for label in labels])
248
251
  weights = scaler / (counter + scaler)
@@ -255,9 +258,7 @@ class MappedDataset:
255
258
  for storage in self.storages:
256
259
  with _Connect(storage) as store:
257
260
  codes = self.get_codes(store, label_key)
258
- labels = (
259
- decode(codes) if isinstance(codes[0], bytes) else codes
260
- )
261
+ labels = decode(codes) if isinstance(codes[0], bytes) else codes
261
262
  cats = self.get_categories(store, label_key)
262
263
  if cats is not None:
263
264
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
@@ -277,9 +278,7 @@ class MappedDataset:
277
278
  cats_merge.update(cats)
278
279
  else:
279
280
  codes = self.get_codes(store, label_key)
280
- codes = (
281
- decode(codes) if isinstance(codes[0], bytes) else codes
282
- )
281
+ codes = decode(codes) if isinstance(codes[0], bytes) else codes
283
282
  cats_merge.update(codes)
284
283
  return cats_merge
285
284