scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -1,83 +1,57 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
3
  import lamindb as ln
4
- import lnschema_bionty as lb
4
+
5
+ # ln.connect("scprint")
6
+
7
+ import bionty as bt
5
8
  import pandas as pd
6
9
  from torch.utils.data import Dataset as torchDataset
7
- from typing import Union
10
+ from typing import Union, Optional, Literal
8
11
  from scdataloader import mapped
9
12
  import warnings
10
13
 
11
- # TODO: manage load gene embeddings to make
12
- # from scprint.dataloader.embedder import embed
14
+ from anndata import AnnData
15
+ from scipy.sparse import issparse
16
+
13
17
  from scdataloader.utils import get_ancestry_mapping, load_genes
14
18
 
15
- LABELS_TOADD = {
16
- "assay_ontology_term_id": {
17
- "10x transcription profiling": "EFO:0030003",
18
- "spatial transcriptomics": "EFO:0008994",
19
- "10x 3' transcription profiling": "EFO:0030003",
20
- "10x 5' transcription profiling": "EFO:0030004",
21
- },
22
- "disease_ontology_term_id": {
23
- "metabolic disease": "MONDO:0005066",
24
- "chronic kidney disease": "MONDO:0005300",
25
- "chromosomal disorder": "MONDO:0019040",
26
- "infectious disease": "MONDO:0005550",
27
- "inflammatory disease": "MONDO:0021166",
28
- # "immune system disease",
29
- "disorder of development or morphogenesis": "MONDO:0021147",
30
- "mitochondrial disease": "MONDO:0044970",
31
- "psychiatric disorder": "MONDO:0002025",
32
- "cancer or benign tumor": "MONDO:0002025",
33
- "neoplasm": "MONDO:0005070",
34
- },
35
- "cell_type_ontology_term_id": {
36
- "progenitor cell": "CL:0011026",
37
- "hematopoietic cell": "CL:0000988",
38
- "myoblast": "CL:0000056",
39
- "myeloid cell": "CL:0000763",
40
- "neuron": "CL:0000540",
41
- "electrically active cell": "CL:0000211",
42
- "epithelial cell": "CL:0000066",
43
- "secretory cell": "CL:0000151",
44
- "stem cell": "CL:0000034",
45
- "non-terminally differentiated cell": "CL:0000055",
46
- "supporting cell": "CL:0000630",
47
- },
48
- }
19
+ from .config import LABELS_TOADD
49
20
 
50
21
 
51
22
  @dataclass
52
23
  class Dataset(torchDataset):
53
24
  """
54
- Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
25
+ Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
55
26
 
56
- For an example, see :meth:`~lamindb.Dataset.mapped`.
27
+ This serves as a wrapper around lamin's mappedCollection to provide more features,
28
+ mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
29
+
30
+ For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
57
31
 
58
32
  .. note::
59
33
 
60
- A similar data loader exists `here
34
+ A related data loader exists `here
61
35
  <https://github.com/Genentech/scimilarity>`__.
62
36
 
63
- Attributes:
37
+ Args:
64
38
  ----
65
39
  lamin_dataset (lamindb.Dataset): lamin dataset to load
66
40
  genedf (pd.Dataframe): dataframe containing the genes to load
67
- gene_embedding: dataframe containing the gene embeddings
68
41
  organisms (list[str]): list of organisms to load
69
- obs (list[str]): list of observations to load
42
+ (for now only validates the the genes map to this organism)
43
+ obs (list[str]): list of observations to load from the Collection
70
44
  clss_to_pred (list[str]): list of observations to encode
71
- hierarchical_clss: list of observations to map to a hierarchy
45
+ join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
46
+ hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
72
47
  """
73
48
 
74
49
  lamin_dataset: ln.Collection
75
- genedf: pd.DataFrame = None
76
- # gene_embedding: pd.DataFrame = None # TODO: make it part of specialized dataset
77
- organisms: Union[list[str], str] = field(
50
+ genedf: Optional[pd.DataFrame] = None
51
+ organisms: Optional[Union[list[str], str]] = field(
78
52
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
79
53
  )
80
- obs: list[str] = field(
54
+ obs: Optional[list[str]] = field(
81
55
  default_factory=[
82
56
  "self_reported_ethnicity_ontology_term_id",
83
57
  "assay_ontology_term_id",
@@ -88,29 +62,32 @@ class Dataset(torchDataset):
88
62
  "sex_ontology_term_id",
89
63
  #'dataset_id',
90
64
  #'cell_culture',
91
- "dpt_group",
92
- "heat_diff",
93
- "nnz",
65
+ # "dpt_group",
66
+ # "heat_diff",
67
+ # "nnz",
94
68
  ]
95
69
  )
96
70
  # set of obs to prepare for prediction (encode)
97
- clss_to_pred: list[str] = field(default_factory=list)
71
+ clss_to_pred: Optional[list[str]] = field(default_factory=list)
98
72
  # set of obs that need to be hierarchically prepared
99
- hierarchical_clss: list[str] = field(default_factory=list)
100
- join_vars: str = "None"
73
+ hierarchical_clss: Optional[list[str]] = field(default_factory=list)
74
+ join_vars: Literal["inner", "outer"] | None = None
101
75
 
102
76
  def __post_init__(self):
103
77
  self.mapped_dataset = mapped.mapped(
104
78
  self.lamin_dataset,
105
79
  label_keys=self.obs,
80
+ join=self.join_vars,
106
81
  encode_labels=self.clss_to_pred,
82
+ unknown_label="unknown",
107
83
  stream=True,
108
84
  parallel=True,
109
- join_vars=self.join_vars,
110
85
  )
111
86
  print(
112
87
  "won't do any check but we recommend to have your dataset coming from local storage"
113
88
  )
89
+ self.labels_groupings = {}
90
+ self.class_topred = {}
114
91
  # generate tree from ontologies
115
92
  if len(self.hierarchical_clss) > 0:
116
93
  self.define_hierarchies(self.hierarchical_clss)
@@ -121,24 +98,19 @@ class Dataset(torchDataset):
121
98
  self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
122
99
  clss
123
100
  )
124
- update = {}
125
- c = 0
126
- for k, v in self.mapped_dataset.encoders[clss].items():
127
- if k == self.mapped_dataset.unknown_class:
128
- update.update({k: v})
129
- c += 1
130
- self.class_topred[clss] -= set([k])
131
- else:
132
- update.update({k: v - c})
133
- self.mapped_dataset.encoders[clss] = update
101
+ if (
102
+ self.mapped_dataset.unknown_label
103
+ in self.mapped_dataset.encoders[clss].keys()
104
+ ):
105
+ self.class_topred[clss] -= set(
106
+ [self.mapped_dataset.unknown_label]
107
+ )
134
108
 
135
109
  if self.genedf is None:
136
110
  self.genedf = load_genes(self.organisms)
137
111
 
138
112
  self.genedf.columns = self.genedf.columns.astype(str)
139
- for organism in self.organisms:
140
- ogenedf = self.genedf[self.genedf.organism == organism]
141
- self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
113
+ self.mapped_dataset._check_aligned_vars(self.genedf.index.tolist())
142
114
 
143
115
  def __len__(self, **kwargs):
144
116
  return self.mapped_dataset.__len__(**kwargs)
@@ -149,66 +121,66 @@ class Dataset(torchDataset):
149
121
 
150
122
  def __getitem__(self, *args, **kwargs):
151
123
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
152
- #item.update({"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)})
153
- # ret = {}
154
- # ret["count"] = item[0]
155
- # for i, val in enumerate(self.obs):
156
- # ret[val] = item[1][i]
157
- ## mark unseen genes with a flag
158
- ## send the associated
159
- # print(item[0].shape)
160
124
  return item
161
125
 
162
126
  def __repr__(self):
163
- print(
164
- "total dataset size is {} Gb".format(
127
+ return (
128
+ "total dataset size is {} Gb\n".format(
165
129
  sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
166
130
  )
167
- )
168
- print("---")
169
- print("dataset contains:")
170
- print(" {} cells".format(self.mapped_dataset.__len__()))
171
- print(" {} genes".format(self.genedf.shape[0]))
172
- print(" {} labels".format(len(self.obs)))
173
- print(" {} organisms".format(len(self.organisms)))
174
- print(
175
- "dataset contains {} classes to predict".format(
176
- sum([len(self.class_topred[i]) for i in self.class_topred])
131
+ + "---\n"
132
+ + "dataset contains:\n"
133
+ + " {} cells\n".format(self.mapped_dataset.__len__())
134
+ + " {} genes\n".format(self.genedf.shape[0])
135
+ + " {} labels\n".format(len(self.obs))
136
+ + " {} clss_to_pred\n".format(len(self.clss_to_pred))
137
+ + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
138
+ + " {} organisms\n".format(len(self.organisms))
139
+ + (
140
+ "dataset contains {} classes to predict\n".format(
141
+ sum([len(self.class_topred[i]) for i in self.class_topred])
142
+ )
143
+ if len(self.class_topred) > 0
144
+ else ""
177
145
  )
178
146
  )
179
- # print("embedding size is {}".format(self.gene_embedding.shape[1]))
180
- return ""
181
147
 
182
148
  def get_label_weights(self, *args, **kwargs):
149
+ """
150
+ get_label_weights is a wrapper around mappedDataset.get_label_weights
151
+
152
+ Returns:
153
+ dict: dictionary of weights for each label
154
+ """
183
155
  return self.mapped_dataset.get_label_weights(*args, **kwargs)
184
156
 
185
- def get_unseen_mapped_dataset_elements(self, idx):
157
+ def get_unseen_mapped_dataset_elements(self, idx: int):
158
+ """
159
+ get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
160
+
161
+ Args:
162
+ idx (int): index of the element to get
163
+
164
+ Returns:
165
+ list[str]: list of unseen genes
166
+ """
186
167
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
187
168
 
188
- # def load_embeddings(self, genedfs, embedding_size=128, cache=True):
189
- # embeddings = []
190
- # for o in self.organisms:
191
- # genedf = genedfs[genedfs.organism == o]
192
- # org_name = lb.Organism.filter(ontology_id=o).one().scientific_name
193
- # embedding = embed(
194
- # genedf=genedf,
195
- # organism=org_name,
196
- # cache=cache,
197
- # fasta_path="/tmp/data/fasta/",
198
- # embedding_size=embedding_size,
199
- # )
200
- # genedf = pd.concat(
201
- # [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
202
- # )
203
- # genedf.columns = genedf.columns.astype(str)
204
- # embeddings.append(genedf)
205
- # return pd.concat(embeddings)
206
-
207
- def define_hierarchies(self, labels):
208
- self.class_groupings = {}
169
+ def define_hierarchies(self, clsses: list[str]):
170
+ """
171
+ define_hierarchies is a method to define the hierarchies for the classes to predict
172
+
173
+ Args:
174
+ clsses (list[str]): list of classes to predict
175
+
176
+ Raises:
177
+ ValueError: if the class is not in the accepted classes
178
+ """
179
+ # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
180
+ self.labels_groupings = {}
209
181
  self.class_topred = {}
210
- for label in labels:
211
- if label not in [
182
+ for clss in clsses:
183
+ if clss not in [
212
184
  "cell_type_ontology_term_id",
213
185
  "tissue_ontology_term_id",
214
186
  "disease_ontology_term_id",
@@ -217,120 +189,139 @@ class Dataset(torchDataset):
217
189
  "self_reported_ethnicity_ontology_term_id",
218
190
  ]:
219
191
  raise ValueError(
220
- "label {} not in accepted labels, for now only supported from bionty sources".format(
221
- label
192
+ "class {} not in accepted classes, for now only supported from bionty sources".format(
193
+ clss
222
194
  )
223
195
  )
224
- elif label == "cell_type_ontology_term_id":
196
+ elif clss == "cell_type_ontology_term_id":
225
197
  parentdf = (
226
- lb.CellType.filter()
198
+ bt.CellType.filter()
227
199
  .df(include=["parents__ontology_id"])
228
200
  .set_index("ontology_id")
229
201
  )
230
- elif label == "tissue_ontology_term_id":
202
+ elif clss == "tissue_ontology_term_id":
231
203
  parentdf = (
232
- lb.Tissue.filter()
204
+ bt.Tissue.filter()
233
205
  .df(include=["parents__ontology_id"])
234
206
  .set_index("ontology_id")
235
207
  )
236
- elif label == "disease_ontology_term_id":
208
+ elif clss == "disease_ontology_term_id":
237
209
  parentdf = (
238
- lb.Disease.filter()
210
+ bt.Disease.filter()
239
211
  .df(include=["parents__ontology_id"])
240
212
  .set_index("ontology_id")
241
213
  )
242
- elif label == "development_stage_ontology_term_id":
214
+ elif clss == "development_stage_ontology_term_id":
243
215
  parentdf = (
244
- lb.DevelopmentalStage.filter()
216
+ bt.DevelopmentalStage.filter()
245
217
  .df(include=["parents__ontology_id"])
246
218
  .set_index("ontology_id")
247
219
  )
248
- elif label == "assay_ontology_term_id":
220
+ elif clss == "assay_ontology_term_id":
249
221
  parentdf = (
250
- lb.ExperimentalFactor.filter()
222
+ bt.ExperimentalFactor.filter()
251
223
  .df(include=["parents__ontology_id"])
252
224
  .set_index("ontology_id")
253
225
  )
254
- elif label == "self_reported_ethnicity_ontology_term_id":
226
+ elif clss == "self_reported_ethnicity_ontology_term_id":
255
227
  parentdf = (
256
- lb.Ethnicity.filter()
228
+ bt.Ethnicity.filter()
257
229
  .df(include=["parents__ontology_id"])
258
230
  .set_index("ontology_id")
259
231
  )
260
232
 
261
233
  else:
262
234
  raise ValueError(
263
- "label {} not in accepted labels, for now only supported from bionty sources".format(
264
- label
235
+ "class {} not in accepted classes, for now only supported from bionty sources".format(
236
+ clss
265
237
  )
266
238
  )
267
- cats = self.mapped_dataset.get_merged_categories(label)
268
- addition = set(LABELS_TOADD.get(label, {}).values())
239
+ cats = self.mapped_dataset.get_merged_categories(clss)
240
+ addition = set(LABELS_TOADD.get(clss, {}).values())
269
241
  cats |= addition
270
- groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
242
+ groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
271
243
  for i, j in groupings.items():
272
244
  if len(j) == 0:
273
245
  groupings.pop(i)
274
- self.class_groupings[label] = groupings
275
- if label in self.clss_to_pred:
276
- # if we have added new labels, we need to update the encoder with them too.
277
- mlength = len(self.mapped_dataset.encoders[label])
246
+ self.labels_groupings[clss] = groupings
247
+ if clss in self.clss_to_pred:
248
+ # if we have added new clss, we need to update the encoder with them too.
249
+ mlength = len(self.mapped_dataset.encoders[clss])
250
+
278
251
  mlength -= (
279
252
  1
280
- if self.mapped_dataset.unknown_class
281
- in self.mapped_dataset.encoders[label].keys()
253
+ if self.mapped_dataset.unknown_label
254
+ in self.mapped_dataset.encoders[clss].keys()
282
255
  else 0
283
256
  )
284
257
 
285
258
  for i, v in enumerate(
286
- addition - set(self.mapped_dataset.encoders[label].keys())
259
+ addition - set(self.mapped_dataset.encoders[clss].keys())
287
260
  ):
288
- self.mapped_dataset.encoders[label].update({v: mlength + i})
261
+ self.mapped_dataset.encoders[clss].update({v: mlength + i})
289
262
  # we need to change the ordering so that the things that can't be predicted appear afterward
290
263
 
291
- self.class_topred[label] = lclass
264
+ self.class_topred[clss] = leaf_labels
292
265
  c = 0
293
- d = 0
294
266
  update = {}
295
- mlength = len(lclass)
296
- # import pdb
297
-
298
- # pdb.set_trace()
267
+ mlength = len(leaf_labels)
299
268
  mlength -= (
300
269
  1
301
- if self.mapped_dataset.unknown_class
302
- in self.mapped_dataset.encoders[label].keys()
270
+ if self.mapped_dataset.unknown_label
271
+ in self.mapped_dataset.encoders[clss].keys()
303
272
  else 0
304
273
  )
305
- for k, v in self.mapped_dataset.encoders[label].items():
306
- if k in self.class_groupings[label].keys():
274
+ for k, v in self.mapped_dataset.encoders[clss].items():
275
+ if k in self.labels_groupings[clss].keys():
307
276
  update.update({k: mlength + c})
308
277
  c += 1
309
- elif k == self.mapped_dataset.unknown_class:
278
+ elif k == self.mapped_dataset.unknown_label:
310
279
  update.update({k: v})
311
- d += 1
312
- self.class_topred[label] -= set([k])
280
+ self.class_topred[clss] -= set([k])
313
281
  else:
314
- update.update({k: (v - c) - d})
315
- self.mapped_dataset.encoders[label] = update
282
+ update.update({k: v - c})
283
+ self.mapped_dataset.encoders[clss] = update
316
284
 
317
285
 
318
- class SimpleAnnDataset:
319
- def __init__(self, adata, obs_to_output=[], layer=None):
320
- self.adata = adata
321
- self.obs_to_output = obs_to_output
322
- self.layer = layer
286
+ class SimpleAnnDataset(torchDataset):
287
+ def __init__(
288
+ self,
289
+ adata: AnnData,
290
+ obs_to_output: Optional[list[str]] = [],
291
+ layer: Optional[str] = None,
292
+ ):
293
+ """
294
+ SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
295
+ scDataloader and with your model during inference.
296
+
297
+ Args:
298
+ ----
299
+ adata (anndata.AnnData): anndata object to use
300
+ obs_to_output (list[str]): list of observations to output from anndata.obs
301
+ layer (str): layer of the anndata to use
302
+ """
303
+ self.adataX = adata.layers[layer] if layer is not None else adata.X
304
+ self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
305
+ self.obs_to_output = adata.obs[obs_to_output]
323
306
 
324
307
  def __len__(self):
325
- return self.adata.shape[0]
308
+ return self.adataX.shape[0]
309
+
310
+ def __iter__(self):
311
+ for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
312
+ with warnings.catch_warnings():
313
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
314
+ out = {"x": self.adataX[idx].reshape(-1)}
315
+ out.update(
316
+ {name: val for name, val in self.obs_to_output.iloc[idx].items()}
317
+ )
318
+ yield out
326
319
 
327
320
  def __getitem__(self, idx):
328
321
  with warnings.catch_warnings():
329
322
  warnings.filterwarnings("ignore", category=DeprecationWarning)
330
- if self.layer is not None:
331
- out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
332
- else:
333
- out = {"x": self.adata.X[idx].toarray().reshape(-1)}
334
- for i in self.obs_to_output:
335
- out.update({i: self.adata.obs.iloc[idx][i]})
323
+ out = {"x": self.adataX[idx].reshape(-1)}
324
+ out.update(
325
+ {name: val for name, val in self.obs_to_output.iloc[idx].items()}
326
+ )
336
327
  return out