smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import resources
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
SCHEMA_REGISTRY_VERSION = "1"
|
|
7
|
+
SCHEMA_REGISTRY_RESOURCE = "anndata_schema_v1.yaml"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_schema_registry_path() -> Path:
|
|
11
|
+
return resources.files(__package__).joinpath(SCHEMA_REGISTRY_RESOURCE)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
schema_version: "1"
|
|
2
|
+
description: "smftools AnnData schema registry (v1)."
|
|
3
|
+
stages:
|
|
4
|
+
raw:
|
|
5
|
+
stage_requires: []
|
|
6
|
+
obs:
|
|
7
|
+
Experiment_name:
|
|
8
|
+
dtype: "category"
|
|
9
|
+
created_by: "smftools.cli.load_adata"
|
|
10
|
+
modified_by: []
|
|
11
|
+
notes: "Experiment identifier applied to all reads."
|
|
12
|
+
requires: []
|
|
13
|
+
optional_inputs: []
|
|
14
|
+
Experiment_name_and_barcode:
|
|
15
|
+
dtype: "category"
|
|
16
|
+
created_by: "smftools.cli.load_adata"
|
|
17
|
+
modified_by: []
|
|
18
|
+
notes: "Concatenated experiment name and barcode."
|
|
19
|
+
requires: [["obs.Experiment_name", "obs.Barcode"]]
|
|
20
|
+
optional_inputs: []
|
|
21
|
+
Barcode:
|
|
22
|
+
dtype: "category"
|
|
23
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
24
|
+
modified_by: []
|
|
25
|
+
notes: "Barcode assigned during demultiplexing or extraction."
|
|
26
|
+
requires: []
|
|
27
|
+
optional_inputs: []
|
|
28
|
+
read_length:
|
|
29
|
+
dtype: "float"
|
|
30
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
31
|
+
modified_by: []
|
|
32
|
+
notes: "Read length in bases."
|
|
33
|
+
requires: []
|
|
34
|
+
optional_inputs: []
|
|
35
|
+
mapped_length:
|
|
36
|
+
dtype: "float"
|
|
37
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
38
|
+
modified_by: []
|
|
39
|
+
notes: "Aligned length in bases."
|
|
40
|
+
requires: []
|
|
41
|
+
optional_inputs: []
|
|
42
|
+
reference_length:
|
|
43
|
+
dtype: "float"
|
|
44
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
45
|
+
modified_by: []
|
|
46
|
+
notes: "Reference length for alignment target."
|
|
47
|
+
requires: []
|
|
48
|
+
optional_inputs: []
|
|
49
|
+
read_quality:
|
|
50
|
+
dtype: "float"
|
|
51
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
52
|
+
modified_by: []
|
|
53
|
+
notes: "Per-read quality score."
|
|
54
|
+
requires: []
|
|
55
|
+
optional_inputs: []
|
|
56
|
+
mapping_quality:
|
|
57
|
+
dtype: "float"
|
|
58
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
59
|
+
modified_by: []
|
|
60
|
+
notes: "Mapping quality score."
|
|
61
|
+
requires: []
|
|
62
|
+
optional_inputs: []
|
|
63
|
+
read_length_to_reference_length_ratio:
|
|
64
|
+
dtype: "float"
|
|
65
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
66
|
+
modified_by: []
|
|
67
|
+
notes: "Read length divided by reference length."
|
|
68
|
+
requires: [["obs.read_length", "obs.reference_length"]]
|
|
69
|
+
optional_inputs: []
|
|
70
|
+
mapped_length_to_reference_length_ratio:
|
|
71
|
+
dtype: "float"
|
|
72
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
73
|
+
modified_by: []
|
|
74
|
+
notes: "Mapped length divided by reference length."
|
|
75
|
+
requires: [["obs.mapped_length", "obs.reference_length"]]
|
|
76
|
+
optional_inputs: []
|
|
77
|
+
mapped_length_to_read_length_ratio:
|
|
78
|
+
dtype: "float"
|
|
79
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
80
|
+
modified_by: []
|
|
81
|
+
notes: "Mapped length divided by read length."
|
|
82
|
+
requires: [["obs.mapped_length", "obs.read_length"]]
|
|
83
|
+
optional_inputs: []
|
|
84
|
+
Raw_modification_signal:
|
|
85
|
+
dtype: "float"
|
|
86
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
87
|
+
modified_by:
|
|
88
|
+
- "smftools.cli.load_adata"
|
|
89
|
+
notes: "Summed modification signal per read."
|
|
90
|
+
requires: [["X"], ["layers.raw_mods"]]
|
|
91
|
+
optional_inputs: []
|
|
92
|
+
pod5_origin:
|
|
93
|
+
dtype: "string"
|
|
94
|
+
created_by: "smftools.informatics.h5ad_functions.annotate_pod5_origin"
|
|
95
|
+
modified_by: []
|
|
96
|
+
notes: "POD5 filename source for each read."
|
|
97
|
+
requires: [["obs_names"]]
|
|
98
|
+
optional_inputs: []
|
|
99
|
+
demux_type:
|
|
100
|
+
dtype: "category"
|
|
101
|
+
created_by: "smftools.informatics.h5ad_functions.add_demux_type_annotation"
|
|
102
|
+
modified_by: []
|
|
103
|
+
notes: "Classification of demultiplexing status."
|
|
104
|
+
requires: [["obs_names"]]
|
|
105
|
+
optional_inputs: []
|
|
106
|
+
var:
|
|
107
|
+
reference_position:
|
|
108
|
+
dtype: "int"
|
|
109
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
110
|
+
modified_by: []
|
|
111
|
+
notes: "Reference coordinate for each column."
|
|
112
|
+
requires: []
|
|
113
|
+
optional_inputs: []
|
|
114
|
+
reference_id:
|
|
115
|
+
dtype: "category"
|
|
116
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
117
|
+
modified_by: []
|
|
118
|
+
notes: "Reference contig or sequence name."
|
|
119
|
+
requires: []
|
|
120
|
+
optional_inputs: []
|
|
121
|
+
layers:
|
|
122
|
+
raw_mods:
|
|
123
|
+
dtype: "float"
|
|
124
|
+
created_by: "smftools.informatics.modkit_extract_to_adata"
|
|
125
|
+
modified_by: []
|
|
126
|
+
notes: "Raw modification scores (modality-dependent)."
|
|
127
|
+
requires: []
|
|
128
|
+
optional_inputs: []
|
|
129
|
+
obsm: {}
|
|
130
|
+
varm: {}
|
|
131
|
+
obsp: {}
|
|
132
|
+
uns:
|
|
133
|
+
smftools:
|
|
134
|
+
dtype: "mapping"
|
|
135
|
+
created_by: "smftools.metadata.record_smftools_metadata"
|
|
136
|
+
modified_by: []
|
|
137
|
+
notes: "smftools metadata including history, environment, provenance, schema snapshot."
|
|
138
|
+
requires: []
|
|
139
|
+
optional_inputs: []
|
|
140
|
+
preprocess:
|
|
141
|
+
stage_requires: ["raw"]
|
|
142
|
+
obs:
|
|
143
|
+
sequence__merged_cluster_id:
|
|
144
|
+
dtype: "category"
|
|
145
|
+
created_by: "smftools.preprocessing.flag_duplicate_reads"
|
|
146
|
+
modified_by: []
|
|
147
|
+
notes: "Cluster identifier for duplicate detection."
|
|
148
|
+
requires: [["layers.nan0_0minus1"]]
|
|
149
|
+
optional_inputs: ["obs.demux_type"]
|
|
150
|
+
layers:
|
|
151
|
+
nan0_0minus1:
|
|
152
|
+
dtype: "float"
|
|
153
|
+
created_by: "smftools.preprocessing.binarize"
|
|
154
|
+
modified_by:
|
|
155
|
+
- "smftools.preprocessing.clean_NaN"
|
|
156
|
+
notes: "Binarized methylation matrix (nan=0, 0=-1)."
|
|
157
|
+
requires: [["X"]]
|
|
158
|
+
optional_inputs: []
|
|
159
|
+
obsm:
|
|
160
|
+
X_umap:
|
|
161
|
+
dtype: "float"
|
|
162
|
+
created_by: "smftools.tools.calculate_umap"
|
|
163
|
+
modified_by: []
|
|
164
|
+
notes: "UMAP embedding for preprocessed reads."
|
|
165
|
+
requires: [["X"]]
|
|
166
|
+
optional_inputs: []
|
|
167
|
+
varm: {}
|
|
168
|
+
obsp: {}
|
|
169
|
+
uns:
|
|
170
|
+
duplicate_read_groups:
|
|
171
|
+
dtype: "mapping"
|
|
172
|
+
created_by: "smftools.preprocessing.flag_duplicate_reads"
|
|
173
|
+
modified_by: []
|
|
174
|
+
notes: "Duplicate read group metadata."
|
|
175
|
+
requires: [["obs.sequence__merged_cluster_id"]]
|
|
176
|
+
optional_inputs: []
|
|
177
|
+
spatial:
|
|
178
|
+
stage_requires: ["raw", "preprocess"]
|
|
179
|
+
obs:
|
|
180
|
+
leiden:
|
|
181
|
+
dtype: "category"
|
|
182
|
+
created_by: "smftools.tools.calculate_umap"
|
|
183
|
+
modified_by: []
|
|
184
|
+
notes: "Leiden cluster assignments."
|
|
185
|
+
requires: [["obsm.X_umap"]]
|
|
186
|
+
optional_inputs: []
|
|
187
|
+
obsm:
|
|
188
|
+
X_umap:
|
|
189
|
+
dtype: "float"
|
|
190
|
+
created_by: "smftools.tools.calculate_umap"
|
|
191
|
+
modified_by: []
|
|
192
|
+
notes: "UMAP embedding for spatial analyses."
|
|
193
|
+
requires: [["X"]]
|
|
194
|
+
optional_inputs: []
|
|
195
|
+
layers: {}
|
|
196
|
+
varm: {}
|
|
197
|
+
obsp: {}
|
|
198
|
+
uns:
|
|
199
|
+
positionwise_result:
|
|
200
|
+
dtype: "mapping"
|
|
201
|
+
created_by: "smftools.tools.position_stats.compute_positionwise_statistics"
|
|
202
|
+
modified_by: []
|
|
203
|
+
notes: "Positionwise correlation statistics for spatial analyses."
|
|
204
|
+
requires: [["X"]]
|
|
205
|
+
optional_inputs: ["obs.reference_column"]
|
|
206
|
+
hmm:
|
|
207
|
+
stage_requires: ["raw", "preprocess", "spatial"]
|
|
208
|
+
layers:
|
|
209
|
+
hmm_state_calls:
|
|
210
|
+
dtype: "int"
|
|
211
|
+
created_by: "smftools.hmm.call_hmm_peaks"
|
|
212
|
+
modified_by: []
|
|
213
|
+
notes: "HMM-derived state calls per read/position."
|
|
214
|
+
requires: [["layers.nan0_0minus1"]]
|
|
215
|
+
optional_inputs: []
|
|
216
|
+
obsm: {}
|
|
217
|
+
varm: {}
|
|
218
|
+
obsp: {}
|
|
219
|
+
obs: {}
|
|
220
|
+
uns:
|
|
221
|
+
hmm_annotated:
|
|
222
|
+
dtype: "bool"
|
|
223
|
+
created_by: "smftools.cli.hmm_adata"
|
|
224
|
+
modified_by: []
|
|
225
|
+
notes: "Flag indicating HMM annotations are present."
|
|
226
|
+
requires: [["layers.hmm_state_calls"]]
|
|
227
|
+
optional_inputs: []
|
smftools/tools/__init__.py
CHANGED
|
@@ -1,20 +1,27 @@
|
|
|
1
|
-
from
|
|
2
|
-
from .calculate_umap import calculate_umap
|
|
3
|
-
from .cluster_adata_on_methylation import cluster_adata_on_methylation
|
|
4
|
-
from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
|
|
5
|
-
from .read_stats import calculate_row_entropy
|
|
6
|
-
from .spatial_autocorrelation import *
|
|
7
|
-
from .subset_adata import subset_adata
|
|
1
|
+
from __future__ import annotations
|
|
8
2
|
|
|
3
|
+
from importlib import import_module
|
|
9
4
|
|
|
10
|
-
|
|
11
|
-
"
|
|
12
|
-
"
|
|
13
|
-
"
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
17
|
-
"
|
|
18
|
-
"
|
|
19
|
-
"subset_adata",
|
|
20
|
-
|
|
5
|
+
_LAZY_ATTRS = {
|
|
6
|
+
"calculate_umap": "smftools.tools.calculate_umap",
|
|
7
|
+
"cluster_adata_on_methylation": "smftools.tools.cluster_adata_on_methylation",
|
|
8
|
+
"combine_layers": "smftools.tools.general_tools",
|
|
9
|
+
"create_nan_mask_from_X": "smftools.tools.general_tools",
|
|
10
|
+
"create_nan_or_non_gpc_mask": "smftools.tools.general_tools",
|
|
11
|
+
"calculate_relative_risk_on_activity": "smftools.tools.position_stats",
|
|
12
|
+
"compute_positionwise_statistics": "smftools.tools.position_stats",
|
|
13
|
+
"calculate_row_entropy": "smftools.tools.read_stats",
|
|
14
|
+
"subset_adata": "smftools.tools.subset_adata",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def __getattr__(name: str):
|
|
19
|
+
if name in _LAZY_ATTRS:
|
|
20
|
+
module = import_module(_LAZY_ATTRS[name])
|
|
21
|
+
attr = getattr(module, name)
|
|
22
|
+
globals()[name] = attr
|
|
23
|
+
return attr
|
|
24
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = list(_LAZY_ATTRS.keys())
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
## Train CNN, RNN, Random Forest models on double barcoded, low contamination datasets
|
|
2
4
|
import torch
|
|
3
5
|
import torch.nn as nn
|
|
@@ -21,13 +23,29 @@ device = (
|
|
|
21
23
|
|
|
22
24
|
# ------------------------- Utilities -------------------------
|
|
23
25
|
def random_fill_nans(X):
|
|
26
|
+
"""Replace NaNs in an array with random values.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
X: Input NumPy array.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
NumPy array with NaNs replaced.
|
|
33
|
+
"""
|
|
24
34
|
nan_mask = np.isnan(X)
|
|
25
35
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
26
36
|
return X
|
|
27
37
|
|
|
28
38
|
# ------------------------- Model Definitions -------------------------
|
|
29
39
|
class CNNClassifier(nn.Module):
|
|
40
|
+
"""Simple 1D CNN classifier for fixed-length inputs."""
|
|
41
|
+
|
|
30
42
|
def __init__(self, input_size, num_classes):
|
|
43
|
+
"""Initialize CNN classifier layers.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
input_size: Length of the 1D input.
|
|
47
|
+
num_classes: Number of output classes.
|
|
48
|
+
"""
|
|
31
49
|
super().__init__()
|
|
32
50
|
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
|
|
33
51
|
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
|
@@ -39,11 +57,13 @@ class CNNClassifier(nn.Module):
|
|
|
39
57
|
self.fc2 = nn.Linear(64, num_classes)
|
|
40
58
|
|
|
41
59
|
def _forward_conv(self, x):
|
|
60
|
+
"""Apply convolutional layers and activation."""
|
|
42
61
|
x = self.relu(self.conv1(x))
|
|
43
62
|
x = self.relu(self.conv2(x))
|
|
44
63
|
return x
|
|
45
64
|
|
|
46
65
|
def forward(self, x):
|
|
66
|
+
"""Run the forward pass."""
|
|
47
67
|
x = x.unsqueeze(1)
|
|
48
68
|
x = self._forward_conv(x)
|
|
49
69
|
x = x.view(x.size(0), -1)
|
|
@@ -51,7 +71,15 @@ class CNNClassifier(nn.Module):
|
|
|
51
71
|
return self.fc2(x)
|
|
52
72
|
|
|
53
73
|
class MLPClassifier(nn.Module):
|
|
74
|
+
"""Simple MLP classifier."""
|
|
75
|
+
|
|
54
76
|
def __init__(self, input_dim, num_classes):
|
|
77
|
+
"""Initialize MLP layers.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
input_dim: Input feature dimension.
|
|
81
|
+
num_classes: Number of output classes.
|
|
82
|
+
"""
|
|
55
83
|
super().__init__()
|
|
56
84
|
self.model = nn.Sequential(
|
|
57
85
|
nn.Linear(input_dim, 128),
|
|
@@ -64,10 +92,20 @@ class MLPClassifier(nn.Module):
|
|
|
64
92
|
)
|
|
65
93
|
|
|
66
94
|
def forward(self, x):
|
|
95
|
+
"""Run the forward pass."""
|
|
67
96
|
return self.model(x)
|
|
68
97
|
|
|
69
98
|
class RNNClassifier(nn.Module):
|
|
99
|
+
"""LSTM-based classifier for sequential inputs."""
|
|
100
|
+
|
|
70
101
|
def __init__(self, input_size, hidden_dim, num_classes):
|
|
102
|
+
"""Initialize RNN classifier layers.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
input_size: Input feature dimension.
|
|
106
|
+
hidden_dim: Hidden state dimension.
|
|
107
|
+
num_classes: Number of output classes.
|
|
108
|
+
"""
|
|
71
109
|
super().__init__()
|
|
72
110
|
# Define LSTM layer
|
|
73
111
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
@@ -75,18 +113,29 @@ class RNNClassifier(nn.Module):
|
|
|
75
113
|
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
76
114
|
|
|
77
115
|
def forward(self, x):
|
|
116
|
+
"""Run the forward pass."""
|
|
78
117
|
x = x.unsqueeze(1)
|
|
79
118
|
_, (h_n, _) = self.lstm(x)
|
|
80
119
|
return self.fc(h_n.squeeze(0))
|
|
81
120
|
|
|
82
121
|
class AttentionRNNClassifier(nn.Module):
|
|
122
|
+
"""LSTM classifier with simple attention."""
|
|
123
|
+
|
|
83
124
|
def __init__(self, input_size, hidden_dim, num_classes):
|
|
125
|
+
"""Initialize attention-based RNN layers.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
input_size: Input feature dimension.
|
|
129
|
+
hidden_dim: Hidden state dimension.
|
|
130
|
+
num_classes: Number of output classes.
|
|
131
|
+
"""
|
|
84
132
|
super().__init__()
|
|
85
133
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
86
134
|
self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
|
|
87
135
|
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
88
136
|
|
|
89
137
|
def forward(self, x):
|
|
138
|
+
"""Run the forward pass."""
|
|
90
139
|
x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
|
|
91
140
|
lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
|
|
92
141
|
attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
|
|
@@ -94,7 +143,15 @@ class AttentionRNNClassifier(nn.Module):
|
|
|
94
143
|
return self.fc(context)
|
|
95
144
|
|
|
96
145
|
class PositionalEncoding(nn.Module):
|
|
146
|
+
"""Positional encoding module for transformer models."""
|
|
147
|
+
|
|
97
148
|
def __init__(self, d_model, max_len=5000):
|
|
149
|
+
"""Initialize positional encoding buffer.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
d_model: Model embedding dimension.
|
|
153
|
+
max_len: Maximum sequence length.
|
|
154
|
+
"""
|
|
98
155
|
super().__init__()
|
|
99
156
|
pe = torch.zeros(max_len, d_model)
|
|
100
157
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
@@ -104,11 +161,23 @@ class PositionalEncoding(nn.Module):
|
|
|
104
161
|
self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
|
105
162
|
|
|
106
163
|
def forward(self, x):
|
|
164
|
+
"""Add positional encoding to inputs."""
|
|
107
165
|
x = x + self.pe[:, :x.size(1)].to(x.device)
|
|
108
166
|
return x
|
|
109
167
|
|
|
110
168
|
class TransformerClassifier(nn.Module):
|
|
169
|
+
"""Transformer encoder-based classifier."""
|
|
170
|
+
|
|
111
171
|
def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
|
|
172
|
+
"""Initialize transformer classifier layers.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
input_dim: Input feature dimension.
|
|
176
|
+
model_dim: Transformer model dimension.
|
|
177
|
+
num_classes: Number of output classes.
|
|
178
|
+
num_heads: Number of attention heads.
|
|
179
|
+
num_layers: Number of encoder layers.
|
|
180
|
+
"""
|
|
112
181
|
super().__init__()
|
|
113
182
|
self.input_fc = nn.Linear(input_dim, model_dim)
|
|
114
183
|
self.pos_encoder = PositionalEncoding(model_dim)
|
|
@@ -119,6 +188,7 @@ class TransformerClassifier(nn.Module):
|
|
|
119
188
|
self.cls_head = nn.Linear(model_dim, num_classes)
|
|
120
189
|
|
|
121
190
|
def forward(self, x):
|
|
191
|
+
"""Run the forward pass."""
|
|
122
192
|
# x: [batch_size, input_dim]
|
|
123
193
|
x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
|
|
124
194
|
x = self.pos_encoder(x)
|
|
@@ -128,6 +198,19 @@ class TransformerClassifier(nn.Module):
|
|
|
128
198
|
return self.cls_head(pooled)
|
|
129
199
|
|
|
130
200
|
def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
|
|
201
|
+
"""Train a model with early stopping.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
model: PyTorch model.
|
|
205
|
+
loader: DataLoader for training data.
|
|
206
|
+
optimizer: Optimizer instance.
|
|
207
|
+
criterion: Loss function.
|
|
208
|
+
device: Torch device.
|
|
209
|
+
ref_name: Reference label for logging.
|
|
210
|
+
model_name: Model label for logging.
|
|
211
|
+
epochs: Maximum epochs.
|
|
212
|
+
patience: Early-stopping patience.
|
|
213
|
+
"""
|
|
131
214
|
model.train()
|
|
132
215
|
best_loss = float('inf')
|
|
133
216
|
trigger_times = 0
|
|
@@ -154,6 +237,17 @@ def train_model(model, loader, optimizer, criterion, device, ref_name="", model_
|
|
|
154
237
|
break
|
|
155
238
|
|
|
156
239
|
def evaluate_model(model, X_tensor, y_encoded, device):
|
|
240
|
+
"""Evaluate a trained model and compute metrics.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
model: Trained model.
|
|
244
|
+
X_tensor: Input tensor.
|
|
245
|
+
y_encoded: Encoded labels.
|
|
246
|
+
device: Torch device.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple of metrics dict, predicted labels, and probabilities.
|
|
250
|
+
"""
|
|
157
251
|
model.eval()
|
|
158
252
|
with torch.no_grad():
|
|
159
253
|
outputs = model(X_tensor.to(device))
|
|
@@ -176,6 +270,18 @@ def evaluate_model(model, X_tensor, y_encoded, device):
|
|
|
176
270
|
}, preds, probs
|
|
177
271
|
|
|
178
272
|
def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
|
|
273
|
+
"""Train a random forest classifier.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
X_tensor: Input tensor.
|
|
277
|
+
y_tensor: Label tensor.
|
|
278
|
+
train_indices: Indices for training.
|
|
279
|
+
test_indices: Indices for testing.
|
|
280
|
+
n_estimators: Number of trees.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Tuple of (model, preds, probs).
|
|
284
|
+
"""
|
|
179
285
|
model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
|
180
286
|
model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
|
|
181
287
|
probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
|
|
@@ -186,6 +292,25 @@ def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
|
|
|
186
292
|
def run_training_loop(adata, site_config, layer_name=None,
|
|
187
293
|
mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
|
|
188
294
|
max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
|
|
295
|
+
"""Train one or more classifier types on AnnData.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
adata: AnnData object containing data and labels.
|
|
299
|
+
site_config: Mapping of reference to site list.
|
|
300
|
+
layer_name: Optional layer to use as input.
|
|
301
|
+
mlp: Whether to train an MLP model.
|
|
302
|
+
cnn: Whether to train a CNN model.
|
|
303
|
+
rnn: Whether to train an RNN model.
|
|
304
|
+
arnn: Whether to train an attention RNN model.
|
|
305
|
+
transformer: Whether to train a transformer model.
|
|
306
|
+
rf: Whether to train a random forest model.
|
|
307
|
+
nb: Whether to train a Naive Bayes model.
|
|
308
|
+
rr_bayes: Whether to train a ridge regression model.
|
|
309
|
+
max_epochs: Maximum training epochs.
|
|
310
|
+
max_patience: Early stopping patience.
|
|
311
|
+
n_estimators: Random forest estimator count.
|
|
312
|
+
training_split: Fraction of data used for training.
|
|
313
|
+
"""
|
|
189
314
|
device = (
|
|
190
315
|
torch.device('cuda') if torch.cuda.is_available() else
|
|
191
316
|
torch.device('mps') if torch.backends.mps.is_available() else
|
|
@@ -701,6 +826,20 @@ def evaluate_model_by_subgroups(
|
|
|
701
826
|
label_col="activity_status",
|
|
702
827
|
min_samples=10,
|
|
703
828
|
exclude_training_data=True):
|
|
829
|
+
"""Evaluate predictions within categorical subgroups.
|
|
830
|
+
|
|
831
|
+
Args:
|
|
832
|
+
adata: AnnData with prediction columns.
|
|
833
|
+
model_prefix: Prediction column prefix.
|
|
834
|
+
suffix: Prediction column suffix.
|
|
835
|
+
groupby_cols: Columns to group by.
|
|
836
|
+
label_col: Ground-truth label column.
|
|
837
|
+
min_samples: Minimum samples per group.
|
|
838
|
+
exclude_training_data: Whether to exclude training rows.
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
DataFrame of subgroup-level metrics.
|
|
842
|
+
"""
|
|
704
843
|
import pandas as pd
|
|
705
844
|
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
|
706
845
|
|
|
@@ -745,6 +884,18 @@ def evaluate_model_by_subgroups(
|
|
|
745
884
|
return pd.DataFrame(results)
|
|
746
885
|
|
|
747
886
|
def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
|
|
887
|
+
"""Evaluate multiple model prefixes across subgroups.
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
adata: AnnData with prediction columns.
|
|
891
|
+
model_prefixes: Iterable of model prefixes.
|
|
892
|
+
groupby_cols: Columns to group by.
|
|
893
|
+
label_col: Ground-truth label column.
|
|
894
|
+
exclude_training_data: Whether to exclude training rows.
|
|
895
|
+
|
|
896
|
+
Returns:
|
|
897
|
+
Concatenated DataFrame of subgroup-level metrics.
|
|
898
|
+
"""
|
|
748
899
|
import pandas as pd
|
|
749
900
|
all_metrics = []
|
|
750
901
|
for model_prefix in model_prefixes:
|
|
@@ -758,6 +909,20 @@ def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col,
|
|
|
758
909
|
return final_df
|
|
759
910
|
|
|
760
911
|
def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
|
|
912
|
+
"""Prepare a long-format DataFrame for model performance plots.
|
|
913
|
+
|
|
914
|
+
Args:
|
|
915
|
+
adata: AnnData with prediction columns.
|
|
916
|
+
outkey: Key to store the melted DataFrame in ``adata.uns``.
|
|
917
|
+
groupby: Grouping columns to include.
|
|
918
|
+
label_col: Ground-truth label column.
|
|
919
|
+
model_names: Model prefixes to include.
|
|
920
|
+
suffix: Prediction column suffix.
|
|
921
|
+
omit_training: Whether to exclude training rows.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
Melted DataFrame of predictions.
|
|
925
|
+
"""
|
|
761
926
|
import pandas as pd
|
|
762
927
|
import seaborn as sns
|
|
763
928
|
import matplotlib.pyplot as plt
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
# subset_adata
|
|
2
4
|
|
|
3
5
|
def subset_adata(adata, obs_columns):
|
|
@@ -13,6 +15,15 @@ def subset_adata(adata, obs_columns):
|
|
|
13
15
|
"""
|
|
14
16
|
|
|
15
17
|
def subset_recursive(adata_subset, columns):
|
|
18
|
+
"""Recursively subset AnnData by categorical columns.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
adata_subset: AnnData subset to split.
|
|
22
|
+
columns: Remaining columns to split on.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dictionary mapping category tuples to AnnData subsets.
|
|
26
|
+
"""
|
|
16
27
|
if not columns:
|
|
17
28
|
return {(): adata_subset}
|
|
18
29
|
|
|
@@ -29,4 +40,4 @@ def subset_adata(adata, obs_columns):
|
|
|
29
40
|
# Start the recursive subset process
|
|
30
41
|
subsets_dict = subset_recursive(adata, obs_columns)
|
|
31
42
|
|
|
32
|
-
return subsets_dict
|
|
43
|
+
return subsets_dict
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
# subset_adata
|
|
2
4
|
|
|
3
5
|
def subset_adata(adata, columns, cat_type='obs'):
|
|
@@ -14,6 +16,17 @@ def subset_adata(adata, columns, cat_type='obs'):
|
|
|
14
16
|
"""
|
|
15
17
|
|
|
16
18
|
def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
|
|
19
|
+
"""Recursively subset AnnData by categorical columns.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
adata_subset: AnnData subset to split.
|
|
23
|
+
columns: Remaining columns to split on.
|
|
24
|
+
cat_type: Whether to use obs or var categories.
|
|
25
|
+
key_prefix: Tuple of previous category keys.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Dictionary mapping category tuples to AnnData subsets.
|
|
29
|
+
"""
|
|
17
30
|
# Returns when the bottom of the stack is reached
|
|
18
31
|
if not columns:
|
|
19
32
|
# If there's only one column, return the key as a single value, not a tuple
|
|
@@ -43,4 +56,4 @@ def subset_adata(adata, columns, cat_type='obs'):
|
|
|
43
56
|
# Start the recursive subset process
|
|
44
57
|
subsets_dict = subset_recursive(adata, columns, cat_type)
|
|
45
58
|
|
|
46
|
-
return subsets_dict
|
|
59
|
+
return subsets_dict
|