smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +7 -1
- smftools/cli/hmm_adata.py +902 -244
- smftools/cli/load_adata.py +318 -198
- smftools/cli/preprocess_adata.py +285 -171
- smftools/cli/spatial_adata.py +137 -53
- smftools/cli_entry.py +94 -178
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +22 -17
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +505 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2125 -1426
- smftools/hmm/__init__.py +2 -3
- smftools/hmm/archived/call_hmm_peaks.py +16 -1
- smftools/hmm/call_hmm_peaks.py +173 -193
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +379 -156
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +195 -29
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +347 -168
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +145 -85
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +8 -8
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +103 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +688 -271
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.4.dist-info/RECORD +0 -176
- /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
from .call_hmm_peaks import call_hmm_peaks
|
|
2
2
|
from .display_hmm import display_hmm
|
|
3
3
|
from .hmm_readwrite import load_hmm, save_hmm
|
|
4
|
-
from .nucleosome_hmm_refinement import
|
|
5
|
-
|
|
4
|
+
from .nucleosome_hmm_refinement import infer_nucleosomes_in_large_bound, refine_nucleosome_calls
|
|
6
5
|
|
|
7
6
|
__all__ = [
|
|
8
7
|
"call_hmm_peaks",
|
|
@@ -11,4 +10,4 @@ __all__ = [
|
|
|
11
10
|
"refine_nucleosome_calls",
|
|
12
11
|
"infer_nucleosomes_in_large_bound",
|
|
13
12
|
"save_hmm",
|
|
14
|
-
]
|
|
13
|
+
]
|
|
@@ -8,6 +8,21 @@ def call_hmm_peaks(
|
|
|
8
8
|
date_tag=None,
|
|
9
9
|
inplace=False
|
|
10
10
|
):
|
|
11
|
+
"""Call peaks from HMM feature layers and annotate AnnData.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
adata: AnnData containing feature layers.
|
|
15
|
+
feature_configs: Mapping of layer name to peak config.
|
|
16
|
+
obs_column: Obs column for reference categories.
|
|
17
|
+
site_types: Site types to summarize around peaks.
|
|
18
|
+
save_plot: Whether to save peak plots.
|
|
19
|
+
output_dir: Output directory for plots.
|
|
20
|
+
date_tag: Optional tag for plot filenames.
|
|
21
|
+
inplace: Whether to modify AnnData in place.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Annotated AnnData with peak masks and summary columns.
|
|
25
|
+
"""
|
|
11
26
|
import numpy as np
|
|
12
27
|
import pandas as pd
|
|
13
28
|
import matplotlib.pyplot as plt
|
|
@@ -103,4 +118,4 @@ def call_hmm_peaks(
|
|
|
103
118
|
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
|
104
119
|
adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
|
|
105
120
|
|
|
106
|
-
return adata if not inplace else None
|
|
121
|
+
return adata if not inplace else None
|
smftools/hmm/call_hmm_peaks.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
# FILE: smftools/hmm/call_hmm_peaks.py
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from smftools.logging_utils import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
3
10
|
|
|
4
11
|
def call_hmm_peaks(
|
|
5
12
|
adata,
|
|
@@ -14,96 +21,76 @@ def call_hmm_peaks(
|
|
|
14
21
|
alternate_labels: bool = False,
|
|
15
22
|
):
|
|
16
23
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
Keys are usually *feature types* like "all_accessible_features" or
|
|
33
|
-
"small_bound_stretch". These are matched against existing HMM layers
|
|
34
|
-
(e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
|
|
35
|
-
using a suffix match. You can also pass full layer names if you wish.
|
|
36
|
-
ref_column : str
|
|
37
|
-
Column in adata.obs defining reference groups (e.g. "Reference_strand").
|
|
38
|
-
site_types : sequence of str
|
|
39
|
-
Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
|
|
40
|
-
e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
|
|
41
|
-
save_plot : bool
|
|
42
|
-
If True, save peak diagnostic plots instead of just showing them.
|
|
43
|
-
output_dir : path-like or None
|
|
44
|
-
Directory for saved plots (created if needed).
|
|
45
|
-
date_tag : str or None
|
|
46
|
-
Optional tag to prefix plot filenames.
|
|
47
|
-
inplace : bool
|
|
48
|
-
If False, operate on a copy and return it. If True, modify adata and return None.
|
|
49
|
-
index_col_suffix : str or None
|
|
50
|
-
If None, coordinates come from adata.var_names (cast to int when possible).
|
|
51
|
-
If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
|
|
52
|
-
coordinate system (e.g. a reindexed coordinate).
|
|
53
|
-
|
|
54
|
-
Returns
|
|
55
|
-
-------
|
|
56
|
-
None or AnnData
|
|
24
|
+
Peak calling over HMM (or other) layers, per reference group and per layer.
|
|
25
|
+
Writes:
|
|
26
|
+
- adata.uns["{layer}_{ref}_peak_centers"] = list of centers
|
|
27
|
+
- adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
|
|
28
|
+
- adata.obs per-read summaries for each peak window:
|
|
29
|
+
mean_{layer}_{ref}_around_{center}
|
|
30
|
+
sum_{layer}_{ref}_around_{center}
|
|
31
|
+
{layer}_{ref}_present_at_{center} (bool)
|
|
32
|
+
and per site-type:
|
|
33
|
+
sum_{layer}_{site}_{ref}_around_{center}
|
|
34
|
+
mean_{layer}_{site}_{ref}_around_{center}
|
|
35
|
+
- adata.var["is_in_any_{layer}_peak_{ref}"]
|
|
36
|
+
- adata.var["is_in_any_peak"] (global)
|
|
57
37
|
"""
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
58
39
|
import numpy as np
|
|
59
40
|
import pandas as pd
|
|
60
|
-
import matplotlib.pyplot as plt
|
|
61
41
|
from scipy.signal import find_peaks
|
|
62
42
|
from scipy.sparse import issparse
|
|
63
43
|
|
|
64
44
|
if not inplace:
|
|
65
45
|
adata = adata.copy()
|
|
66
46
|
|
|
67
|
-
|
|
47
|
+
if ref_column not in adata.obs:
|
|
48
|
+
raise KeyError(f"obs column '{ref_column}' not found")
|
|
49
|
+
|
|
50
|
+
# Ensure categorical for predictable ref iteration
|
|
68
51
|
if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
|
|
69
52
|
adata.obs[ref_column] = adata.obs[ref_column].astype("category")
|
|
70
53
|
|
|
71
|
-
#
|
|
54
|
+
# Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
|
|
55
|
+
if getattr(adata.obs.columns, "duplicated", None) is not None:
|
|
56
|
+
if adata.obs.columns.duplicated().any():
|
|
57
|
+
adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
|
|
58
|
+
|
|
59
|
+
# Fallback coordinates from var_names
|
|
72
60
|
try:
|
|
73
61
|
base_coordinates = adata.var_names.astype(int).values
|
|
74
62
|
except Exception:
|
|
75
63
|
base_coordinates = np.arange(adata.n_vars, dtype=int)
|
|
76
64
|
|
|
65
|
+
# Output dir
|
|
77
66
|
if output_dir is not None:
|
|
78
67
|
output_dir = Path(output_dir)
|
|
79
68
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
80
69
|
|
|
81
|
-
# HMM layers
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
70
|
+
# Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
|
|
71
|
+
declared = list(adata.uns.get("hmm_appended_layers", []) or [])
|
|
72
|
+
search_pool = [
|
|
73
|
+
layer
|
|
74
|
+
for layer in declared
|
|
75
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
76
|
+
]
|
|
88
77
|
|
|
89
78
|
all_peak_var_cols = []
|
|
90
79
|
|
|
91
|
-
# Iterate
|
|
80
|
+
# Iterate per reference
|
|
92
81
|
for ref in adata.obs[ref_column].cat.categories:
|
|
93
82
|
ref_mask = (adata.obs[ref_column] == ref).values
|
|
94
83
|
if not ref_mask.any():
|
|
95
84
|
continue
|
|
96
85
|
|
|
97
|
-
# Per-ref
|
|
86
|
+
# Per-ref coordinate system
|
|
98
87
|
if index_col_suffix is not None:
|
|
99
88
|
coord_col = f"{ref}_{index_col_suffix}"
|
|
100
89
|
if coord_col not in adata.var:
|
|
101
90
|
raise KeyError(
|
|
102
|
-
f"index_col_suffix='{index_col_suffix}' requested, "
|
|
103
|
-
f"but var column '{coord_col}' is missing for ref '{ref}'."
|
|
91
|
+
f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
|
|
104
92
|
)
|
|
105
93
|
coord_vals = adata.var[coord_col].values
|
|
106
|
-
# Try to coerce to numeric
|
|
107
94
|
try:
|
|
108
95
|
coordinates = coord_vals.astype(int)
|
|
109
96
|
except Exception:
|
|
@@ -111,184 +98,159 @@ def call_hmm_peaks(
|
|
|
111
98
|
else:
|
|
112
99
|
coordinates = base_coordinates
|
|
113
100
|
|
|
114
|
-
|
|
101
|
+
if coordinates.shape[0] != adata.n_vars:
|
|
102
|
+
raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
|
|
103
|
+
|
|
104
|
+
# Feature keys to consider
|
|
115
105
|
for feature_key, config in feature_configs.items():
|
|
116
|
-
#
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
if not candidate_layers:
|
|
129
|
-
for lname in search_layers:
|
|
130
|
-
if lname.endswith(feature_key):
|
|
131
|
-
candidate_layers.append(lname)
|
|
132
|
-
|
|
133
|
-
# Third: if user passed a full layer name that wasn't in hmm_layers,
|
|
134
|
-
# but does exist in adata.layers, allow it.
|
|
135
|
-
if not candidate_layers and feature_key in adata.layers:
|
|
136
|
-
candidate_layers.append(feature_key)
|
|
137
|
-
|
|
138
|
-
if not candidate_layers:
|
|
139
|
-
print(
|
|
140
|
-
f"[call_hmm_peaks] WARNING: no layers found matching feature key "
|
|
141
|
-
f"'{feature_key}' in ref '{ref}'. Skipping."
|
|
106
|
+
# Resolve candidate layers: exact → suffix → direct present
|
|
107
|
+
candidates = [ln for ln in search_pool if ln == feature_key]
|
|
108
|
+
if not candidates:
|
|
109
|
+
candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
|
|
110
|
+
if not candidates and feature_key in adata.layers:
|
|
111
|
+
candidates = [feature_key]
|
|
112
|
+
|
|
113
|
+
if not candidates:
|
|
114
|
+
logger.warning(
|
|
115
|
+
"[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
|
|
116
|
+
feature_key,
|
|
117
|
+
ref,
|
|
142
118
|
)
|
|
143
119
|
continue
|
|
144
120
|
|
|
145
|
-
#
|
|
146
|
-
|
|
121
|
+
# Hyperparams (sanitized)
|
|
122
|
+
min_distance = max(1, int(config.get("min_distance", 200)))
|
|
123
|
+
peak_width = max(1, int(config.get("peak_width", 200)))
|
|
124
|
+
peak_prom = float(config.get("peak_prominence", 0.2))
|
|
125
|
+
peak_threshold = float(config.get("peak_threshold", 0.8))
|
|
126
|
+
rolling_window = max(1, int(config.get("rolling_window", 1)))
|
|
127
|
+
|
|
128
|
+
for layer_name in candidates:
|
|
147
129
|
if layer_name not in adata.layers:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
130
|
+
logger.warning(
|
|
131
|
+
"[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
|
|
132
|
+
layer_name,
|
|
151
133
|
)
|
|
152
134
|
continue
|
|
153
135
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
136
|
+
# Dense layer data
|
|
137
|
+
L = adata.layers[layer_name]
|
|
138
|
+
L = L.toarray() if issparse(L) else np.asarray(L)
|
|
139
|
+
if L.shape != (adata.n_obs, adata.n_vars):
|
|
140
|
+
logger.warning(
|
|
141
|
+
"[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
|
|
142
|
+
layer_name,
|
|
143
|
+
L.shape,
|
|
144
|
+
adata.n_obs,
|
|
145
|
+
adata.n_vars,
|
|
146
|
+
)
|
|
147
|
+
continue
|
|
164
148
|
|
|
165
|
-
#
|
|
166
|
-
matrix =
|
|
167
|
-
if matrix.shape[0] == 0:
|
|
149
|
+
# Ref subset
|
|
150
|
+
matrix = L[ref_mask, :]
|
|
151
|
+
if matrix.size == 0 or matrix.shape[0] == 0:
|
|
168
152
|
continue
|
|
169
153
|
|
|
170
|
-
# Mean signal along positions (within this ref only)
|
|
171
154
|
means = np.nanmean(matrix, axis=0)
|
|
155
|
+
means = np.nan_to_num(means, nan=0.0)
|
|
172
156
|
|
|
173
|
-
# Optional rolling-mean smoothing before peak detection
|
|
174
|
-
rolling_window = int(config.get("rolling_window", 1))
|
|
175
157
|
if rolling_window > 1:
|
|
176
|
-
# Simple centered rolling mean via convolution
|
|
177
158
|
kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
|
|
178
|
-
|
|
179
|
-
peak_metric = smoothed
|
|
159
|
+
peak_metric = np.convolve(means, kernel, mode="same")
|
|
180
160
|
else:
|
|
181
161
|
peak_metric = means
|
|
182
162
|
|
|
183
163
|
# Peak detection
|
|
184
164
|
peak_indices, _ = find_peaks(
|
|
185
|
-
peak_metric, prominence=
|
|
165
|
+
peak_metric, prominence=peak_prom, distance=min_distance
|
|
186
166
|
)
|
|
187
167
|
if peak_indices.size == 0:
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
168
|
+
logger.info(
|
|
169
|
+
"[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
|
|
170
|
+
layer_name,
|
|
171
|
+
ref,
|
|
191
172
|
)
|
|
192
173
|
continue
|
|
193
174
|
|
|
194
175
|
peak_centers = coordinates[peak_indices]
|
|
195
|
-
# Store per-ref peak centers
|
|
196
176
|
adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
|
|
197
177
|
|
|
198
|
-
#
|
|
199
|
-
plt.
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
178
|
+
# Plot once per layer/ref
|
|
179
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
180
|
+
ax.plot(coordinates, peak_metric, linewidth=1)
|
|
181
|
+
ax.set_title(f"{layer_name} peaks in {ref}")
|
|
182
|
+
ax.set_xlabel("Coordinate")
|
|
183
|
+
ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
|
|
205
184
|
for i, center in enumerate(peak_centers):
|
|
206
185
|
start = center - peak_width // 2
|
|
207
186
|
end = center + peak_width // 2
|
|
208
187
|
height = peak_metric[peak_indices[i]]
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
else:
|
|
217
|
-
x_text, ha = end, "left"
|
|
218
|
-
else:
|
|
219
|
-
x_text, ha = start, "right"
|
|
220
|
-
|
|
221
|
-
plt.text(
|
|
222
|
-
x_text,
|
|
223
|
-
height * 0.8,
|
|
224
|
-
f"Peak {i}\n{center}",
|
|
225
|
-
color="red",
|
|
226
|
-
ha=ha,
|
|
227
|
-
va="bottom",
|
|
228
|
-
fontsize=8,
|
|
188
|
+
ax.axvspan(start, end, alpha=0.2)
|
|
189
|
+
ax.axvline(center, linestyle="--", linewidth=0.8)
|
|
190
|
+
x_text, ha = (
|
|
191
|
+
(start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
|
|
192
|
+
)
|
|
193
|
+
ax.text(
|
|
194
|
+
x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
|
|
229
195
|
)
|
|
230
196
|
|
|
231
197
|
if save_plot and output_dir is not None:
|
|
232
198
|
tag = date_tag or "output"
|
|
233
|
-
# include ref in filename
|
|
234
199
|
safe_ref = str(ref).replace("/", "_")
|
|
235
200
|
safe_layer = str(layer_name).replace("/", "_")
|
|
236
201
|
fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
plt.close()
|
|
202
|
+
fig.savefig(fname, bbox_inches="tight", dpi=200)
|
|
203
|
+
logger.info("[call_hmm_peaks] Saved plot to %s", fname)
|
|
204
|
+
plt.close(fig)
|
|
240
205
|
else:
|
|
241
|
-
|
|
206
|
+
fig.tight_layout()
|
|
242
207
|
plt.show()
|
|
243
208
|
|
|
209
|
+
# Collect new obs columns; assign once per layer/ref
|
|
210
|
+
new_obs_cols: Dict[str, np.ndarray] = {}
|
|
244
211
|
feature_peak_cols = []
|
|
245
212
|
|
|
246
|
-
|
|
247
|
-
for center in peak_centers:
|
|
213
|
+
for center in np.asarray(peak_centers).tolist():
|
|
248
214
|
start = center - peak_width // 2
|
|
249
215
|
end = center + peak_width // 2
|
|
250
216
|
|
|
251
|
-
#
|
|
217
|
+
# var window mask
|
|
252
218
|
colname = f"{layer_name}_{ref}_peak_{center}"
|
|
253
219
|
feature_peak_cols.append(colname)
|
|
254
220
|
all_peak_var_cols.append(colname)
|
|
255
|
-
|
|
256
|
-
# Var-level mask: is this position in the window?
|
|
257
221
|
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
258
222
|
adata.var[colname] = peak_mask
|
|
259
223
|
|
|
260
|
-
#
|
|
261
|
-
region = matrix[:, peak_mask] # (
|
|
224
|
+
# feature-layer summaries for reads in this ref
|
|
225
|
+
region = matrix[:, peak_mask] # (n_ref, n_window)
|
|
262
226
|
|
|
263
|
-
# Per-read summary in this window for the feature layer itself
|
|
264
227
|
mean_col = f"mean_{layer_name}_{ref}_around_{center}"
|
|
265
228
|
sum_col = f"sum_{layer_name}_{ref}_around_{center}"
|
|
266
229
|
present_col = f"{layer_name}_{ref}_present_at_{center}"
|
|
267
230
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
231
|
+
for nm, default, dt in (
|
|
232
|
+
(mean_col, np.nan, float),
|
|
233
|
+
(sum_col, 0.0, float),
|
|
234
|
+
(present_col, False, bool),
|
|
235
|
+
):
|
|
236
|
+
if nm not in new_obs_cols:
|
|
237
|
+
new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
|
|
238
|
+
|
|
239
|
+
if region.shape[1] > 0:
|
|
240
|
+
means_per_read = np.nanmean(region, axis=1)
|
|
241
|
+
sums_per_read = np.nansum(region, axis=1)
|
|
242
|
+
else:
|
|
243
|
+
means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
|
|
244
|
+
sums_per_read = np.zeros(matrix.shape[0], dtype=float)
|
|
245
|
+
|
|
246
|
+
new_obs_cols[mean_col][ref_mask] = means_per_read
|
|
247
|
+
new_obs_cols[sum_col][ref_mask] = sums_per_read
|
|
248
|
+
new_obs_cols[present_col][ref_mask] = (
|
|
249
|
+
np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
|
|
280
250
|
)
|
|
281
251
|
|
|
282
|
-
#
|
|
283
|
-
|
|
284
|
-
sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
|
|
285
|
-
mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
|
|
286
|
-
if sum_site_col not in adata.obs:
|
|
287
|
-
adata.obs[sum_site_col] = 0.0
|
|
288
|
-
if mean_site_col not in adata.obs:
|
|
289
|
-
adata.obs[mean_site_col] = np.nan
|
|
290
|
-
|
|
291
|
-
# Per-site-type summaries for this ref
|
|
252
|
+
# site-type summaries from adata.X, not an AnnData view
|
|
253
|
+
Xmat = adata.X
|
|
292
254
|
for site_type in site_types:
|
|
293
255
|
mask_key = f"{ref}_{site_type}_site"
|
|
294
256
|
if mask_key not in adata.var:
|
|
@@ -299,35 +261,53 @@ def call_hmm_peaks(
|
|
|
299
261
|
continue
|
|
300
262
|
|
|
301
263
|
site_coords = coordinates[site_mask]
|
|
302
|
-
|
|
303
|
-
|
|
264
|
+
site_region_mask = (site_coords >= start) & (site_coords <= end)
|
|
265
|
+
sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
266
|
+
mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
267
|
+
|
|
268
|
+
if sum_site_col not in new_obs_cols:
|
|
269
|
+
new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
|
|
270
|
+
if mean_site_col not in new_obs_cols:
|
|
271
|
+
new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
|
|
272
|
+
|
|
273
|
+
if not site_region_mask.any():
|
|
304
274
|
continue
|
|
305
275
|
|
|
306
276
|
full_mask = np.zeros_like(site_mask, dtype=bool)
|
|
307
|
-
full_mask[site_mask] =
|
|
277
|
+
full_mask[site_mask] = site_region_mask
|
|
308
278
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
site_region = site_region.
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
279
|
+
if issparse(Xmat):
|
|
280
|
+
site_region = Xmat[ref_mask][:, full_mask]
|
|
281
|
+
site_region = site_region.toarray()
|
|
282
|
+
else:
|
|
283
|
+
Xnp = np.asarray(Xmat)
|
|
284
|
+
site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
|
|
315
285
|
|
|
316
|
-
|
|
317
|
-
|
|
286
|
+
if site_region.shape[1] > 0:
|
|
287
|
+
new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
|
|
288
|
+
new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
|
|
318
289
|
|
|
319
|
-
|
|
320
|
-
|
|
290
|
+
# one-shot assignment to avoid fragmentation
|
|
291
|
+
if new_obs_cols:
|
|
292
|
+
adata.obs = adata.obs.assign(
|
|
293
|
+
**{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
|
|
294
|
+
)
|
|
321
295
|
|
|
322
|
-
#
|
|
296
|
+
# per (layer, ref) any-peak
|
|
323
297
|
any_col = f"is_in_any_{layer_name}_peak_{ref}"
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
298
|
+
if feature_peak_cols:
|
|
299
|
+
adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
|
|
300
|
+
else:
|
|
301
|
+
adata.var[any_col] = False
|
|
302
|
+
|
|
303
|
+
logger.info(
|
|
304
|
+
"[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
|
|
305
|
+
len(peak_centers),
|
|
306
|
+
layer_name,
|
|
307
|
+
ref,
|
|
328
308
|
)
|
|
329
309
|
|
|
330
|
-
#
|
|
310
|
+
# global any-peak across all layers/refs
|
|
331
311
|
if all_peak_var_cols:
|
|
332
312
|
adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
|
|
333
313
|
|
smftools/hmm/display_hmm.py
CHANGED
|
@@ -1,18 +1,31 @@
|
|
|
1
|
+
from smftools.logging_utils import get_logger
|
|
2
|
+
|
|
3
|
+
logger = get_logger(__name__)
|
|
4
|
+
|
|
5
|
+
|
|
1
6
|
def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
|
|
7
|
+
"""Log a summary of HMM transition and emission parameters.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
hmm: HMM object with edges and distributions.
|
|
11
|
+
state_labels: Optional labels for states.
|
|
12
|
+
obs_labels: Optional labels for observations.
|
|
13
|
+
"""
|
|
2
14
|
import torch
|
|
3
|
-
print("\n**HMM Model Overview**")
|
|
4
|
-
print(hmm)
|
|
5
15
|
|
|
6
|
-
|
|
16
|
+
logger.info("**HMM Model Overview**")
|
|
17
|
+
logger.info("%s", hmm)
|
|
18
|
+
|
|
19
|
+
logger.info("**Transition Matrix**")
|
|
7
20
|
transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
|
|
8
21
|
for i, row in enumerate(transition_matrix):
|
|
9
22
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
10
23
|
formatted_row = ", ".join(f"{p:.6f}" for p in row)
|
|
11
|
-
|
|
24
|
+
logger.info("%s: [%s]", label, formatted_row)
|
|
12
25
|
|
|
13
|
-
|
|
26
|
+
logger.info("**Emission Probabilities**")
|
|
14
27
|
for i, dist in enumerate(hmm.distributions):
|
|
15
28
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
16
29
|
probs = dist.probs.detach().cpu().numpy()
|
|
17
30
|
formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
|
|
18
|
-
|
|
31
|
+
logger.info("%s: %s", label, formatted_emissions)
|
smftools/hmm/hmm_readwrite.py
CHANGED
|
@@ -1,16 +1,25 @@
|
|
|
1
|
-
def load_hmm(model_path, device=
|
|
1
|
+
def load_hmm(model_path, device="cpu"):
|
|
2
2
|
"""
|
|
3
3
|
Reads in a pretrained HMM.
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
Parameters:
|
|
6
6
|
model_path (str): Path to a pretrained HMM
|
|
7
7
|
"""
|
|
8
8
|
import torch
|
|
9
|
+
|
|
9
10
|
# Load model using PyTorch
|
|
10
11
|
hmm = torch.load(model_path)
|
|
11
|
-
hmm.to(device)
|
|
12
|
+
hmm.to(device)
|
|
12
13
|
return hmm
|
|
13
14
|
|
|
15
|
+
|
|
14
16
|
def save_hmm(model, model_path):
|
|
17
|
+
"""Save a pretrained HMM to disk.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model: HMM model instance.
|
|
21
|
+
model_path: Output path for the model.
|
|
22
|
+
"""
|
|
15
23
|
import torch
|
|
16
|
-
|
|
24
|
+
|
|
25
|
+
torch.save(model, model_path)
|