scdataloader 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/collator.py CHANGED
@@ -27,37 +27,60 @@ class Collator:
27
27
  genedf: Optional[pd.DataFrame] = None,
28
28
  ):
29
29
  """
30
- This class is responsible for collating data for the scPRINT model. It handles the
31
- organization and preparation of gene expression data from different organisms,
32
- allowing for various configurations such as maximum gene list length, normalization,
33
- and selection method for gene expression.
30
+ Collator for preparing gene expression data batches for the scPRINT model.
34
31
 
35
- This Collator should work with scVI's dataloader as well!
32
+ This class handles the organization and preparation of gene expression data from
33
+ different organisms, allowing for various configurations such as maximum gene list
34
+ length, normalization, binning, and gene selection strategies.
35
+
36
+ Compatible with scVI's dataloader and other PyTorch data loading pipelines.
36
37
 
37
38
  Args:
38
- organisms (list): List of organisms to be considered for gene expression data.
39
- it will drop any other organism it sees (might lead to batches of different sizes!)
40
- how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
41
- one of ["most expr", "random expr", "all", "some"]:
42
- "most expr": selects the max_len most expressed genes,
43
- if less genes are expressed, will sample random unexpressed genes,
44
- "random expr": uses a random set of max_len expressed genes.
45
- if less genes are expressed, will sample random unexpressed genes
46
- "all": uses all genes
47
- "some": uses only the genes provided through the genelist param
48
- org_to_id (dict): Dictionary mapping organisms to their respective IDs.
49
- valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
50
- it will drop any other genes from the input expression data (usefull when your model only works on some genes)
51
- max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
52
- n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
53
- add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
54
- logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
55
- norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
56
- organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
57
- tp_name (str, optional): Name of the heat diff. Defaults to None.
58
- class_names (list, optional): List of other classes to be considered. Defaults to [].
59
- genelist (list, optional): List of genes to be considered. Defaults to [].
60
- If [] all genes will be considered
39
+ organisms (List[str]): List of organism ontology term IDs to include.
40
+ Samples from other organisms will be dropped (may lead to variable batch sizes).
41
+ how (str, optional): Gene selection strategy. Defaults to "all".
42
+ - "most expr": Select the `max_len` most expressed genes. If fewer genes
43
+ are expressed, randomly sample unexpressed genes to fill.
44
+ - "random expr": Randomly select `max_len` expressed genes. If fewer genes
45
+ are expressed, randomly sample unexpressed genes to fill.
46
+ - "all": Use all genes without filtering.
47
+ - "some": Use only genes specified in the `genelist` parameter.
48
+ org_to_id (dict[str, int], optional): Mapping from organism names to integer IDs.
49
+ If None, organism names are used directly. Defaults to None.
50
+ valid_genes (List[str], optional): List of gene names to consider from input data.
51
+ Genes not in this list will be dropped. Useful when the model only supports
52
+ specific genes. Defaults to None (use all genes).
53
+ max_len (int, optional): Maximum number of genes to include when using "most expr"
54
+ or "random expr" selection methods. Defaults to 2000.
55
+ add_zero_genes (int, optional): Number of additional unexpressed genes to include
56
+ in the output. Only applies when `how` is "most expr" or "random expr".
57
+ Defaults to 0.
58
+ logp1 (bool, optional): Apply log2(1 + x) transformation to expression values.
59
+ Applied after normalization if both are enabled. Defaults to False.
60
+ norm_to (float, optional): Target sum for count normalization. Expression values
61
+ are scaled so that total counts equal this value. Defaults to None (no normalization).
62
+ n_bins (int, optional): Number of bins for expression value binning. If 0, no
63
+ binning is applied. Binning uses quantile-based discretization. Defaults to 0.
64
+ tp_name (str, optional): Column name in batch data for time point or heat diffusion
65
+ values. If None, time point values default to 0. Defaults to None.
66
+ organism_name (str, optional): Column name in batch data for organism ontology
67
+ term ID. Defaults to "organism_ontology_term_id".
68
+ class_names (List[str], optional): List of additional metadata column names to
69
+ include in the output. Defaults to [].
70
+ genelist (List[str], optional): List of specific genes to use when `how="some"`.
71
+ Required if `how="some"`. Defaults to [].
72
+ genedf (pd.DataFrame, optional): DataFrame containing gene information indexed by
73
+ gene name with an 'organism' column. If None, loaded automatically using
74
+ `load_genes()`. Defaults to None.
75
+
76
+ Attributes:
77
+ organism_ids (set): Set of organism IDs being processed.
78
+ start_idx (dict): Mapping from organism ID to starting gene index in the model.
79
+ accepted_genes (dict): Boolean masks for valid genes per organism.
80
+ to_subset (dict): Boolean masks for genelist filtering per organism.
81
+
82
+ Raises:
83
+ AssertionError: If `how="some"` but `genelist` is empty.
61
84
  """
62
85
  self.organisms = organisms
63
86
  self.max_len = max_len
@@ -77,6 +100,19 @@ class Collator:
77
100
  self._setup(genedf, org_to_id, valid_genes, genelist)
78
101
 
79
102
  def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
103
+ """
104
+ Initialize gene mappings and indices for each organism.
105
+
106
+ Sets up internal data structures for gene filtering, organism-specific
107
+ gene indices, and gene subsetting based on the provided configuration.
108
+
109
+ Args:
110
+ genedf (pd.DataFrame, optional): Gene information DataFrame. If None,
111
+ loaded via `load_genes()`. Defaults to None.
112
+ org_to_id (dict, optional): Organism name to ID mapping. Defaults to None.
113
+ valid_genes (List[str], optional): Genes to accept from input. Defaults to [].
114
+ genelist (List[str], optional): Genes to subset to when `how="some"`. Defaults to [].
115
+ """
80
116
  if genedf is None:
81
117
  genedf = load_genes(self.organisms)
82
118
  self.organism_ids = (
@@ -108,18 +144,45 @@ class Collator:
108
144
 
109
145
  def __call__(self, batch) -> dict[str, Tensor]:
110
146
  """
111
- __call__ applies the collator to a minibatch of data
147
+ Collate a minibatch of gene expression data.
148
+
149
+ Processes a list of sample dictionaries, applying gene selection, normalization,
150
+ log transformation, and binning as configured. Filters out samples from organisms
151
+ not in the configured organism list.
112
152
 
113
153
  Args:
114
- batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
115
- the first list is for the different samples, the second list is for the different elements with
116
- elem["X"]: gene expression
117
- elem["organism_name"]: organism ontology term id
118
- elem["tp_name"]: heat diff
119
- elem["class_names.."]: other classes
154
+ batch (List[dict]): List of sample dictionaries, each containing:
155
+ - "X" (array): Gene expression values.
156
+ - organism_name (any): Organism identifier (column name set by `organism_name`).
157
+ - tp_name (float, optional): Time point value (column name set by `tp_name`).
158
+ - class_names... (any, optional): Additional class labels.
159
+ - "_storage_idx" (int, optional): Dataset storage index.
160
+ - "is_meta" (int, optional): Metadata flag.
161
+ - "knn_cells" (array, optional): KNN neighbor expression data.
162
+ - "knn_cells_info" (array, optional): KNN neighbor metadata.
120
163
 
121
164
  Returns:
122
- List[Tensor]: List of tensors containing the collated data.
165
+ dict[str, Tensor]: Dictionary containing collated tensors:
166
+ - "x" (Tensor): Gene expression matrix of shape (batch_size, n_genes).
167
+ Values may be raw counts, normalized, log-transformed, or binned
168
+ depending on configuration.
169
+ - "genes" (Tensor): Gene indices of shape (batch_size, n_genes) as int32.
170
+ Indices correspond to positions in the model's gene vocabulary.
171
+ - "class" (Tensor): Class labels of shape (batch_size, n_classes) as int32.
172
+ - "tp" (Tensor): Time point values of shape (batch_size,).
173
+ - "depth" (Tensor): Total counts per cell of shape (batch_size,).
174
+ - "is_meta" (Tensor, optional): Metadata flags as int32. Present if input
175
+ contains "is_meta".
176
+ - "knn_cells" (Tensor, optional): KNN expression data. Present if input
177
+ contains "knn_cells".
178
+ - "knn_cells_info" (Tensor, optional): KNN metadata. Present if input
179
+ contains "knn_cells_info".
180
+ - "dataset" (Tensor, optional): Dataset indices as int64. Present if input
181
+ contains "_storage_idx".
182
+
183
+ Note:
184
+ Batch size in output may be smaller than input if some samples are filtered
185
+ out due to organism mismatch.
123
186
  """
124
187
  # do count selection
125
188
  # get the unseen info and don't add any unseen