smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +7 -1
- smftools/cli/hmm_adata.py +902 -244
- smftools/cli/load_adata.py +318 -198
- smftools/cli/preprocess_adata.py +285 -171
- smftools/cli/spatial_adata.py +137 -53
- smftools/cli_entry.py +94 -178
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +22 -17
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +505 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2125 -1426
- smftools/hmm/__init__.py +2 -3
- smftools/hmm/archived/call_hmm_peaks.py +16 -1
- smftools/hmm/call_hmm_peaks.py +173 -193
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +379 -156
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +195 -29
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +347 -168
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +145 -85
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +8 -8
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +103 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +688 -271
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.4.dist-info/RECORD +0 -176
- /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import resources
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
SCHEMA_REGISTRY_VERSION = "1"
|
|
7
|
+
SCHEMA_REGISTRY_RESOURCE = "anndata_schema_v1.yaml"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_schema_registry_path() -> Path:
|
|
11
|
+
return resources.files(__package__).joinpath(SCHEMA_REGISTRY_RESOURCE)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
schema_version: "1"
|
|
2
|
+
description: "smftools AnnData schema registry (v1)."
|
|
3
|
+
stages:
|
|
4
|
+
raw:
|
|
5
|
+
stage_requires: []
|
|
6
|
+
obs:
|
|
7
|
+
Experiment_name:
|
|
8
|
+
dtype: "category"
|
|
9
|
+
created_by: "smftools.cli.load_adata"
|
|
10
|
+
modified_by: []
|
|
11
|
+
notes: "Experiment identifier applied to all reads."
|
|
12
|
+
requires: []
|
|
13
|
+
optional_inputs: []
|
|
14
|
+
Experiment_name_and_barcode:
|
|
15
|
+
dtype: "category"
|
|
16
|
+
created_by: "smftools.cli.load_adata"
|
|
17
|
+
modified_by: []
|
|
18
|
+
notes: "Concatenated experiment name and barcode."
|
|
19
|
+
requires: [["obs.Experiment_name", "obs.Barcode"]]
|
|
20
|
+
optional_inputs: []
|
|
21
|
+
Barcode:
|
|
22
|
+
dtype: "category"
|
|
23
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
24
|
+
modified_by: []
|
|
25
|
+
notes: "Barcode assigned during demultiplexing or extraction."
|
|
26
|
+
requires: []
|
|
27
|
+
optional_inputs: []
|
|
28
|
+
read_length:
|
|
29
|
+
dtype: "float"
|
|
30
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
31
|
+
modified_by: []
|
|
32
|
+
notes: "Read length in bases."
|
|
33
|
+
requires: []
|
|
34
|
+
optional_inputs: []
|
|
35
|
+
mapped_length:
|
|
36
|
+
dtype: "float"
|
|
37
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
38
|
+
modified_by: []
|
|
39
|
+
notes: "Aligned length in bases."
|
|
40
|
+
requires: []
|
|
41
|
+
optional_inputs: []
|
|
42
|
+
reference_length:
|
|
43
|
+
dtype: "float"
|
|
44
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
45
|
+
modified_by: []
|
|
46
|
+
notes: "Reference length for alignment target."
|
|
47
|
+
requires: []
|
|
48
|
+
optional_inputs: []
|
|
49
|
+
read_quality:
|
|
50
|
+
dtype: "float"
|
|
51
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
52
|
+
modified_by: []
|
|
53
|
+
notes: "Per-read quality score."
|
|
54
|
+
requires: []
|
|
55
|
+
optional_inputs: []
|
|
56
|
+
mapping_quality:
|
|
57
|
+
dtype: "float"
|
|
58
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
59
|
+
modified_by: []
|
|
60
|
+
notes: "Mapping quality score."
|
|
61
|
+
requires: []
|
|
62
|
+
optional_inputs: []
|
|
63
|
+
read_length_to_reference_length_ratio:
|
|
64
|
+
dtype: "float"
|
|
65
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
66
|
+
modified_by: []
|
|
67
|
+
notes: "Read length divided by reference length."
|
|
68
|
+
requires: [["obs.read_length", "obs.reference_length"]]
|
|
69
|
+
optional_inputs: []
|
|
70
|
+
mapped_length_to_reference_length_ratio:
|
|
71
|
+
dtype: "float"
|
|
72
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
73
|
+
modified_by: []
|
|
74
|
+
notes: "Mapped length divided by reference length."
|
|
75
|
+
requires: [["obs.mapped_length", "obs.reference_length"]]
|
|
76
|
+
optional_inputs: []
|
|
77
|
+
mapped_length_to_read_length_ratio:
|
|
78
|
+
dtype: "float"
|
|
79
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
80
|
+
modified_by: []
|
|
81
|
+
notes: "Mapped length divided by read length."
|
|
82
|
+
requires: [["obs.mapped_length", "obs.read_length"]]
|
|
83
|
+
optional_inputs: []
|
|
84
|
+
Raw_modification_signal:
|
|
85
|
+
dtype: "float"
|
|
86
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
87
|
+
modified_by:
|
|
88
|
+
- "smftools.cli.load_adata"
|
|
89
|
+
notes: "Summed modification signal per read."
|
|
90
|
+
requires: [["X"], ["layers.raw_mods"]]
|
|
91
|
+
optional_inputs: []
|
|
92
|
+
pod5_origin:
|
|
93
|
+
dtype: "string"
|
|
94
|
+
created_by: "smftools.informatics.h5ad_functions.annotate_pod5_origin"
|
|
95
|
+
modified_by: []
|
|
96
|
+
notes: "POD5 filename source for each read."
|
|
97
|
+
requires: [["obs_names"]]
|
|
98
|
+
optional_inputs: []
|
|
99
|
+
demux_type:
|
|
100
|
+
dtype: "category"
|
|
101
|
+
created_by: "smftools.informatics.h5ad_functions.add_demux_type_annotation"
|
|
102
|
+
modified_by: []
|
|
103
|
+
notes: "Classification of demultiplexing status."
|
|
104
|
+
requires: [["obs_names"]]
|
|
105
|
+
optional_inputs: []
|
|
106
|
+
var:
|
|
107
|
+
reference_position:
|
|
108
|
+
dtype: "int"
|
|
109
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
110
|
+
modified_by: []
|
|
111
|
+
notes: "Reference coordinate for each column."
|
|
112
|
+
requires: []
|
|
113
|
+
optional_inputs: []
|
|
114
|
+
reference_id:
|
|
115
|
+
dtype: "category"
|
|
116
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
117
|
+
modified_by: []
|
|
118
|
+
notes: "Reference contig or sequence name."
|
|
119
|
+
requires: []
|
|
120
|
+
optional_inputs: []
|
|
121
|
+
layers:
|
|
122
|
+
raw_mods:
|
|
123
|
+
dtype: "float"
|
|
124
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
125
|
+
modified_by: []
|
|
126
|
+
notes: "Raw modification scores (modality-dependent)."
|
|
127
|
+
requires: []
|
|
128
|
+
optional_inputs: []
|
|
129
|
+
obsm: {}
|
|
130
|
+
varm: {}
|
|
131
|
+
obsp: {}
|
|
132
|
+
uns:
|
|
133
|
+
smftools:
|
|
134
|
+
dtype: "mapping"
|
|
135
|
+
created_by: "smftools.metadata.record_smftools_metadata"
|
|
136
|
+
modified_by: []
|
|
137
|
+
notes: "smftools metadata including history, environment, provenance, schema snapshot."
|
|
138
|
+
requires: []
|
|
139
|
+
optional_inputs: []
|
|
140
|
+
preprocess:
|
|
141
|
+
stage_requires: ["raw"]
|
|
142
|
+
obs:
|
|
143
|
+
sequence__merged_cluster_id:
|
|
144
|
+
dtype: "category"
|
|
145
|
+
created_by: "smftools.preprocessing.flag_duplicate_reads"
|
|
146
|
+
modified_by: []
|
|
147
|
+
notes: "Cluster identifier for duplicate detection."
|
|
148
|
+
requires: [["layers.nan0_0minus1"]]
|
|
149
|
+
optional_inputs: ["obs.demux_type"]
|
|
150
|
+
layers:
|
|
151
|
+
nan0_0minus1:
|
|
152
|
+
dtype: "float"
|
|
153
|
+
created_by: "smftools.preprocessing.binarize"
|
|
154
|
+
modified_by:
|
|
155
|
+
- "smftools.preprocessing.clean_NaN"
|
|
156
|
+
notes: "Binarized methylation matrix (nan=0, 0=-1)."
|
|
157
|
+
requires: [["X"]]
|
|
158
|
+
optional_inputs: []
|
|
159
|
+
obsm:
|
|
160
|
+
X_umap:
|
|
161
|
+
dtype: "float"
|
|
162
|
+
created_by: "smftools.tools.calculate_umap"
|
|
163
|
+
modified_by: []
|
|
164
|
+
notes: "UMAP embedding for preprocessed reads."
|
|
165
|
+
requires: [["X"]]
|
|
166
|
+
optional_inputs: []
|
|
167
|
+
varm: {}
|
|
168
|
+
obsp: {}
|
|
169
|
+
uns:
|
|
170
|
+
duplicate_read_groups:
|
|
171
|
+
dtype: "mapping"
|
|
172
|
+
created_by: "smftools.preprocessing.flag_duplicate_reads"
|
|
173
|
+
modified_by: []
|
|
174
|
+
notes: "Duplicate read group metadata."
|
|
175
|
+
requires: [["obs.sequence__merged_cluster_id"]]
|
|
176
|
+
optional_inputs: []
|
|
177
|
+
spatial:
|
|
178
|
+
stage_requires: ["raw", "preprocess"]
|
|
179
|
+
obs:
|
|
180
|
+
leiden:
|
|
181
|
+
dtype: "category"
|
|
182
|
+
created_by: "smftools.tools.calculate_umap"
|
|
183
|
+
modified_by: []
|
|
184
|
+
notes: "Leiden cluster assignments."
|
|
185
|
+
requires: [["obsm.X_umap"]]
|
|
186
|
+
optional_inputs: []
|
|
187
|
+
obsm:
|
|
188
|
+
X_umap:
|
|
189
|
+
dtype: "float"
|
|
190
|
+
created_by: "smftools.tools.calculate_umap"
|
|
191
|
+
modified_by: []
|
|
192
|
+
notes: "UMAP embedding for spatial analyses."
|
|
193
|
+
requires: [["X"]]
|
|
194
|
+
optional_inputs: []
|
|
195
|
+
layers: {}
|
|
196
|
+
varm: {}
|
|
197
|
+
obsp: {}
|
|
198
|
+
uns:
|
|
199
|
+
positionwise_result:
|
|
200
|
+
dtype: "mapping"
|
|
201
|
+
created_by: "smftools.tools.position_stats.compute_positionwise_statistics"
|
|
202
|
+
modified_by: []
|
|
203
|
+
notes: "Positionwise correlation statistics for spatial analyses."
|
|
204
|
+
requires: [["X"]]
|
|
205
|
+
optional_inputs: ["obs.reference_column"]
|
|
206
|
+
hmm:
|
|
207
|
+
stage_requires: ["raw", "preprocess", "spatial"]
|
|
208
|
+
layers:
|
|
209
|
+
hmm_state_calls:
|
|
210
|
+
dtype: "int"
|
|
211
|
+
created_by: "smftools.hmm.call_hmm_peaks"
|
|
212
|
+
modified_by: []
|
|
213
|
+
notes: "HMM-derived state calls per read/position."
|
|
214
|
+
requires: [["layers.nan0_0minus1"]]
|
|
215
|
+
optional_inputs: []
|
|
216
|
+
obsm: {}
|
|
217
|
+
varm: {}
|
|
218
|
+
obsp: {}
|
|
219
|
+
obs: {}
|
|
220
|
+
uns:
|
|
221
|
+
hmm_annotated:
|
|
222
|
+
dtype: "bool"
|
|
223
|
+
created_by: "smftools.cli.hmm_adata"
|
|
224
|
+
modified_by: []
|
|
225
|
+
notes: "Flag indicating HMM annotations are present."
|
|
226
|
+
requires: [["layers.hmm_state_calls"]]
|
|
227
|
+
optional_inputs: []
|
smftools/tools/__init__.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
|
-
from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistics
|
|
2
1
|
from .calculate_umap import calculate_umap
|
|
3
2
|
from .cluster_adata_on_methylation import cluster_adata_on_methylation
|
|
4
|
-
from .general_tools import
|
|
3
|
+
from .general_tools import combine_layers, create_nan_mask_from_X, create_nan_or_non_gpc_mask
|
|
4
|
+
from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistics
|
|
5
5
|
from .read_stats import calculate_row_entropy
|
|
6
6
|
from .spatial_autocorrelation import *
|
|
7
7
|
from .subset_adata import subset_adata
|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
__all__ = [
|
|
11
10
|
"compute_positionwise_statistics",
|
|
12
11
|
"calculate_row_entropy",
|
|
@@ -17,4 +16,4 @@ __all__ = [
|
|
|
17
16
|
"create_nan_or_non_gpc_mask",
|
|
18
17
|
"combine_layers",
|
|
19
18
|
"subset_adata",
|
|
20
|
-
]
|
|
19
|
+
]
|
|
@@ -21,13 +21,29 @@ device = (
|
|
|
21
21
|
|
|
22
22
|
# ------------------------- Utilities -------------------------
|
|
23
23
|
def random_fill_nans(X):
|
|
24
|
+
"""Replace NaNs in an array with random values.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
X: Input NumPy array.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
NumPy array with NaNs replaced.
|
|
31
|
+
"""
|
|
24
32
|
nan_mask = np.isnan(X)
|
|
25
33
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
26
34
|
return X
|
|
27
35
|
|
|
28
36
|
# ------------------------- Model Definitions -------------------------
|
|
29
37
|
class CNNClassifier(nn.Module):
|
|
38
|
+
"""Simple 1D CNN classifier for fixed-length inputs."""
|
|
39
|
+
|
|
30
40
|
def __init__(self, input_size, num_classes):
|
|
41
|
+
"""Initialize CNN classifier layers.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_size: Length of the 1D input.
|
|
45
|
+
num_classes: Number of output classes.
|
|
46
|
+
"""
|
|
31
47
|
super().__init__()
|
|
32
48
|
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
|
|
33
49
|
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
|
@@ -39,11 +55,13 @@ class CNNClassifier(nn.Module):
|
|
|
39
55
|
self.fc2 = nn.Linear(64, num_classes)
|
|
40
56
|
|
|
41
57
|
def _forward_conv(self, x):
|
|
58
|
+
"""Apply convolutional layers and activation."""
|
|
42
59
|
x = self.relu(self.conv1(x))
|
|
43
60
|
x = self.relu(self.conv2(x))
|
|
44
61
|
return x
|
|
45
62
|
|
|
46
63
|
def forward(self, x):
|
|
64
|
+
"""Run the forward pass."""
|
|
47
65
|
x = x.unsqueeze(1)
|
|
48
66
|
x = self._forward_conv(x)
|
|
49
67
|
x = x.view(x.size(0), -1)
|
|
@@ -51,7 +69,15 @@ class CNNClassifier(nn.Module):
|
|
|
51
69
|
return self.fc2(x)
|
|
52
70
|
|
|
53
71
|
class MLPClassifier(nn.Module):
|
|
72
|
+
"""Simple MLP classifier."""
|
|
73
|
+
|
|
54
74
|
def __init__(self, input_dim, num_classes):
|
|
75
|
+
"""Initialize MLP layers.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
input_dim: Input feature dimension.
|
|
79
|
+
num_classes: Number of output classes.
|
|
80
|
+
"""
|
|
55
81
|
super().__init__()
|
|
56
82
|
self.model = nn.Sequential(
|
|
57
83
|
nn.Linear(input_dim, 128),
|
|
@@ -64,10 +90,20 @@ class MLPClassifier(nn.Module):
|
|
|
64
90
|
)
|
|
65
91
|
|
|
66
92
|
def forward(self, x):
|
|
93
|
+
"""Run the forward pass."""
|
|
67
94
|
return self.model(x)
|
|
68
95
|
|
|
69
96
|
class RNNClassifier(nn.Module):
|
|
97
|
+
"""LSTM-based classifier for sequential inputs."""
|
|
98
|
+
|
|
70
99
|
def __init__(self, input_size, hidden_dim, num_classes):
|
|
100
|
+
"""Initialize RNN classifier layers.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
input_size: Input feature dimension.
|
|
104
|
+
hidden_dim: Hidden state dimension.
|
|
105
|
+
num_classes: Number of output classes.
|
|
106
|
+
"""
|
|
71
107
|
super().__init__()
|
|
72
108
|
# Define LSTM layer
|
|
73
109
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
@@ -75,18 +111,29 @@ class RNNClassifier(nn.Module):
|
|
|
75
111
|
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
76
112
|
|
|
77
113
|
def forward(self, x):
|
|
114
|
+
"""Run the forward pass."""
|
|
78
115
|
x = x.unsqueeze(1)
|
|
79
116
|
_, (h_n, _) = self.lstm(x)
|
|
80
117
|
return self.fc(h_n.squeeze(0))
|
|
81
118
|
|
|
82
119
|
class AttentionRNNClassifier(nn.Module):
|
|
120
|
+
"""LSTM classifier with simple attention."""
|
|
121
|
+
|
|
83
122
|
def __init__(self, input_size, hidden_dim, num_classes):
|
|
123
|
+
"""Initialize attention-based RNN layers.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
input_size: Input feature dimension.
|
|
127
|
+
hidden_dim: Hidden state dimension.
|
|
128
|
+
num_classes: Number of output classes.
|
|
129
|
+
"""
|
|
84
130
|
super().__init__()
|
|
85
131
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
86
132
|
self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
|
|
87
133
|
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
88
134
|
|
|
89
135
|
def forward(self, x):
|
|
136
|
+
"""Run the forward pass."""
|
|
90
137
|
x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
|
|
91
138
|
lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
|
|
92
139
|
attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
|
|
@@ -94,7 +141,15 @@ class AttentionRNNClassifier(nn.Module):
|
|
|
94
141
|
return self.fc(context)
|
|
95
142
|
|
|
96
143
|
class PositionalEncoding(nn.Module):
|
|
144
|
+
"""Positional encoding module for transformer models."""
|
|
145
|
+
|
|
97
146
|
def __init__(self, d_model, max_len=5000):
|
|
147
|
+
"""Initialize positional encoding buffer.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
d_model: Model embedding dimension.
|
|
151
|
+
max_len: Maximum sequence length.
|
|
152
|
+
"""
|
|
98
153
|
super().__init__()
|
|
99
154
|
pe = torch.zeros(max_len, d_model)
|
|
100
155
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
@@ -104,11 +159,23 @@ class PositionalEncoding(nn.Module):
|
|
|
104
159
|
self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
|
105
160
|
|
|
106
161
|
def forward(self, x):
|
|
162
|
+
"""Add positional encoding to inputs."""
|
|
107
163
|
x = x + self.pe[:, :x.size(1)].to(x.device)
|
|
108
164
|
return x
|
|
109
165
|
|
|
110
166
|
class TransformerClassifier(nn.Module):
|
|
167
|
+
"""Transformer encoder-based classifier."""
|
|
168
|
+
|
|
111
169
|
def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
|
|
170
|
+
"""Initialize transformer classifier layers.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
input_dim: Input feature dimension.
|
|
174
|
+
model_dim: Transformer model dimension.
|
|
175
|
+
num_classes: Number of output classes.
|
|
176
|
+
num_heads: Number of attention heads.
|
|
177
|
+
num_layers: Number of encoder layers.
|
|
178
|
+
"""
|
|
112
179
|
super().__init__()
|
|
113
180
|
self.input_fc = nn.Linear(input_dim, model_dim)
|
|
114
181
|
self.pos_encoder = PositionalEncoding(model_dim)
|
|
@@ -119,6 +186,7 @@ class TransformerClassifier(nn.Module):
|
|
|
119
186
|
self.cls_head = nn.Linear(model_dim, num_classes)
|
|
120
187
|
|
|
121
188
|
def forward(self, x):
|
|
189
|
+
"""Run the forward pass."""
|
|
122
190
|
# x: [batch_size, input_dim]
|
|
123
191
|
x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
|
|
124
192
|
x = self.pos_encoder(x)
|
|
@@ -128,6 +196,19 @@ class TransformerClassifier(nn.Module):
|
|
|
128
196
|
return self.cls_head(pooled)
|
|
129
197
|
|
|
130
198
|
def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
|
|
199
|
+
"""Train a model with early stopping.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model: PyTorch model.
|
|
203
|
+
loader: DataLoader for training data.
|
|
204
|
+
optimizer: Optimizer instance.
|
|
205
|
+
criterion: Loss function.
|
|
206
|
+
device: Torch device.
|
|
207
|
+
ref_name: Reference label for logging.
|
|
208
|
+
model_name: Model label for logging.
|
|
209
|
+
epochs: Maximum epochs.
|
|
210
|
+
patience: Early-stopping patience.
|
|
211
|
+
"""
|
|
131
212
|
model.train()
|
|
132
213
|
best_loss = float('inf')
|
|
133
214
|
trigger_times = 0
|
|
@@ -154,6 +235,17 @@ def train_model(model, loader, optimizer, criterion, device, ref_name="", model_
|
|
|
154
235
|
break
|
|
155
236
|
|
|
156
237
|
def evaluate_model(model, X_tensor, y_encoded, device):
|
|
238
|
+
"""Evaluate a trained model and compute metrics.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
model: Trained model.
|
|
242
|
+
X_tensor: Input tensor.
|
|
243
|
+
y_encoded: Encoded labels.
|
|
244
|
+
device: Torch device.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Tuple of metrics dict, predicted labels, and probabilities.
|
|
248
|
+
"""
|
|
157
249
|
model.eval()
|
|
158
250
|
with torch.no_grad():
|
|
159
251
|
outputs = model(X_tensor.to(device))
|
|
@@ -176,6 +268,18 @@ def evaluate_model(model, X_tensor, y_encoded, device):
|
|
|
176
268
|
}, preds, probs
|
|
177
269
|
|
|
178
270
|
def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
|
|
271
|
+
"""Train a random forest classifier.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
X_tensor: Input tensor.
|
|
275
|
+
y_tensor: Label tensor.
|
|
276
|
+
train_indices: Indices for training.
|
|
277
|
+
test_indices: Indices for testing.
|
|
278
|
+
n_estimators: Number of trees.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Tuple of (model, preds, probs).
|
|
282
|
+
"""
|
|
179
283
|
model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
|
180
284
|
model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
|
|
181
285
|
probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
|
|
@@ -186,6 +290,25 @@ def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
|
|
|
186
290
|
def run_training_loop(adata, site_config, layer_name=None,
|
|
187
291
|
mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
|
|
188
292
|
max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
|
|
293
|
+
"""Train one or more classifier types on AnnData.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
adata: AnnData object containing data and labels.
|
|
297
|
+
site_config: Mapping of reference to site list.
|
|
298
|
+
layer_name: Optional layer to use as input.
|
|
299
|
+
mlp: Whether to train an MLP model.
|
|
300
|
+
cnn: Whether to train a CNN model.
|
|
301
|
+
rnn: Whether to train an RNN model.
|
|
302
|
+
arnn: Whether to train an attention RNN model.
|
|
303
|
+
transformer: Whether to train a transformer model.
|
|
304
|
+
rf: Whether to train a random forest model.
|
|
305
|
+
nb: Whether to train a Naive Bayes model.
|
|
306
|
+
rr_bayes: Whether to train a ridge regression model.
|
|
307
|
+
max_epochs: Maximum training epochs.
|
|
308
|
+
max_patience: Early stopping patience.
|
|
309
|
+
n_estimators: Random forest estimator count.
|
|
310
|
+
training_split: Fraction of data used for training.
|
|
311
|
+
"""
|
|
189
312
|
device = (
|
|
190
313
|
torch.device('cuda') if torch.cuda.is_available() else
|
|
191
314
|
torch.device('mps') if torch.backends.mps.is_available() else
|
|
@@ -701,6 +824,20 @@ def evaluate_model_by_subgroups(
|
|
|
701
824
|
label_col="activity_status",
|
|
702
825
|
min_samples=10,
|
|
703
826
|
exclude_training_data=True):
|
|
827
|
+
"""Evaluate predictions within categorical subgroups.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
adata: AnnData with prediction columns.
|
|
831
|
+
model_prefix: Prediction column prefix.
|
|
832
|
+
suffix: Prediction column suffix.
|
|
833
|
+
groupby_cols: Columns to group by.
|
|
834
|
+
label_col: Ground-truth label column.
|
|
835
|
+
min_samples: Minimum samples per group.
|
|
836
|
+
exclude_training_data: Whether to exclude training rows.
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
DataFrame of subgroup-level metrics.
|
|
840
|
+
"""
|
|
704
841
|
import pandas as pd
|
|
705
842
|
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
|
706
843
|
|
|
@@ -745,6 +882,18 @@ def evaluate_model_by_subgroups(
|
|
|
745
882
|
return pd.DataFrame(results)
|
|
746
883
|
|
|
747
884
|
def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
|
|
885
|
+
"""Evaluate multiple model prefixes across subgroups.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
adata: AnnData with prediction columns.
|
|
889
|
+
model_prefixes: Iterable of model prefixes.
|
|
890
|
+
groupby_cols: Columns to group by.
|
|
891
|
+
label_col: Ground-truth label column.
|
|
892
|
+
exclude_training_data: Whether to exclude training rows.
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
Concatenated DataFrame of subgroup-level metrics.
|
|
896
|
+
"""
|
|
748
897
|
import pandas as pd
|
|
749
898
|
all_metrics = []
|
|
750
899
|
for model_prefix in model_prefixes:
|
|
@@ -758,6 +907,20 @@ def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col,
|
|
|
758
907
|
return final_df
|
|
759
908
|
|
|
760
909
|
def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
|
|
910
|
+
"""Prepare a long-format DataFrame for model performance plots.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
adata: AnnData with prediction columns.
|
|
914
|
+
outkey: Key to store the melted DataFrame in ``adata.uns``.
|
|
915
|
+
groupby: Grouping columns to include.
|
|
916
|
+
label_col: Ground-truth label column.
|
|
917
|
+
model_names: Model prefixes to include.
|
|
918
|
+
suffix: Prediction column suffix.
|
|
919
|
+
omit_training: Whether to exclude training rows.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
Melted DataFrame of predictions.
|
|
923
|
+
"""
|
|
761
924
|
import pandas as pd
|
|
762
925
|
import seaborn as sns
|
|
763
926
|
import matplotlib.pyplot as plt
|
|
@@ -13,6 +13,15 @@ def subset_adata(adata, obs_columns):
|
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
15
|
def subset_recursive(adata_subset, columns):
|
|
16
|
+
"""Recursively subset AnnData by categorical columns.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
adata_subset: AnnData subset to split.
|
|
20
|
+
columns: Remaining columns to split on.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dictionary mapping category tuples to AnnData subsets.
|
|
24
|
+
"""
|
|
16
25
|
if not columns:
|
|
17
26
|
return {(): adata_subset}
|
|
18
27
|
|
|
@@ -29,4 +38,4 @@ def subset_adata(adata, obs_columns):
|
|
|
29
38
|
# Start the recursive subset process
|
|
30
39
|
subsets_dict = subset_recursive(adata, obs_columns)
|
|
31
40
|
|
|
32
|
-
return subsets_dict
|
|
41
|
+
return subsets_dict
|
|
@@ -14,6 +14,17 @@ def subset_adata(adata, columns, cat_type='obs'):
|
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
16
|
def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
|
|
17
|
+
"""Recursively subset AnnData by categorical columns.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
adata_subset: AnnData subset to split.
|
|
21
|
+
columns: Remaining columns to split on.
|
|
22
|
+
cat_type: Whether to use obs or var categories.
|
|
23
|
+
key_prefix: Tuple of previous category keys.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Dictionary mapping category tuples to AnnData subsets.
|
|
27
|
+
"""
|
|
17
28
|
# Returns when the bottom of the stack is reached
|
|
18
29
|
if not columns:
|
|
19
30
|
# If there's only one column, return the key as a single value, not a tuple
|
|
@@ -43,4 +54,4 @@ def subset_adata(adata, columns, cat_type='obs'):
|
|
|
43
54
|
# Start the recursive subset process
|
|
44
55
|
subsets_dict = subset_recursive(adata, columns, cat_type)
|
|
45
56
|
|
|
46
|
-
return subsets_dict
|
|
57
|
+
return subsets_dict
|