scdataloader 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -0
- scdataloader/__main__.py +38 -8
- scdataloader/collator.py +6 -2
- scdataloader/config.py +99 -0
- scdataloader/data.py +44 -37
- scdataloader/datamodule.py +124 -41
- scdataloader/mapped.py +700 -0
- scdataloader/preprocess.py +229 -86
- scdataloader/utils.py +212 -27
- {scdataloader-1.6.3.dist-info → scdataloader-1.7.0.dist-info}/METADATA +9 -6
- scdataloader-1.7.0.dist-info/RECORD +15 -0
- {scdataloader-1.6.3.dist-info → scdataloader-1.7.0.dist-info}/WHEEL +1 -1
- scdataloader-1.6.3.dist-info/RECORD +0 -14
- {scdataloader-1.6.3.dist-info → scdataloader-1.7.0.dist-info}/licenses/LICENSE +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
1.
|
|
1
|
+
1.7.0
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -10,7 +10,12 @@ from scdataloader.preprocess import (
|
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15"
|
|
13
|
+
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" \
|
|
14
|
+
# --description="scPRINT-V2 datasets" --new_name="scprint v2" --n_hvg_for_postp=4000 --cache=False \
|
|
15
|
+
# --filter_gene_by_counts=0 --filter_cell_by_counts=300 --min_valid_genes_id=500 \
|
|
16
|
+
# --min_nnz_genes=120 --min_dataset_size=100 --maxdropamount=90 \
|
|
17
|
+
# --organisms=["NCBITaxon:9606","NCBITaxon:9544","NCBITaxon:9483","NCBITaxon:10090"] \
|
|
18
|
+
# --start_at=0
|
|
14
19
|
def main():
|
|
15
20
|
"""
|
|
16
21
|
main function to preprocess datasets in a given lamindb collection.
|
|
@@ -70,7 +75,7 @@ def main():
|
|
|
70
75
|
help="Determines whether to normalize the total counts of each cell to a specific value.",
|
|
71
76
|
)
|
|
72
77
|
parser.add_argument(
|
|
73
|
-
"--
|
|
78
|
+
"--n_hvg_for_postp",
|
|
74
79
|
type=int,
|
|
75
80
|
default=0,
|
|
76
81
|
help="Determines whether to subset highly variable genes.",
|
|
@@ -120,7 +125,7 @@ def main():
|
|
|
120
125
|
parser.add_argument(
|
|
121
126
|
"--min_nnz_genes",
|
|
122
127
|
type=int,
|
|
123
|
-
default=
|
|
128
|
+
default=200,
|
|
124
129
|
help="Specifies the minimum non-zero genes.",
|
|
125
130
|
)
|
|
126
131
|
parser.add_argument(
|
|
@@ -139,7 +144,16 @@ def main():
|
|
|
139
144
|
help="Specifies the percentage of MT outlier.",
|
|
140
145
|
)
|
|
141
146
|
parser.add_argument(
|
|
142
|
-
"--
|
|
147
|
+
"--batch_keys",
|
|
148
|
+
type=list[str],
|
|
149
|
+
default=[
|
|
150
|
+
"assay_ontology_term_id",
|
|
151
|
+
"self_reported_ethnicity_ontology_term_id",
|
|
152
|
+
"sex_ontology_term_id",
|
|
153
|
+
"donor_id",
|
|
154
|
+
"suspension_type",
|
|
155
|
+
],
|
|
156
|
+
help="Specifies the batch keys.",
|
|
143
157
|
)
|
|
144
158
|
parser.add_argument(
|
|
145
159
|
"--skip_validate",
|
|
@@ -150,15 +164,30 @@ def main():
|
|
|
150
164
|
parser.add_argument(
|
|
151
165
|
"--do_postp",
|
|
152
166
|
type=bool,
|
|
153
|
-
default=
|
|
167
|
+
default=True,
|
|
154
168
|
help="Determines whether to do postprocessing.",
|
|
155
169
|
)
|
|
156
170
|
parser.add_argument(
|
|
157
171
|
"--cache",
|
|
158
172
|
type=bool,
|
|
159
|
-
default=
|
|
173
|
+
default=False,
|
|
160
174
|
help="Determines whether to cache the dataset.",
|
|
161
175
|
)
|
|
176
|
+
parser.add_argument(
|
|
177
|
+
"--organisms",
|
|
178
|
+
type=list,
|
|
179
|
+
default=[
|
|
180
|
+
"NCBITaxon:9606",
|
|
181
|
+
"NCBITaxon:10090",
|
|
182
|
+
],
|
|
183
|
+
help="Determines the organisms to keep.",
|
|
184
|
+
)
|
|
185
|
+
parser.add_argument(
|
|
186
|
+
"--force_preloaded",
|
|
187
|
+
type=bool,
|
|
188
|
+
default=False,
|
|
189
|
+
help="Determines whether the dataset is preloaded.",
|
|
190
|
+
)
|
|
162
191
|
args = parser.parse_args()
|
|
163
192
|
|
|
164
193
|
# Load the collection
|
|
@@ -182,7 +211,7 @@ def main():
|
|
|
182
211
|
filter_gene_by_counts=args.filter_gene_by_counts,
|
|
183
212
|
filter_cell_by_counts=args.filter_cell_by_counts,
|
|
184
213
|
normalize_sum=args.normalize_sum,
|
|
185
|
-
|
|
214
|
+
n_hvg_for_postp=args.n_hvg_for_postp,
|
|
186
215
|
hvg_flavor=args.hvg_flavor,
|
|
187
216
|
cache=args.cache,
|
|
188
217
|
binning=args.binning,
|
|
@@ -195,12 +224,13 @@ def main():
|
|
|
195
224
|
maxdropamount=args.maxdropamount,
|
|
196
225
|
madoutlier=args.madoutlier,
|
|
197
226
|
pct_mt_outlier=args.pct_mt_outlier,
|
|
198
|
-
|
|
227
|
+
batch_keys=args.batch_keys,
|
|
199
228
|
skip_validate=args.skip_validate,
|
|
200
229
|
do_postp=args.do_postp,
|
|
201
230
|
additional_preprocess=additional_preprocess,
|
|
202
231
|
additional_postprocess=additional_postprocess,
|
|
203
232
|
keep_files=False,
|
|
233
|
+
force_preloaded=args.force_preloaded,
|
|
204
234
|
)
|
|
205
235
|
|
|
206
236
|
# Preprocess the dataset
|
scdataloader/collator.py
CHANGED
|
@@ -131,6 +131,7 @@ class Collator:
|
|
|
131
131
|
tp = []
|
|
132
132
|
dataset = []
|
|
133
133
|
nnz_loc = []
|
|
134
|
+
is_meta = []
|
|
134
135
|
for elem in batch:
|
|
135
136
|
organism_id = elem[self.organism_name]
|
|
136
137
|
if organism_id not in self.organism_ids:
|
|
@@ -188,12 +189,12 @@ class Collator:
|
|
|
188
189
|
loc = loc[self.to_subset[organism_id]]
|
|
189
190
|
exprs.append(expr)
|
|
190
191
|
gene_locs.append(loc)
|
|
191
|
-
|
|
192
|
+
if "is_meta" in elem:
|
|
193
|
+
is_meta.append(elem["is_meta"])
|
|
192
194
|
if self.tp_name is not None:
|
|
193
195
|
tp.append(elem[self.tp_name])
|
|
194
196
|
else:
|
|
195
197
|
tp.append(0)
|
|
196
|
-
|
|
197
198
|
other_classes.append([elem[i] for i in self.class_names])
|
|
198
199
|
|
|
199
200
|
expr = np.array(exprs)
|
|
@@ -202,6 +203,7 @@ class Collator:
|
|
|
202
203
|
total_count = np.array(total_count)
|
|
203
204
|
other_classes = np.array(other_classes)
|
|
204
205
|
dataset = np.array(dataset)
|
|
206
|
+
is_meta = np.array(is_meta)
|
|
205
207
|
|
|
206
208
|
# normalize counts
|
|
207
209
|
if self.norm_to is not None:
|
|
@@ -229,6 +231,8 @@ class Collator:
|
|
|
229
231
|
"tp": Tensor(tp),
|
|
230
232
|
"depth": Tensor(total_count),
|
|
231
233
|
}
|
|
234
|
+
if len(is_meta) > 0:
|
|
235
|
+
ret.update({"is_meta": Tensor(is_meta)})
|
|
232
236
|
if len(dataset) > 0:
|
|
233
237
|
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
234
238
|
if self.downsample is not None:
|
scdataloader/config.py
CHANGED
|
@@ -110,3 +110,102 @@ COARSE_ASSAY = {
|
|
|
110
110
|
"TruDrop": "",
|
|
111
111
|
"Visium Spatial Gene Expression": "",
|
|
112
112
|
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
|
|
116
|
+
"HsapDv:0010000": [
|
|
117
|
+
"MmusDv:0000092", # postnatal stage
|
|
118
|
+
],
|
|
119
|
+
"HsapDv:0000258": [ # mature stage
|
|
120
|
+
"MmusDv:0000110", # mature stage
|
|
121
|
+
"HsapDv:0000204",
|
|
122
|
+
],
|
|
123
|
+
"HsapDv:0000227": [ # late adult stage
|
|
124
|
+
"MmusDv:0000091", # 20 month-old stage
|
|
125
|
+
"MmusDv:0000089", # 18 month-old stage
|
|
126
|
+
],
|
|
127
|
+
"HsapDv:0000272": [], # 60-79 year-old stage
|
|
128
|
+
"HsapDv:0000095": [], # 80 year-old and over stage
|
|
129
|
+
"HsapDv:0000267": [ # middle aged stage
|
|
130
|
+
"MmusDv:0000087", # 16 month-old stage
|
|
131
|
+
"UBERON:0018241", # prime adult stage
|
|
132
|
+
"MmusDv:0000083", # 12 month-old stage
|
|
133
|
+
"HsapDv:0000092", # same
|
|
134
|
+
],
|
|
135
|
+
"HsapDv:0000266": [ # young adult stage
|
|
136
|
+
"MmusDv:0000050", # 6 weeks
|
|
137
|
+
"HsapDv:0000089", # same
|
|
138
|
+
"MmusDv:0000051", # 7 weeks
|
|
139
|
+
"MmusDv:0000052", # 8 weeks
|
|
140
|
+
"MmusDv:0000053", # 9 weeks
|
|
141
|
+
"MmusDv:0000054", # 10 weeks
|
|
142
|
+
"MmusDv:0000055", # 11 weeks
|
|
143
|
+
"MmusDv:0000056", # 12 weeks
|
|
144
|
+
"MmusDv:0000057", # 13 weeks
|
|
145
|
+
"MmusDv:0000058", # 14 weeks
|
|
146
|
+
"MmusDv:0000059", # 15 weeks
|
|
147
|
+
"MmusDv:0000061", # early adult stage
|
|
148
|
+
"MmusDv:0000062", # 2 month-old stage
|
|
149
|
+
"MmusDv:0000063", # 3 month-old stage
|
|
150
|
+
"MmusDv:0000064", # 4 month-old stage
|
|
151
|
+
"MmusDv:0000065", # 16 weeks
|
|
152
|
+
"MmusDv:0000066", # 17 weeks
|
|
153
|
+
"MmusDv:0000067", # 18 weeks
|
|
154
|
+
"MmusDv:0000068", # 19 weeks
|
|
155
|
+
"MmusDv:0000070", # 20 weeks
|
|
156
|
+
"MmusDv:0000071", # 21 weeks
|
|
157
|
+
"MmusDv:0000072", # 22 weeks
|
|
158
|
+
"MmusDv:0000073", # 23 weeks
|
|
159
|
+
"MmusDv:0000074", # 24 weeks
|
|
160
|
+
"MmusDv:0000077", # 6 month-old stage
|
|
161
|
+
"MmusDv:0000079", # 8 month-old stage
|
|
162
|
+
"MmusDv:0000098", # 25 weeks
|
|
163
|
+
"MmusDv:0000099", # 26 weeks
|
|
164
|
+
"MmusDv:0000102", # 29 weeks
|
|
165
|
+
],
|
|
166
|
+
"HsapDv:0000265": [], # child stage (1-4 yo)
|
|
167
|
+
"HsapDv:0000271": [ # juvenile stage (5-14 yo)
|
|
168
|
+
"MmusDv:0000048", # 4 weeks
|
|
169
|
+
"MmusDv:0000049", # 5 weeks
|
|
170
|
+
],
|
|
171
|
+
"HsapDv:0000260": [ # infant stage
|
|
172
|
+
"MmusDv:0000046", # 2 weeks
|
|
173
|
+
"MmusDv:0000045", # 1 week
|
|
174
|
+
"MmusDv:0000047", # 3 weeks
|
|
175
|
+
"HsapDv:0000083",
|
|
176
|
+
],
|
|
177
|
+
"HsapDv:0000262": [ # newborn stage (0-28 days)
|
|
178
|
+
"MmusDv:0000036", # Theiler stage 27
|
|
179
|
+
"MmusDv:0000037", # Theiler stage 28
|
|
180
|
+
"MmusDv:0000113", # 4-7 days
|
|
181
|
+
],
|
|
182
|
+
"HsapDv:0000007": [], # Carnegie stage 03
|
|
183
|
+
"HsapDv:0000008": [], # Carnegie stage 04
|
|
184
|
+
"HsapDv:0000009": [], # Carnegie stage 05
|
|
185
|
+
"HsapDv:0000003": [], # Carnegie stage 01
|
|
186
|
+
"HsapDv:0000005": [], # Carnegie stage 02
|
|
187
|
+
"HsapDv:0000010": [], # gastrula stage
|
|
188
|
+
"HsapDv:0000012": [], # neurula stage
|
|
189
|
+
"HsapDv:0000015": [ # organogenesis stage
|
|
190
|
+
"MmusDv:0000019", # Theiler stage 13
|
|
191
|
+
"MmusDv:0000020", # Theiler stage 12
|
|
192
|
+
"MmusDv:0000021", # Theiler stage 14
|
|
193
|
+
"MmusDv:0000022", # Theiler stage 15
|
|
194
|
+
"MmusDv:0000023", # Theiler stage 16
|
|
195
|
+
"MmusDv:0000024", # Theiler stage 17
|
|
196
|
+
"MmusDv:0000025", # Theiler stage 18
|
|
197
|
+
"MmusDv:0000026", # Theiler stage 19
|
|
198
|
+
"MmusDv:0000027", # Theiler stage 20
|
|
199
|
+
"MmusDv:0000028", # Theiler stage 21
|
|
200
|
+
"MmusDv:0000029", # Theiler stage 22
|
|
201
|
+
],
|
|
202
|
+
"HsapDv:0000037": [ # fetal stage
|
|
203
|
+
"MmusDv:0000033", # Theiler stage 24
|
|
204
|
+
"MmusDv:0000034", # Theiler stage 25
|
|
205
|
+
"MmusDv:0000035", # Theiler stage 26
|
|
206
|
+
"MmusDv:0000032", # Theiler stage 23
|
|
207
|
+
],
|
|
208
|
+
"unknown": [
|
|
209
|
+
"MmusDv:0000041", # unknown
|
|
210
|
+
],
|
|
211
|
+
}
|
scdataloader/data.py
CHANGED
|
@@ -10,8 +10,7 @@ import lamindb as ln
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from anndata import AnnData
|
|
13
|
-
from
|
|
14
|
-
from lamindb.core._mapped_collection import _Connect
|
|
13
|
+
from .mapped import MappedCollection, _Connect
|
|
15
14
|
from lamindb.core.storage._anndata_accessor import _safer_read_index
|
|
16
15
|
from scipy.sparse import issparse
|
|
17
16
|
from torch.utils.data import Dataset as torchDataset
|
|
@@ -43,7 +42,7 @@ class Dataset(torchDataset):
|
|
|
43
42
|
organisms (list[str]): list of organisms to load
|
|
44
43
|
(for now only validates the the genes map to this organism)
|
|
45
44
|
obs (list[str]): list of observations to load from the Collection
|
|
46
|
-
|
|
45
|
+
clss_to_predict (list[str]): list of observations to encode
|
|
47
46
|
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
48
47
|
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
49
48
|
"""
|
|
@@ -53,37 +52,23 @@ class Dataset(torchDataset):
|
|
|
53
52
|
organisms: Optional[Union[list[str], str]] = field(
|
|
54
53
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
55
54
|
)
|
|
56
|
-
obs: Optional[list[str]] = field(
|
|
57
|
-
default_factory=[
|
|
58
|
-
"self_reported_ethnicity_ontology_term_id",
|
|
59
|
-
"assay_ontology_term_id",
|
|
60
|
-
"development_stage_ontology_term_id",
|
|
61
|
-
"disease_ontology_term_id",
|
|
62
|
-
"cell_type_ontology_term_id",
|
|
63
|
-
"tissue_ontology_term_id",
|
|
64
|
-
"sex_ontology_term_id",
|
|
65
|
-
#'dataset_id',
|
|
66
|
-
#'cell_culture',
|
|
67
|
-
# "dpt_group",
|
|
68
|
-
# "heat_diff",
|
|
69
|
-
# "nnz",
|
|
70
|
-
]
|
|
71
|
-
)
|
|
72
55
|
# set of obs to prepare for prediction (encode)
|
|
73
|
-
|
|
56
|
+
clss_to_predict: Optional[list[str]] = field(default_factory=list)
|
|
74
57
|
# set of obs that need to be hierarchically prepared
|
|
75
58
|
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
76
59
|
join_vars: Literal["inner", "outer"] | None = None
|
|
60
|
+
metacell_mode: float = 0.0
|
|
77
61
|
|
|
78
62
|
def __post_init__(self):
|
|
79
63
|
self.mapped_dataset = mapped(
|
|
80
64
|
self.lamin_dataset,
|
|
81
|
-
obs_keys=self.
|
|
65
|
+
obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
|
|
82
66
|
join=self.join_vars,
|
|
83
|
-
encode_labels=self.
|
|
67
|
+
encode_labels=self.clss_to_predict,
|
|
84
68
|
unknown_label="unknown",
|
|
85
69
|
stream=True,
|
|
86
70
|
parallel=True,
|
|
71
|
+
metacell_mode=self.metacell_mode,
|
|
87
72
|
)
|
|
88
73
|
print(
|
|
89
74
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
@@ -93,12 +78,12 @@ class Dataset(torchDataset):
|
|
|
93
78
|
# generate tree from ontologies
|
|
94
79
|
if len(self.hierarchical_clss) > 0:
|
|
95
80
|
self.define_hierarchies(self.hierarchical_clss)
|
|
96
|
-
if len(self.
|
|
97
|
-
for clss in self.
|
|
81
|
+
if len(self.clss_to_predict) > 0:
|
|
82
|
+
for clss in self.clss_to_predict:
|
|
98
83
|
if clss not in self.hierarchical_clss:
|
|
99
84
|
# otherwise it's already been done
|
|
100
|
-
self.class_topred[clss] =
|
|
101
|
-
clss
|
|
85
|
+
self.class_topred[clss] = set(
|
|
86
|
+
self.mapped_dataset.get_merged_categories(clss)
|
|
102
87
|
)
|
|
103
88
|
if (
|
|
104
89
|
self.mapped_dataset.unknown_label
|
|
@@ -143,8 +128,7 @@ class Dataset(torchDataset):
|
|
|
143
128
|
+ "dataset contains:\n"
|
|
144
129
|
+ " {} cells\n".format(self.mapped_dataset.__len__())
|
|
145
130
|
+ " {} genes\n".format(self.genedf.shape[0])
|
|
146
|
-
+ " {}
|
|
147
|
-
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
131
|
+
+ " {} clss_to_predict\n".format(len(self.clss_to_predict))
|
|
148
132
|
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
149
133
|
+ " {} organisms\n".format(len(self.organisms))
|
|
150
134
|
+ (
|
|
@@ -154,9 +138,16 @@ class Dataset(torchDataset):
|
|
|
154
138
|
if len(self.class_topred) > 0
|
|
155
139
|
else ""
|
|
156
140
|
)
|
|
141
|
+
+ " {} metacell_mode\n".format(self.metacell_mode)
|
|
157
142
|
)
|
|
158
143
|
|
|
159
|
-
def get_label_weights(
|
|
144
|
+
def get_label_weights(
|
|
145
|
+
self,
|
|
146
|
+
obs_keys: str | list[str],
|
|
147
|
+
scaler: int = 10,
|
|
148
|
+
return_categories=False,
|
|
149
|
+
bypass_label=["neuron"],
|
|
150
|
+
):
|
|
160
151
|
"""Get all weights for the given label keys."""
|
|
161
152
|
if isinstance(obs_keys, str):
|
|
162
153
|
obs_keys = [obs_keys]
|
|
@@ -167,16 +158,24 @@ class Dataset(torchDataset):
|
|
|
167
158
|
)
|
|
168
159
|
labels_list.append(labels_to_str)
|
|
169
160
|
if len(labels_list) > 1:
|
|
170
|
-
labels =
|
|
161
|
+
labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
|
|
171
162
|
else:
|
|
172
163
|
labels = labels_list[0]
|
|
173
164
|
|
|
174
165
|
counter = Counter(labels) # type: ignore
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
166
|
+
if return_categories:
|
|
167
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
168
|
+
labels = np.array([rn[label] for label in labels])
|
|
169
|
+
counter = np.array(list(counter.values()))
|
|
170
|
+
weights = scaler / (counter + scaler)
|
|
171
|
+
return weights, labels
|
|
172
|
+
else:
|
|
173
|
+
counts = np.array([counter[label] for label in labels])
|
|
174
|
+
if scaler is None:
|
|
175
|
+
weights = 1.0 / counts
|
|
176
|
+
else:
|
|
177
|
+
weights = scaler / (counts + scaler)
|
|
178
|
+
return weights
|
|
180
179
|
|
|
181
180
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
182
181
|
"""
|
|
@@ -209,6 +208,8 @@ class Dataset(torchDataset):
|
|
|
209
208
|
"tissue_ontology_term_id",
|
|
210
209
|
"disease_ontology_term_id",
|
|
211
210
|
"development_stage_ontology_term_id",
|
|
211
|
+
"simplified_dev_stage",
|
|
212
|
+
"age_group",
|
|
212
213
|
"assay_ontology_term_id",
|
|
213
214
|
"self_reported_ethnicity_ontology_term_id",
|
|
214
215
|
]:
|
|
@@ -235,7 +236,11 @@ class Dataset(torchDataset):
|
|
|
235
236
|
.df(include=["parents__ontology_id"])
|
|
236
237
|
.set_index("ontology_id")
|
|
237
238
|
)
|
|
238
|
-
elif clss
|
|
239
|
+
elif clss in [
|
|
240
|
+
"development_stage_ontology_term_id",
|
|
241
|
+
"simplified_dev_stage",
|
|
242
|
+
"age_group",
|
|
243
|
+
]:
|
|
239
244
|
parentdf = (
|
|
240
245
|
bt.DevelopmentalStage.filter()
|
|
241
246
|
.df(include=["parents__ontology_id"])
|
|
@@ -268,7 +273,7 @@ class Dataset(torchDataset):
|
|
|
268
273
|
if len(j) == 0:
|
|
269
274
|
groupings.pop(i)
|
|
270
275
|
self.labels_groupings[clss] = groupings
|
|
271
|
-
if clss in self.
|
|
276
|
+
if clss in self.clss_to_predict:
|
|
272
277
|
# if we have added new clss, we need to update the encoder with them too.
|
|
273
278
|
mlength = len(self.mapped_dataset.encoders[clss])
|
|
274
279
|
|
|
@@ -362,6 +367,7 @@ def mapped(
|
|
|
362
367
|
dtype: str | None = None,
|
|
363
368
|
stream: bool = False,
|
|
364
369
|
is_run_input: bool | None = None,
|
|
370
|
+
metacell_mode: bool = False,
|
|
365
371
|
) -> MappedCollection:
|
|
366
372
|
path_list = []
|
|
367
373
|
for artifact in dataset.artifacts.all():
|
|
@@ -384,5 +390,6 @@ def mapped(
|
|
|
384
390
|
cache_categories=cache_categories,
|
|
385
391
|
parallel=parallel,
|
|
386
392
|
dtype=dtype,
|
|
393
|
+
metacell_mode=metacell_mode,
|
|
387
394
|
)
|
|
388
395
|
return ds
|