scdataloader 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -1,84 +1,53 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
3
  import lamindb as ln
4
- import lnschema_bionty as lb
4
+ import bionty as bt
5
5
  import pandas as pd
6
6
  from torch.utils.data import Dataset as torchDataset
7
-
7
+ from typing import Union, Optional, Literal
8
8
  from scdataloader import mapped
9
+ import warnings
10
+
11
+ from anndata import AnnData
12
+
13
+ from scdataloader.utils import get_ancestry_mapping, load_genes
9
14
 
10
- # TODO: manage load gene embeddings to make
11
- # from scprint.dataloader.embedder import embed
12
- from scdataloader.utils import get_ancestry_mapping, pd_load_cached
13
-
14
- LABELS_TOADD = {
15
- "assay_ontology_term_id": [
16
- "10x transcription profiling",
17
- "spatial transcriptomics",
18
- "10x 3' transcription profiling",
19
- "10x 5' transcription profiling",
20
- ],
21
- "disease_ontology_term_id": [
22
- "metabolic disease",
23
- "chronic kidney disease",
24
- "chromosomal disorder",
25
- "infectious disease",
26
- "inflammatory disease",
27
- # "immune system disease",
28
- "disorder of development or morphogenesis",
29
- "mitochondrial disease",
30
- "psychiatric disorder",
31
- "cancer or benign tumor",
32
- "neoplasm",
33
- ],
34
- "cell_type_ontology_term_id": [
35
- "progenitor cell",
36
- "hematopoietic cell",
37
- "myoblast",
38
- "myeloid cell",
39
- "neuron",
40
- "electrically active cell",
41
- "epithelial cell",
42
- "secretory cell",
43
- "stem cell",
44
- "non-terminally differentiated cell",
45
- "supporting cell",
46
- ],
47
- }
15
+ from .config import LABELS_TOADD
48
16
 
49
17
 
50
18
  @dataclass
51
19
  class Dataset(torchDataset):
52
20
  """
53
- Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
21
+ Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
54
22
 
55
- For an example, see :meth:`~lamindb.Dataset.mapped`.
23
+ This serves as a wrapper around lamin's mappedCollection to provide more features,
24
+ mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
25
+
26
+ For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
56
27
 
57
28
  .. note::
58
29
 
59
- A similar data loader exists `here
30
+ A related data loader exists `here
60
31
  <https://github.com/Genentech/scimilarity>`__.
61
32
 
62
- Attributes:
33
+ Args:
63
34
  ----
64
35
  lamin_dataset (lamindb.Dataset): lamin dataset to load
65
36
  genedf (pd.Dataframe): dataframe containing the genes to load
66
- gene_embedding: dataframe containing the gene embeddings
67
37
  organisms (list[str]): list of organisms to load
68
- obs (list[str]): list of observations to load
69
- encode_obs (list[str]): list of observations to encode
70
- map_hierarchy: list of observations to map to a hierarchy
38
+ (for now only validates the the genes map to this organism)
39
+ obs (list[str]): list of observations to load from the Collection
40
+ clss_to_pred (list[str]): list of observations to encode
41
+ join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
42
+ hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
71
43
  """
72
44
 
73
- lamin_dataset: ln.Dataset
74
- genedf: pd.DataFrame = None
75
- gene_embedding: pd.DataFrame = (
76
- None # TODO: make it part of specialized dataset
77
- )
78
- organisms: list[str] = field(
45
+ lamin_dataset: ln.Collection
46
+ genedf: Optional[pd.DataFrame] = None
47
+ organisms: Optional[Union[list[str], str]] = field(
79
48
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
80
49
  )
81
- obs: list[str] = field(
50
+ obs: Optional[list[str]] = field(
82
51
  default_factory=[
83
52
  "self_reported_ethnicity_ontology_term_id",
84
53
  "assay_ontology_term_id",
@@ -89,48 +58,75 @@ class Dataset(torchDataset):
89
58
  "sex_ontology_term_id",
90
59
  #'dataset_id',
91
60
  #'cell_culture',
92
- "dpt_group",
93
- "heat_diff",
94
- "nnz",
61
+ #"dpt_group",
62
+ #"heat_diff",
63
+ #"nnz",
95
64
  ]
96
65
  )
97
- encode_obs: list[str] = field(default_factory=list)
98
- map_hierarchy: list[str] = field(default_factory=list)
66
+ # set of obs to prepare for prediction (encode)
67
+ clss_to_pred: Optional[list[str]] = field(default_factory=list)
68
+ # set of obs that need to be hierarchically prepared
69
+ hierarchical_clss: Optional[list[str]] = field(default_factory=list)
70
+ join_vars: Optional[Literal["auto", "inner", "None"]] = "None"
99
71
 
100
72
  def __post_init__(self):
101
73
  self.mapped_dataset = mapped.mapped(
102
74
  self.lamin_dataset,
103
75
  label_keys=self.obs,
104
- encode_labels=self.encode_obs,
76
+ encode_labels=self.clss_to_pred,
105
77
  stream=True,
106
78
  parallel=True,
107
- join_vars="None",
79
+ join_vars=self.join_vars,
108
80
  )
109
81
  print(
110
82
  "won't do any check but we recommend to have your dataset coming from local storage"
111
83
  )
84
+ self.class_groupings = {}
85
+ self.class_topred = {}
112
86
  # generate tree from ontologies
113
- if len(self.map_hierarchy) > 0:
114
- self.define_hierarchies(self.map_hierarchy)
87
+ if len(self.hierarchical_clss) > 0:
88
+ self.define_hierarchies(self.hierarchical_clss)
89
+ if len(self.clss_to_pred) > 0:
90
+ for clss in self.clss_to_pred:
91
+ if clss not in self.hierarchical_clss:
92
+ # otherwise it's already been done
93
+ self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
94
+ clss
95
+ )
96
+ update = {}
97
+ c = 0
98
+ for k, v in self.mapped_dataset.encoders[clss].items():
99
+ if k == self.mapped_dataset.unknown_class:
100
+ update.update({k: v})
101
+ c += 1
102
+ self.class_topred[clss] -= set([k])
103
+ else:
104
+ update.update({k: v - c})
105
+ self.mapped_dataset.encoders[clss] = update
115
106
 
116
107
  if self.genedf is None:
117
- self.genedf = self.load_genes(self.organisms)
118
-
119
- if self.gene_embedding is None:
120
- self.gene_embedding = self.load_embeddings(self.genedf)
121
- else:
122
- # self.genedf = pd.concat(
123
- # [self.genedf.set_index("ensembl_gene_id"), self.gene_embedding],
124
- # axis=1,
125
- # join="inner",
126
- # )
127
- self.genedf.columns = self.genedf.columns.astype(str)
108
+ self.genedf = load_genes(self.organisms)
109
+
110
+ self.genedf.columns = self.genedf.columns.astype(str)
111
+ for organism in self.organisms:
112
+ ogenedf = self.genedf[self.genedf.organism == organism]
113
+ self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
128
114
 
129
115
  def __len__(self, **kwargs):
130
116
  return self.mapped_dataset.__len__(**kwargs)
131
117
 
118
+ @property
119
+ def encoder(self):
120
+ return self.mapped_dataset.encoders
121
+
132
122
  def __getitem__(self, *args, **kwargs):
133
123
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
124
+ # import pdb
125
+
126
+ # pdb.set_trace()
127
+ # item.update(
128
+ # {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
129
+ # )
134
130
  # ret = {}
135
131
  # ret["count"] = item[0]
136
132
  # for i, val in enumerate(self.obs):
@@ -141,149 +137,36 @@ class Dataset(torchDataset):
141
137
  return item
142
138
 
143
139
  def __repr__(self):
144
- print(
145
- "total dataset size is {} Gb".format(
146
- sum([file.size for file in self.lamin_dataset.artifacts.all()])
147
- / 1e9
140
+ return (
141
+ "total dataset size is {} Gb\n".format(
142
+ sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
148
143
  )
149
- )
150
- print("---")
151
- print("dataset contains:")
152
- print(" {} cells".format(self.mapped_dataset.__len__()))
153
- print(" {} genes".format(self.genedf.shape[0]))
154
- print(" {} labels".format(len(self.obs)))
155
- print(" {} organisms".format(len(self.organisms)))
156
- print(
157
- "dataset contains {} classes to predict".format(
158
- sum([len(self.class_topred[i]) for i in self.class_topred])
144
+ + "---\n"
145
+ + "dataset contains:\n"
146
+ + " {} cells\n".format(self.mapped_dataset.__len__())
147
+ + " {} genes\n".format(self.genedf.shape[0])
148
+ + " {} labels\n".format(len(self.obs))
149
+ + " {} clss_to_pred\n".format(len(self.clss_to_pred))
150
+ + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
151
+ + " {} join_vars\n".format(len(self.join_vars))
152
+ + " {} organisms\n".format(len(self.organisms))
153
+ + (
154
+ "dataset contains {} classes to predict\n".format(
155
+ sum([len(self.class_topred[i]) for i in self.class_topred])
156
+ )
157
+ if len(self.class_topred) > 0
158
+ else ""
159
159
  )
160
160
  )
161
- print("embedding size is {}".format(self.gene_embedding.shape[1]))
162
- return ""
163
161
 
164
162
  def get_label_weights(self, *args, **kwargs):
165
163
  return self.mapped_dataset.get_label_weights(*args, **kwargs)
166
164
 
167
- def get_unseen_mapped_dataset_elements(self, idx):
168
- return [
169
- str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")
170
- ]
171
-
172
- def use_prior_network(
173
- self, name="collectri", organism="human", split_complexes=True
174
- ):
175
- """
176
- use_prior_network loads a prior GRN from a list of available networks.
177
-
178
- Args:
179
- name (str, optional): name of the network to load. Defaults to "collectri".
180
- organism (str, optional): organism to load the network for. Defaults to "human".
181
- split_complexes (bool, optional): whether to split complexes into individual genes. Defaults to True.
182
-
183
- Raises:
184
- ValueError: if the provided name is not amongst the available names.
185
- """
186
- # TODO: use omnipath instead
187
- if name == "tflink":
188
- TFLINK = "https://cdn.netbiol.org/tflink/download_files/TFLink_Homo_sapiens_interactions_All_simpleFormat_v1.0.tsv.gz"
189
- net = pd_load_cached(TFLINK)
190
- net = net.rename(
191
- columns={"Name.TF": "regulator", "Name.Target": "target"}
192
- )
193
- elif name == "htftarget":
194
- HTFTARGET = "http://bioinfo.life.hust.edu.cn/static/hTFtarget/file_download/tf-target-infomation.txt"
195
- net = pd_load_cached(HTFTARGET)
196
- net = net.rename(columns={"TF": "regulator"})
197
- elif name == "collectri":
198
- import decoupler as dc
199
-
200
- net = dc.get_collectri(
201
- organism=organism, split_complexes=split_complexes
202
- )
203
- net = net.rename(columns={"source": "regulator"})
204
- else:
205
- raise ValueError(
206
- f"provided name: '{name}' is not amongst the available names."
207
- )
208
- self.add_prior_network(net)
209
-
210
- def add_prior_network(self, prior_network: pd.DataFrame, init_len):
211
- # validate the network dataframe
212
- required_columns: list[str] = ["target", "regulators"]
213
- optional_columns: list[str] = ["type", "weight"]
214
-
215
- for column in required_columns:
216
- assert (
217
- column in prior_network.columns
218
- ), f"Column '{column}' is missing in the provided network dataframe."
219
-
220
- for column in optional_columns:
221
- if column not in prior_network.columns:
222
- print(
223
- f"Optional column '{column}' is not present in the provided network dataframe."
224
- )
225
-
226
- assert (
227
- prior_network["target"].dtype == "str"
228
- ), "Column 'target' should be of dtype 'str'."
229
- assert (
230
- prior_network["regulators"].dtype == "str"
231
- ), "Column 'regulators' should be of dtype 'str'."
165
+ def get_unseen_mapped_dataset_elements(self, idx: int):
166
+ return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
232
167
 
233
- if "type" in prior_network.columns:
234
- assert (
235
- prior_network["type"].dtype == "str"
236
- ), "Column 'type' should be of dtype 'str'."
237
-
238
- if "weight" in prior_network.columns:
239
- assert (
240
- prior_network["weight"].dtype == "float"
241
- ), "Column 'weight' should be of dtype 'float'."
242
-
243
- # TODO: check that we match the genes in the network to the genes in the dataset
244
-
245
- print(
246
- "loaded {:.2f}% of the edges".format(
247
- (len(prior_network) / init_len) * 100
248
- )
249
- )
250
- # TODO: transform it into a sparse matrix
251
- self.prior_network = prior_network
252
- self.network_size = len(prior_network)
253
- # self.overlap =
254
- # self.edge_freq
255
-
256
- def load_genes(self, organisms):
257
- organismdf = []
258
- for o in organisms:
259
- organism = lb.Gene(
260
- organism=lb.Organism.filter(ontology_id=o).one()
261
- ).df()
262
- organism["organism"] = o
263
- organismdf.append(organism)
264
-
265
- return pd.concat(organismdf)
266
-
267
- # def load_embeddings(self, genedfs, embedding_size=128, cache=True):
268
- # embeddings = []
269
- # for o in self.organisms:
270
- # genedf = genedfs[genedfs.organism == o]
271
- # org_name = lb.Organism.filter(ontology_id=o).one().scientific_name
272
- # embedding = embed(
273
- # genedf=genedf,
274
- # organism=org_name,
275
- # cache=cache,
276
- # fasta_path="/tmp/data/fasta/",
277
- # embedding_size=embedding_size,
278
- # )
279
- # genedf = pd.concat(
280
- # [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
281
- # )
282
- # genedf.columns = genedf.columns.astype(str)
283
- # embeddings.append(genedf)
284
- # return pd.concat(embeddings)
285
-
286
- def define_hierarchies(self, labels):
168
+ def define_hierarchies(self, labels: list[str]):
169
+ # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
287
170
  self.class_groupings = {}
288
171
  self.class_topred = {}
289
172
  for label in labels:
@@ -302,37 +185,37 @@ class Dataset(torchDataset):
302
185
  )
303
186
  elif label == "cell_type_ontology_term_id":
304
187
  parentdf = (
305
- lb.CellType.filter()
188
+ bt.CellType.filter()
306
189
  .df(include=["parents__ontology_id"])
307
190
  .set_index("ontology_id")
308
191
  )
309
192
  elif label == "tissue_ontology_term_id":
310
193
  parentdf = (
311
- lb.Tissue.filter()
194
+ bt.Tissue.filter()
312
195
  .df(include=["parents__ontology_id"])
313
196
  .set_index("ontology_id")
314
197
  )
315
198
  elif label == "disease_ontology_term_id":
316
199
  parentdf = (
317
- lb.Disease.filter()
200
+ bt.Disease.filter()
318
201
  .df(include=["parents__ontology_id"])
319
202
  .set_index("ontology_id")
320
203
  )
321
204
  elif label == "development_stage_ontology_term_id":
322
205
  parentdf = (
323
- lb.DevelopmentalStage.filter()
206
+ bt.DevelopmentalStage.filter()
324
207
  .df(include=["parents__ontology_id"])
325
208
  .set_index("ontology_id")
326
209
  )
327
210
  elif label == "assay_ontology_term_id":
328
211
  parentdf = (
329
- lb.ExperimentalFactor.filter()
212
+ bt.ExperimentalFactor.filter()
330
213
  .df(include=["parents__ontology_id"])
331
214
  .set_index("ontology_id")
332
215
  )
333
216
  elif label == "self_reported_ethnicity_ontology_term_id":
334
217
  parentdf = (
335
- lb.Ethnicity.filter()
218
+ bt.Ethnicity.filter()
336
219
  .df(include=["parents__ontology_id"])
337
220
  .set_index("ontology_id")
338
221
  )
@@ -344,10 +227,90 @@ class Dataset(torchDataset):
344
227
  )
345
228
  )
346
229
  cats = self.mapped_dataset.get_merged_categories(label)
347
- cats |= set(LABELS_TOADD.get(label, []))
230
+ addition = set(LABELS_TOADD.get(label, {}).values())
231
+ cats |= addition
232
+ # import pdb
233
+
234
+ # pdb.set_trace()
348
235
  groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
349
236
  for i, j in groupings.items():
350
237
  if len(j) == 0:
351
238
  groupings.pop(i)
352
239
  self.class_groupings[label] = groupings
353
- self.class_topred[label] = lclass
240
+ if label in self.clss_to_pred:
241
+ # if we have added new labels, we need to update the encoder with them too.
242
+ mlength = len(self.mapped_dataset.encoders[label])
243
+ mlength -= (
244
+ 1
245
+ if self.mapped_dataset.unknown_class
246
+ in self.mapped_dataset.encoders[label].keys()
247
+ else 0
248
+ )
249
+
250
+ for i, v in enumerate(
251
+ addition - set(self.mapped_dataset.encoders[label].keys())
252
+ ):
253
+ self.mapped_dataset.encoders[label].update({v: mlength + i})
254
+ # we need to change the ordering so that the things that can't be predicted appear afterward
255
+
256
+ self.class_topred[label] = lclass
257
+ c = 0
258
+ d = 0
259
+ update = {}
260
+ mlength = len(lclass)
261
+ # import pdb
262
+
263
+ # pdb.set_trace()
264
+ mlength -= (
265
+ 1
266
+ if self.mapped_dataset.unknown_class
267
+ in self.mapped_dataset.encoders[label].keys()
268
+ else 0
269
+ )
270
+ for k, v in self.mapped_dataset.encoders[label].items():
271
+ if k in self.class_groupings[label].keys():
272
+ update.update({k: mlength + c})
273
+ c += 1
274
+ elif k == self.mapped_dataset.unknown_class:
275
+ update.update({k: v})
276
+ d += 1
277
+ self.class_topred[label] -= set([k])
278
+ else:
279
+ update.update({k: (v - c) - d})
280
+ self.mapped_dataset.encoders[label] = update
281
+
282
+
283
+ class SimpleAnnDataset:
284
+ def __init__(
285
+ self,
286
+ adata: AnnData,
287
+ obs_to_output: Optional[list[str]] = [],
288
+ layer: Optional[str] = None,
289
+ ):
290
+ """
291
+ SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
292
+ scDataloader and with your model during inference.
293
+
294
+ Args:
295
+ ----
296
+ adata (anndata.AnnData): anndata object to use
297
+ obs_to_output (list[str]): list of observations to output from anndata.obs
298
+ layer (str): layer of the anndata to use
299
+ """
300
+ self.adata = adata
301
+ self.obs_to_output = obs_to_output
302
+ self.layer = layer
303
+
304
+ def __len__(self):
305
+ return self.adata.shape[0]
306
+
307
+ def __getitem__(self, idx):
308
+ with warnings.catch_warnings():
309
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
310
+ if self.layer is not None:
311
+ out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
312
+ else:
313
+ out = {"x": self.adata.X[idx].toarray().reshape(-1)}
314
+ for i in self.obs_to_output:
315
+ out.update({i: self.adata.obs.iloc[idx][i]})
316
+ return out