scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/__main__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import argparse
2
- from typing import Optional, Union
2
+ from typing import List, Optional, Union
3
3
 
4
4
  import lamindb as ln
5
5
 
@@ -149,7 +149,7 @@ def main():
149
149
  )
150
150
  preprocess_parser.add_argument(
151
151
  "--batch_keys",
152
- type=list[str],
152
+ type=List[str],
153
153
  default=[
154
154
  "assay_ontology_term_id",
155
155
  "self_reported_ethnicity_ontology_term_id",
@@ -229,11 +229,11 @@ def main():
229
229
  if args.instance is not None:
230
230
  collection = (
231
231
  ln.Collection.using(instance=args.instance)
232
- .filter(name=args.name, version=args.version)
232
+ .filter(key=args.name, version=args.version)
233
233
  .first()
234
234
  )
235
235
  else:
236
- collection = ln.Collection.filter(name=args.name, version=args.version).first()
236
+ collection = ln.Collection.filter(key=args.name, version=args.version).first()
237
237
 
238
238
  print(
239
239
  "using the dataset ", collection, " of size ", len(collection.artifacts.all())
@@ -262,7 +262,6 @@ def main():
262
262
  additional_preprocess=additional_preprocess,
263
263
  additional_postprocess=additional_postprocess,
264
264
  keep_files=False,
265
- force_preloaded=args.force_preloaded,
266
265
  )
267
266
 
268
267
  # Preprocess the dataset
scdataloader/collator.py CHANGED
@@ -1,18 +1,20 @@
1
- from typing import Optional
1
+ from typing import List, Optional
2
2
 
3
3
  import numpy as np
4
+ import pandas as pd
4
5
  from torch import Tensor, long
5
6
 
6
- from .utils import downsample_profile, load_genes
7
+ from .preprocess import _digitize
8
+ from .utils import load_genes
7
9
 
8
10
 
9
11
  class Collator:
10
12
  def __init__(
11
13
  self,
12
- organisms: list[str],
14
+ organisms: List[str],
13
15
  how: str = "all",
14
16
  org_to_id: dict[str, int] = None,
15
- valid_genes: list[str] = [],
17
+ valid_genes: Optional[List[str]] = None,
16
18
  max_len: int = 2000,
17
19
  add_zero_genes: int = 0,
18
20
  logp1: bool = False,
@@ -20,10 +22,9 @@ class Collator:
20
22
  n_bins: int = 0,
21
23
  tp_name: Optional[str] = None,
22
24
  organism_name: str = "organism_ontology_term_id",
23
- class_names: list[str] = [],
24
- genelist: list[str] = [],
25
- downsample: Optional[float] = None, # don't use it for training!
26
- save_output: Optional[str] = None,
25
+ class_names: List[str] = [],
26
+ genelist: List[str] = [],
27
+ genedf: Optional[pd.DataFrame] = None,
27
28
  ):
28
29
  """
29
30
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -57,13 +58,8 @@ class Collator:
57
58
  class_names (list, optional): List of other classes to be considered. Defaults to [].
58
59
  genelist (list, optional): List of genes to be considered. Defaults to [].
59
60
  If [] all genes will be considered
60
- downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
61
- This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
62
- save_output (str, optional): If not None, saves the output to a file. Defaults to None.
63
- This is mainly for debugging purposes
64
61
  """
65
62
  self.organisms = organisms
66
- self.genedf = load_genes(organisms)
67
63
  self.max_len = max_len
68
64
  self.n_bins = n_bins
69
65
  self.add_zero_genes = add_zero_genes
@@ -75,32 +71,36 @@ class Collator:
75
71
  self.organism_name = organism_name
76
72
  self.tp_name = tp_name
77
73
  self.class_names = class_names
78
- self.save_output = save_output
79
74
  self.start_idx = {}
80
75
  self.accepted_genes = {}
81
- self.downsample = downsample
82
76
  self.to_subset = {}
83
- self._setup(org_to_id, valid_genes, genelist)
77
+ self._setup(genedf, org_to_id, valid_genes, genelist)
84
78
 
85
- def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
79
+ def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
80
+ if genedf is None:
81
+ genedf = load_genes(self.organisms)
82
+ self.organism_ids = (
83
+ set([org_to_id[k] for k in self.organisms])
84
+ if org_to_id is not None
85
+ else set(self.organisms)
86
+ )
86
87
  self.org_to_id = org_to_id
87
88
  self.to_subset = {}
88
89
  self.accepted_genes = {}
89
90
  self.start_idx = {}
90
- self.organism_ids = (
91
- set([org_to_id[k] for k in self.organisms])
92
- if org_to_id is not None
93
- else set(self.organisms)
94
- )
91
+
92
+ if valid_genes is not None:
93
+ if len(set(valid_genes) - set(genedf.index)) > 0:
94
+ print("Some valid genes are not in the genedf!!!")
95
+ tot = genedf[genedf.index.isin(valid_genes)]
96
+ else:
97
+ tot = genedf
95
98
  for organism in self.organisms:
96
- ogenedf = self.genedf[self.genedf.organism == organism]
97
- if len(valid_genes) > 0:
98
- tot = self.genedf[self.genedf.index.isin(valid_genes)]
99
- else:
100
- tot = self.genedf
101
99
  org = org_to_id[organism] if org_to_id is not None else organism
102
100
  self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
103
- if len(valid_genes) > 0:
101
+
102
+ ogenedf = genedf[genedf.organism == organism]
103
+ if valid_genes is not None:
104
104
  self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
105
105
  if len(genelist) > 0:
106
106
  df = ogenedf[ogenedf.index.isin(valid_genes)]
@@ -111,7 +111,7 @@ class Collator:
111
111
  __call__ applies the collator to a minibatch of data
112
112
 
113
113
  Args:
114
- batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
114
+ batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
115
115
  the first list is for the different samples, the second list is for the different elements with
116
116
  elem["X"]: gene expression
117
117
  elem["organism_name"]: organism ontology term id
@@ -119,7 +119,7 @@ class Collator:
119
119
  elem["class_names.."]: other classes
120
120
 
121
121
  Returns:
122
- list[Tensor]: List of tensors containing the collated data.
122
+ List[Tensor]: List of tensors containing the collated data.
123
123
  """
124
124
  # do count selection
125
125
  # get the unseen info and don't add any unseen
@@ -133,6 +133,7 @@ class Collator:
133
133
  nnz_loc = []
134
134
  is_meta = []
135
135
  knn_cells = []
136
+ knn_cells_info = []
136
137
  for elem in batch:
137
138
  organism_id = elem[self.organism_name]
138
139
  if organism_id not in self.organism_ids:
@@ -188,7 +189,14 @@ class Collator:
188
189
  if "knn_cells" in elem:
189
190
  # we complete with genes expressed in the knn
190
191
  # which is not a zero_loc in this context
191
- zero_loc = np.argsort(elem["knn_cells"].sum(0))[-ma:][::-1]
192
+ knn_expr = elem["knn_cells"].sum(0)
193
+ mask = np.ones(len(knn_expr), dtype=bool)
194
+ mask[loc] = False
195
+ available_indices = np.where(mask)[0]
196
+ available_knn_expr = knn_expr[available_indices]
197
+ sorted_indices = np.argsort(available_knn_expr)[::-1]
198
+ selected = min(ma, len(available_indices))
199
+ zero_loc = available_indices[sorted_indices[:selected]]
192
200
  else:
193
201
  zero_loc = np.where(expr == 0)[0]
194
202
  zero_loc = zero_loc[
@@ -212,6 +220,8 @@ class Collator:
212
220
  exprs.append(expr)
213
221
  if "knn_cells" in elem:
214
222
  knn_cells.append(elem["knn_cells"])
223
+ if "knn_cells_info" in elem:
224
+ knn_cells_info.append(elem["knn_cells_info"])
215
225
  # then we need to add the start_idx to the loc to give it the correct index
216
226
  # according to the model
217
227
  gene_locs.append(loc + self.start_idx[organism_id])
@@ -231,15 +241,46 @@ class Collator:
231
241
  dataset = np.array(dataset)
232
242
  is_meta = np.array(is_meta)
233
243
  knn_cells = np.array(knn_cells)
244
+ knn_cells_info = np.array(knn_cells_info)
245
+
234
246
  # normalize counts
235
247
  if self.norm_to is not None:
236
248
  expr = (expr * self.norm_to) / total_count[:, None]
249
+ # TODO: solve issue here
250
+ knn_cells = (knn_cells * self.norm_to) / total_count[:, None]
237
251
  if self.logp1:
238
252
  expr = np.log2(1 + expr)
253
+ knn_cells = np.log2(1 + knn_cells)
239
254
 
240
255
  # do binning of counts
241
- if self.n_bins:
242
- pass
256
+ if self.n_bins > 0:
257
+ binned_rows = []
258
+ bin_edges = []
259
+ for row in expr:
260
+ if row.max() == 0:
261
+ print(
262
+ "The input data contains all zero rows. Please make sure "
263
+ "this is expected. You can use the `filter_cell_by_counts` "
264
+ "arg to filter out all zero rows."
265
+ )
266
+ binned_rows.append(np.zeros_like(row, dtype=np.int64))
267
+ bin_edges.append(np.array([0] * self.n_bins))
268
+ continue
269
+ non_zero_ids = row.nonzero()
270
+ non_zero_row = row[non_zero_ids]
271
+ bins = np.quantile(non_zero_row, np.linspace(0, 1, self.n_bins - 1))
272
+ # bins = np.sort(np.unique(bins))
273
+ # NOTE: comment this line for now, since this will make the each category
274
+ # has different relative meaning across datasets
275
+ non_zero_digits = _digitize(non_zero_row, bins)
276
+ assert non_zero_digits.min() >= 1
277
+ assert non_zero_digits.max() <= self.n_bins - 1
278
+ binned_row = np.zeros_like(row, dtype=np.int64)
279
+ binned_row[non_zero_ids] = non_zero_digits
280
+ binned_rows.append(binned_row)
281
+ bin_edges.append(np.concatenate([[0], bins]))
282
+ expr = np.stack(binned_rows)
283
+ # expr = np.digitize(expr, bins=self.bins)
243
284
 
244
285
  ret = {
245
286
  "x": Tensor(expr),
@@ -252,51 +293,8 @@ class Collator:
252
293
  ret.update({"is_meta": Tensor(is_meta).int()})
253
294
  if len(knn_cells) > 0:
254
295
  ret.update({"knn_cells": Tensor(knn_cells)})
296
+ if len(knn_cells_info) > 0:
297
+ ret.update({"knn_cells_info": Tensor(knn_cells_info)})
255
298
  if len(dataset) > 0:
256
299
  ret.update({"dataset": Tensor(dataset).to(long)})
257
- if self.downsample is not None:
258
- ret["x"] = downsample_profile(ret["x"], self.downsample)
259
- if self.save_output is not None:
260
- with open(self.save_output, "a") as f:
261
- np.savetxt(f, ret["x"].numpy())
262
- with open(self.save_output + "_loc", "a") as f:
263
- np.savetxt(f, gene_locs)
264
300
  return ret
265
-
266
-
267
- #############
268
- #### WIP ####
269
- #############
270
- class GeneformerCollator(Collator):
271
- def __init__(self, *args, gene_norm_list: list, **kwargs):
272
- """
273
- GeneformerCollator to finish
274
-
275
- Args:
276
- gene_norm_list (list): the normalization of expression through all datasets, per gene.
277
- """
278
- super().__init__(*args, **kwargs)
279
- self.gene_norm_list = gene_norm_list
280
-
281
- def __call__(self, batch):
282
- super().__call__(batch)
283
- # normlization per gene
284
-
285
- # tokenize the empty locations
286
-
287
-
288
- class scGPTCollator(Collator):
289
- """
290
- scGPTCollator to finish
291
- """
292
-
293
- def __call__(self, batch):
294
- super().__call__(batch)
295
- # binning
296
-
297
- # tokenize the empty locations
298
-
299
-
300
- class scPRINTCollator(Collator):
301
- def __call__(self, batch):
302
- super().__call__(batch)
scdataloader/config.py CHANGED
@@ -113,26 +113,34 @@ COARSE_ASSAY = {
113
113
 
114
114
 
115
115
  MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
116
- "HsapDv:0010000": [
116
+ "HsapDv:0010000": [ # postnatal stage
117
117
  "MmusDv:0000092", # postnatal stage
118
118
  ],
119
- "HsapDv:0000258": [ # mature stage
119
+ "HsapDv:0000258": [ # mature stage >15
120
120
  "MmusDv:0000110", # mature stage
121
- "HsapDv:0000204", #
121
+ "HsapDv:0000204", #
122
122
  ],
123
- "HsapDv:0000227": [ # late adult stage
123
+ "HsapDv:0000087": [], # adult stage >19
124
+ "HsapDv:0000227": [ # late adult stage > 40
124
125
  "MmusDv:0000091", # 20 month-old stage
125
126
  "MmusDv:0000089", # 18 month-old stage
127
+ "HsapDv:0000091", # > 45
128
+ "HsapDv:0000093", # > 65
129
+ ],
130
+ "HsapDv:0000272": [ # 60-79 year-old stage
131
+ "HsapDv:0000094", # 60-79 year-old stage
126
132
  ],
127
- "HsapDv:0000272": [], # 60-79 year-old stage
128
133
  "HsapDv:0000095": [], # 80 year-old and over stage
129
- "HsapDv:0000267": [ # middle aged stage
134
+ "HsapDv:0000267": [ # middle aged stage >40 <60
130
135
  "MmusDv:0000087", # 16 month-old stage
131
136
  "UBERON:0018241", # prime adult stage
132
137
  "MmusDv:0000083", # 12 month-old stage
133
138
  "HsapDv:0000092", # same
134
139
  ],
135
- "HsapDv:0000266": [ # young adult stage
140
+ "HsapDv:0000266": [ # young adult stage <40
141
+ "HsapDv:0000088", # mature stage
142
+ "HsapDv:0000090", # 25 - 44
143
+ "HsapDv:0000086", # adolescent stage
136
144
  "MmusDv:0000050", # 6 weeks
137
145
  "HsapDv:0000089", # same
138
146
  "MmusDv:0000051", # 7 weeks
@@ -163,22 +171,30 @@ MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
163
171
  "MmusDv:0000099", # 26 weeks
164
172
  "MmusDv:0000102", # 29 weeks
165
173
  ],
166
- "HsapDv:0000265": [], # child stage (1-4 yo)
174
+ "HsapDv:0000265": [ # child stage (1-4 yo)
175
+ "HsapDv:0000084", # 2-5 yo
176
+ ],
167
177
  "HsapDv:0000271": [ # juvenile stage (5-14 yo)
168
178
  "MmusDv:0000048", # 4 weeks
169
179
  "MmusDv:0000049", # 5 weeks
180
+ "HsapDv:0000081", # child
181
+ "HsapDv:0000085", # 6-11 yo
170
182
  ],
171
- "HsapDv:0000260": [ # infant stage
183
+ "HsapDv:0000260": [ # infant stage <2
172
184
  "MmusDv:0000046", # 2 weeks
173
185
  "MmusDv:0000045", # 1 week
174
186
  "MmusDv:0000047", # 3 weeks
175
187
  "HsapDv:0000083",
188
+ "HsapDv:0000256", # under 1 yo
176
189
  ],
177
190
  "HsapDv:0000262": [ # newborn stage (0-28 days)
178
191
  "MmusDv:0000036", # Theiler stage 27
179
192
  "MmusDv:0000037", # Theiler stage 28
180
193
  "MmusDv:0000113", # 4-7 days
194
+ "HsapDv:0000174", # 1 month-old stage
195
+ "HsapDv:0000082", # newborn stage
181
196
  ],
197
+ "HsapDv:0000002": [], # embryonic stage
182
198
  "HsapDv:0000007": [], # Carnegie stage 03
183
199
  "HsapDv:0000008": [], # Carnegie stage 04
184
200
  "HsapDv:0000009": [], # Carnegie stage 05