scdataloader 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,31 +1,33 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scdataloader
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: a dataloader for single cell data in lamindb
5
- Home-page: https://github.com/jkobject/scPrint
5
+ Home-page: https://github.com/jkobject/scDataLoader
6
6
  License: GPL3
7
7
  Keywords: scRNAseq,dataloader,pytorch,lamindb,scPrint
8
8
  Author: jkobject
9
- Requires-Python: >=3.10,<4.0
9
+ Requires-Python: ==3.10.*
10
10
  Classifier: License :: Other/Proprietary License
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
13
  Requires-Dist: anndata
16
14
  Requires-Dist: biomart
15
+ Requires-Dist: bionty
17
16
  Requires-Dist: cellxgene-census
18
17
  Requires-Dist: decoupler
19
18
  Requires-Dist: django
20
19
  Requires-Dist: ipykernel
21
20
  Requires-Dist: lamindb
22
21
  Requires-Dist: leidenalg
22
+ Requires-Dist: lightning
23
+ Requires-Dist: lnschema-bionty
23
24
  Requires-Dist: matplotlib
24
25
  Requires-Dist: pandas (>=2.0.0)
26
+ Requires-Dist: scikit-misc
25
27
  Requires-Dist: seaborn
26
28
  Requires-Dist: torch
27
29
  Requires-Dist: torchdata
28
- Project-URL: Repository, https://github.com/jkobject/scPrint
30
+ Project-URL: Repository, https://github.com/jkobject/scDataLoader
29
31
  Description-Content-Type: text/markdown
30
32
 
31
33
  # scdataloader
@@ -33,7 +35,9 @@ Description-Content-Type: text/markdown
33
35
  [![codecov](https://codecov.io/gh/jkobject/scDataLoader/branch/main/graph/badge.svg?token=scDataLoader_token_here)](https://codecov.io/gh/jkobject/scDataLoader)
34
36
  [![CI](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml/badge.svg)](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml)
35
37
 
36
- Awesome single cell dataloader created by @jkobject
38
+ Awesome single cell dataloader created by @jkobject
39
+
40
+ built on top of `lamindb` and the `.mapped()` function by Sergey: https://github.com/Koncopd
37
41
 
38
42
  This data loader is designed to be used with:
39
43
 
@@ -51,12 +55,34 @@ It allows you to:
51
55
  3. create a more complex single cell dataset
52
56
  4. extend it to your need
53
57
 
58
+ ## About
59
+
60
+ the idea is to use it to train models like scGPT / GeneFormer (and soon, scPrint ;)). It is:
61
+
62
+ 1. loading from lamin
63
+ 2. doing some dataset specific preprocessing if needed
64
+ 3. creating a dataset object on top of .mapped() (that is needed for mapping genes, cell labels etc..)
65
+ 4. passing it to a dataloader object that can work with it correctly
66
+
67
+ Currently one would have to use the preprocess function to make the dataset fit for different tools like scGPT / Geneformer. But I would want to enable it through different Collators. This is still missing and a WIP... (please do contribute!)
68
+
69
+ ![](docs/scdataloader.drawio.png)
70
+
54
71
  ## Install it from PyPI
55
72
 
56
73
  ```bash
57
74
  pip install scdataloader
58
75
  ```
59
76
 
77
+ ### Install it locally and run the notebooks:
78
+
79
+ ```bash
80
+ git clone https://github.com/jkobject/scDataLoader.git
81
+ cd scDataLoader
82
+ poetry install
83
+ ```
84
+ then run the notebooks with the poetry installed environment
85
+
60
86
  ## Usage
61
87
 
62
88
  see the notebooks in [docs](https://jkobject.github.io/scDataLoader/):
@@ -3,7 +3,9 @@
3
3
  [![codecov](https://codecov.io/gh/jkobject/scDataLoader/branch/main/graph/badge.svg?token=scDataLoader_token_here)](https://codecov.io/gh/jkobject/scDataLoader)
4
4
  [![CI](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml/badge.svg)](https://github.com/jkobject/scDataLoader/actions/workflows/main.yml)
5
5
 
6
- Awesome single cell dataloader created by @jkobject
6
+ Awesome single cell dataloader created by @jkobject
7
+
8
+ built on top of `lamindb` and the `.mapped()` function by Sergey: https://github.com/Koncopd
7
9
 
8
10
  This data loader is designed to be used with:
9
11
 
@@ -21,12 +23,34 @@ It allows you to:
21
23
  3. create a more complex single cell dataset
22
24
  4. extend it to your need
23
25
 
26
+ ## About
27
+
28
+ the idea is to use it to train models like scGPT / GeneFormer (and soon, scPrint ;)). It is:
29
+
30
+ 1. loading from lamin
31
+ 2. doing some dataset specific preprocessing if needed
32
+ 3. creating a dataset object on top of .mapped() (that is needed for mapping genes, cell labels etc..)
33
+ 4. passing it to a dataloader object that can work with it correctly
34
+
35
+ Currently one would have to use the preprocess function to make the dataset fit for different tools like scGPT / Geneformer. But I would want to enable it through different Collators. This is still missing and a WIP... (please do contribute!)
36
+
37
+ ![](docs/scdataloader.drawio.png)
38
+
24
39
  ## Install it from PyPI
25
40
 
26
41
  ```bash
27
42
  pip install scdataloader
28
43
  ```
29
44
 
45
+ ### Install it locally and run the notebooks:
46
+
47
+ ```bash
48
+ git clone https://github.com/jkobject/scDataLoader.git
49
+ cd scDataLoader
50
+ poetry install
51
+ ```
52
+ then run the notebooks with the poetry installed environment
53
+
30
54
  ## Usage
31
55
 
32
56
  see the notebooks in [docs](https://jkobject.github.io/scDataLoader/):
@@ -1,24 +1,19 @@
1
1
  [tool.poetry]
2
2
  name = "scdataloader"
3
- version = "0.0.2"
3
+ version = "0.0.3"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = ["jkobject"]
6
6
  license = "GPL3"
7
7
  readme = ["README.md", "LICENSE"]
8
- repository = "https://github.com/jkobject/scPrint"
9
- keywords = [
10
- "scRNAseq",
11
- "dataloader",
12
- "pytorch",
13
- "lamindb",
14
- "scPrint",
15
- ]
8
+ repository = "https://github.com/jkobject/scDataLoader"
9
+ keywords = ["scRNAseq", "dataloader", "pytorch", "lamindb", "scPrint"]
16
10
 
17
11
  [tool.poetry.dependencies]
18
- python = "^3.10"
12
+ python = "3.10.*"
19
13
  lamindb = "*"
20
14
  cellxgene-census = "*"
21
15
  torch = "*"
16
+ lightning = "*"
22
17
  anndata = "*"
23
18
  matplotlib = "*"
24
19
  seaborn = "*"
@@ -29,6 +24,9 @@ pandas = ">=2.0.0"
29
24
  leidenalg = "*"
30
25
  decoupler = "*"
31
26
  django = "*"
27
+ lnschema-bionty = "*"
28
+ bionty = "*"
29
+ scikit-misc = "*"
32
30
 
33
31
  [tool.poetry.group.dev.dependencies]
34
32
  pytest = "^7.4.3"
@@ -46,6 +44,7 @@ mkdocs-git-authors-plugin = "*"
46
44
  mkdocs-jupyter = "*"
47
45
  mkdocstrings-python = "*"
48
46
 
47
+
49
48
  [build-system]
50
49
  requires = ["poetry-core"]
51
50
  build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,4 @@
1
+ from .data import Dataset
2
+ from .dataloader import DataModule
3
+ from .preprocess import Preprocessor
4
+ from .collator import *
@@ -0,0 +1,188 @@
1
+ import argparse
2
+ from scdataloader.preprocess import LaminPreprocessor
3
+ import lamindb as ln
4
+ from typing import Optional, Union
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ description="Preprocess datasets in a given lamindb collection."
9
+ )
10
+ parser.add_argument(
11
+ "--name", type=str, required=True, help="Name of the input dataset"
12
+ )
13
+ parser.add_argument(
14
+ "--new_name", type=str, required=True, help="Name of the preprocessed dataset."
15
+ )
16
+ parser.add_argument(
17
+ "--description",
18
+ type=str,
19
+ default="preprocessed by scDataLoader"
20
+ help="Description of the preprocessed dataset.",
21
+ )
22
+ parser.add_argument(
23
+ "--start_at", type=int, default=0, help="Position to start preprocessing at."
24
+ )
25
+ parser.add_argument(
26
+ "--new_version", type=str, default="2", help="Version of the output dataset and files."
27
+ )
28
+ parser.add_argument(
29
+ "--instance", type=str, default=None, help="Instance storing the input dataset, if not local"
30
+ )
31
+ parser.add_argument(
32
+ "--version", type=str, default=None, help="Version of the input dataset."
33
+ )
34
+ parser.add_argument(
35
+ "--filter_gene_by_counts",
36
+ type=Union[int, bool],
37
+ default=False,
38
+ help="Determines whether to filter genes by counts."
39
+ )
40
+ parser.add_argument(
41
+ "--filter_cell_by_counts",
42
+ type=Union[int, bool],
43
+ default=False,
44
+ help="Determines whether to filter cells by counts."
45
+ )
46
+ parser.add_argument(
47
+ "--normalize_sum",
48
+ type=float,
49
+ default=1e4,
50
+ help="Determines whether to normalize the total counts of each cell to a specific value."
51
+ )
52
+ parser.add_argument(
53
+ "--keep_norm_layer",
54
+ type=bool,
55
+ default=False,
56
+ help="Determines whether to keep the normalization layer."
57
+ )
58
+ parser.add_argument(
59
+ "--subset_hvg",
60
+ type=int,
61
+ default=0,
62
+ help="Determines whether to subset highly variable genes."
63
+ )
64
+ parser.add_argument(
65
+ "--hvg_flavor",
66
+ type=str,
67
+ default="seurat_v3",
68
+ help="Specifies the flavor of highly variable genes selection."
69
+ )
70
+ parser.add_argument(
71
+ "--binning",
72
+ type=Optional[int],
73
+ default=None,
74
+ help="Determines whether to bin the data into discrete values of number of bins provided."
75
+ )
76
+ parser.add_argument(
77
+ "--result_binned_key",
78
+ type=str,
79
+ default="X_binned",
80
+ help="Specifies the key of AnnData to store the binned data."
81
+ )
82
+ parser.add_argument(
83
+ "--length_normalize",
84
+ type=bool,
85
+ default=False,
86
+ help="Determines whether to normalize the length."
87
+ )
88
+ parser.add_argument(
89
+ "--force_preprocess",
90
+ type=bool,
91
+ default=False,
92
+ help="Determines whether to force preprocessing."
93
+ )
94
+ parser.add_argument(
95
+ "--min_dataset_size",
96
+ type=int,
97
+ default=100,
98
+ help="Specifies the minimum dataset size."
99
+ )
100
+ parser.add_argument(
101
+ "--min_valid_genes_id",
102
+ type=int,
103
+ default=10_000,
104
+ help="Specifies the minimum valid genes id."
105
+ )
106
+ parser.add_argument(
107
+ "--min_nnz_genes",
108
+ type=int,
109
+ default=200,
110
+ help="Specifies the minimum non-zero genes."
111
+ )
112
+ parser.add_argument(
113
+ "--maxdropamount",
114
+ type=int,
115
+ default=2,
116
+ help="Specifies the maximum drop amount."
117
+ )
118
+ parser.add_argument(
119
+ "--madoutlier",
120
+ type=int,
121
+ default=5,
122
+ help="Specifies the MAD outlier."
123
+ )
124
+ parser.add_argument(
125
+ "--pct_mt_outlier",
126
+ type=int,
127
+ default=8,
128
+ help="Specifies the percentage of MT outlier."
129
+ )
130
+ parser.add_argument(
131
+ "--batch_key",
132
+ type=Optional[str],
133
+ default=None,
134
+ help="Specifies the batch key."
135
+ )
136
+ parser.add_argument(
137
+ "--skip_validate",
138
+ type=bool,
139
+ default=False,
140
+ help="Determines whether to skip validation."
141
+ )
142
+ args = parser.parse_args()
143
+
144
+ # Load the collection
145
+ if args.instance is not None:
146
+ collection = ln.Collection.using(instance=args.instance).filter(name=args.name, version=args.version).first()
147
+
148
+ collection = ln.Collection.filter(name=args.name, version=args.version).first()
149
+
150
+ print("using the dataset ",collection, " of size ",len(collection.artifacts.all()))
151
+ # Initialize the preprocessor
152
+ preprocessor = LaminPreprocessor(
153
+ filter_gene_by_counts=args.filter_gene_by_counts,
154
+ filter_cell_by_counts=args.filter_cell_by_counts,
155
+ normalize_sum=args.normalize_sum,
156
+ keep_norm_layer=args.keep_norm_layer,
157
+ subset_hvg=args.subset_hvg,
158
+ hvg_flavor=args.hvg_flavor,
159
+ binning=args.binning,
160
+ result_binned_key=args.result_binned_key,
161
+ length_normalize=args.length_normalize,
162
+ force_preprocess=args.force_preprocess,
163
+ min_dataset_size=args.min_dataset_size,
164
+ min_valid_genes_id=args.min_valid_genes_id,
165
+ min_nnz_genes=args.min_nnz_genes,
166
+ maxdropamount=args.maxdropamount,
167
+ madoutlier=args.madoutlier,
168
+ pct_mt_outlier=args.pct_mt_outlier,
169
+ batch_key=args.batch_key,
170
+ skip_validate=args.skip_validate,
171
+ )
172
+
173
+ # Preprocess the dataset
174
+ preprocessed_dataset = preprocessor(
175
+ collection,
176
+ name=args.new_name,
177
+ description=args.description,
178
+ start_at=args.start_at,
179
+ version=args.new_version,
180
+ )
181
+
182
+ print(
183
+ f"Preprocessed dataset saved with version {args.version} and name {args.new_name}."
184
+ )
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
@@ -0,0 +1,263 @@
1
+ import numpy as np
2
+ from .utils import load_genes
3
+ from torch import Tensor
4
+
5
+ # class SimpleCollator:
6
+
7
+
8
+ class Collator:
9
+ def __init__(
10
+ self,
11
+ organisms: list,
12
+ org_to_id: dict = None,
13
+ valid_genes: list = [],
14
+ max_len=2000,
15
+ n_bins=0,
16
+ add_zero_genes=200,
17
+ logp1=False,
18
+ norm_to=None,
19
+ how="all",
20
+ tp_name=None,
21
+ organism_name="organism_ontology_term_id",
22
+ class_names=[],
23
+ ):
24
+ """
25
+ This class is responsible for collating data for the scPRINT model. It handles the
26
+ organization and preparation of gene expression data from different organisms,
27
+ allowing for various configurations such as maximum gene list length, normalization,
28
+ and selection method for gene expression.
29
+
30
+ Args:
31
+ organisms (list): List of organisms to be considered for gene expression data.
32
+ org_to_id (dict): Dictionary mapping organisms to their respective IDs.
33
+ labels (list, optional): List of labels for the data. Defaults to [].
34
+ valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
35
+ max_len (int, optional): Maximum length of the gene list. Defaults to 2000.
36
+ n_bins (int, optional): Number of bins for binning the data. Defaults to 0.
37
+ add_zero_genes (int, optional): Number of zero genes to add. Defaults to 200.
38
+ logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
39
+ norm_to (str, optional): Normalization method to be applied. Defaults to None.
40
+ how (str, optional): Method for selecting gene expression. Defaults to "most expr".
41
+ """
42
+ self.organisms = organisms
43
+ self.valid_genes = valid_genes
44
+ self.max_len = max_len
45
+ self.n_bins = n_bins
46
+ self.add_zero_genes = add_zero_genes
47
+ self.logp1 = logp1
48
+ self.norm_to = norm_to
49
+ self.org_to_id = org_to_id
50
+ self.how = how
51
+ self.organism_ids = (
52
+ set([org_to_id[k] for k in organisms])
53
+ if org_to_id is not None
54
+ else set(organisms)
55
+ )
56
+ self.organism_name = organism_name
57
+ self.tp_name = tp_name
58
+ self.class_names = class_names
59
+
60
+ self.start_idx = {}
61
+ self.accepted_genes = {}
62
+ self.genedf = load_genes(organisms)
63
+ for organism in set(self.genedf.organism):
64
+ ogenedf = self.genedf[self.genedf.organism == organism]
65
+ org = org_to_id[organism] if org_to_id is not None else organism
66
+ self.start_idx.update(
67
+ {org: np.where(self.genedf.organism == organism)[0][0]}
68
+ )
69
+ if len(valid_genes) > 0:
70
+ self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
71
+
72
+ def __call__(self, batch):
73
+ """
74
+ __call__ is a special method in Python that is called when an instance of the class is called.
75
+
76
+ Args:
77
+ batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
78
+ the first list is for the different samples, the second list is for the different elements with
79
+ elem["x"]: gene expression
80
+ elem["organism_name"]: organism ontology term id
81
+ elem["tp_name"]: heat diff
82
+ elem["class_names.."]: other classes
83
+
84
+ Returns:
85
+ list[Tensor]: List of tensors containing the collated data.
86
+ """
87
+ # do count selection
88
+ # get the unseen info and don't add any unseen
89
+ # get the I most expressed genes, add randomly some unexpressed genes that are not unseen
90
+ exprs = []
91
+ total_count = []
92
+ other_classes = []
93
+ gene_locs = []
94
+ tp = []
95
+ for elem in batch:
96
+ organism_id = elem[self.organism_name]
97
+ if organism_id not in self.organism_ids:
98
+ continue
99
+ expr = np.array(elem["x"])
100
+ total_count.append(expr.sum())
101
+ if len(self.accepted_genes) > 0:
102
+ expr = expr[self.accepted_genes[organism_id]]
103
+ if self.how == "most expr":
104
+ loc = np.argsort(expr)[-(self.max_len) :][::-1]
105
+ elif self.how == "random expr":
106
+ nnz_loc = np.where(expr > 0)[0]
107
+ loc = nnz_loc[
108
+ np.random.choice(len(nnz_loc), self.max_len, replace=False)
109
+ ]
110
+ elif self.how == "all":
111
+ loc = np.arange(len(expr))
112
+ else:
113
+ raise ValueError("how must be either most expr or random expr")
114
+ if self.add_zero_genes > 0 and self.how != "all":
115
+ zero_loc = np.where(expr == 0)[0]
116
+ zero_loc = [
117
+ np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
118
+ ]
119
+ loc = np.concatenate((loc, zero_loc), axis=None)
120
+ exprs.append(expr[loc])
121
+ gene_locs.append(loc + self.start_idx[organism_id])
122
+
123
+ if self.tp_name is not None:
124
+ tp.append(elem[self.tp_name])
125
+ else:
126
+ tp.append(0)
127
+
128
+ other_classes.append([elem[i] for i in self.class_names])
129
+
130
+ expr = np.array(exprs)
131
+ tp = np.array(tp)
132
+ gene_locs = np.array(gene_locs)
133
+ total_count = np.array(total_count)
134
+ other_classes = np.array(other_classes)
135
+
136
+ # normalize counts
137
+ if self.norm_to is not None:
138
+ expr = (expr * self.norm_to) / total_count[:, None]
139
+ if self.logp1:
140
+ expr = np.log2(1 + expr)
141
+
142
+ # do binning of counts
143
+ if self.n_bins:
144
+ pass
145
+
146
+ # find the associated gene ids (given the species)
147
+
148
+ # get the NN cells
149
+
150
+ # do encoding / selection a la scGPT
151
+
152
+ # do encoding of graph location
153
+ # encode all the edges in some sparse way
154
+ # normalizing total counts between 0,1
155
+ return {
156
+ "x": Tensor(expr),
157
+ "genes": Tensor(gene_locs).int(),
158
+ "class": Tensor(other_classes).int(),
159
+ "tp": Tensor(tp),
160
+ "depth": Tensor(total_count),
161
+ }
162
+
163
+
164
+ class AnnDataCollator(Collator):
165
+ def __init__(self, *args, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+
168
+ def __call__(self, batch):
169
+ exprs = []
170
+ total_count = []
171
+ other_classes = []
172
+ gene_locs = []
173
+ tp = []
174
+ for elem in batch:
175
+ organism_id = elem.obs[self.organism_name]
176
+ if organism_id.item() not in self.organism_ids:
177
+ print(organism_id)
178
+ expr = np.array(elem.X[0])
179
+
180
+ total_count.append(expr.sum())
181
+ if len(self.accepted_genes) > 0:
182
+ expr = expr[self.accepted_genes[organism_id]]
183
+ if self.how == "most expr":
184
+ loc = np.argsort(expr)[-(self.max_len) :][::-1]
185
+ elif self.how == "random expr":
186
+ nnz_loc = np.where(expr > 0)[0]
187
+ loc = nnz_loc[
188
+ np.random.choice(len(nnz_loc), self.max_len, replace=False)
189
+ ]
190
+ else:
191
+ raise ValueError("how must be either most expr or random expr")
192
+ if self.add_zero_genes > 0:
193
+ zero_loc = np.where(expr == 0)[0]
194
+ zero_loc = [
195
+ np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
196
+ ]
197
+ loc = np.concatenate((loc, zero_loc), axis=None)
198
+ exprs.append(expr[loc])
199
+ gene_locs.append(loc + self.start_idx[organism_id.item()])
200
+
201
+ if self.tp_name is not None:
202
+ tp.append(elem.obs[self.tp_name])
203
+ else:
204
+ tp.append(0)
205
+
206
+ other_classes.append([elem.obs[i].values[0] for i in self.class_names])
207
+
208
+ expr = np.array(exprs)
209
+ tp = np.array(tp)
210
+ gene_locs = np.array(gene_locs)
211
+ total_count = np.array(total_count)
212
+ other_classes = np.array(other_classes)
213
+ return {
214
+ "x": Tensor(expr),
215
+ "genes": Tensor(gene_locs).int(),
216
+ "depth": Tensor(total_count),
217
+ "class": Tensor(other_classes),
218
+ }
219
+
220
+
221
+ class SCVICollator(Collator):
222
+ def __init__(self, *args, **kwargs):
223
+ super().__init__(*args, **kwargs)
224
+
225
+ def __call__(self, batch):
226
+ expr = batch["x"]
227
+ total_count = expr.sum(axis=1)
228
+ if self.how == "most expr":
229
+ loc = np.argsort(expr)[:, -(self.max_len) :][:, ::-1]
230
+ else:
231
+ raise ValueError("how must be either most expr or random expr")
232
+ if self.logp1:
233
+ expr = np.log2(1 + expr)
234
+ return {
235
+ "x": Tensor(expr[np.arange(expr.shape[0])[:, None], loc]),
236
+ "genes": Tensor(loc.copy()).int(),
237
+ "depth": Tensor(total_count),
238
+ }
239
+
240
+
241
+ class GeneformerCollator(Collator):
242
+ def __init__(self, *args, gene_norm_list: list, **kwargs):
243
+ super().__init__(*args, **kwargs)
244
+ self.gene_norm_list = gene_norm_list
245
+
246
+ def __call__(self, batch):
247
+ super().__call__(batch)
248
+ # normlization per gene
249
+
250
+ # tokenize the empty locations
251
+
252
+
253
+ class scGPTCollator(Collator):
254
+ def __call__(self, batch):
255
+ super().__call__(batch)
256
+ # binning
257
+
258
+ # tokenize the empty locations
259
+
260
+
261
+ class scPRINTCollator(Collator):
262
+ def __call__(self, batch):
263
+ super().__call__(batch)