scdataloader 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
- 1.6.3
1
+ 1.7.0
scdataloader/__init__.py CHANGED
@@ -2,3 +2,5 @@ from .collator import Collator
2
2
  from .data import Dataset, SimpleAnnDataset
3
3
  from .datamodule import DataModule
4
4
  from .preprocess import Preprocessor
5
+
6
+ __version__ = "1.6.5"
scdataloader/__main__.py CHANGED
@@ -10,7 +10,12 @@ from scdataloader.preprocess import (
10
10
  )
11
11
 
12
12
 
13
- # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
13
+ # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" \
14
+ # --description="scPRINT-V2 datasets" --new_name="scprint v2" --n_hvg_for_postp=4000 --cache=False \
15
+ # --filter_gene_by_counts=0 --filter_cell_by_counts=300 --min_valid_genes_id=500 \
16
+ # --min_nnz_genes=120 --min_dataset_size=100 --maxdropamount=90 \
17
+ # --organisms=["NCBITaxon:9606","NCBITaxon:9544","NCBITaxon:9483","NCBITaxon:10090"] \
18
+ # --start_at=0
14
19
  def main():
15
20
  """
16
21
  main function to preprocess datasets in a given lamindb collection.
@@ -70,7 +75,7 @@ def main():
70
75
  help="Determines whether to normalize the total counts of each cell to a specific value.",
71
76
  )
72
77
  parser.add_argument(
73
- "--subset_hvg",
78
+ "--n_hvg_for_postp",
74
79
  type=int,
75
80
  default=0,
76
81
  help="Determines whether to subset highly variable genes.",
@@ -120,7 +125,7 @@ def main():
120
125
  parser.add_argument(
121
126
  "--min_nnz_genes",
122
127
  type=int,
123
- default=400,
128
+ default=200,
124
129
  help="Specifies the minimum non-zero genes.",
125
130
  )
126
131
  parser.add_argument(
@@ -139,7 +144,16 @@ def main():
139
144
  help="Specifies the percentage of MT outlier.",
140
145
  )
141
146
  parser.add_argument(
142
- "--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
147
+ "--batch_keys",
148
+ type=list[str],
149
+ default=[
150
+ "assay_ontology_term_id",
151
+ "self_reported_ethnicity_ontology_term_id",
152
+ "sex_ontology_term_id",
153
+ "donor_id",
154
+ "suspension_type",
155
+ ],
156
+ help="Specifies the batch keys.",
143
157
  )
144
158
  parser.add_argument(
145
159
  "--skip_validate",
@@ -150,15 +164,30 @@ def main():
150
164
  parser.add_argument(
151
165
  "--do_postp",
152
166
  type=bool,
153
- default=False,
167
+ default=True,
154
168
  help="Determines whether to do postprocessing.",
155
169
  )
156
170
  parser.add_argument(
157
171
  "--cache",
158
172
  type=bool,
159
- default=True,
173
+ default=False,
160
174
  help="Determines whether to cache the dataset.",
161
175
  )
176
+ parser.add_argument(
177
+ "--organisms",
178
+ type=list,
179
+ default=[
180
+ "NCBITaxon:9606",
181
+ "NCBITaxon:10090",
182
+ ],
183
+ help="Determines the organisms to keep.",
184
+ )
185
+ parser.add_argument(
186
+ "--force_preloaded",
187
+ type=bool,
188
+ default=False,
189
+ help="Determines whether the dataset is preloaded.",
190
+ )
162
191
  args = parser.parse_args()
163
192
 
164
193
  # Load the collection
@@ -182,7 +211,7 @@ def main():
182
211
  filter_gene_by_counts=args.filter_gene_by_counts,
183
212
  filter_cell_by_counts=args.filter_cell_by_counts,
184
213
  normalize_sum=args.normalize_sum,
185
- subset_hvg=args.subset_hvg,
214
+ n_hvg_for_postp=args.n_hvg_for_postp,
186
215
  hvg_flavor=args.hvg_flavor,
187
216
  cache=args.cache,
188
217
  binning=args.binning,
@@ -195,12 +224,13 @@ def main():
195
224
  maxdropamount=args.maxdropamount,
196
225
  madoutlier=args.madoutlier,
197
226
  pct_mt_outlier=args.pct_mt_outlier,
198
- batch_key=args.batch_key,
227
+ batch_keys=args.batch_keys,
199
228
  skip_validate=args.skip_validate,
200
229
  do_postp=args.do_postp,
201
230
  additional_preprocess=additional_preprocess,
202
231
  additional_postprocess=additional_postprocess,
203
232
  keep_files=False,
233
+ force_preloaded=args.force_preloaded,
204
234
  )
205
235
 
206
236
  # Preprocess the dataset
scdataloader/collator.py CHANGED
@@ -131,6 +131,7 @@ class Collator:
131
131
  tp = []
132
132
  dataset = []
133
133
  nnz_loc = []
134
+ is_meta = []
134
135
  for elem in batch:
135
136
  organism_id = elem[self.organism_name]
136
137
  if organism_id not in self.organism_ids:
@@ -188,12 +189,12 @@ class Collator:
188
189
  loc = loc[self.to_subset[organism_id]]
189
190
  exprs.append(expr)
190
191
  gene_locs.append(loc)
191
-
192
+ if "is_meta" in elem:
193
+ is_meta.append(elem["is_meta"])
192
194
  if self.tp_name is not None:
193
195
  tp.append(elem[self.tp_name])
194
196
  else:
195
197
  tp.append(0)
196
-
197
198
  other_classes.append([elem[i] for i in self.class_names])
198
199
 
199
200
  expr = np.array(exprs)
@@ -202,6 +203,7 @@ class Collator:
202
203
  total_count = np.array(total_count)
203
204
  other_classes = np.array(other_classes)
204
205
  dataset = np.array(dataset)
206
+ is_meta = np.array(is_meta)
205
207
 
206
208
  # normalize counts
207
209
  if self.norm_to is not None:
@@ -229,6 +231,8 @@ class Collator:
229
231
  "tp": Tensor(tp),
230
232
  "depth": Tensor(total_count),
231
233
  }
234
+ if len(is_meta) > 0:
235
+ ret.update({"is_meta": Tensor(is_meta)})
232
236
  if len(dataset) > 0:
233
237
  ret.update({"dataset": Tensor(dataset).to(long)})
234
238
  if self.downsample is not None:
scdataloader/config.py CHANGED
@@ -110,3 +110,102 @@ COARSE_ASSAY = {
110
110
  "TruDrop": "",
111
111
  "Visium Spatial Gene Expression": "",
112
112
  }
113
+
114
+
115
+ MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
116
+ "HsapDv:0010000": [
117
+ "MmusDv:0000092", # postnatal stage
118
+ ],
119
+ "HsapDv:0000258": [ # mature stage
120
+ "MmusDv:0000110", # mature stage
121
+ "HsapDv:0000204",
122
+ ],
123
+ "HsapDv:0000227": [ # late adult stage
124
+ "MmusDv:0000091", # 20 month-old stage
125
+ "MmusDv:0000089", # 18 month-old stage
126
+ ],
127
+ "HsapDv:0000272": [], # 60-79 year-old stage
128
+ "HsapDv:0000095": [], # 80 year-old and over stage
129
+ "HsapDv:0000267": [ # middle aged stage
130
+ "MmusDv:0000087", # 16 month-old stage
131
+ "UBERON:0018241", # prime adult stage
132
+ "MmusDv:0000083", # 12 month-old stage
133
+ "HsapDv:0000092", # same
134
+ ],
135
+ "HsapDv:0000266": [ # young adult stage
136
+ "MmusDv:0000050", # 6 weeks
137
+ "HsapDv:0000089", # same
138
+ "MmusDv:0000051", # 7 weeks
139
+ "MmusDv:0000052", # 8 weeks
140
+ "MmusDv:0000053", # 9 weeks
141
+ "MmusDv:0000054", # 10 weeks
142
+ "MmusDv:0000055", # 11 weeks
143
+ "MmusDv:0000056", # 12 weeks
144
+ "MmusDv:0000057", # 13 weeks
145
+ "MmusDv:0000058", # 14 weeks
146
+ "MmusDv:0000059", # 15 weeks
147
+ "MmusDv:0000061", # early adult stage
148
+ "MmusDv:0000062", # 2 month-old stage
149
+ "MmusDv:0000063", # 3 month-old stage
150
+ "MmusDv:0000064", # 4 month-old stage
151
+ "MmusDv:0000065", # 16 weeks
152
+ "MmusDv:0000066", # 17 weeks
153
+ "MmusDv:0000067", # 18 weeks
154
+ "MmusDv:0000068", # 19 weeks
155
+ "MmusDv:0000070", # 20 weeks
156
+ "MmusDv:0000071", # 21 weeks
157
+ "MmusDv:0000072", # 22 weeks
158
+ "MmusDv:0000073", # 23 weeks
159
+ "MmusDv:0000074", # 24 weeks
160
+ "MmusDv:0000077", # 6 month-old stage
161
+ "MmusDv:0000079", # 8 month-old stage
162
+ "MmusDv:0000098", # 25 weeks
163
+ "MmusDv:0000099", # 26 weeks
164
+ "MmusDv:0000102", # 29 weeks
165
+ ],
166
+ "HsapDv:0000265": [], # child stage (1-4 yo)
167
+ "HsapDv:0000271": [ # juvenile stage (5-14 yo)
168
+ "MmusDv:0000048", # 4 weeks
169
+ "MmusDv:0000049", # 5 weeks
170
+ ],
171
+ "HsapDv:0000260": [ # infant stage
172
+ "MmusDv:0000046", # 2 weeks
173
+ "MmusDv:0000045", # 1 week
174
+ "MmusDv:0000047", # 3 weeks
175
+ "HsapDv:0000083",
176
+ ],
177
+ "HsapDv:0000262": [ # newborn stage (0-28 days)
178
+ "MmusDv:0000036", # Theiler stage 27
179
+ "MmusDv:0000037", # Theiler stage 28
180
+ "MmusDv:0000113", # 4-7 days
181
+ ],
182
+ "HsapDv:0000007": [], # Carnegie stage 03
183
+ "HsapDv:0000008": [], # Carnegie stage 04
184
+ "HsapDv:0000009": [], # Carnegie stage 05
185
+ "HsapDv:0000003": [], # Carnegie stage 01
186
+ "HsapDv:0000005": [], # Carnegie stage 02
187
+ "HsapDv:0000010": [], # gastrula stage
188
+ "HsapDv:0000012": [], # neurula stage
189
+ "HsapDv:0000015": [ # organogenesis stage
190
+ "MmusDv:0000019", # Theiler stage 13
191
+ "MmusDv:0000020", # Theiler stage 12
192
+ "MmusDv:0000021", # Theiler stage 14
193
+ "MmusDv:0000022", # Theiler stage 15
194
+ "MmusDv:0000023", # Theiler stage 16
195
+ "MmusDv:0000024", # Theiler stage 17
196
+ "MmusDv:0000025", # Theiler stage 18
197
+ "MmusDv:0000026", # Theiler stage 19
198
+ "MmusDv:0000027", # Theiler stage 20
199
+ "MmusDv:0000028", # Theiler stage 21
200
+ "MmusDv:0000029", # Theiler stage 22
201
+ ],
202
+ "HsapDv:0000037": [ # fetal stage
203
+ "MmusDv:0000033", # Theiler stage 24
204
+ "MmusDv:0000034", # Theiler stage 25
205
+ "MmusDv:0000035", # Theiler stage 26
206
+ "MmusDv:0000032", # Theiler stage 23
207
+ ],
208
+ "unknown": [
209
+ "MmusDv:0000041", # unknown
210
+ ],
211
+ }
scdataloader/data.py CHANGED
@@ -10,8 +10,7 @@ import lamindb as ln
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  from anndata import AnnData
13
- from lamindb.core import MappedCollection
14
- from lamindb.core._mapped_collection import _Connect
13
+ from .mapped import MappedCollection, _Connect
15
14
  from lamindb.core.storage._anndata_accessor import _safer_read_index
16
15
  from scipy.sparse import issparse
17
16
  from torch.utils.data import Dataset as torchDataset
@@ -43,7 +42,7 @@ class Dataset(torchDataset):
43
42
  organisms (list[str]): list of organisms to load
44
43
  (for now only validates the the genes map to this organism)
45
44
  obs (list[str]): list of observations to load from the Collection
46
- clss_to_pred (list[str]): list of observations to encode
45
+ clss_to_predict (list[str]): list of observations to encode
47
46
  join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
48
47
  hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
49
48
  """
@@ -53,37 +52,23 @@ class Dataset(torchDataset):
53
52
  organisms: Optional[Union[list[str], str]] = field(
54
53
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
55
54
  )
56
- obs: Optional[list[str]] = field(
57
- default_factory=[
58
- "self_reported_ethnicity_ontology_term_id",
59
- "assay_ontology_term_id",
60
- "development_stage_ontology_term_id",
61
- "disease_ontology_term_id",
62
- "cell_type_ontology_term_id",
63
- "tissue_ontology_term_id",
64
- "sex_ontology_term_id",
65
- #'dataset_id',
66
- #'cell_culture',
67
- # "dpt_group",
68
- # "heat_diff",
69
- # "nnz",
70
- ]
71
- )
72
55
  # set of obs to prepare for prediction (encode)
73
- clss_to_pred: Optional[list[str]] = field(default_factory=list)
56
+ clss_to_predict: Optional[list[str]] = field(default_factory=list)
74
57
  # set of obs that need to be hierarchically prepared
75
58
  hierarchical_clss: Optional[list[str]] = field(default_factory=list)
76
59
  join_vars: Literal["inner", "outer"] | None = None
60
+ metacell_mode: float = 0.0
77
61
 
78
62
  def __post_init__(self):
79
63
  self.mapped_dataset = mapped(
80
64
  self.lamin_dataset,
81
- obs_keys=self.obs,
65
+ obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
82
66
  join=self.join_vars,
83
- encode_labels=self.clss_to_pred,
67
+ encode_labels=self.clss_to_predict,
84
68
  unknown_label="unknown",
85
69
  stream=True,
86
70
  parallel=True,
71
+ metacell_mode=self.metacell_mode,
87
72
  )
88
73
  print(
89
74
  "won't do any check but we recommend to have your dataset coming from local storage"
@@ -93,12 +78,12 @@ class Dataset(torchDataset):
93
78
  # generate tree from ontologies
94
79
  if len(self.hierarchical_clss) > 0:
95
80
  self.define_hierarchies(self.hierarchical_clss)
96
- if len(self.clss_to_pred) > 0:
97
- for clss in self.clss_to_pred:
81
+ if len(self.clss_to_predict) > 0:
82
+ for clss in self.clss_to_predict:
98
83
  if clss not in self.hierarchical_clss:
99
84
  # otherwise it's already been done
100
- self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
101
- clss
85
+ self.class_topred[clss] = set(
86
+ self.mapped_dataset.get_merged_categories(clss)
102
87
  )
103
88
  if (
104
89
  self.mapped_dataset.unknown_label
@@ -143,8 +128,7 @@ class Dataset(torchDataset):
143
128
  + "dataset contains:\n"
144
129
  + " {} cells\n".format(self.mapped_dataset.__len__())
145
130
  + " {} genes\n".format(self.genedf.shape[0])
146
- + " {} labels\n".format(len(self.obs))
147
- + " {} clss_to_pred\n".format(len(self.clss_to_pred))
131
+ + " {} clss_to_predict\n".format(len(self.clss_to_predict))
148
132
  + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
149
133
  + " {} organisms\n".format(len(self.organisms))
150
134
  + (
@@ -154,9 +138,16 @@ class Dataset(torchDataset):
154
138
  if len(self.class_topred) > 0
155
139
  else ""
156
140
  )
141
+ + " {} metacell_mode\n".format(self.metacell_mode)
157
142
  )
158
143
 
159
- def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
144
+ def get_label_weights(
145
+ self,
146
+ obs_keys: str | list[str],
147
+ scaler: int = 10,
148
+ return_categories=False,
149
+ bypass_label=["neuron"],
150
+ ):
160
151
  """Get all weights for the given label keys."""
161
152
  if isinstance(obs_keys, str):
162
153
  obs_keys = [obs_keys]
@@ -167,16 +158,24 @@ class Dataset(torchDataset):
167
158
  )
168
159
  labels_list.append(labels_to_str)
169
160
  if len(labels_list) > 1:
170
- labels = reduce(lambda a, b: a + b, labels_list)
161
+ labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
171
162
  else:
172
163
  labels = labels_list[0]
173
164
 
174
165
  counter = Counter(labels) # type: ignore
175
- rn = {n: i for i, n in enumerate(counter.keys())}
176
- labels = np.array([rn[label] for label in labels])
177
- counter = np.array(list(counter.values()))
178
- weights = scaler / (counter + scaler)
179
- return weights, labels
166
+ if return_categories:
167
+ rn = {n: i for i, n in enumerate(counter.keys())}
168
+ labels = np.array([rn[label] for label in labels])
169
+ counter = np.array(list(counter.values()))
170
+ weights = scaler / (counter + scaler)
171
+ return weights, labels
172
+ else:
173
+ counts = np.array([counter[label] for label in labels])
174
+ if scaler is None:
175
+ weights = 1.0 / counts
176
+ else:
177
+ weights = scaler / (counts + scaler)
178
+ return weights
180
179
 
181
180
  def get_unseen_mapped_dataset_elements(self, idx: int):
182
181
  """
@@ -209,6 +208,8 @@ class Dataset(torchDataset):
209
208
  "tissue_ontology_term_id",
210
209
  "disease_ontology_term_id",
211
210
  "development_stage_ontology_term_id",
211
+ "simplified_dev_stage",
212
+ "age_group",
212
213
  "assay_ontology_term_id",
213
214
  "self_reported_ethnicity_ontology_term_id",
214
215
  ]:
@@ -235,7 +236,11 @@ class Dataset(torchDataset):
235
236
  .df(include=["parents__ontology_id"])
236
237
  .set_index("ontology_id")
237
238
  )
238
- elif clss == "development_stage_ontology_term_id":
239
+ elif clss in [
240
+ "development_stage_ontology_term_id",
241
+ "simplified_dev_stage",
242
+ "age_group",
243
+ ]:
239
244
  parentdf = (
240
245
  bt.DevelopmentalStage.filter()
241
246
  .df(include=["parents__ontology_id"])
@@ -268,7 +273,7 @@ class Dataset(torchDataset):
268
273
  if len(j) == 0:
269
274
  groupings.pop(i)
270
275
  self.labels_groupings[clss] = groupings
271
- if clss in self.clss_to_pred:
276
+ if clss in self.clss_to_predict:
272
277
  # if we have added new clss, we need to update the encoder with them too.
273
278
  mlength = len(self.mapped_dataset.encoders[clss])
274
279
 
@@ -362,6 +367,7 @@ def mapped(
362
367
  dtype: str | None = None,
363
368
  stream: bool = False,
364
369
  is_run_input: bool | None = None,
370
+ metacell_mode: bool = False,
365
371
  ) -> MappedCollection:
366
372
  path_list = []
367
373
  for artifact in dataset.artifacts.all():
@@ -384,5 +390,6 @@ def mapped(
384
390
  cache_categories=cache_categories,
385
391
  parallel=parallel,
386
392
  dtype=dtype,
393
+ metacell_mode=metacell_mode,
387
394
  )
388
395
  return ds