scdataloader 2.0.0__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/__main__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import argparse
2
- from typing import Optional, Union
2
+ from typing import List, Optional, Union
3
3
 
4
4
  import lamindb as ln
5
5
 
@@ -149,7 +149,7 @@ def main():
149
149
  )
150
150
  preprocess_parser.add_argument(
151
151
  "--batch_keys",
152
- type=list[str],
152
+ type=List[str],
153
153
  default=[
154
154
  "assay_ontology_term_id",
155
155
  "self_reported_ethnicity_ontology_term_id",
@@ -229,11 +229,11 @@ def main():
229
229
  if args.instance is not None:
230
230
  collection = (
231
231
  ln.Collection.using(instance=args.instance)
232
- .filter(name=args.name, version=args.version)
232
+ .filter(key=args.name, version=args.version)
233
233
  .first()
234
234
  )
235
235
  else:
236
- collection = ln.Collection.filter(name=args.name, version=args.version).first()
236
+ collection = ln.Collection.filter(key=args.name, version=args.version).first()
237
237
 
238
238
  print(
239
239
  "using the dataset ", collection, " of size ", len(collection.artifacts.all())
@@ -262,7 +262,6 @@ def main():
262
262
  additional_preprocess=additional_preprocess,
263
263
  additional_postprocess=additional_postprocess,
264
264
  keep_files=False,
265
- force_preloaded=args.force_preloaded,
266
265
  )
267
266
 
268
267
  # Preprocess the dataset
scdataloader/collator.py CHANGED
@@ -1,18 +1,20 @@
1
- from typing import Optional
1
+ from typing import List, Optional
2
2
 
3
3
  import numpy as np
4
+ import pandas as pd
4
5
  from torch import Tensor, long
5
6
 
7
+ from .preprocess import _digitize
6
8
  from .utils import load_genes
7
9
 
8
10
 
9
11
  class Collator:
10
12
  def __init__(
11
13
  self,
12
- organisms: list[str],
14
+ organisms: List[str],
13
15
  how: str = "all",
14
16
  org_to_id: dict[str, int] = None,
15
- valid_genes: list[str] = [],
17
+ valid_genes: Optional[List[str]] = None,
16
18
  max_len: int = 2000,
17
19
  add_zero_genes: int = 0,
18
20
  logp1: bool = False,
@@ -20,8 +22,9 @@ class Collator:
20
22
  n_bins: int = 0,
21
23
  tp_name: Optional[str] = None,
22
24
  organism_name: str = "organism_ontology_term_id",
23
- class_names: list[str] = [],
24
- genelist: list[str] = [],
25
+ class_names: List[str] = [],
26
+ genelist: List[str] = [],
27
+ genedf: Optional[pd.DataFrame] = None,
25
28
  ):
26
29
  """
27
30
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -71,21 +74,22 @@ class Collator:
71
74
  self.start_idx = {}
72
75
  self.accepted_genes = {}
73
76
  self.to_subset = {}
74
- self._setup(None, org_to_id, valid_genes, genelist)
77
+ self._setup(genedf, org_to_id, valid_genes, genelist)
75
78
 
76
79
  def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
77
80
  if genedf is None:
78
81
  genedf = load_genes(self.organisms)
82
+ self.organism_ids = (
83
+ set([org_to_id[k] for k in self.organisms])
84
+ if org_to_id is not None
85
+ else set(self.organisms)
86
+ )
79
87
  self.org_to_id = org_to_id
80
88
  self.to_subset = {}
81
89
  self.accepted_genes = {}
82
90
  self.start_idx = {}
83
- self.organism_ids = (
84
- set([org_to_id[k] for k in self.organisms])
85
- if org_to_id is not None
86
- else set(self.organisms)
87
- )
88
- if len(valid_genes) > 0:
91
+
92
+ if valid_genes is not None:
89
93
  if len(set(valid_genes) - set(genedf.index)) > 0:
90
94
  print("Some valid genes are not in the genedf!!!")
91
95
  tot = genedf[genedf.index.isin(valid_genes)]
@@ -96,7 +100,7 @@ class Collator:
96
100
  self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
97
101
 
98
102
  ogenedf = genedf[genedf.organism == organism]
99
- if len(valid_genes) > 0:
103
+ if valid_genes is not None:
100
104
  self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
101
105
  if len(genelist) > 0:
102
106
  df = ogenedf[ogenedf.index.isin(valid_genes)]
@@ -107,7 +111,7 @@ class Collator:
107
111
  __call__ applies the collator to a minibatch of data
108
112
 
109
113
  Args:
110
- batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
114
+ batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
111
115
  the first list is for the different samples, the second list is for the different elements with
112
116
  elem["X"]: gene expression
113
117
  elem["organism_name"]: organism ontology term id
@@ -115,7 +119,7 @@ class Collator:
115
119
  elem["class_names.."]: other classes
116
120
 
117
121
  Returns:
118
- list[Tensor]: List of tensors containing the collated data.
122
+ List[Tensor]: List of tensors containing the collated data.
119
123
  """
120
124
  # do count selection
121
125
  # get the unseen info and don't add any unseen
@@ -129,6 +133,7 @@ class Collator:
129
133
  nnz_loc = []
130
134
  is_meta = []
131
135
  knn_cells = []
136
+ knn_cells_info = []
132
137
  for elem in batch:
133
138
  organism_id = elem[self.organism_name]
134
139
  if organism_id not in self.organism_ids:
@@ -184,7 +189,14 @@ class Collator:
184
189
  if "knn_cells" in elem:
185
190
  # we complete with genes expressed in the knn
186
191
  # which is not a zero_loc in this context
187
- zero_loc = np.argsort(elem["knn_cells"].sum(0))[-ma:][::-1]
192
+ knn_expr = elem["knn_cells"].sum(0)
193
+ mask = np.ones(len(knn_expr), dtype=bool)
194
+ mask[loc] = False
195
+ available_indices = np.where(mask)[0]
196
+ available_knn_expr = knn_expr[available_indices]
197
+ sorted_indices = np.argsort(available_knn_expr)[::-1]
198
+ selected = min(ma, len(available_indices))
199
+ zero_loc = available_indices[sorted_indices[:selected]]
188
200
  else:
189
201
  zero_loc = np.where(expr == 0)[0]
190
202
  zero_loc = zero_loc[
@@ -208,6 +220,8 @@ class Collator:
208
220
  exprs.append(expr)
209
221
  if "knn_cells" in elem:
210
222
  knn_cells.append(elem["knn_cells"])
223
+ if "knn_cells_info" in elem:
224
+ knn_cells_info.append(elem["knn_cells_info"])
211
225
  # then we need to add the start_idx to the loc to give it the correct index
212
226
  # according to the model
213
227
  gene_locs.append(loc + self.start_idx[organism_id])
@@ -227,15 +241,46 @@ class Collator:
227
241
  dataset = np.array(dataset)
228
242
  is_meta = np.array(is_meta)
229
243
  knn_cells = np.array(knn_cells)
244
+ knn_cells_info = np.array(knn_cells_info)
245
+
230
246
  # normalize counts
231
247
  if self.norm_to is not None:
232
248
  expr = (expr * self.norm_to) / total_count[:, None]
249
+ # TODO: solve issue here
250
+ knn_cells = (knn_cells * self.norm_to) / total_count[:, None]
233
251
  if self.logp1:
234
252
  expr = np.log2(1 + expr)
253
+ knn_cells = np.log2(1 + knn_cells)
235
254
 
236
255
  # do binning of counts
237
- if self.n_bins:
238
- pass
256
+ if self.n_bins > 0:
257
+ binned_rows = []
258
+ bin_edges = []
259
+ for row in expr:
260
+ if row.max() == 0:
261
+ print(
262
+ "The input data contains all zero rows. Please make sure "
263
+ "this is expected. You can use the `filter_cell_by_counts` "
264
+ "arg to filter out all zero rows."
265
+ )
266
+ binned_rows.append(np.zeros_like(row, dtype=np.int64))
267
+ bin_edges.append(np.array([0] * self.n_bins))
268
+ continue
269
+ non_zero_ids = row.nonzero()
270
+ non_zero_row = row[non_zero_ids]
271
+ bins = np.quantile(non_zero_row, np.linspace(0, 1, self.n_bins - 1))
272
+ # bins = np.sort(np.unique(bins))
273
+ # NOTE: comment this line for now, since this will make the each category
274
+ # has different relative meaning across datasets
275
+ non_zero_digits = _digitize(non_zero_row, bins)
276
+ assert non_zero_digits.min() >= 1
277
+ assert non_zero_digits.max() <= self.n_bins - 1
278
+ binned_row = np.zeros_like(row, dtype=np.int64)
279
+ binned_row[non_zero_ids] = non_zero_digits
280
+ binned_rows.append(binned_row)
281
+ bin_edges.append(np.concatenate([[0], bins]))
282
+ expr = np.stack(binned_rows)
283
+ # expr = np.digitize(expr, bins=self.bins)
239
284
 
240
285
  ret = {
241
286
  "x": Tensor(expr),
@@ -248,44 +293,8 @@ class Collator:
248
293
  ret.update({"is_meta": Tensor(is_meta).int()})
249
294
  if len(knn_cells) > 0:
250
295
  ret.update({"knn_cells": Tensor(knn_cells)})
296
+ if len(knn_cells_info) > 0:
297
+ ret.update({"knn_cells_info": Tensor(knn_cells_info)})
251
298
  if len(dataset) > 0:
252
299
  ret.update({"dataset": Tensor(dataset).to(long)})
253
300
  return ret
254
-
255
-
256
- #############
257
- #### WIP ####
258
- #############
259
- class GeneformerCollator(Collator):
260
- def __init__(self, *args, gene_norm_list: list, **kwargs):
261
- """
262
- GeneformerCollator to finish
263
-
264
- Args:
265
- gene_norm_list (list): the normalization of expression through all datasets, per gene.
266
- """
267
- super().__init__(*args, **kwargs)
268
- self.gene_norm_list = gene_norm_list
269
-
270
- def __call__(self, batch):
271
- super().__call__(batch)
272
- # normlization per gene
273
-
274
- # tokenize the empty locations
275
-
276
-
277
- class scGPTCollator(Collator):
278
- """
279
- scGPTCollator to finish
280
- """
281
-
282
- def __call__(self, batch):
283
- super().__call__(batch)
284
- # binning
285
-
286
- # tokenize the empty locations
287
-
288
-
289
- class scPRINTCollator(Collator):
290
- def __call__(self, batch):
291
- super().__call__(batch)
scdataloader/data.py CHANGED
@@ -2,7 +2,7 @@ import warnings
2
2
  from collections import Counter
3
3
  from dataclasses import dataclass, field
4
4
  from functools import reduce
5
- from typing import Literal, Optional, Union
5
+ from typing import List, Literal, Optional, Union
6
6
 
7
7
  # ln.connect("scprint")
8
8
  import bionty as bt
@@ -38,8 +38,8 @@ class Dataset(torchDataset):
38
38
  ----
39
39
  lamin_dataset (lamindb.Dataset): lamin dataset to load
40
40
  genedf (pd.Dataframe): dataframe containing the genes to load
41
- obs (list[str]): list of observations to load from the Collection
42
- clss_to_predict (list[str]): list of observations to encode
41
+ obs (List[str]): list of observations to load from the Collection
42
+ clss_to_predict (List[str]): list of observations to encode
43
43
  join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
44
44
  hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
45
45
  metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
@@ -51,9 +51,9 @@ class Dataset(torchDataset):
51
51
  lamin_dataset: ln.Collection
52
52
  genedf: Optional[pd.DataFrame] = None
53
53
  # set of obs to prepare for prediction (encode)
54
- clss_to_predict: Optional[list[str]] = field(default_factory=list)
54
+ clss_to_predict: Optional[List[str]] = field(default_factory=list)
55
55
  # set of obs that need to be hierarchically prepared
56
- hierarchical_clss: Optional[list[str]] = field(default_factory=list)
56
+ hierarchical_clss: Optional[List[str]] = field(default_factory=list)
57
57
  join_vars: Literal["inner", "outer"] | None = None
58
58
  metacell_mode: float = 0.0
59
59
  get_knn_cells: bool = False
@@ -61,6 +61,7 @@ class Dataset(torchDataset):
61
61
  force_recompute_indices: bool = False
62
62
 
63
63
  def __post_init__(self):
64
+ # see at the end of the file for the mapped function
64
65
  self.mapped_dataset = mapped(
65
66
  self.lamin_dataset,
66
67
  obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
@@ -102,10 +103,10 @@ class Dataset(torchDataset):
102
103
  "need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
103
104
  )
104
105
  self.organisms = list(self.class_topred["organism_ontology_term_id"])
105
- self.organisms.sort()
106
106
  self.genedf = load_genes(self.organisms)
107
107
  else:
108
- self.organisms = None
108
+ self.organisms = self.genedf["organism"].unique().tolist()
109
+ self.organisms.sort()
109
110
 
110
111
  self.genedf.columns = self.genedf.columns.astype(str)
111
112
  # self.check_aligned_vars()
@@ -160,13 +161,11 @@ class Dataset(torchDataset):
160
161
  + " {} metacell_mode\n".format(self.metacell_mode)
161
162
  )
162
163
 
163
- def get_label_weights(
164
+ def get_label_cats(
164
165
  self,
165
- obs_keys: str | list[str],
166
- scaler: int = 10,
167
- return_categories=False,
166
+ obs_keys: Union[str, List[str]],
168
167
  ):
169
- """Get all weights for the given label keys."""
168
+ """Get all categories for the given label keys."""
170
169
  if isinstance(obs_keys, str):
171
170
  obs_keys = [obs_keys]
172
171
  labels = None
@@ -176,18 +175,7 @@ class Dataset(torchDataset):
176
175
  labels = labels_to_str
177
176
  else:
178
177
  labels = concat_categorical_codes([labels, labels_to_str])
179
- counter = Counter(labels.codes) # type: ignore
180
- if return_categories:
181
- counter = np.array(list(counter.values()))
182
- weights = scaler / (counter + scaler)
183
- return weights, np.array(labels.codes)
184
- else:
185
- counts = np.array([counter[label] for label in labels.codes])
186
- if scaler is None:
187
- weights = 1.0 / counts
188
- else:
189
- weights = scaler / (counts + scaler)
190
- return weights
178
+ return np.array(labels.codes)
191
179
 
192
180
  def get_unseen_mapped_dataset_elements(self, idx: int):
193
181
  """
@@ -197,16 +185,16 @@ class Dataset(torchDataset):
197
185
  idx (int): index of the element to get
198
186
 
199
187
  Returns:
200
- list[str]: list of unseen genes
188
+ List[str]: list of unseen genes
201
189
  """
202
190
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
203
191
 
204
- def define_hierarchies(self, clsses: list[str]):
192
+ def define_hierarchies(self, clsses: List[str]):
205
193
  """
206
194
  define_hierarchies is a method to define the hierarchies for the classes to predict
207
195
 
208
196
  Args:
209
- clsses (list[str]): list of classes to predict
197
+ clsses (List[str]): list of classes to predict
210
198
 
211
199
  Raises:
212
200
  ValueError: if the class is not in the accepted classes
@@ -233,19 +221,19 @@ class Dataset(torchDataset):
233
221
  elif clss == "cell_type_ontology_term_id":
234
222
  parentdf = (
235
223
  bt.CellType.filter()
236
- .df(include=["parents__ontology_id"])
224
+ .df(include=["parents__ontology_id", "ontology_id"])
237
225
  .set_index("ontology_id")
238
226
  )
239
227
  elif clss == "tissue_ontology_term_id":
240
228
  parentdf = (
241
229
  bt.Tissue.filter()
242
- .df(include=["parents__ontology_id"])
230
+ .df(include=["parents__ontology_id", "ontology_id"])
243
231
  .set_index("ontology_id")
244
232
  )
245
233
  elif clss == "disease_ontology_term_id":
246
234
  parentdf = (
247
235
  bt.Disease.filter()
248
- .df(include=["parents__ontology_id"])
236
+ .df(include=["parents__ontology_id", "ontology_id"])
249
237
  .set_index("ontology_id")
250
238
  )
251
239
  elif clss in [
@@ -255,19 +243,19 @@ class Dataset(torchDataset):
255
243
  ]:
256
244
  parentdf = (
257
245
  bt.DevelopmentalStage.filter()
258
- .df(include=["parents__ontology_id"])
246
+ .df(include=["parents__ontology_id", "ontology_id"])
259
247
  .set_index("ontology_id")
260
248
  )
261
249
  elif clss == "assay_ontology_term_id":
262
250
  parentdf = (
263
251
  bt.ExperimentalFactor.filter()
264
- .df(include=["parents__ontology_id"])
252
+ .df(include=["parents__ontology_id", "ontology_id"])
265
253
  .set_index("ontology_id")
266
254
  )
267
255
  elif clss == "self_reported_ethnicity_ontology_term_id":
268
256
  parentdf = (
269
257
  bt.Ethnicity.filter()
270
- .df(include=["parents__ontology_id"])
258
+ .df(include=["parents__ontology_id", "ontology_id"])
271
259
  .set_index("ontology_id")
272
260
  )
273
261
 
@@ -279,6 +267,7 @@ class Dataset(torchDataset):
279
267
  )
280
268
  cats = set(self.mapped_dataset.encoders[clss].keys())
281
269
  groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
270
+ groupings.pop(None, None)
282
271
  for i, j in groupings.items():
283
272
  if len(j) == 0:
284
273
  # that should not happen
@@ -286,6 +275,7 @@ class Dataset(torchDataset):
286
275
 
287
276
  pdb.set_trace()
288
277
  groupings.pop(i)
278
+
289
279
  self.labels_groupings[clss] = groupings
290
280
  if clss in self.clss_to_predict:
291
281
  # if we have added new clss, we need to update the encoder with them too.
@@ -331,9 +321,10 @@ class SimpleAnnDataset(torchDataset):
331
321
  def __init__(
332
322
  self,
333
323
  adata: AnnData,
334
- obs_to_output: Optional[list[str]] = [],
324
+ obs_to_output: Optional[List[str]] = [],
335
325
  layer: Optional[str] = None,
336
326
  get_knn_cells: bool = False,
327
+ encoder: Optional[dict[str, dict]] = None,
337
328
  ):
338
329
  """
339
330
  SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
@@ -342,12 +333,14 @@ class SimpleAnnDataset(torchDataset):
342
333
  Args:
343
334
  ----
344
335
  adata (anndata.AnnData): anndata object to use
345
- obs_to_output (list[str]): list of observations to output from anndata.obs
336
+ obs_to_output (List[str]): list of observations to output from anndata.obs
346
337
  layer (str): layer of the anndata to use
347
338
  get_knn_cells (bool): whether to get the knn cells
339
+ encoder (dict[str, dict]): dictionary of encoders for the observations.
348
340
  """
349
341
  self.adataX = adata.layers[layer] if layer is not None else adata.X
350
342
  self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
343
+ self.encoder = encoder if encoder is not None else {}
351
344
 
352
345
  self.obs_to_output = adata.obs[obs_to_output]
353
346
  self.get_knn_cells = get_knn_cells
@@ -361,23 +354,14 @@ class SimpleAnnDataset(torchDataset):
361
354
 
362
355
  def __iter__(self):
363
356
  for idx in range(self.adataX.shape[0]):
364
- out = {"X": self.adataX[idx].reshape(-1)}
365
- out.update(
366
- {name: val for name, val in self.obs_to_output.iloc[idx].items()}
367
- )
368
- if self.get_knn_cells:
369
- distances = self.distances[idx].toarray()[0]
370
- nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
371
- out["knn_cells"] = np.array(
372
- [self.adataX[i].reshape(-1) for i in nn_idx],
373
- dtype=int,
374
- )
375
- out["distances"] = distances[nn_idx]
357
+ out = self.__getitem__(idx)
376
358
  yield out
377
359
 
378
360
  def __getitem__(self, idx):
379
361
  out = {"X": self.adataX[idx].reshape(-1)}
380
- out.update({name: val for name, val in self.obs_to_output.iloc[idx].items()})
362
+ # put the observation into the output and encode if needed
363
+ for name, val in self.obs_to_output.iloc[idx].items():
364
+ out.update({name: self.encoder[name][val] if name in self.encoder else val})
381
365
  if self.get_knn_cells:
382
366
  distances = self.distances[idx].toarray()[0]
383
367
  nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
@@ -385,17 +369,17 @@ class SimpleAnnDataset(torchDataset):
385
369
  [self.adataX[i].reshape(-1) for i in nn_idx],
386
370
  dtype=int,
387
371
  )
388
- out["distances"] = distances[nn_idx]
372
+ out["knn_cells_info"] = distances[nn_idx]
389
373
  return out
390
374
 
391
375
 
392
376
  def mapped(
393
377
  dataset,
394
- obs_keys: list[str] | None = None,
395
- obsm_keys: list[str] | None = None,
378
+ obs_keys: List[str] | None = None,
379
+ obsm_keys: List[str] | None = None,
396
380
  obs_filter: dict[str, str | tuple[str, ...]] | None = None,
397
381
  join: Literal["inner", "outer"] | None = "inner",
398
- encode_labels: bool | list[str] = True,
382
+ encode_labels: bool | List[str] = True,
399
383
  unknown_label: str | dict[str, str] | None = None,
400
384
  cache_categories: bool = True,
401
385
  parallel: bool = False,
@@ -403,7 +387,7 @@ def mapped(
403
387
  stream: bool = False,
404
388
  is_run_input: bool | None = None,
405
389
  metacell_mode: bool = False,
406
- meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
390
+ meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
407
391
  get_knn_cells: bool = False,
408
392
  store_location: str | None = None,
409
393
  force_recompute_indices: bool = False,
@@ -440,7 +424,7 @@ def mapped(
440
424
  return ds
441
425
 
442
426
 
443
- def concat_categorical_codes(series_list: list[pd.Categorical]) -> pd.Categorical:
427
+ def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
444
428
  """Efficiently combine multiple categorical data using their codes,
445
429
  only creating categories for combinations that exist in the data.
446
430