scdataloader 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scdataloader
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Home-page: https://github.com/jkobject/scDataLoader
6
6
  License: GPL3
@@ -34,6 +34,8 @@ Description-Content-Type: text/markdown
34
34
 
35
35
  [![codecov](https://codecov.io/gh/jkobject/scDataLoader/branch/main/graph/badge.svg?token=scDataLoader_token_here)](https://codecov.io/gh/jkobject/scDataLoader)
36
36
  [![CI](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml/badge.svg)](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml)
37
+ [![DOI](https://zenodo.org/badge/731248665.svg)](https://zenodo.org/doi/10.5281/zenodo.10573143)
38
+
37
39
 
38
40
  Awesome single cell dataloader created by @jkobject
39
41
 
@@ -66,7 +68,7 @@ the idea is to use it to train models like scGPT / GeneFormer (and soon, scPrint
66
68
 
67
69
  Currently one would have to use the preprocess function to make the dataset fit for different tools like scGPT / Geneformer. But I would want to enable it through different Collators. This is still missing and a WIP... (please do contribute!)
68
70
 
69
- ![](docs/scdataloader.drawio.png)
71
+ ![docs/scdataloader.drawio.png](docs/scdataloader.drawio.png)
70
72
 
71
73
  ## Install it from PyPI
72
74
 
@@ -85,6 +87,48 @@ then run the notebooks with the poetry installed environment
85
87
 
86
88
  ## Usage
87
89
 
90
+ ```python
91
+ # initialize a local lamin database
92
+ # !lamin init --storage ~/scdataloader --schema bionty
93
+
94
+ from scdataloader import utils
95
+ from scdataloader.preprocess import LaminPreprocessor, additional_postprocess, additional_preprocess
96
+
97
+ # preprocess datasets
98
+ DESCRIPTION='preprocessed by scDataLoader'
99
+
100
+ cx_dataset = ln.Collection.using(instance="laminlabs/cellxgene").filter(name="cellxgene-census", version='2023-12-15').one()
101
+ cx_dataset, len(cx_dataset.artifacts.all())
102
+
103
+
104
+ do_preprocess = LaminPreprocessor(additional_postprocess=additional_postprocess, additional_preprocess=additional_preprocess, skip_validate=True, subset_hvg=0)
105
+
106
+ preprocessed_dataset = do_preprocess(cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=6, version="2")
107
+
108
+ # create dataloaders
109
+ from scdataloader import DataModule
110
+ import tqdm
111
+
112
+ datamodule = DataModule(
113
+ collection_name="preprocessed dataset",
114
+ organisms=["NCBITaxon:9606"], #organism that we will work on
115
+ how="most expr", # for the collator (most expr genes only will be selected)
116
+ max_len=1000, # only the 1000 most expressed
117
+ batch_size=64,
118
+ num_workers=1,
119
+ validation_split=0.1,
120
+ test_split=0)
121
+
122
+ for i in tqdm.tqdm(datamodule.train_dataloader()):
123
+ # pass #or do pass
124
+ print(i)
125
+ break
126
+
127
+ # with lightning:
128
+ # Trainer(model, datamodule)
129
+
130
+ ```
131
+
88
132
  see the notebooks in [docs](https://jkobject.github.io/scDataLoader/):
89
133
 
90
134
  1. [load a dataset](https://jkobject.github.io/scDataLoader/notebooks/01_load_dataset.html)
@@ -2,6 +2,8 @@
2
2
 
3
3
  [![codecov](https://codecov.io/gh/jkobject/scDataLoader/branch/main/graph/badge.svg?token=scDataLoader_token_here)](https://codecov.io/gh/jkobject/scDataLoader)
4
4
  [![CI](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml/badge.svg)](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml)
5
+ [![DOI](https://zenodo.org/badge/731248665.svg)](https://zenodo.org/doi/10.5281/zenodo.10573143)
6
+
5
7
 
6
8
  Awesome single cell dataloader created by @jkobject
7
9
 
@@ -34,7 +36,7 @@ the idea is to use it to train models like scGPT / GeneFormer (and soon, scPrint
34
36
 
35
37
  Currently one would have to use the preprocess function to make the dataset fit for different tools like scGPT / Geneformer. But I would want to enable it through different Collators. This is still missing and a WIP... (please do contribute!)
36
38
 
37
- ![](docs/scdataloader.drawio.png)
39
+ ![docs/scdataloader.drawio.png](docs/scdataloader.drawio.png)
38
40
 
39
41
  ## Install it from PyPI
40
42
 
@@ -53,6 +55,48 @@ then run the notebooks with the poetry installed environment
53
55
 
54
56
  ## Usage
55
57
 
58
+ ```python
59
+ # initialize a local lamin database
60
+ # !lamin init --storage ~/scdataloader --schema bionty
61
+
62
+ from scdataloader import utils
63
+ from scdataloader.preprocess import LaminPreprocessor, additional_postprocess, additional_preprocess
64
+
65
+ # preprocess datasets
66
+ DESCRIPTION='preprocessed by scDataLoader'
67
+
68
+ cx_dataset = ln.Collection.using(instance="laminlabs/cellxgene").filter(name="cellxgene-census", version='2023-12-15').one()
69
+ cx_dataset, len(cx_dataset.artifacts.all())
70
+
71
+
72
+ do_preprocess = LaminPreprocessor(additional_postprocess=additional_postprocess, additional_preprocess=additional_preprocess, skip_validate=True, subset_hvg=0)
73
+
74
+ preprocessed_dataset = do_preprocess(cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=6, version="2")
75
+
76
+ # create dataloaders
77
+ from scdataloader import DataModule
78
+ import tqdm
79
+
80
+ datamodule = DataModule(
81
+ collection_name="preprocessed dataset",
82
+ organisms=["NCBITaxon:9606"], #organism that we will work on
83
+ how="most expr", # for the collator (most expr genes only will be selected)
84
+ max_len=1000, # only the 1000 most expressed
85
+ batch_size=64,
86
+ num_workers=1,
87
+ validation_split=0.1,
88
+ test_split=0)
89
+
90
+ for i in tqdm.tqdm(datamodule.train_dataloader()):
91
+ # pass #or do pass
92
+ print(i)
93
+ break
94
+
95
+ # with lightning:
96
+ # Trainer(model, datamodule)
97
+
98
+ ```
99
+
56
100
  see the notebooks in [docs](https://jkobject.github.io/scDataLoader/):
57
101
 
58
102
  1. [load a dataset](https://jkobject.github.io/scDataLoader/notebooks/01_load_dataset.html)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "scdataloader"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = ["jkobject"]
6
6
  license = "GPL3"
@@ -0,0 +1 @@
1
+ 0.7.0
@@ -1,4 +1,4 @@
1
1
  from .data import Dataset
2
- from .dataloader import DataModule
2
+ from .datamodule import DataModule
3
3
  from .preprocess import Preprocessor
4
4
  from .collator import *
@@ -1,8 +1,14 @@
1
1
  import argparse
2
- from scdataloader.preprocess import LaminPreprocessor
2
+ from scdataloader.preprocess import (
3
+ LaminPreprocessor,
4
+ additional_preprocess,
5
+ additional_postprocess,
6
+ )
3
7
  import lamindb as ln
4
8
  from typing import Optional, Union
5
9
 
10
+
11
+ # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
6
12
  def main():
7
13
  parser = argparse.ArgumentParser(
8
14
  description="Preprocess datasets in a given lamindb collection."
@@ -11,22 +17,31 @@ def main():
11
17
  "--name", type=str, required=True, help="Name of the input dataset"
12
18
  )
13
19
  parser.add_argument(
14
- "--new_name", type=str, required=True, help="Name of the preprocessed dataset."
20
+ "--new_name",
21
+ type=str,
22
+ default="preprocessed dataset",
23
+ help="Name of the preprocessed dataset.",
15
24
  )
16
25
  parser.add_argument(
17
26
  "--description",
18
27
  type=str,
19
- default="preprocessed by scDataLoader"
28
+ default="preprocessed by scDataLoader",
20
29
  help="Description of the preprocessed dataset.",
21
30
  )
22
31
  parser.add_argument(
23
32
  "--start_at", type=int, default=0, help="Position to start preprocessing at."
24
33
  )
25
34
  parser.add_argument(
26
- "--new_version", type=str, default="2", help="Version of the output dataset and files."
35
+ "--new_version",
36
+ type=str,
37
+ default="2",
38
+ help="Version of the output dataset and files.",
27
39
  )
28
40
  parser.add_argument(
29
- "--instance", type=str, default=None, help="Instance storing the input dataset, if not local"
41
+ "--instance",
42
+ type=str,
43
+ default=None,
44
+ help="Instance storing the input dataset, if not local",
30
45
  )
31
46
  parser.add_argument(
32
47
  "--version", type=str, default=None, help="Version of the input dataset."
@@ -35,125 +50,127 @@ def main():
35
50
  "--filter_gene_by_counts",
36
51
  type=Union[int, bool],
37
52
  default=False,
38
- help="Determines whether to filter genes by counts."
53
+ help="Determines whether to filter genes by counts.",
39
54
  )
40
55
  parser.add_argument(
41
56
  "--filter_cell_by_counts",
42
57
  type=Union[int, bool],
43
58
  default=False,
44
- help="Determines whether to filter cells by counts."
59
+ help="Determines whether to filter cells by counts.",
45
60
  )
46
61
  parser.add_argument(
47
62
  "--normalize_sum",
48
63
  type=float,
49
64
  default=1e4,
50
- help="Determines whether to normalize the total counts of each cell to a specific value."
51
- )
52
- parser.add_argument(
53
- "--keep_norm_layer",
54
- type=bool,
55
- default=False,
56
- help="Determines whether to keep the normalization layer."
65
+ help="Determines whether to normalize the total counts of each cell to a specific value.",
57
66
  )
58
67
  parser.add_argument(
59
68
  "--subset_hvg",
60
69
  type=int,
61
70
  default=0,
62
- help="Determines whether to subset highly variable genes."
71
+ help="Determines whether to subset highly variable genes.",
63
72
  )
64
73
  parser.add_argument(
65
74
  "--hvg_flavor",
66
75
  type=str,
67
76
  default="seurat_v3",
68
- help="Specifies the flavor of highly variable genes selection."
77
+ help="Specifies the flavor of highly variable genes selection.",
69
78
  )
70
79
  parser.add_argument(
71
80
  "--binning",
72
81
  type=Optional[int],
73
82
  default=None,
74
- help="Determines whether to bin the data into discrete values of number of bins provided."
83
+ help="Determines whether to bin the data into discrete values of number of bins provided.",
75
84
  )
76
85
  parser.add_argument(
77
86
  "--result_binned_key",
78
87
  type=str,
79
88
  default="X_binned",
80
- help="Specifies the key of AnnData to store the binned data."
89
+ help="Specifies the key of AnnData to store the binned data.",
81
90
  )
82
91
  parser.add_argument(
83
92
  "--length_normalize",
84
93
  type=bool,
85
94
  default=False,
86
- help="Determines whether to normalize the length."
95
+ help="Determines whether to normalize the length.",
87
96
  )
88
97
  parser.add_argument(
89
98
  "--force_preprocess",
90
99
  type=bool,
91
100
  default=False,
92
- help="Determines whether to force preprocessing."
101
+ help="Determines whether to force preprocessing.",
93
102
  )
94
103
  parser.add_argument(
95
104
  "--min_dataset_size",
96
105
  type=int,
97
106
  default=100,
98
- help="Specifies the minimum dataset size."
107
+ help="Specifies the minimum dataset size.",
99
108
  )
100
109
  parser.add_argument(
101
110
  "--min_valid_genes_id",
102
111
  type=int,
103
112
  default=10_000,
104
- help="Specifies the minimum valid genes id."
113
+ help="Specifies the minimum valid genes id.",
105
114
  )
106
115
  parser.add_argument(
107
116
  "--min_nnz_genes",
108
117
  type=int,
109
- default=200,
110
- help="Specifies the minimum non-zero genes."
118
+ default=400,
119
+ help="Specifies the minimum non-zero genes.",
111
120
  )
112
121
  parser.add_argument(
113
122
  "--maxdropamount",
114
123
  type=int,
115
- default=2,
116
- help="Specifies the maximum drop amount."
124
+ default=50,
125
+ help="Specifies the maximum drop amount.",
117
126
  )
118
127
  parser.add_argument(
119
- "--madoutlier",
120
- type=int,
121
- default=5,
122
- help="Specifies the MAD outlier."
128
+ "--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
123
129
  )
124
130
  parser.add_argument(
125
131
  "--pct_mt_outlier",
126
132
  type=int,
127
133
  default=8,
128
- help="Specifies the percentage of MT outlier."
134
+ help="Specifies the percentage of MT outlier.",
129
135
  )
130
136
  parser.add_argument(
131
- "--batch_key",
132
- type=Optional[str],
133
- default=None,
134
- help="Specifies the batch key."
137
+ "--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
135
138
  )
136
139
  parser.add_argument(
137
140
  "--skip_validate",
138
141
  type=bool,
139
142
  default=False,
140
- help="Determines whether to skip validation."
143
+ help="Determines whether to skip validation.",
144
+ )
145
+ parser.add_argument(
146
+ "--do_postp",
147
+ type=bool,
148
+ default=False,
149
+ help="Determines whether to do postprocessing.",
141
150
  )
142
151
  args = parser.parse_args()
143
152
 
144
153
  # Load the collection
154
+ # if not args.preprocess:
155
+ # print("Only preprocess is available for now")
156
+ # return
145
157
  if args.instance is not None:
146
- collection = ln.Collection.using(instance=args.instance).filter(name=args.name, version=args.version).first()
147
-
148
- collection = ln.Collection.filter(name=args.name, version=args.version).first()
158
+ collection = (
159
+ ln.Collection.using(instance=args.instance)
160
+ .filter(name=args.name, version=args.version)
161
+ .first()
162
+ )
163
+ else:
164
+ collection = ln.Collection.filter(name=args.name, version=args.version).first()
149
165
 
150
- print("using the dataset ",collection, " of size ",len(collection.artifacts.all()))
166
+ print(
167
+ "using the dataset ", collection, " of size ", len(collection.artifacts.all())
168
+ )
151
169
  # Initialize the preprocessor
152
170
  preprocessor = LaminPreprocessor(
153
171
  filter_gene_by_counts=args.filter_gene_by_counts,
154
172
  filter_cell_by_counts=args.filter_cell_by_counts,
155
173
  normalize_sum=args.normalize_sum,
156
- keep_norm_layer=args.keep_norm_layer,
157
174
  subset_hvg=args.subset_hvg,
158
175
  hvg_flavor=args.hvg_flavor,
159
176
  binning=args.binning,
@@ -168,10 +185,14 @@ def main():
168
185
  pct_mt_outlier=args.pct_mt_outlier,
169
186
  batch_key=args.batch_key,
170
187
  skip_validate=args.skip_validate,
188
+ do_postp=args.do_postp,
189
+ additional_preprocess=additional_preprocess,
190
+ additional_postprocess=additional_postprocess,
191
+ keep_files=False,
171
192
  )
172
193
 
173
194
  # Preprocess the dataset
174
- preprocessed_dataset = preprocessor(
195
+ preprocessor(
175
196
  collection,
176
197
  name=args.new_name,
177
198
  description=args.description,
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
  from .utils import load_genes
3
- from torch import Tensor
3
+ from torch import Tensor, long
4
4
 
5
5
  # class SimpleCollator:
6
6
 
@@ -9,17 +9,18 @@ class Collator:
9
9
  def __init__(
10
10
  self,
11
11
  organisms: list,
12
+ how="all",
12
13
  org_to_id: dict = None,
13
14
  valid_genes: list = [],
14
15
  max_len=2000,
15
- n_bins=0,
16
- add_zero_genes=200,
16
+ add_zero_genes=0,
17
17
  logp1=False,
18
18
  norm_to=None,
19
- how="all",
19
+ n_bins=0,
20
20
  tp_name=None,
21
21
  organism_name="organism_ontology_term_id",
22
22
  class_names=[],
23
+ genelist=[],
23
24
  ):
24
25
  """
25
26
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -27,20 +28,29 @@ class Collator:
27
28
  allowing for various configurations such as maximum gene list length, normalization,
28
29
  and selection method for gene expression.
29
30
 
31
+ This Collator should work with scVI's dataloader as well!
32
+
30
33
  Args:
31
34
  organisms (list): List of organisms to be considered for gene expression data.
35
+ it will drop any other organism it sees (might lead to batches of different sizes!)
36
+ how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
37
+ one of ["most expr", "random expr", "all", "some"]:
38
+ "most expr": selects the max_len most expressed genes,
39
+ if less genes are expressed, will sample random unexpressed genes,
40
+ "random expr": uses a random set of max_len expressed genes.
41
+ if less genes are expressed, will sample random unexpressed genes
42
+ "all": uses all genes
43
+ "some": uses only the genes provided through the genelist param
32
44
  org_to_id (dict): Dictionary mapping organisms to their respective IDs.
33
- labels (list, optional): List of labels for the data. Defaults to [].
34
45
  valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
35
- max_len (int, optional): Maximum length of the gene list. Defaults to 2000.
36
- n_bins (int, optional): Number of bins for binning the data. Defaults to 0.
37
- add_zero_genes (int, optional): Number of zero genes to add. Defaults to 200.
46
+ it will drop any other genes from the input expression data (usefull when your model only works on some genes)
47
+ max_len (int, optional): Maximum number of genes to use (for random expr and most expr). Defaults to 2000.
48
+ n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
49
+ add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
38
50
  logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
39
51
  norm_to (str, optional): Normalization method to be applied. Defaults to None.
40
- how (str, optional): Method for selecting gene expression. Defaults to "most expr".
41
52
  """
42
53
  self.organisms = organisms
43
- self.valid_genes = valid_genes
44
54
  self.max_len = max_len
45
55
  self.n_bins = n_bins
46
56
  self.add_zero_genes = add_zero_genes
@@ -53,6 +63,8 @@ class Collator:
53
63
  if org_to_id is not None
54
64
  else set(organisms)
55
65
  )
66
+ if self.how == "some":
67
+ assert len(genelist) > 0, "if how is some, genelist must be provided"
56
68
  self.organism_name = organism_name
57
69
  self.tp_name = tp_name
58
70
  self.class_names = class_names
@@ -60,18 +72,21 @@ class Collator:
60
72
  self.start_idx = {}
61
73
  self.accepted_genes = {}
62
74
  self.genedf = load_genes(organisms)
75
+ self.to_subset = {}
63
76
  for organism in set(self.genedf.organism):
64
77
  ogenedf = self.genedf[self.genedf.organism == organism]
78
+ tot = self.genedf[self.genedf.index.isin(valid_genes)]
65
79
  org = org_to_id[organism] if org_to_id is not None else organism
66
- self.start_idx.update(
67
- {org: np.where(self.genedf.organism == organism)[0][0]}
68
- )
80
+ self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
69
81
  if len(valid_genes) > 0:
70
82
  self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
83
+ if len(genelist) > 0:
84
+ df = ogenedf[ogenedf.index.isin(valid_genes)]
85
+ self.to_subset.update({org: df.index.isin(genelist)})
71
86
 
72
87
  def __call__(self, batch):
73
88
  """
74
- __call__ is a special method in Python that is called when an instance of the class is called.
89
+ __call__ applies the collator to a minibatch of data
75
90
 
76
91
  Args:
77
92
  batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
@@ -92,33 +107,62 @@ class Collator:
92
107
  other_classes = []
93
108
  gene_locs = []
94
109
  tp = []
110
+ dataset = []
111
+ nnz_loc = []
95
112
  for elem in batch:
96
113
  organism_id = elem[self.organism_name]
97
114
  if organism_id not in self.organism_ids:
98
115
  continue
116
+ if "dataset" in elem:
117
+ dataset.append(elem["dataset"])
99
118
  expr = np.array(elem["x"])
100
119
  total_count.append(expr.sum())
101
120
  if len(self.accepted_genes) > 0:
102
121
  expr = expr[self.accepted_genes[organism_id]]
103
122
  if self.how == "most expr":
104
- loc = np.argsort(expr)[-(self.max_len) :][::-1]
123
+ nnz_loc = np.where(expr > 0)[0]
124
+ ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
125
+ loc = np.argsort(expr)[-(ma):][::-1]
126
+ # nnz_loc = [1] * 30_000
127
+ # loc = np.argsort(expr)[-(self.max_len) :][::-1]
105
128
  elif self.how == "random expr":
106
129
  nnz_loc = np.where(expr > 0)[0]
107
130
  loc = nnz_loc[
108
- np.random.choice(len(nnz_loc), self.max_len, replace=False)
131
+ np.random.choice(
132
+ len(nnz_loc),
133
+ self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
134
+ replace=False,
135
+ # p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
136
+ )
109
137
  ]
110
- elif self.how == "all":
138
+ elif self.how in ["all", "some"]:
111
139
  loc = np.arange(len(expr))
112
140
  else:
113
141
  raise ValueError("how must be either most expr or random expr")
114
- if self.add_zero_genes > 0 and self.how != "all":
142
+ if (
143
+ (self.add_zero_genes > 0) or (self.max_len > len(nnz_loc))
144
+ ) and self.how not in ["all", "some"]:
115
145
  zero_loc = np.where(expr == 0)[0]
116
- zero_loc = [
117
- np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
146
+ zero_loc = zero_loc[
147
+ np.random.choice(
148
+ len(zero_loc),
149
+ self.add_zero_genes
150
+ + (
151
+ 0
152
+ if self.max_len < len(nnz_loc)
153
+ else self.max_len - len(nnz_loc)
154
+ ),
155
+ replace=False,
156
+ )
118
157
  ]
119
158
  loc = np.concatenate((loc, zero_loc), axis=None)
120
- exprs.append(expr[loc])
121
- gene_locs.append(loc + self.start_idx[organism_id])
159
+ expr = expr[loc]
160
+ loc = loc + self.start_idx[organism_id]
161
+ if self.how == "some":
162
+ expr = expr[self.to_subset[organism_id]]
163
+ loc = loc[self.to_subset[organism_id]]
164
+ exprs.append(expr)
165
+ gene_locs.append(loc)
122
166
 
123
167
  if self.tp_name is not None:
124
168
  tp.append(elem[self.tp_name])
@@ -132,6 +176,7 @@ class Collator:
132
176
  gene_locs = np.array(gene_locs)
133
177
  total_count = np.array(total_count)
134
178
  other_classes = np.array(other_classes)
179
+ dataset = np.array(dataset)
135
180
 
136
181
  # normalize counts
137
182
  if self.norm_to is not None:
@@ -152,17 +197,26 @@ class Collator:
152
197
  # do encoding of graph location
153
198
  # encode all the edges in some sparse way
154
199
  # normalizing total counts between 0,1
155
- return {
200
+ ret = {
156
201
  "x": Tensor(expr),
157
202
  "genes": Tensor(gene_locs).int(),
158
203
  "class": Tensor(other_classes).int(),
159
204
  "tp": Tensor(tp),
160
205
  "depth": Tensor(total_count),
161
206
  }
207
+ if len(dataset) > 0:
208
+ ret.update({"dataset": Tensor(dataset).to(long)})
209
+ return ret
162
210
 
163
211
 
164
212
  class AnnDataCollator(Collator):
165
213
  def __init__(self, *args, **kwargs):
214
+ """
215
+ AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
216
+
217
+ Args:
218
+ @see Collator
219
+ """
166
220
  super().__init__(*args, **kwargs)
167
221
 
168
222
  def __call__(self, batch):
@@ -218,28 +272,14 @@ class AnnDataCollator(Collator):
218
272
  }
219
273
 
220
274
 
221
- class SCVICollator(Collator):
222
- def __init__(self, *args, **kwargs):
223
- super().__init__(*args, **kwargs)
224
-
225
- def __call__(self, batch):
226
- expr = batch["x"]
227
- total_count = expr.sum(axis=1)
228
- if self.how == "most expr":
229
- loc = np.argsort(expr)[:, -(self.max_len) :][:, ::-1]
230
- else:
231
- raise ValueError("how must be either most expr or random expr")
232
- if self.logp1:
233
- expr = np.log2(1 + expr)
234
- return {
235
- "x": Tensor(expr[np.arange(expr.shape[0])[:, None], loc]),
236
- "genes": Tensor(loc.copy()).int(),
237
- "depth": Tensor(total_count),
238
- }
239
-
240
-
241
275
  class GeneformerCollator(Collator):
242
276
  def __init__(self, *args, gene_norm_list: list, **kwargs):
277
+ """
278
+ GeneformerCollator to finish
279
+
280
+ Args:
281
+ gene_norm_list (list): the normalization of expression through all datasets, per gene.
282
+ """
243
283
  super().__init__(*args, **kwargs)
244
284
  self.gene_norm_list = gene_norm_list
245
285
 
@@ -251,6 +291,10 @@ class GeneformerCollator(Collator):
251
291
 
252
292
 
253
293
  class scGPTCollator(Collator):
294
+ """
295
+ scGPTCollator to finish
296
+ """
297
+
254
298
  def __call__(self, batch):
255
299
  super().__call__(batch)
256
300
  # binning