smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from importlib import resources
4
+ from pathlib import Path
5
+
6
+ SCHEMA_REGISTRY_VERSION = "1"
7
+ SCHEMA_REGISTRY_RESOURCE = "anndata_schema_v1.yaml"
8
+
9
+
10
+ def get_schema_registry_path() -> Path:
11
+ return resources.files(__package__).joinpath(SCHEMA_REGISTRY_RESOURCE)
@@ -0,0 +1,227 @@
1
+ schema_version: "1"
2
+ description: "smftools AnnData schema registry (v1)."
3
+ stages:
4
+ raw:
5
+ stage_requires: []
6
+ obs:
7
+ Experiment_name:
8
+ dtype: "category"
9
+ created_by: "smftools.cli.load_adata"
10
+ modified_by: []
11
+ notes: "Experiment identifier applied to all reads."
12
+ requires: []
13
+ optional_inputs: []
14
+ Experiment_name_and_barcode:
15
+ dtype: "category"
16
+ created_by: "smftools.cli.load_adata"
17
+ modified_by: []
18
+ notes: "Concatenated experiment name and barcode."
19
+ requires: [["obs.Experiment_name", "obs.Barcode"]]
20
+ optional_inputs: []
21
+ Barcode:
22
+ dtype: "category"
23
+ created_by: "smftools.informatics.modkit_extract_to_adata"
24
+ modified_by: []
25
+ notes: "Barcode assigned during demultiplexing or extraction."
26
+ requires: []
27
+ optional_inputs: []
28
+ read_length:
29
+ dtype: "float"
30
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
31
+ modified_by: []
32
+ notes: "Read length in bases."
33
+ requires: []
34
+ optional_inputs: []
35
+ mapped_length:
36
+ dtype: "float"
37
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
38
+ modified_by: []
39
+ notes: "Aligned length in bases."
40
+ requires: []
41
+ optional_inputs: []
42
+ reference_length:
43
+ dtype: "float"
44
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
45
+ modified_by: []
46
+ notes: "Reference length for alignment target."
47
+ requires: []
48
+ optional_inputs: []
49
+ read_quality:
50
+ dtype: "float"
51
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
52
+ modified_by: []
53
+ notes: "Per-read quality score."
54
+ requires: []
55
+ optional_inputs: []
56
+ mapping_quality:
57
+ dtype: "float"
58
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
59
+ modified_by: []
60
+ notes: "Mapping quality score."
61
+ requires: []
62
+ optional_inputs: []
63
+ read_length_to_reference_length_ratio:
64
+ dtype: "float"
65
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
66
+ modified_by: []
67
+ notes: "Read length divided by reference length."
68
+ requires: [["obs.read_length", "obs.reference_length"]]
69
+ optional_inputs: []
70
+ mapped_length_to_reference_length_ratio:
71
+ dtype: "float"
72
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
73
+ modified_by: []
74
+ notes: "Mapped length divided by reference length."
75
+ requires: [["obs.mapped_length", "obs.reference_length"]]
76
+ optional_inputs: []
77
+ mapped_length_to_read_length_ratio:
78
+ dtype: "float"
79
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
80
+ modified_by: []
81
+ notes: "Mapped length divided by read length."
82
+ requires: [["obs.mapped_length", "obs.read_length"]]
83
+ optional_inputs: []
84
+ Raw_modification_signal:
85
+ dtype: "float"
86
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
87
+ modified_by:
88
+ - "smftools.cli.load_adata"
89
+ notes: "Summed modification signal per read."
90
+ requires: [["X"], ["layers.raw_mods"]]
91
+ optional_inputs: []
92
+ pod5_origin:
93
+ dtype: "string"
94
+ created_by: "smftools.informatics.h5ad_functions.annotate_pod5_origin"
95
+ modified_by: []
96
+ notes: "POD5 filename source for each read."
97
+ requires: [["obs_names"]]
98
+ optional_inputs: []
99
+ demux_type:
100
+ dtype: "category"
101
+ created_by: "smftools.informatics.h5ad_functions.add_demux_type_annotation"
102
+ modified_by: []
103
+ notes: "Classification of demultiplexing status."
104
+ requires: [["obs_names"]]
105
+ optional_inputs: []
106
+ var:
107
+ reference_position:
108
+ dtype: "int"
109
+ created_by: "smftools.informatics.modkit_extract_to_adata"
110
+ modified_by: []
111
+ notes: "Reference coordinate for each column."
112
+ requires: []
113
+ optional_inputs: []
114
+ reference_id:
115
+ dtype: "category"
116
+ created_by: "smftools.informatics.modkit_extract_to_adata"
117
+ modified_by: []
118
+ notes: "Reference contig or sequence name."
119
+ requires: []
120
+ optional_inputs: []
121
+ layers:
122
+ raw_mods:
123
+ dtype: "float"
124
+ created_by: "smftools.informatics.modkit_extract_to_adata"
125
+ modified_by: []
126
+ notes: "Raw modification scores (modality-dependent)."
127
+ requires: []
128
+ optional_inputs: []
129
+ obsm: {}
130
+ varm: {}
131
+ obsp: {}
132
+ uns:
133
+ smftools:
134
+ dtype: "mapping"
135
+ created_by: "smftools.metadata.record_smftools_metadata"
136
+ modified_by: []
137
+ notes: "smftools metadata including history, environment, provenance, schema snapshot."
138
+ requires: []
139
+ optional_inputs: []
140
+ preprocess:
141
+ stage_requires: ["raw"]
142
+ obs:
143
+ sequence__merged_cluster_id:
144
+ dtype: "category"
145
+ created_by: "smftools.preprocessing.flag_duplicate_reads"
146
+ modified_by: []
147
+ notes: "Cluster identifier for duplicate detection."
148
+ requires: [["layers.nan0_0minus1"]]
149
+ optional_inputs: ["obs.demux_type"]
150
+ layers:
151
+ nan0_0minus1:
152
+ dtype: "float"
153
+ created_by: "smftools.preprocessing.binarize"
154
+ modified_by:
155
+ - "smftools.preprocessing.clean_NaN"
156
+ notes: "Binarized methylation matrix (nan=0, 0=-1)."
157
+ requires: [["X"]]
158
+ optional_inputs: []
159
+ obsm:
160
+ X_umap:
161
+ dtype: "float"
162
+ created_by: "smftools.tools.calculate_umap"
163
+ modified_by: []
164
+ notes: "UMAP embedding for preprocessed reads."
165
+ requires: [["X"]]
166
+ optional_inputs: []
167
+ varm: {}
168
+ obsp: {}
169
+ uns:
170
+ duplicate_read_groups:
171
+ dtype: "mapping"
172
+ created_by: "smftools.preprocessing.flag_duplicate_reads"
173
+ modified_by: []
174
+ notes: "Duplicate read group metadata."
175
+ requires: [["obs.sequence__merged_cluster_id"]]
176
+ optional_inputs: []
177
+ spatial:
178
+ stage_requires: ["raw", "preprocess"]
179
+ obs:
180
+ leiden:
181
+ dtype: "category"
182
+ created_by: "smftools.tools.calculate_umap"
183
+ modified_by: []
184
+ notes: "Leiden cluster assignments."
185
+ requires: [["obsm.X_umap"]]
186
+ optional_inputs: []
187
+ obsm:
188
+ X_umap:
189
+ dtype: "float"
190
+ created_by: "smftools.tools.calculate_umap"
191
+ modified_by: []
192
+ notes: "UMAP embedding for spatial analyses."
193
+ requires: [["X"]]
194
+ optional_inputs: []
195
+ layers: {}
196
+ varm: {}
197
+ obsp: {}
198
+ uns:
199
+ positionwise_result:
200
+ dtype: "mapping"
201
+ created_by: "smftools.tools.position_stats.compute_positionwise_statistics"
202
+ modified_by: []
203
+ notes: "Positionwise correlation statistics for spatial analyses."
204
+ requires: [["X"]]
205
+ optional_inputs: ["obs.reference_column"]
206
+ hmm:
207
+ stage_requires: ["raw", "preprocess", "spatial"]
208
+ layers:
209
+ hmm_state_calls:
210
+ dtype: "int"
211
+ created_by: "smftools.hmm.call_hmm_peaks"
212
+ modified_by: []
213
+ notes: "HMM-derived state calls per read/position."
214
+ requires: [["layers.nan0_0minus1"]]
215
+ optional_inputs: []
216
+ obsm: {}
217
+ varm: {}
218
+ obsp: {}
219
+ obs: {}
220
+ uns:
221
+ hmm_annotated:
222
+ dtype: "bool"
223
+ created_by: "smftools.cli.hmm_adata"
224
+ modified_by: []
225
+ notes: "Flag indicating HMM annotations are present."
226
+ requires: [["layers.hmm_state_calls"]]
227
+ optional_inputs: []
@@ -1,12 +1,11 @@
1
- from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistics
2
1
  from .calculate_umap import calculate_umap
3
2
  from .cluster_adata_on_methylation import cluster_adata_on_methylation
4
- from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
3
+ from .general_tools import combine_layers, create_nan_mask_from_X, create_nan_or_non_gpc_mask
4
+ from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistics
5
5
  from .read_stats import calculate_row_entropy
6
6
  from .spatial_autocorrelation import *
7
7
  from .subset_adata import subset_adata
8
8
 
9
-
10
9
  __all__ = [
11
10
  "compute_positionwise_statistics",
12
11
  "calculate_row_entropy",
@@ -17,4 +16,4 @@ __all__ = [
17
16
  "create_nan_or_non_gpc_mask",
18
17
  "combine_layers",
19
18
  "subset_adata",
20
- ]
19
+ ]
@@ -21,13 +21,29 @@ device = (
21
21
 
22
22
  # ------------------------- Utilities -------------------------
23
23
  def random_fill_nans(X):
24
+ """Replace NaNs in an array with random values.
25
+
26
+ Args:
27
+ X: Input NumPy array.
28
+
29
+ Returns:
30
+ NumPy array with NaNs replaced.
31
+ """
24
32
  nan_mask = np.isnan(X)
25
33
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
26
34
  return X
27
35
 
28
36
  # ------------------------- Model Definitions -------------------------
29
37
  class CNNClassifier(nn.Module):
38
+ """Simple 1D CNN classifier for fixed-length inputs."""
39
+
30
40
  def __init__(self, input_size, num_classes):
41
+ """Initialize CNN classifier layers.
42
+
43
+ Args:
44
+ input_size: Length of the 1D input.
45
+ num_classes: Number of output classes.
46
+ """
31
47
  super().__init__()
32
48
  self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
33
49
  self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
@@ -39,11 +55,13 @@ class CNNClassifier(nn.Module):
39
55
  self.fc2 = nn.Linear(64, num_classes)
40
56
 
41
57
  def _forward_conv(self, x):
58
+ """Apply convolutional layers and activation."""
42
59
  x = self.relu(self.conv1(x))
43
60
  x = self.relu(self.conv2(x))
44
61
  return x
45
62
 
46
63
  def forward(self, x):
64
+ """Run the forward pass."""
47
65
  x = x.unsqueeze(1)
48
66
  x = self._forward_conv(x)
49
67
  x = x.view(x.size(0), -1)
@@ -51,7 +69,15 @@ class CNNClassifier(nn.Module):
51
69
  return self.fc2(x)
52
70
 
53
71
  class MLPClassifier(nn.Module):
72
+ """Simple MLP classifier."""
73
+
54
74
  def __init__(self, input_dim, num_classes):
75
+ """Initialize MLP layers.
76
+
77
+ Args:
78
+ input_dim: Input feature dimension.
79
+ num_classes: Number of output classes.
80
+ """
55
81
  super().__init__()
56
82
  self.model = nn.Sequential(
57
83
  nn.Linear(input_dim, 128),
@@ -64,10 +90,20 @@ class MLPClassifier(nn.Module):
64
90
  )
65
91
 
66
92
  def forward(self, x):
93
+ """Run the forward pass."""
67
94
  return self.model(x)
68
95
 
69
96
  class RNNClassifier(nn.Module):
97
+ """LSTM-based classifier for sequential inputs."""
98
+
70
99
  def __init__(self, input_size, hidden_dim, num_classes):
100
+ """Initialize RNN classifier layers.
101
+
102
+ Args:
103
+ input_size: Input feature dimension.
104
+ hidden_dim: Hidden state dimension.
105
+ num_classes: Number of output classes.
106
+ """
71
107
  super().__init__()
72
108
  # Define LSTM layer
73
109
  self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
@@ -75,18 +111,29 @@ class RNNClassifier(nn.Module):
75
111
  self.fc = nn.Linear(hidden_dim, num_classes)
76
112
 
77
113
  def forward(self, x):
114
+ """Run the forward pass."""
78
115
  x = x.unsqueeze(1)
79
116
  _, (h_n, _) = self.lstm(x)
80
117
  return self.fc(h_n.squeeze(0))
81
118
 
82
119
  class AttentionRNNClassifier(nn.Module):
120
+ """LSTM classifier with simple attention."""
121
+
83
122
  def __init__(self, input_size, hidden_dim, num_classes):
123
+ """Initialize attention-based RNN layers.
124
+
125
+ Args:
126
+ input_size: Input feature dimension.
127
+ hidden_dim: Hidden state dimension.
128
+ num_classes: Number of output classes.
129
+ """
84
130
  super().__init__()
85
131
  self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
86
132
  self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
87
133
  self.fc = nn.Linear(hidden_dim, num_classes)
88
134
 
89
135
  def forward(self, x):
136
+ """Run the forward pass."""
90
137
  x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
91
138
  lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
92
139
  attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
@@ -94,7 +141,15 @@ class AttentionRNNClassifier(nn.Module):
94
141
  return self.fc(context)
95
142
 
96
143
  class PositionalEncoding(nn.Module):
144
+ """Positional encoding module for transformer models."""
145
+
97
146
  def __init__(self, d_model, max_len=5000):
147
+ """Initialize positional encoding buffer.
148
+
149
+ Args:
150
+ d_model: Model embedding dimension.
151
+ max_len: Maximum sequence length.
152
+ """
98
153
  super().__init__()
99
154
  pe = torch.zeros(max_len, d_model)
100
155
  position = torch.arange(0, max_len).unsqueeze(1).float()
@@ -104,11 +159,23 @@ class PositionalEncoding(nn.Module):
104
159
  self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
105
160
 
106
161
  def forward(self, x):
162
+ """Add positional encoding to inputs."""
107
163
  x = x + self.pe[:, :x.size(1)].to(x.device)
108
164
  return x
109
165
 
110
166
  class TransformerClassifier(nn.Module):
167
+ """Transformer encoder-based classifier."""
168
+
111
169
  def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
170
+ """Initialize transformer classifier layers.
171
+
172
+ Args:
173
+ input_dim: Input feature dimension.
174
+ model_dim: Transformer model dimension.
175
+ num_classes: Number of output classes.
176
+ num_heads: Number of attention heads.
177
+ num_layers: Number of encoder layers.
178
+ """
112
179
  super().__init__()
113
180
  self.input_fc = nn.Linear(input_dim, model_dim)
114
181
  self.pos_encoder = PositionalEncoding(model_dim)
@@ -119,6 +186,7 @@ class TransformerClassifier(nn.Module):
119
186
  self.cls_head = nn.Linear(model_dim, num_classes)
120
187
 
121
188
  def forward(self, x):
189
+ """Run the forward pass."""
122
190
  # x: [batch_size, input_dim]
123
191
  x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
124
192
  x = self.pos_encoder(x)
@@ -128,6 +196,19 @@ class TransformerClassifier(nn.Module):
128
196
  return self.cls_head(pooled)
129
197
 
130
198
  def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
199
+ """Train a model with early stopping.
200
+
201
+ Args:
202
+ model: PyTorch model.
203
+ loader: DataLoader for training data.
204
+ optimizer: Optimizer instance.
205
+ criterion: Loss function.
206
+ device: Torch device.
207
+ ref_name: Reference label for logging.
208
+ model_name: Model label for logging.
209
+ epochs: Maximum epochs.
210
+ patience: Early-stopping patience.
211
+ """
131
212
  model.train()
132
213
  best_loss = float('inf')
133
214
  trigger_times = 0
@@ -154,6 +235,17 @@ def train_model(model, loader, optimizer, criterion, device, ref_name="", model_
154
235
  break
155
236
 
156
237
  def evaluate_model(model, X_tensor, y_encoded, device):
238
+ """Evaluate a trained model and compute metrics.
239
+
240
+ Args:
241
+ model: Trained model.
242
+ X_tensor: Input tensor.
243
+ y_encoded: Encoded labels.
244
+ device: Torch device.
245
+
246
+ Returns:
247
+ Tuple of metrics dict, predicted labels, and probabilities.
248
+ """
157
249
  model.eval()
158
250
  with torch.no_grad():
159
251
  outputs = model(X_tensor.to(device))
@@ -176,6 +268,18 @@ def evaluate_model(model, X_tensor, y_encoded, device):
176
268
  }, preds, probs
177
269
 
178
270
  def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
271
+ """Train a random forest classifier.
272
+
273
+ Args:
274
+ X_tensor: Input tensor.
275
+ y_tensor: Label tensor.
276
+ train_indices: Indices for training.
277
+ test_indices: Indices for testing.
278
+ n_estimators: Number of trees.
279
+
280
+ Returns:
281
+ Tuple of (model, preds, probs).
282
+ """
179
283
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
180
284
  model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
181
285
  probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
@@ -186,6 +290,25 @@ def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
186
290
  def run_training_loop(adata, site_config, layer_name=None,
187
291
  mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
188
292
  max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
293
+ """Train one or more classifier types on AnnData.
294
+
295
+ Args:
296
+ adata: AnnData object containing data and labels.
297
+ site_config: Mapping of reference to site list.
298
+ layer_name: Optional layer to use as input.
299
+ mlp: Whether to train an MLP model.
300
+ cnn: Whether to train a CNN model.
301
+ rnn: Whether to train an RNN model.
302
+ arnn: Whether to train an attention RNN model.
303
+ transformer: Whether to train a transformer model.
304
+ rf: Whether to train a random forest model.
305
+ nb: Whether to train a Naive Bayes model.
306
+ rr_bayes: Whether to train a ridge regression model.
307
+ max_epochs: Maximum training epochs.
308
+ max_patience: Early stopping patience.
309
+ n_estimators: Random forest estimator count.
310
+ training_split: Fraction of data used for training.
311
+ """
189
312
  device = (
190
313
  torch.device('cuda') if torch.cuda.is_available() else
191
314
  torch.device('mps') if torch.backends.mps.is_available() else
@@ -701,6 +824,20 @@ def evaluate_model_by_subgroups(
701
824
  label_col="activity_status",
702
825
  min_samples=10,
703
826
  exclude_training_data=True):
827
+ """Evaluate predictions within categorical subgroups.
828
+
829
+ Args:
830
+ adata: AnnData with prediction columns.
831
+ model_prefix: Prediction column prefix.
832
+ suffix: Prediction column suffix.
833
+ groupby_cols: Columns to group by.
834
+ label_col: Ground-truth label column.
835
+ min_samples: Minimum samples per group.
836
+ exclude_training_data: Whether to exclude training rows.
837
+
838
+ Returns:
839
+ DataFrame of subgroup-level metrics.
840
+ """
704
841
  import pandas as pd
705
842
  from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
706
843
 
@@ -745,6 +882,18 @@ def evaluate_model_by_subgroups(
745
882
  return pd.DataFrame(results)
746
883
 
747
884
  def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
885
+ """Evaluate multiple model prefixes across subgroups.
886
+
887
+ Args:
888
+ adata: AnnData with prediction columns.
889
+ model_prefixes: Iterable of model prefixes.
890
+ groupby_cols: Columns to group by.
891
+ label_col: Ground-truth label column.
892
+ exclude_training_data: Whether to exclude training rows.
893
+
894
+ Returns:
895
+ Concatenated DataFrame of subgroup-level metrics.
896
+ """
748
897
  import pandas as pd
749
898
  all_metrics = []
750
899
  for model_prefix in model_prefixes:
@@ -758,6 +907,20 @@ def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col,
758
907
  return final_df
759
908
 
760
909
  def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
910
+ """Prepare a long-format DataFrame for model performance plots.
911
+
912
+ Args:
913
+ adata: AnnData with prediction columns.
914
+ outkey: Key to store the melted DataFrame in ``adata.uns``.
915
+ groupby: Grouping columns to include.
916
+ label_col: Ground-truth label column.
917
+ model_names: Model prefixes to include.
918
+ suffix: Prediction column suffix.
919
+ omit_training: Whether to exclude training rows.
920
+
921
+ Returns:
922
+ Melted DataFrame of predictions.
923
+ """
761
924
  import pandas as pd
762
925
  import seaborn as sns
763
926
  import matplotlib.pyplot as plt
@@ -13,6 +13,15 @@ def subset_adata(adata, obs_columns):
13
13
  """
14
14
 
15
15
  def subset_recursive(adata_subset, columns):
16
+ """Recursively subset AnnData by categorical columns.
17
+
18
+ Args:
19
+ adata_subset: AnnData subset to split.
20
+ columns: Remaining columns to split on.
21
+
22
+ Returns:
23
+ Dictionary mapping category tuples to AnnData subsets.
24
+ """
16
25
  if not columns:
17
26
  return {(): adata_subset}
18
27
 
@@ -29,4 +38,4 @@ def subset_adata(adata, obs_columns):
29
38
  # Start the recursive subset process
30
39
  subsets_dict = subset_recursive(adata, obs_columns)
31
40
 
32
- return subsets_dict
41
+ return subsets_dict
@@ -14,6 +14,17 @@ def subset_adata(adata, columns, cat_type='obs'):
14
14
  """
15
15
 
16
16
  def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
17
+ """Recursively subset AnnData by categorical columns.
18
+
19
+ Args:
20
+ adata_subset: AnnData subset to split.
21
+ columns: Remaining columns to split on.
22
+ cat_type: Whether to use obs or var categories.
23
+ key_prefix: Tuple of previous category keys.
24
+
25
+ Returns:
26
+ Dictionary mapping category tuples to AnnData subsets.
27
+ """
17
28
  # Returns when the bottom of the stack is reached
18
29
  if not columns:
19
30
  # If there's only one column, return the key as a single value, not a tuple
@@ -43,4 +54,4 @@ def subset_adata(adata, columns, cat_type='obs'):
43
54
  # Start the recursive subset process
44
55
  subsets_dict = subset_recursive(adata, columns, cat_type)
45
56
 
46
- return subsets_dict
57
+ return subsets_dict