scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -4,46 +4,47 @@ import lamindb as ln
4
4
  import lnschema_bionty as lb
5
5
  import pandas as pd
6
6
  from torch.utils.data import Dataset as torchDataset
7
-
7
+ from typing import Union
8
8
  from scdataloader import mapped
9
+ import warnings
9
10
 
10
11
  # TODO: manage load gene embeddings to make
11
12
  # from scprint.dataloader.embedder import embed
12
- from scdataloader.utils import get_ancestry_mapping, pd_load_cached
13
+ from scdataloader.utils import get_ancestry_mapping, load_genes
13
14
 
14
15
  LABELS_TOADD = {
15
- "assay_ontology_term_id": [
16
- "10x transcription profiling",
17
- "spatial transcriptomics",
18
- "10x 3' transcription profiling",
19
- "10x 5' transcription profiling",
20
- ],
21
- "disease_ontology_term_id": [
22
- "metabolic disease",
23
- "chronic kidney disease",
24
- "chromosomal disorder",
25
- "infectious disease",
26
- "inflammatory disease",
16
+ "assay_ontology_term_id": {
17
+ "10x transcription profiling": "EFO:0030003",
18
+ "spatial transcriptomics": "EFO:0008994",
19
+ "10x 3' transcription profiling": "EFO:0030003",
20
+ "10x 5' transcription profiling": "EFO:0030004",
21
+ },
22
+ "disease_ontology_term_id": {
23
+ "metabolic disease": "MONDO:0005066",
24
+ "chronic kidney disease": "MONDO:0005300",
25
+ "chromosomal disorder": "MONDO:0019040",
26
+ "infectious disease": "MONDO:0005550",
27
+ "inflammatory disease": "MONDO:0021166",
27
28
  # "immune system disease",
28
- "disorder of development or morphogenesis",
29
- "mitochondrial disease",
30
- "psychiatric disorder",
31
- "cancer or benign tumor",
32
- "neoplasm",
33
- ],
34
- "cell_type_ontology_term_id": [
35
- "progenitor cell",
36
- "hematopoietic cell",
37
- "myoblast",
38
- "myeloid cell",
39
- "neuron",
40
- "electrically active cell",
41
- "epithelial cell",
42
- "secretory cell",
43
- "stem cell",
44
- "non-terminally differentiated cell",
45
- "supporting cell",
46
- ],
29
+ "disorder of development or morphogenesis": "MONDO:0021147",
30
+ "mitochondrial disease": "MONDO:0044970",
31
+ "psychiatric disorder": "MONDO:0002025",
32
+ "cancer or benign tumor": "MONDO:0002025",
33
+ "neoplasm": "MONDO:0005070",
34
+ },
35
+ "cell_type_ontology_term_id": {
36
+ "progenitor cell": "CL:0011026",
37
+ "hematopoietic cell": "CL:0000988",
38
+ "myoblast": "CL:0000056",
39
+ "myeloid cell": "CL:0000763",
40
+ "neuron": "CL:0000540",
41
+ "electrically active cell": "CL:0000211",
42
+ "epithelial cell": "CL:0000066",
43
+ "secretory cell": "CL:0000151",
44
+ "stem cell": "CL:0000034",
45
+ "non-terminally differentiated cell": "CL:0000055",
46
+ "supporting cell": "CL:0000630",
47
+ },
47
48
  }
48
49
 
49
50
 
@@ -66,16 +67,14 @@ class Dataset(torchDataset):
66
67
  gene_embedding: dataframe containing the gene embeddings
67
68
  organisms (list[str]): list of organisms to load
68
69
  obs (list[str]): list of observations to load
69
- encode_obs (list[str]): list of observations to encode
70
- map_hierarchy: list of observations to map to a hierarchy
70
+ clss_to_pred (list[str]): list of observations to encode
71
+ hierarchical_clss: list of observations to map to a hierarchy
71
72
  """
72
73
 
73
- lamin_dataset: ln.Dataset
74
+ lamin_dataset: ln.Collection
74
75
  genedf: pd.DataFrame = None
75
- gene_embedding: pd.DataFrame = (
76
- None # TODO: make it part of specialized dataset
77
- )
78
- organisms: list[str] = field(
76
+ # gene_embedding: pd.DataFrame = None # TODO: make it part of specialized dataset
77
+ organisms: Union[list[str], str] = field(
79
78
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
80
79
  )
81
80
  obs: list[str] = field(
@@ -94,43 +93,63 @@ class Dataset(torchDataset):
94
93
  "nnz",
95
94
  ]
96
95
  )
97
- encode_obs: list[str] = field(default_factory=list)
98
- map_hierarchy: list[str] = field(default_factory=list)
96
+ # set of obs to prepare for prediction (encode)
97
+ clss_to_pred: list[str] = field(default_factory=list)
98
+ # set of obs that need to be hierarchically prepared
99
+ hierarchical_clss: list[str] = field(default_factory=list)
100
+ join_vars: str = "None"
99
101
 
100
102
  def __post_init__(self):
101
103
  self.mapped_dataset = mapped.mapped(
102
104
  self.lamin_dataset,
103
105
  label_keys=self.obs,
104
- encode_labels=self.encode_obs,
106
+ encode_labels=self.clss_to_pred,
105
107
  stream=True,
106
108
  parallel=True,
107
- join_vars="None",
109
+ join_vars=self.join_vars,
108
110
  )
109
111
  print(
110
112
  "won't do any check but we recommend to have your dataset coming from local storage"
111
113
  )
112
114
  # generate tree from ontologies
113
- if len(self.map_hierarchy) > 0:
114
- self.define_hierarchies(self.map_hierarchy)
115
+ if len(self.hierarchical_clss) > 0:
116
+ self.define_hierarchies(self.hierarchical_clss)
117
+ if len(self.clss_to_pred) > 0:
118
+ for clss in self.clss_to_pred:
119
+ if clss not in self.hierarchical_clss:
120
+ # otherwise it's already been done
121
+ self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
122
+ clss
123
+ )
124
+ update = {}
125
+ c = 0
126
+ for k, v in self.mapped_dataset.encoders[clss].items():
127
+ if k == self.mapped_dataset.unknown_class:
128
+ update.update({k: v})
129
+ c += 1
130
+ self.class_topred[clss] -= set([k])
131
+ else:
132
+ update.update({k: v - c})
133
+ self.mapped_dataset.encoders[clss] = update
115
134
 
116
135
  if self.genedf is None:
117
- self.genedf = self.load_genes(self.organisms)
136
+ self.genedf = load_genes(self.organisms)
118
137
 
119
- if self.gene_embedding is None:
120
- self.gene_embedding = self.load_embeddings(self.genedf)
121
- else:
122
- # self.genedf = pd.concat(
123
- # [self.genedf.set_index("ensembl_gene_id"), self.gene_embedding],
124
- # axis=1,
125
- # join="inner",
126
- # )
127
- self.genedf.columns = self.genedf.columns.astype(str)
138
+ self.genedf.columns = self.genedf.columns.astype(str)
139
+ for organism in self.organisms:
140
+ ogenedf = self.genedf[self.genedf.organism == organism]
141
+ self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
128
142
 
129
143
  def __len__(self, **kwargs):
130
144
  return self.mapped_dataset.__len__(**kwargs)
131
145
 
146
+ @property
147
+ def encoder(self):
148
+ return self.mapped_dataset.encoders
149
+
132
150
  def __getitem__(self, *args, **kwargs):
133
151
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
152
+ #item.update({"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)})
134
153
  # ret = {}
135
154
  # ret["count"] = item[0]
136
155
  # for i, val in enumerate(self.obs):
@@ -143,8 +162,7 @@ class Dataset(torchDataset):
143
162
  def __repr__(self):
144
163
  print(
145
164
  "total dataset size is {} Gb".format(
146
- sum([file.size for file in self.lamin_dataset.artifacts.all()])
147
- / 1e9
165
+ sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
148
166
  )
149
167
  )
150
168
  print("---")
@@ -158,111 +176,14 @@ class Dataset(torchDataset):
158
176
  sum([len(self.class_topred[i]) for i in self.class_topred])
159
177
  )
160
178
  )
161
- print("embedding size is {}".format(self.gene_embedding.shape[1]))
179
+ # print("embedding size is {}".format(self.gene_embedding.shape[1]))
162
180
  return ""
163
181
 
164
182
  def get_label_weights(self, *args, **kwargs):
165
183
  return self.mapped_dataset.get_label_weights(*args, **kwargs)
166
184
 
167
185
  def get_unseen_mapped_dataset_elements(self, idx):
168
- return [
169
- str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")
170
- ]
171
-
172
- def use_prior_network(
173
- self, name="collectri", organism="human", split_complexes=True
174
- ):
175
- """
176
- use_prior_network loads a prior GRN from a list of available networks.
177
-
178
- Args:
179
- name (str, optional): name of the network to load. Defaults to "collectri".
180
- organism (str, optional): organism to load the network for. Defaults to "human".
181
- split_complexes (bool, optional): whether to split complexes into individual genes. Defaults to True.
182
-
183
- Raises:
184
- ValueError: if the provided name is not amongst the available names.
185
- """
186
- # TODO: use omnipath instead
187
- if name == "tflink":
188
- TFLINK = "https://cdn.netbiol.org/tflink/download_files/TFLink_Homo_sapiens_interactions_All_simpleFormat_v1.0.tsv.gz"
189
- net = pd_load_cached(TFLINK)
190
- net = net.rename(
191
- columns={"Name.TF": "regulator", "Name.Target": "target"}
192
- )
193
- elif name == "htftarget":
194
- HTFTARGET = "http://bioinfo.life.hust.edu.cn/static/hTFtarget/file_download/tf-target-infomation.txt"
195
- net = pd_load_cached(HTFTARGET)
196
- net = net.rename(columns={"TF": "regulator"})
197
- elif name == "collectri":
198
- import decoupler as dc
199
-
200
- net = dc.get_collectri(
201
- organism=organism, split_complexes=split_complexes
202
- )
203
- net = net.rename(columns={"source": "regulator"})
204
- else:
205
- raise ValueError(
206
- f"provided name: '{name}' is not amongst the available names."
207
- )
208
- self.add_prior_network(net)
209
-
210
- def add_prior_network(self, prior_network: pd.DataFrame, init_len):
211
- # validate the network dataframe
212
- required_columns: list[str] = ["target", "regulators"]
213
- optional_columns: list[str] = ["type", "weight"]
214
-
215
- for column in required_columns:
216
- assert (
217
- column in prior_network.columns
218
- ), f"Column '{column}' is missing in the provided network dataframe."
219
-
220
- for column in optional_columns:
221
- if column not in prior_network.columns:
222
- print(
223
- f"Optional column '{column}' is not present in the provided network dataframe."
224
- )
225
-
226
- assert (
227
- prior_network["target"].dtype == "str"
228
- ), "Column 'target' should be of dtype 'str'."
229
- assert (
230
- prior_network["regulators"].dtype == "str"
231
- ), "Column 'regulators' should be of dtype 'str'."
232
-
233
- if "type" in prior_network.columns:
234
- assert (
235
- prior_network["type"].dtype == "str"
236
- ), "Column 'type' should be of dtype 'str'."
237
-
238
- if "weight" in prior_network.columns:
239
- assert (
240
- prior_network["weight"].dtype == "float"
241
- ), "Column 'weight' should be of dtype 'float'."
242
-
243
- # TODO: check that we match the genes in the network to the genes in the dataset
244
-
245
- print(
246
- "loaded {:.2f}% of the edges".format(
247
- (len(prior_network) / init_len) * 100
248
- )
249
- )
250
- # TODO: transform it into a sparse matrix
251
- self.prior_network = prior_network
252
- self.network_size = len(prior_network)
253
- # self.overlap =
254
- # self.edge_freq
255
-
256
- def load_genes(self, organisms):
257
- organismdf = []
258
- for o in organisms:
259
- organism = lb.Gene(
260
- organism=lb.Organism.filter(ontology_id=o).one()
261
- ).df()
262
- organism["organism"] = o
263
- organismdf.append(organism)
264
-
265
- return pd.concat(organismdf)
186
+ return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
266
187
 
267
188
  # def load_embeddings(self, genedfs, embedding_size=128, cache=True):
268
189
  # embeddings = []
@@ -344,10 +265,72 @@ class Dataset(torchDataset):
344
265
  )
345
266
  )
346
267
  cats = self.mapped_dataset.get_merged_categories(label)
347
- cats |= set(LABELS_TOADD.get(label, []))
268
+ addition = set(LABELS_TOADD.get(label, {}).values())
269
+ cats |= addition
348
270
  groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
349
271
  for i, j in groupings.items():
350
272
  if len(j) == 0:
351
273
  groupings.pop(i)
352
274
  self.class_groupings[label] = groupings
353
- self.class_topred[label] = lclass
275
+ if label in self.clss_to_pred:
276
+ # if we have added new labels, we need to update the encoder with them too.
277
+ mlength = len(self.mapped_dataset.encoders[label])
278
+ mlength -= (
279
+ 1
280
+ if self.mapped_dataset.unknown_class
281
+ in self.mapped_dataset.encoders[label].keys()
282
+ else 0
283
+ )
284
+
285
+ for i, v in enumerate(
286
+ addition - set(self.mapped_dataset.encoders[label].keys())
287
+ ):
288
+ self.mapped_dataset.encoders[label].update({v: mlength + i})
289
+ # we need to change the ordering so that the things that can't be predicted appear afterward
290
+
291
+ self.class_topred[label] = lclass
292
+ c = 0
293
+ d = 0
294
+ update = {}
295
+ mlength = len(lclass)
296
+ # import pdb
297
+
298
+ # pdb.set_trace()
299
+ mlength -= (
300
+ 1
301
+ if self.mapped_dataset.unknown_class
302
+ in self.mapped_dataset.encoders[label].keys()
303
+ else 0
304
+ )
305
+ for k, v in self.mapped_dataset.encoders[label].items():
306
+ if k in self.class_groupings[label].keys():
307
+ update.update({k: mlength + c})
308
+ c += 1
309
+ elif k == self.mapped_dataset.unknown_class:
310
+ update.update({k: v})
311
+ d += 1
312
+ self.class_topred[label] -= set([k])
313
+ else:
314
+ update.update({k: (v - c) - d})
315
+ self.mapped_dataset.encoders[label] = update
316
+
317
+
318
+ class SimpleAnnDataset:
319
+ def __init__(self, adata, obs_to_output=[], layer=None):
320
+ self.adata = adata
321
+ self.obs_to_output = obs_to_output
322
+ self.layer = layer
323
+
324
+ def __len__(self):
325
+ return self.adata.shape[0]
326
+
327
+ def __getitem__(self, idx):
328
+ with warnings.catch_warnings():
329
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
330
+ if self.layer is not None:
331
+ out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
332
+ else:
333
+ out = {"x": self.adata.X[idx].toarray().reshape(-1)}
334
+ for i in self.obs_to_output:
335
+ out.update({i: self.adata.obs.iloc[idx][i]})
336
+ return out