scdataloader 0.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
- 0.7.0
1
+ 1.0.5
scdataloader/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .data import Dataset
1
+ from .data import Dataset, SimpleAnnDataset
2
2
  from .datamodule import DataModule
3
3
  from .preprocess import Preprocessor
4
- from .collator import *
4
+ from .collator import Collator
scdataloader/__main__.py CHANGED
@@ -10,6 +10,9 @@ from typing import Optional, Union
10
10
 
11
11
  # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
12
12
  def main():
13
+ """
14
+ main function to preprocess datasets in a given lamindb collection.
15
+ """
13
16
  parser = argparse.ArgumentParser(
14
17
  description="Preprocess datasets in a given lamindb collection."
15
18
  )
scdataloader/collator.py CHANGED
@@ -1,26 +1,27 @@
1
1
  import numpy as np
2
- from .utils import load_genes
2
+ from .utils import load_genes, downsample_profile
3
3
  from torch import Tensor, long
4
-
5
- # class SimpleCollator:
4
+ from typing import Optional
6
5
 
7
6
 
8
7
  class Collator:
9
8
  def __init__(
10
9
  self,
11
- organisms: list,
12
- how="all",
13
- org_to_id: dict = None,
14
- valid_genes: list = [],
15
- max_len=2000,
16
- add_zero_genes=0,
17
- logp1=False,
18
- norm_to=None,
19
- n_bins=0,
20
- tp_name=None,
21
- organism_name="organism_ontology_term_id",
22
- class_names=[],
23
- genelist=[],
10
+ organisms: list[str],
11
+ how: str = "all",
12
+ org_to_id: dict[str, int] = None,
13
+ valid_genes: list[str] = [],
14
+ max_len: int = 2000,
15
+ add_zero_genes: int = 0,
16
+ logp1: bool = False,
17
+ norm_to: Optional[float] = None,
18
+ n_bins: int = 0,
19
+ tp_name: Optional[str] = None,
20
+ organism_name: str = "organism_ontology_term_id",
21
+ class_names: list[str] = [],
22
+ genelist: list[str] = [],
23
+ downsample: Optional[float] = None, # don't use it for training!
24
+ save_output: bool = False,
24
25
  ):
25
26
  """
26
27
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -44,38 +45,57 @@ class Collator:
44
45
  org_to_id (dict): Dictionary mapping organisms to their respective IDs.
45
46
  valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
46
47
  it will drop any other genes from the input expression data (usefull when your model only works on some genes)
47
- max_len (int, optional): Maximum number of genes to use (for random expr and most expr). Defaults to 2000.
48
+ max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
48
49
  n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
49
50
  add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
50
51
  logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
51
- norm_to (str, optional): Normalization method to be applied. Defaults to None.
52
+ norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
53
+ organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
54
+ tp_name (str, optional): Name of the heat diff. Defaults to None.
55
+ class_names (list, optional): List of other classes to be considered. Defaults to [].
56
+ genelist (list, optional): List of genes to be considered. Defaults to [].
57
+ If [] all genes will be considered
58
+ downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
59
+ This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
60
+ save_output (bool, optional): If True, saves the output to a file. Defaults to False.
61
+ This is mainly for debugging purposes
52
62
  """
53
63
  self.organisms = organisms
64
+ self.genedf = load_genes(organisms)
54
65
  self.max_len = max_len
55
66
  self.n_bins = n_bins
56
67
  self.add_zero_genes = add_zero_genes
57
68
  self.logp1 = logp1
58
69
  self.norm_to = norm_to
59
- self.org_to_id = org_to_id
60
70
  self.how = how
61
- self.organism_ids = (
62
- set([org_to_id[k] for k in organisms])
63
- if org_to_id is not None
64
- else set(organisms)
65
- )
66
71
  if self.how == "some":
67
72
  assert len(genelist) > 0, "if how is some, genelist must be provided"
68
73
  self.organism_name = organism_name
69
74
  self.tp_name = tp_name
70
75
  self.class_names = class_names
71
-
76
+ self.save_output = save_output
72
77
  self.start_idx = {}
73
78
  self.accepted_genes = {}
74
- self.genedf = load_genes(organisms)
79
+ self.downsample = downsample
75
80
  self.to_subset = {}
76
- for organism in set(self.genedf.organism):
81
+ self._setup(org_to_id, valid_genes, genelist)
82
+
83
+ def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
84
+ self.org_to_id = org_to_id
85
+ self.to_subset = {}
86
+ self.accepted_genes = {}
87
+ self.start_idx = {}
88
+ self.organism_ids = (
89
+ set([org_to_id[k] for k in self.organisms])
90
+ if org_to_id is not None
91
+ else set(self.organisms)
92
+ )
93
+ for organism in self.organisms:
77
94
  ogenedf = self.genedf[self.genedf.organism == organism]
78
- tot = self.genedf[self.genedf.index.isin(valid_genes)]
95
+ if len(valid_genes) > 0:
96
+ tot = self.genedf[self.genedf.index.isin(valid_genes)]
97
+ else:
98
+ tot = self.genedf
79
99
  org = org_to_id[organism] if org_to_id is not None else organism
80
100
  self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
81
101
  if len(valid_genes) > 0:
@@ -84,14 +104,14 @@ class Collator:
84
104
  df = ogenedf[ogenedf.index.isin(valid_genes)]
85
105
  self.to_subset.update({org: df.index.isin(genelist)})
86
106
 
87
- def __call__(self, batch):
107
+ def __call__(self, batch) -> dict[str, Tensor]:
88
108
  """
89
109
  __call__ applies the collator to a minibatch of data
90
110
 
91
111
  Args:
92
112
  batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
93
113
  the first list is for the different samples, the second list is for the different elements with
94
- elem["x"]: gene expression
114
+ elem["X"]: gene expression
95
115
  elem["organism_name"]: organism ontology term id
96
116
  elem["tp_name"]: heat diff
97
117
  elem["class_names.."]: other classes
@@ -113,9 +133,9 @@ class Collator:
113
133
  organism_id = elem[self.organism_name]
114
134
  if organism_id not in self.organism_ids:
115
135
  continue
116
- if "dataset" in elem:
117
- dataset.append(elem["dataset"])
118
- expr = np.array(elem["x"])
136
+ if "_storage_idx" in elem:
137
+ dataset.append(elem["_storage_idx"])
138
+ expr = np.array(elem["X"])
119
139
  total_count.append(expr.sum())
120
140
  if len(self.accepted_genes) > 0:
121
141
  expr = expr[self.accepted_genes[organism_id]]
@@ -206,72 +226,17 @@ class Collator:
206
226
  }
207
227
  if len(dataset) > 0:
208
228
  ret.update({"dataset": Tensor(dataset).to(long)})
229
+ if self.downsample is not None:
230
+ ret["x"] = downsample_profile(ret["x"], self.downsample)
231
+ if self.save_output:
232
+ with open("collator_output.txt", "a") as f:
233
+ np.savetxt(f, ret["x"].numpy())
209
234
  return ret
210
235
 
211
236
 
212
- class AnnDataCollator(Collator):
213
- def __init__(self, *args, **kwargs):
214
- """
215
- AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
216
-
217
- Args:
218
- @see Collator
219
- """
220
- super().__init__(*args, **kwargs)
221
-
222
- def __call__(self, batch):
223
- exprs = []
224
- total_count = []
225
- other_classes = []
226
- gene_locs = []
227
- tp = []
228
- for elem in batch:
229
- organism_id = elem.obs[self.organism_name]
230
- if organism_id.item() not in self.organism_ids:
231
- print(organism_id)
232
- expr = np.array(elem.X[0])
233
-
234
- total_count.append(expr.sum())
235
- if len(self.accepted_genes) > 0:
236
- expr = expr[self.accepted_genes[organism_id]]
237
- if self.how == "most expr":
238
- loc = np.argsort(expr)[-(self.max_len) :][::-1]
239
- elif self.how == "random expr":
240
- nnz_loc = np.where(expr > 0)[0]
241
- loc = nnz_loc[
242
- np.random.choice(len(nnz_loc), self.max_len, replace=False)
243
- ]
244
- else:
245
- raise ValueError("how must be either most expr or random expr")
246
- if self.add_zero_genes > 0:
247
- zero_loc = np.where(expr == 0)[0]
248
- zero_loc = [
249
- np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
250
- ]
251
- loc = np.concatenate((loc, zero_loc), axis=None)
252
- exprs.append(expr[loc])
253
- gene_locs.append(loc + self.start_idx[organism_id.item()])
254
-
255
- if self.tp_name is not None:
256
- tp.append(elem.obs[self.tp_name])
257
- else:
258
- tp.append(0)
259
-
260
- other_classes.append([elem.obs[i].values[0] for i in self.class_names])
261
-
262
- expr = np.array(exprs)
263
- tp = np.array(tp)
264
- gene_locs = np.array(gene_locs)
265
- total_count = np.array(total_count)
266
- other_classes = np.array(other_classes)
267
- return {
268
- "x": Tensor(expr),
269
- "genes": Tensor(gene_locs).int(),
270
- "depth": Tensor(total_count),
271
- "class": Tensor(other_classes),
272
- }
273
-
274
-
237
+ #############
238
+ #### WIP ####
239
+ #############
275
240
  class GeneformerCollator(Collator):
276
241
  def __init__(self, *args, gene_norm_list: list, **kwargs):
277
242
  """
scdataloader/config.py CHANGED
@@ -1,3 +1,9 @@
1
+ """
2
+ Configuration file for scDataLoader
3
+
4
+ Missing labels are added to the dataset to complete a better hierarchical tree
5
+ """
6
+
1
7
  LABELS_TOADD = {
2
8
  "assay_ontology_term_id": {
3
9
  "10x transcription profiling": "EFO:0030003",
scdataloader/data.py CHANGED
@@ -1,14 +1,18 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
3
  import lamindb as ln
4
+
5
+ # ln.connect("scprint")
6
+
4
7
  import bionty as bt
5
8
  import pandas as pd
6
9
  from torch.utils.data import Dataset as torchDataset
7
10
  from typing import Union, Optional, Literal
8
- from scdataloader import mapped
11
+ from scdataloader.mapped import MappedCollection
9
12
  import warnings
10
13
 
11
14
  from anndata import AnnData
15
+ from scipy.sparse import issparse
12
16
 
13
17
  from scdataloader.utils import get_ancestry_mapping, load_genes
14
18
 
@@ -58,30 +62,31 @@ class Dataset(torchDataset):
58
62
  "sex_ontology_term_id",
59
63
  #'dataset_id',
60
64
  #'cell_culture',
61
- #"dpt_group",
62
- #"heat_diff",
63
- #"nnz",
65
+ # "dpt_group",
66
+ # "heat_diff",
67
+ # "nnz",
64
68
  ]
65
69
  )
66
70
  # set of obs to prepare for prediction (encode)
67
71
  clss_to_pred: Optional[list[str]] = field(default_factory=list)
68
72
  # set of obs that need to be hierarchically prepared
69
73
  hierarchical_clss: Optional[list[str]] = field(default_factory=list)
70
- join_vars: Optional[Literal["auto", "inner", "None"]] = "None"
74
+ join_vars: Literal["inner", "outer"] | None = None
71
75
 
72
76
  def __post_init__(self):
73
- self.mapped_dataset = mapped.mapped(
77
+ self.mapped_dataset = mapped(
74
78
  self.lamin_dataset,
75
- label_keys=self.obs,
79
+ obs_keys=self.obs,
80
+ join=self.join_vars,
76
81
  encode_labels=self.clss_to_pred,
82
+ unknown_label="unknown",
77
83
  stream=True,
78
84
  parallel=True,
79
- join_vars=self.join_vars,
80
85
  )
81
86
  print(
82
87
  "won't do any check but we recommend to have your dataset coming from local storage"
83
88
  )
84
- self.class_groupings = {}
89
+ self.labels_groupings = {}
85
90
  self.class_topred = {}
86
91
  # generate tree from ontologies
87
92
  if len(self.hierarchical_clss) > 0:
@@ -93,24 +98,19 @@ class Dataset(torchDataset):
93
98
  self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
94
99
  clss
95
100
  )
96
- update = {}
97
- c = 0
98
- for k, v in self.mapped_dataset.encoders[clss].items():
99
- if k == self.mapped_dataset.unknown_class:
100
- update.update({k: v})
101
- c += 1
102
- self.class_topred[clss] -= set([k])
103
- else:
104
- update.update({k: v - c})
105
- self.mapped_dataset.encoders[clss] = update
101
+ if (
102
+ self.mapped_dataset.unknown_label
103
+ in self.mapped_dataset.encoders[clss].keys()
104
+ ):
105
+ self.class_topred[clss] -= set(
106
+ [self.mapped_dataset.unknown_label]
107
+ )
106
108
 
107
109
  if self.genedf is None:
108
110
  self.genedf = load_genes(self.organisms)
109
111
 
110
112
  self.genedf.columns = self.genedf.columns.astype(str)
111
- for organism in self.organisms:
112
- ogenedf = self.genedf[self.genedf.organism == organism]
113
- self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
113
+ self.mapped_dataset._check_aligned_vars(self.genedf.index.tolist())
114
114
 
115
115
  def __len__(self, **kwargs):
116
116
  return self.mapped_dataset.__len__(**kwargs)
@@ -121,19 +121,6 @@ class Dataset(torchDataset):
121
121
 
122
122
  def __getitem__(self, *args, **kwargs):
123
123
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
124
- # import pdb
125
-
126
- # pdb.set_trace()
127
- # item.update(
128
- # {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
129
- # )
130
- # ret = {}
131
- # ret["count"] = item[0]
132
- # for i, val in enumerate(self.obs):
133
- # ret[val] = item[1][i]
134
- ## mark unseen genes with a flag
135
- ## send the associated
136
- # print(item[0].shape)
137
124
  return item
138
125
 
139
126
  def __repr__(self):
@@ -148,7 +135,6 @@ class Dataset(torchDataset):
148
135
  + " {} labels\n".format(len(self.obs))
149
136
  + " {} clss_to_pred\n".format(len(self.clss_to_pred))
150
137
  + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
151
- + " {} join_vars\n".format(len(self.join_vars))
152
138
  + " {} organisms\n".format(len(self.organisms))
153
139
  + (
154
140
  "dataset contains {} classes to predict\n".format(
@@ -160,17 +146,41 @@ class Dataset(torchDataset):
160
146
  )
161
147
 
162
148
  def get_label_weights(self, *args, **kwargs):
149
+ """
150
+ get_label_weights is a wrapper around mappedDataset.get_label_weights
151
+
152
+ Returns:
153
+ dict: dictionary of weights for each label
154
+ """
163
155
  return self.mapped_dataset.get_label_weights(*args, **kwargs)
164
156
 
165
157
  def get_unseen_mapped_dataset_elements(self, idx: int):
158
+ """
159
+ get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
160
+
161
+ Args:
162
+ idx (int): index of the element to get
163
+
164
+ Returns:
165
+ list[str]: list of unseen genes
166
+ """
166
167
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
167
168
 
168
- def define_hierarchies(self, labels: list[str]):
169
+ def define_hierarchies(self, clsses: list[str]):
170
+ """
171
+ define_hierarchies is a method to define the hierarchies for the classes to predict
172
+
173
+ Args:
174
+ clsses (list[str]): list of classes to predict
175
+
176
+ Raises:
177
+ ValueError: if the class is not in the accepted classes
178
+ """
169
179
  # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
170
- self.class_groupings = {}
180
+ self.labels_groupings = {}
171
181
  self.class_topred = {}
172
- for label in labels:
173
- if label not in [
182
+ for clss in clsses:
183
+ if clss not in [
174
184
  "cell_type_ontology_term_id",
175
185
  "tissue_ontology_term_id",
176
186
  "disease_ontology_term_id",
@@ -179,41 +189,41 @@ class Dataset(torchDataset):
179
189
  "self_reported_ethnicity_ontology_term_id",
180
190
  ]:
181
191
  raise ValueError(
182
- "label {} not in accepted labels, for now only supported from bionty sources".format(
183
- label
192
+ "class {} not in accepted classes, for now only supported from bionty sources".format(
193
+ clss
184
194
  )
185
195
  )
186
- elif label == "cell_type_ontology_term_id":
196
+ elif clss == "cell_type_ontology_term_id":
187
197
  parentdf = (
188
198
  bt.CellType.filter()
189
199
  .df(include=["parents__ontology_id"])
190
200
  .set_index("ontology_id")
191
201
  )
192
- elif label == "tissue_ontology_term_id":
202
+ elif clss == "tissue_ontology_term_id":
193
203
  parentdf = (
194
204
  bt.Tissue.filter()
195
205
  .df(include=["parents__ontology_id"])
196
206
  .set_index("ontology_id")
197
207
  )
198
- elif label == "disease_ontology_term_id":
208
+ elif clss == "disease_ontology_term_id":
199
209
  parentdf = (
200
210
  bt.Disease.filter()
201
211
  .df(include=["parents__ontology_id"])
202
212
  .set_index("ontology_id")
203
213
  )
204
- elif label == "development_stage_ontology_term_id":
214
+ elif clss == "development_stage_ontology_term_id":
205
215
  parentdf = (
206
216
  bt.DevelopmentalStage.filter()
207
217
  .df(include=["parents__ontology_id"])
208
218
  .set_index("ontology_id")
209
219
  )
210
- elif label == "assay_ontology_term_id":
220
+ elif clss == "assay_ontology_term_id":
211
221
  parentdf = (
212
222
  bt.ExperimentalFactor.filter()
213
223
  .df(include=["parents__ontology_id"])
214
224
  .set_index("ontology_id")
215
225
  )
216
- elif label == "self_reported_ethnicity_ontology_term_id":
226
+ elif clss == "self_reported_ethnicity_ontology_term_id":
217
227
  parentdf = (
218
228
  bt.Ethnicity.filter()
219
229
  .df(include=["parents__ontology_id"])
@@ -222,65 +232,58 @@ class Dataset(torchDataset):
222
232
 
223
233
  else:
224
234
  raise ValueError(
225
- "label {} not in accepted labels, for now only supported from bionty sources".format(
226
- label
235
+ "class {} not in accepted classes, for now only supported from bionty sources".format(
236
+ clss
227
237
  )
228
238
  )
229
- cats = self.mapped_dataset.get_merged_categories(label)
230
- addition = set(LABELS_TOADD.get(label, {}).values())
239
+ cats = self.mapped_dataset.get_merged_categories(clss)
240
+ addition = set(LABELS_TOADD.get(clss, {}).values())
231
241
  cats |= addition
232
- # import pdb
233
-
234
- # pdb.set_trace()
235
- groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
242
+ groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
236
243
  for i, j in groupings.items():
237
244
  if len(j) == 0:
238
245
  groupings.pop(i)
239
- self.class_groupings[label] = groupings
240
- if label in self.clss_to_pred:
241
- # if we have added new labels, we need to update the encoder with them too.
242
- mlength = len(self.mapped_dataset.encoders[label])
246
+ self.labels_groupings[clss] = groupings
247
+ if clss in self.clss_to_pred:
248
+ # if we have added new clss, we need to update the encoder with them too.
249
+ mlength = len(self.mapped_dataset.encoders[clss])
250
+
243
251
  mlength -= (
244
252
  1
245
- if self.mapped_dataset.unknown_class
246
- in self.mapped_dataset.encoders[label].keys()
253
+ if self.mapped_dataset.unknown_label
254
+ in self.mapped_dataset.encoders[clss].keys()
247
255
  else 0
248
256
  )
249
257
 
250
258
  for i, v in enumerate(
251
- addition - set(self.mapped_dataset.encoders[label].keys())
259
+ addition - set(self.mapped_dataset.encoders[clss].keys())
252
260
  ):
253
- self.mapped_dataset.encoders[label].update({v: mlength + i})
261
+ self.mapped_dataset.encoders[clss].update({v: mlength + i})
254
262
  # we need to change the ordering so that the things that can't be predicted appear afterward
255
263
 
256
- self.class_topred[label] = lclass
264
+ self.class_topred[clss] = leaf_labels
257
265
  c = 0
258
- d = 0
259
266
  update = {}
260
- mlength = len(lclass)
261
- # import pdb
262
-
263
- # pdb.set_trace()
267
+ mlength = len(leaf_labels)
264
268
  mlength -= (
265
269
  1
266
- if self.mapped_dataset.unknown_class
267
- in self.mapped_dataset.encoders[label].keys()
270
+ if self.mapped_dataset.unknown_label
271
+ in self.mapped_dataset.encoders[clss].keys()
268
272
  else 0
269
273
  )
270
- for k, v in self.mapped_dataset.encoders[label].items():
271
- if k in self.class_groupings[label].keys():
274
+ for k, v in self.mapped_dataset.encoders[clss].items():
275
+ if k in self.labels_groupings[clss].keys():
272
276
  update.update({k: mlength + c})
273
277
  c += 1
274
- elif k == self.mapped_dataset.unknown_class:
278
+ elif k == self.mapped_dataset.unknown_label:
275
279
  update.update({k: v})
276
- d += 1
277
- self.class_topred[label] -= set([k])
280
+ self.class_topred[clss] -= set([k])
278
281
  else:
279
- update.update({k: (v - c) - d})
280
- self.mapped_dataset.encoders[label] = update
282
+ update.update({k: v - c})
283
+ self.mapped_dataset.encoders[clss] = update
281
284
 
282
285
 
283
- class SimpleAnnDataset:
286
+ class SimpleAnnDataset(torchDataset):
284
287
  def __init__(
285
288
  self,
286
289
  adata: AnnData,
@@ -297,20 +300,65 @@ class SimpleAnnDataset:
297
300
  obs_to_output (list[str]): list of observations to output from anndata.obs
298
301
  layer (str): layer of the anndata to use
299
302
  """
300
- self.adata = adata
301
- self.obs_to_output = obs_to_output
302
- self.layer = layer
303
+ self.adataX = adata.layers[layer] if layer is not None else adata.X
304
+ self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
305
+ self.obs_to_output = adata.obs[obs_to_output]
303
306
 
304
307
  def __len__(self):
305
- return self.adata.shape[0]
308
+ return self.adataX.shape[0]
309
+
310
+ def __iter__(self):
311
+ for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
312
+ with warnings.catch_warnings():
313
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
314
+ out = {"X": self.adataX[idx].reshape(-1)}
315
+ out.update(
316
+ {name: val for name, val in self.obs_to_output.iloc[idx].items()}
317
+ )
318
+ yield out
306
319
 
307
320
  def __getitem__(self, idx):
308
321
  with warnings.catch_warnings():
309
322
  warnings.filterwarnings("ignore", category=DeprecationWarning)
310
- if self.layer is not None:
311
- out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
312
- else:
313
- out = {"x": self.adata.X[idx].toarray().reshape(-1)}
314
- for i in self.obs_to_output:
315
- out.update({i: self.adata.obs.iloc[idx][i]})
323
+ out = {"X": self.adataX[idx].reshape(-1)}
324
+ out.update(
325
+ {name: val for name, val in self.obs_to_output.iloc[idx].items()}
326
+ )
316
327
  return out
328
+
329
+
330
+ def mapped(
331
+ dataset,
332
+ obs_keys: list[str] | None = None,
333
+ join: Literal["inner", "outer"] | None = "inner",
334
+ encode_labels: bool | list[str] = True,
335
+ unknown_label: str | dict[str, str] | None = None,
336
+ cache_categories: bool = True,
337
+ parallel: bool = False,
338
+ dtype: str | None = None,
339
+ stream: bool = False,
340
+ is_run_input: bool | None = None,
341
+ ) -> MappedCollection:
342
+ path_list = []
343
+ for artifact in dataset.artifacts.all():
344
+ if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
345
+ print(f"Ignoring artifact with suffix {artifact.suffix}")
346
+ continue
347
+ elif not artifact.path.exists():
348
+ print(f"Path does not exist for artifact with suffix {artifact.suffix}")
349
+ continue
350
+ elif not stream:
351
+ path_list.append(artifact.stage())
352
+ else:
353
+ path_list.append(artifact.path)
354
+ ds = MappedCollection(
355
+ path_list=path_list,
356
+ obs_keys=obs_keys,
357
+ join=join,
358
+ encode_labels=encode_labels,
359
+ unknown_label=unknown_label,
360
+ cache_categories=cache_categories,
361
+ parallel=parallel,
362
+ dtype=dtype,
363
+ )
364
+ return ds