scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
- 1.6.4
1
+ 1.8.0
scdataloader/__init__.py CHANGED
@@ -2,3 +2,5 @@ from .collator import Collator
2
2
  from .data import Dataset, SimpleAnnDataset
3
3
  from .datamodule import DataModule
4
4
  from .preprocess import Preprocessor
5
+
6
+ __version__ = "1.7.0"
scdataloader/__main__.py CHANGED
@@ -10,157 +10,218 @@ from scdataloader.preprocess import (
10
10
  )
11
11
 
12
12
 
13
- # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
13
+ # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" \
14
+ # --description="scPRINT-V2 datasets" --new_name="scprint v2" --n_hvg_for_postp=4000 --cache=False \
15
+ # --filter_gene_by_counts=0 --filter_cell_by_counts=300 --min_valid_genes_id=500 \
16
+ # --min_nnz_genes=120 --min_dataset_size=100 --maxdropamount=90 \
17
+ # --organisms=["NCBITaxon:9606","NCBITaxon:9544","NCBITaxon:9483","NCBITaxon:10090"] \
18
+ # --start_at=0
14
19
  def main():
15
20
  """
16
- main function to preprocess datasets in a given lamindb collection.
21
+ Main function to either preprocess datasets in a lamindb collection or populate ontologies.
17
22
  """
18
23
  parser = argparse.ArgumentParser(
19
- description="Preprocess datasets in a given lamindb collection."
24
+ description="Preprocess datasets or populate ontologies."
20
25
  )
21
- parser.add_argument(
26
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
27
+
28
+ # Preprocess command
29
+ preprocess_parser = subparsers.add_parser("preprocess", help="Preprocess datasets")
30
+ preprocess_parser.add_argument(
22
31
  "--name", type=str, required=True, help="Name of the input dataset"
23
32
  )
24
- parser.add_argument(
33
+ preprocess_parser.add_argument(
25
34
  "--new_name",
26
35
  type=str,
27
36
  default="preprocessed dataset",
28
37
  help="Name of the preprocessed dataset.",
29
38
  )
30
- parser.add_argument(
39
+ preprocess_parser.add_argument(
31
40
  "--description",
32
41
  type=str,
33
42
  default="preprocessed by scDataLoader",
34
43
  help="Description of the preprocessed dataset.",
35
44
  )
36
- parser.add_argument(
45
+ preprocess_parser.add_argument(
37
46
  "--start_at", type=int, default=0, help="Position to start preprocessing at."
38
47
  )
39
- parser.add_argument(
48
+ preprocess_parser.add_argument(
40
49
  "--new_version",
41
50
  type=str,
42
51
  default="2",
43
52
  help="Version of the output dataset and files.",
44
53
  )
45
- parser.add_argument(
54
+ preprocess_parser.add_argument(
46
55
  "--instance",
47
56
  type=str,
48
57
  default=None,
49
58
  help="Instance storing the input dataset, if not local",
50
59
  )
51
- parser.add_argument(
60
+ preprocess_parser.add_argument(
52
61
  "--version", type=str, default=None, help="Version of the input dataset."
53
62
  )
54
- parser.add_argument(
63
+ preprocess_parser.add_argument(
55
64
  "--filter_gene_by_counts",
56
65
  type=int,
57
66
  default=0,
58
67
  help="Determines whether to filter genes by counts.",
59
68
  )
60
- parser.add_argument(
69
+ preprocess_parser.add_argument(
61
70
  "--filter_cell_by_counts",
62
71
  type=int,
63
72
  default=0,
64
73
  help="Determines whether to filter cells by counts.",
65
74
  )
66
- parser.add_argument(
75
+ preprocess_parser.add_argument(
67
76
  "--normalize_sum",
68
77
  type=float,
69
78
  default=1e4,
70
79
  help="Determines whether to normalize the total counts of each cell to a specific value.",
71
80
  )
72
- parser.add_argument(
73
- "--subset_hvg",
81
+ preprocess_parser.add_argument(
82
+ "--n_hvg_for_postp",
74
83
  type=int,
75
84
  default=0,
76
85
  help="Determines whether to subset highly variable genes.",
77
86
  )
78
- parser.add_argument(
87
+ preprocess_parser.add_argument(
79
88
  "--hvg_flavor",
80
89
  type=str,
81
90
  default="seurat_v3",
82
91
  help="Specifies the flavor of highly variable genes selection.",
83
92
  )
84
- parser.add_argument(
93
+ preprocess_parser.add_argument(
85
94
  "--binning",
86
95
  type=Optional[int],
87
96
  default=None,
88
97
  help="Determines whether to bin the data into discrete values of number of bins provided.",
89
98
  )
90
- parser.add_argument(
99
+ preprocess_parser.add_argument(
91
100
  "--result_binned_key",
92
101
  type=str,
93
102
  default="X_binned",
94
103
  help="Specifies the key of AnnData to store the binned data.",
95
104
  )
96
- parser.add_argument(
105
+ preprocess_parser.add_argument(
97
106
  "--length_normalize",
98
107
  type=bool,
99
108
  default=False,
100
109
  help="Determines whether to normalize the length.",
101
110
  )
102
- parser.add_argument(
111
+ preprocess_parser.add_argument(
103
112
  "--force_preprocess",
104
113
  type=bool,
105
114
  default=False,
106
115
  help="Determines whether to force preprocessing.",
107
116
  )
108
- parser.add_argument(
117
+ preprocess_parser.add_argument(
109
118
  "--min_dataset_size",
110
119
  type=int,
111
120
  default=100,
112
121
  help="Specifies the minimum dataset size.",
113
122
  )
114
- parser.add_argument(
123
+ preprocess_parser.add_argument(
115
124
  "--min_valid_genes_id",
116
125
  type=int,
117
126
  default=10_000,
118
127
  help="Specifies the minimum valid genes id.",
119
128
  )
120
- parser.add_argument(
129
+ preprocess_parser.add_argument(
121
130
  "--min_nnz_genes",
122
131
  type=int,
123
- default=400,
132
+ default=200,
124
133
  help="Specifies the minimum non-zero genes.",
125
134
  )
126
- parser.add_argument(
135
+ preprocess_parser.add_argument(
127
136
  "--maxdropamount",
128
137
  type=int,
129
138
  default=50,
130
139
  help="Specifies the maximum drop amount.",
131
140
  )
132
- parser.add_argument(
141
+ preprocess_parser.add_argument(
133
142
  "--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
134
143
  )
135
- parser.add_argument(
144
+ preprocess_parser.add_argument(
136
145
  "--pct_mt_outlier",
137
146
  type=int,
138
147
  default=8,
139
148
  help="Specifies the percentage of MT outlier.",
140
149
  )
141
- parser.add_argument(
142
- "--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
150
+ preprocess_parser.add_argument(
151
+ "--batch_keys",
152
+ type=list[str],
153
+ default=[
154
+ "assay_ontology_term_id",
155
+ "self_reported_ethnicity_ontology_term_id",
156
+ "sex_ontology_term_id",
157
+ "donor_id",
158
+ "suspension_type",
159
+ ],
160
+ help="Specifies the batch keys.",
143
161
  )
144
- parser.add_argument(
162
+ preprocess_parser.add_argument(
145
163
  "--skip_validate",
146
164
  type=bool,
147
165
  default=False,
148
166
  help="Determines whether to skip validation.",
149
167
  )
150
- parser.add_argument(
168
+ preprocess_parser.add_argument(
151
169
  "--do_postp",
152
170
  type=bool,
153
- default=False,
171
+ default=True,
154
172
  help="Determines whether to do postprocessing.",
155
173
  )
156
- parser.add_argument(
174
+ preprocess_parser.add_argument(
157
175
  "--cache",
158
176
  type=bool,
159
- default=True,
177
+ default=False,
160
178
  help="Determines whether to cache the dataset.",
161
179
  )
180
+ preprocess_parser.add_argument(
181
+ "--organisms",
182
+ type=list,
183
+ default=[
184
+ "NCBITaxon:9606",
185
+ "NCBITaxon:10090",
186
+ ],
187
+ help="Determines the organisms to keep.",
188
+ )
189
+ preprocess_parser.add_argument(
190
+ "--force_preloaded",
191
+ type=bool,
192
+ default=False,
193
+ help="Determines whether the dataset is preloaded.",
194
+ )
195
+ # Populate command
196
+ populate_parser = subparsers.add_parser("populate", help="Populate ontologies")
197
+ populate_parser.add_argument(
198
+ "what",
199
+ nargs="?",
200
+ default="all",
201
+ choices=[
202
+ "all",
203
+ "organisms",
204
+ "celltypes",
205
+ "diseases",
206
+ "tissues",
207
+ "assays",
208
+ "ethnicities",
209
+ "sex",
210
+ "dev_stages",
211
+ ],
212
+ help="What ontologies to populate",
213
+ )
162
214
  args = parser.parse_args()
163
215
 
216
+ if args.command == "populate":
217
+ from scdataloader.utils import populate_my_ontology
218
+
219
+ if args.what != "all":
220
+ raise ValueError("Only 'all' is supported for now")
221
+ else:
222
+ populate_my_ontology()
223
+ return
224
+
164
225
  # Load the collection
165
226
  # if not args.preprocess:
166
227
  # print("Only preprocess is available for now")
@@ -182,7 +243,7 @@ def main():
182
243
  filter_gene_by_counts=args.filter_gene_by_counts,
183
244
  filter_cell_by_counts=args.filter_cell_by_counts,
184
245
  normalize_sum=args.normalize_sum,
185
- subset_hvg=args.subset_hvg,
246
+ n_hvg_for_postp=args.n_hvg_for_postp,
186
247
  hvg_flavor=args.hvg_flavor,
187
248
  cache=args.cache,
188
249
  binning=args.binning,
@@ -195,12 +256,13 @@ def main():
195
256
  maxdropamount=args.maxdropamount,
196
257
  madoutlier=args.madoutlier,
197
258
  pct_mt_outlier=args.pct_mt_outlier,
198
- batch_key=args.batch_key,
259
+ batch_keys=args.batch_keys,
199
260
  skip_validate=args.skip_validate,
200
261
  do_postp=args.do_postp,
201
262
  additional_preprocess=additional_preprocess,
202
263
  additional_postprocess=additional_postprocess,
203
264
  keep_files=False,
265
+ force_preloaded=args.force_preloaded,
204
266
  )
205
267
 
206
268
  # Preprocess the dataset
scdataloader/collator.py CHANGED
@@ -23,7 +23,8 @@ class Collator:
23
23
  class_names: list[str] = [],
24
24
  genelist: list[str] = [],
25
25
  downsample: Optional[float] = None, # don't use it for training!
26
- save_output: bool = False,
26
+ save_output: Optional[str] = None,
27
+ metacell_mode: bool = False,
27
28
  ):
28
29
  """
29
30
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -59,8 +60,9 @@ class Collator:
59
60
  If [] all genes will be considered
60
61
  downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
61
62
  This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
62
- save_output (bool, optional): If True, saves the output to a file. Defaults to False.
63
+ save_output (str, optional): If not None, saves the output to a file. Defaults to None.
63
64
  This is mainly for debugging purposes
65
+ metacell_mode (bool, optional): Whether to sample a metacell. Defaults to False.
64
66
  """
65
67
  self.organisms = organisms
66
68
  self.genedf = load_genes(organisms)
@@ -80,6 +82,7 @@ class Collator:
80
82
  self.accepted_genes = {}
81
83
  self.downsample = downsample
82
84
  self.to_subset = {}
85
+ self.metacell_mode = metacell_mode
83
86
  self._setup(org_to_id, valid_genes, genelist)
84
87
 
85
88
  def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
@@ -131,6 +134,7 @@ class Collator:
131
134
  tp = []
132
135
  dataset = []
133
136
  nnz_loc = []
137
+ is_meta = []
134
138
  for elem in batch:
135
139
  organism_id = elem[self.organism_name]
136
140
  if organism_id not in self.organism_ids:
@@ -193,16 +197,16 @@ class Collator:
193
197
  tp.append(elem[self.tp_name])
194
198
  else:
195
199
  tp.append(0)
196
-
200
+ if self.metacell_mode:
201
+ is_meta.append(elem["is_meta"])
197
202
  other_classes.append([elem[i] for i in self.class_names])
198
-
199
203
  expr = np.array(exprs)
200
204
  tp = np.array(tp)
201
205
  gene_locs = np.array(gene_locs)
202
206
  total_count = np.array(total_count)
203
207
  other_classes = np.array(other_classes)
204
208
  dataset = np.array(dataset)
205
-
209
+ is_meta = np.array(is_meta)
206
210
  # normalize counts
207
211
  if self.norm_to is not None:
208
212
  expr = (expr * self.norm_to) / total_count[:, None]
@@ -229,12 +233,14 @@ class Collator:
229
233
  "tp": Tensor(tp),
230
234
  "depth": Tensor(total_count),
231
235
  }
236
+ if self.metacell_mode:
237
+ ret.update({"is_meta": Tensor(is_meta).int()})
232
238
  if len(dataset) > 0:
233
239
  ret.update({"dataset": Tensor(dataset).to(long)})
234
240
  if self.downsample is not None:
235
241
  ret["x"] = downsample_profile(ret["x"], self.downsample)
236
- if self.save_output:
237
- with open("collator_output.txt", "a") as f:
242
+ if self.save_output is not None:
243
+ with open(self.save_output, "a") as f:
238
244
  np.savetxt(f, ret["x"].numpy())
239
245
  return ret
240
246
 
scdataloader/config.py CHANGED
@@ -110,3 +110,102 @@ COARSE_ASSAY = {
110
110
  "TruDrop": "",
111
111
  "Visium Spatial Gene Expression": "",
112
112
  }
113
+
114
+
115
+ MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
116
+ "HsapDv:0010000": [
117
+ "MmusDv:0000092", # postnatal stage
118
+ ],
119
+ "HsapDv:0000258": [ # mature stage
120
+ "MmusDv:0000110", # mature stage
121
+ "HsapDv:0000204",
122
+ ],
123
+ "HsapDv:0000227": [ # late adult stage
124
+ "MmusDv:0000091", # 20 month-old stage
125
+ "MmusDv:0000089", # 18 month-old stage
126
+ ],
127
+ "HsapDv:0000272": [], # 60-79 year-old stage
128
+ "HsapDv:0000095": [], # 80 year-old and over stage
129
+ "HsapDv:0000267": [ # middle aged stage
130
+ "MmusDv:0000087", # 16 month-old stage
131
+ "UBERON:0018241", # prime adult stage
132
+ "MmusDv:0000083", # 12 month-old stage
133
+ "HsapDv:0000092", # same
134
+ ],
135
+ "HsapDv:0000266": [ # young adult stage
136
+ "MmusDv:0000050", # 6 weeks
137
+ "HsapDv:0000089", # same
138
+ "MmusDv:0000051", # 7 weeks
139
+ "MmusDv:0000052", # 8 weeks
140
+ "MmusDv:0000053", # 9 weeks
141
+ "MmusDv:0000054", # 10 weeks
142
+ "MmusDv:0000055", # 11 weeks
143
+ "MmusDv:0000056", # 12 weeks
144
+ "MmusDv:0000057", # 13 weeks
145
+ "MmusDv:0000058", # 14 weeks
146
+ "MmusDv:0000059", # 15 weeks
147
+ "MmusDv:0000061", # early adult stage
148
+ "MmusDv:0000062", # 2 month-old stage
149
+ "MmusDv:0000063", # 3 month-old stage
150
+ "MmusDv:0000064", # 4 month-old stage
151
+ "MmusDv:0000065", # 16 weeks
152
+ "MmusDv:0000066", # 17 weeks
153
+ "MmusDv:0000067", # 18 weeks
154
+ "MmusDv:0000068", # 19 weeks
155
+ "MmusDv:0000070", # 20 weeks
156
+ "MmusDv:0000071", # 21 weeks
157
+ "MmusDv:0000072", # 22 weeks
158
+ "MmusDv:0000073", # 23 weeks
159
+ "MmusDv:0000074", # 24 weeks
160
+ "MmusDv:0000077", # 6 month-old stage
161
+ "MmusDv:0000079", # 8 month-old stage
162
+ "MmusDv:0000098", # 25 weeks
163
+ "MmusDv:0000099", # 26 weeks
164
+ "MmusDv:0000102", # 29 weeks
165
+ ],
166
+ "HsapDv:0000265": [], # child stage (1-4 yo)
167
+ "HsapDv:0000271": [ # juvenile stage (5-14 yo)
168
+ "MmusDv:0000048", # 4 weeks
169
+ "MmusDv:0000049", # 5 weeks
170
+ ],
171
+ "HsapDv:0000260": [ # infant stage
172
+ "MmusDv:0000046", # 2 weeks
173
+ "MmusDv:0000045", # 1 week
174
+ "MmusDv:0000047", # 3 weeks
175
+ "HsapDv:0000083",
176
+ ],
177
+ "HsapDv:0000262": [ # newborn stage (0-28 days)
178
+ "MmusDv:0000036", # Theiler stage 27
179
+ "MmusDv:0000037", # Theiler stage 28
180
+ "MmusDv:0000113", # 4-7 days
181
+ ],
182
+ "HsapDv:0000007": [], # Carnegie stage 03
183
+ "HsapDv:0000008": [], # Carnegie stage 04
184
+ "HsapDv:0000009": [], # Carnegie stage 05
185
+ "HsapDv:0000003": [], # Carnegie stage 01
186
+ "HsapDv:0000005": [], # Carnegie stage 02
187
+ "HsapDv:0000010": [], # gastrula stage
188
+ "HsapDv:0000012": [], # neurula stage
189
+ "HsapDv:0000015": [ # organogenesis stage
190
+ "MmusDv:0000019", # Theiler stage 13
191
+ "MmusDv:0000020", # Theiler stage 12
192
+ "MmusDv:0000021", # Theiler stage 14
193
+ "MmusDv:0000022", # Theiler stage 15
194
+ "MmusDv:0000023", # Theiler stage 16
195
+ "MmusDv:0000024", # Theiler stage 17
196
+ "MmusDv:0000025", # Theiler stage 18
197
+ "MmusDv:0000026", # Theiler stage 19
198
+ "MmusDv:0000027", # Theiler stage 20
199
+ "MmusDv:0000028", # Theiler stage 21
200
+ "MmusDv:0000029", # Theiler stage 22
201
+ ],
202
+ "HsapDv:0000037": [ # fetal stage
203
+ "MmusDv:0000033", # Theiler stage 24
204
+ "MmusDv:0000034", # Theiler stage 25
205
+ "MmusDv:0000035", # Theiler stage 26
206
+ "MmusDv:0000032", # Theiler stage 23
207
+ ],
208
+ "unknown": [
209
+ "MmusDv:0000041", # unknown
210
+ ],
211
+ }
scdataloader/data.py CHANGED
@@ -10,8 +10,6 @@ import lamindb as ln
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  from anndata import AnnData
13
- from lamindb.core import MappedCollection
14
- from lamindb.core._mapped_collection import _Connect
15
13
  from lamindb.core.storage._anndata_accessor import _safer_read_index
16
14
  from scipy.sparse import issparse
17
15
  from torch.utils.data import Dataset as torchDataset
@@ -19,6 +17,7 @@ from torch.utils.data import Dataset as torchDataset
19
17
  from scdataloader.utils import get_ancestry_mapping, load_genes
20
18
 
21
19
  from .config import LABELS_TOADD
20
+ from .mapped import MappedCollection, _Connect
22
21
 
23
22
 
24
23
  @dataclass
@@ -43,7 +42,7 @@ class Dataset(torchDataset):
43
42
  organisms (list[str]): list of organisms to load
44
43
  (for now only validates the the genes map to this organism)
45
44
  obs (list[str]): list of observations to load from the Collection
46
- clss_to_pred (list[str]): list of observations to encode
45
+ clss_to_predict (list[str]): list of observations to encode
47
46
  join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
48
47
  hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
49
48
  """
@@ -53,37 +52,23 @@ class Dataset(torchDataset):
53
52
  organisms: Optional[Union[list[str], str]] = field(
54
53
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
55
54
  )
56
- obs: Optional[list[str]] = field(
57
- default_factory=[
58
- "self_reported_ethnicity_ontology_term_id",
59
- "assay_ontology_term_id",
60
- "development_stage_ontology_term_id",
61
- "disease_ontology_term_id",
62
- "cell_type_ontology_term_id",
63
- "tissue_ontology_term_id",
64
- "sex_ontology_term_id",
65
- #'dataset_id',
66
- #'cell_culture',
67
- # "dpt_group",
68
- # "heat_diff",
69
- # "nnz",
70
- ]
71
- )
72
55
  # set of obs to prepare for prediction (encode)
73
- clss_to_pred: Optional[list[str]] = field(default_factory=list)
56
+ clss_to_predict: Optional[list[str]] = field(default_factory=list)
74
57
  # set of obs that need to be hierarchically prepared
75
58
  hierarchical_clss: Optional[list[str]] = field(default_factory=list)
76
59
  join_vars: Literal["inner", "outer"] | None = None
60
+ metacell_mode: float = 0.0
77
61
 
78
62
  def __post_init__(self):
79
63
  self.mapped_dataset = mapped(
80
64
  self.lamin_dataset,
81
- obs_keys=self.obs,
65
+ obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
82
66
  join=self.join_vars,
83
- encode_labels=self.clss_to_pred,
67
+ encode_labels=self.clss_to_predict,
84
68
  unknown_label="unknown",
85
69
  stream=True,
86
70
  parallel=True,
71
+ metacell_mode=self.metacell_mode,
87
72
  )
88
73
  print(
89
74
  "won't do any check but we recommend to have your dataset coming from local storage"
@@ -93,8 +78,8 @@ class Dataset(torchDataset):
93
78
  # generate tree from ontologies
94
79
  if len(self.hierarchical_clss) > 0:
95
80
  self.define_hierarchies(self.hierarchical_clss)
96
- if len(self.clss_to_pred) > 0:
97
- for clss in self.clss_to_pred:
81
+ if len(self.clss_to_predict) > 0:
82
+ for clss in self.clss_to_predict:
98
83
  if clss not in self.hierarchical_clss:
99
84
  # otherwise it's already been done
100
85
  self.class_topred[clss] = set(
@@ -143,8 +128,7 @@ class Dataset(torchDataset):
143
128
  + "dataset contains:\n"
144
129
  + " {} cells\n".format(self.mapped_dataset.__len__())
145
130
  + " {} genes\n".format(self.genedf.shape[0])
146
- + " {} labels\n".format(len(self.obs))
147
- + " {} clss_to_pred\n".format(len(self.clss_to_pred))
131
+ + " {} clss_to_predict\n".format(len(self.clss_to_predict))
148
132
  + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
149
133
  + " {} organisms\n".format(len(self.organisms))
150
134
  + (
@@ -154,9 +138,16 @@ class Dataset(torchDataset):
154
138
  if len(self.class_topred) > 0
155
139
  else ""
156
140
  )
141
+ + " {} metacell_mode\n".format(self.metacell_mode)
157
142
  )
158
143
 
159
- def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
144
+ def get_label_weights(
145
+ self,
146
+ obs_keys: str | list[str],
147
+ scaler: int = 10,
148
+ return_categories=False,
149
+ bypass_label=["neuron"],
150
+ ):
160
151
  """Get all weights for the given label keys."""
161
152
  if isinstance(obs_keys, str):
162
153
  obs_keys = [obs_keys]
@@ -167,16 +158,24 @@ class Dataset(torchDataset):
167
158
  )
168
159
  labels_list.append(labels_to_str)
169
160
  if len(labels_list) > 1:
170
- labels = reduce(lambda a, b: a + b, labels_list)
161
+ labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
171
162
  else:
172
163
  labels = labels_list[0]
173
164
 
174
165
  counter = Counter(labels) # type: ignore
175
- rn = {n: i for i, n in enumerate(counter.keys())}
176
- labels = np.array([rn[label] for label in labels])
177
- counter = np.array(list(counter.values()))
178
- weights = scaler / (counter + scaler)
179
- return weights, labels
166
+ if return_categories:
167
+ rn = {n: i for i, n in enumerate(counter.keys())}
168
+ labels = np.array([rn[label] for label in labels])
169
+ counter = np.array(list(counter.values()))
170
+ weights = scaler / (counter + scaler)
171
+ return weights, labels
172
+ else:
173
+ counts = np.array([counter[label] for label in labels])
174
+ if scaler is None:
175
+ weights = 1.0 / counts
176
+ else:
177
+ weights = scaler / (counts + scaler)
178
+ return weights
180
179
 
181
180
  def get_unseen_mapped_dataset_elements(self, idx: int):
182
181
  """
@@ -209,6 +208,8 @@ class Dataset(torchDataset):
209
208
  "tissue_ontology_term_id",
210
209
  "disease_ontology_term_id",
211
210
  "development_stage_ontology_term_id",
211
+ "simplified_dev_stage",
212
+ "age_group",
212
213
  "assay_ontology_term_id",
213
214
  "self_reported_ethnicity_ontology_term_id",
214
215
  ]:
@@ -235,7 +236,11 @@ class Dataset(torchDataset):
235
236
  .df(include=["parents__ontology_id"])
236
237
  .set_index("ontology_id")
237
238
  )
238
- elif clss == "development_stage_ontology_term_id":
239
+ elif clss in [
240
+ "development_stage_ontology_term_id",
241
+ "simplified_dev_stage",
242
+ "age_group",
243
+ ]:
239
244
  parentdf = (
240
245
  bt.DevelopmentalStage.filter()
241
246
  .df(include=["parents__ontology_id"])
@@ -268,7 +273,7 @@ class Dataset(torchDataset):
268
273
  if len(j) == 0:
269
274
  groupings.pop(i)
270
275
  self.labels_groupings[clss] = groupings
271
- if clss in self.clss_to_pred:
276
+ if clss in self.clss_to_predict:
272
277
  # if we have added new clss, we need to update the encoder with them too.
273
278
  mlength = len(self.mapped_dataset.encoders[clss])
274
279
 
@@ -354,6 +359,8 @@ class SimpleAnnDataset(torchDataset):
354
359
  def mapped(
355
360
  dataset,
356
361
  obs_keys: list[str] | None = None,
362
+ obsm_keys: list[str] | None = None,
363
+ obs_filter: dict[str, str | tuple[str, ...]] | None = None,
357
364
  join: Literal["inner", "outer"] | None = "inner",
358
365
  encode_labels: bool | list[str] = True,
359
366
  unknown_label: str | dict[str, str] | None = None,
@@ -362,6 +369,8 @@ def mapped(
362
369
  dtype: str | None = None,
363
370
  stream: bool = False,
364
371
  is_run_input: bool | None = None,
372
+ metacell_mode: bool = False,
373
+ meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
365
374
  ) -> MappedCollection:
366
375
  path_list = []
367
376
  for artifact in dataset.artifacts.all():
@@ -378,11 +387,15 @@ def mapped(
378
387
  ds = MappedCollection(
379
388
  path_list=path_list,
380
389
  obs_keys=obs_keys,
390
+ obsm_keys=obsm_keys,
391
+ obs_filter=obs_filter,
381
392
  join=join,
382
393
  encode_labels=encode_labels,
383
394
  unknown_label=unknown_label,
384
395
  cache_categories=cache_categories,
385
396
  parallel=parallel,
386
397
  dtype=dtype,
398
+ meta_assays=meta_assays,
399
+ metacell_mode=metacell_mode,
387
400
  )
388
401
  return ds