smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py
CHANGED
|
@@ -1,14 +1,24 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
"
|
|
9
|
-
"
|
|
10
|
-
"
|
|
11
|
-
"refine_nucleosome_calls",
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
_LAZY_ATTRS = {
|
|
6
|
+
"call_hmm_peaks": "smftools.hmm.call_hmm_peaks",
|
|
7
|
+
"display_hmm": "smftools.hmm.display_hmm",
|
|
8
|
+
"load_hmm": "smftools.hmm.hmm_readwrite",
|
|
9
|
+
"save_hmm": "smftools.hmm.hmm_readwrite",
|
|
10
|
+
"infer_nucleosomes_in_large_bound": "smftools.hmm.nucleosome_hmm_refinement",
|
|
11
|
+
"refine_nucleosome_calls": "smftools.hmm.nucleosome_hmm_refinement",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name: str):
|
|
16
|
+
if name in _LAZY_ATTRS:
|
|
17
|
+
module = import_module(_LAZY_ATTRS[name])
|
|
18
|
+
attr = getattr(module, name)
|
|
19
|
+
globals()[name] = attr
|
|
20
|
+
return attr
|
|
21
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = list(_LAZY_ATTRS.keys())
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
def call_hmm_peaks(
|
|
2
4
|
adata,
|
|
3
5
|
feature_configs,
|
|
@@ -8,6 +10,21 @@ def call_hmm_peaks(
|
|
|
8
10
|
date_tag=None,
|
|
9
11
|
inplace=False
|
|
10
12
|
):
|
|
13
|
+
"""Call peaks from HMM feature layers and annotate AnnData.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
adata: AnnData containing feature layers.
|
|
17
|
+
feature_configs: Mapping of layer name to peak config.
|
|
18
|
+
obs_column: Obs column for reference categories.
|
|
19
|
+
site_types: Site types to summarize around peaks.
|
|
20
|
+
save_plot: Whether to save peak plots.
|
|
21
|
+
output_dir: Output directory for plots.
|
|
22
|
+
date_tag: Optional tag for plot filenames.
|
|
23
|
+
inplace: Whether to modify AnnData in place.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Annotated AnnData with peak masks and summary columns.
|
|
27
|
+
"""
|
|
11
28
|
import numpy as np
|
|
12
29
|
import pandas as pd
|
|
13
30
|
import matplotlib.pyplot as plt
|
|
@@ -103,4 +120,4 @@ def call_hmm_peaks(
|
|
|
103
120
|
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
|
104
121
|
adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
|
|
105
122
|
|
|
106
|
-
return adata if not inplace else None
|
|
123
|
+
return adata if not inplace else None
|
smftools/hmm/call_hmm_peaks.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# FILE: smftools/hmm/call_hmm_peaks.py
|
|
2
4
|
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Optional, Sequence, Union
|
|
6
|
+
|
|
7
|
+
from smftools.logging_utils import get_logger
|
|
8
|
+
from smftools.optional_imports import require
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
3
12
|
|
|
4
13
|
def call_hmm_peaks(
|
|
5
14
|
adata,
|
|
@@ -14,96 +23,77 @@ def call_hmm_peaks(
|
|
|
14
23
|
alternate_labels: bool = False,
|
|
15
24
|
):
|
|
16
25
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
Keys are usually *feature types* like "all_accessible_features" or
|
|
33
|
-
"small_bound_stretch". These are matched against existing HMM layers
|
|
34
|
-
(e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
|
|
35
|
-
using a suffix match. You can also pass full layer names if you wish.
|
|
36
|
-
ref_column : str
|
|
37
|
-
Column in adata.obs defining reference groups (e.g. "Reference_strand").
|
|
38
|
-
site_types : sequence of str
|
|
39
|
-
Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
|
|
40
|
-
e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
|
|
41
|
-
save_plot : bool
|
|
42
|
-
If True, save peak diagnostic plots instead of just showing them.
|
|
43
|
-
output_dir : path-like or None
|
|
44
|
-
Directory for saved plots (created if needed).
|
|
45
|
-
date_tag : str or None
|
|
46
|
-
Optional tag to prefix plot filenames.
|
|
47
|
-
inplace : bool
|
|
48
|
-
If False, operate on a copy and return it. If True, modify adata and return None.
|
|
49
|
-
index_col_suffix : str or None
|
|
50
|
-
If None, coordinates come from adata.var_names (cast to int when possible).
|
|
51
|
-
If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
|
|
52
|
-
coordinate system (e.g. a reindexed coordinate).
|
|
53
|
-
|
|
54
|
-
Returns
|
|
55
|
-
-------
|
|
56
|
-
None or AnnData
|
|
26
|
+
Peak calling over HMM (or other) layers, per reference group and per layer.
|
|
27
|
+
Writes:
|
|
28
|
+
- adata.uns["{layer}_{ref}_peak_centers"] = list of centers
|
|
29
|
+
- adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
|
|
30
|
+
- adata.obs per-read summaries for each peak window:
|
|
31
|
+
mean_{layer}_{ref}_around_{center}
|
|
32
|
+
sum_{layer}_{ref}_around_{center}
|
|
33
|
+
{layer}_{ref}_present_at_{center} (bool)
|
|
34
|
+
and per site-type:
|
|
35
|
+
sum_{layer}_{site}_{ref}_around_{center}
|
|
36
|
+
mean_{layer}_{site}_{ref}_around_{center}
|
|
37
|
+
- adata.var["is_in_any_{layer}_peak_{ref}"]
|
|
38
|
+
- adata.var["is_in_any_peak"] (global)
|
|
57
39
|
"""
|
|
58
40
|
import numpy as np
|
|
59
41
|
import pandas as pd
|
|
60
|
-
import matplotlib.pyplot as plt
|
|
61
42
|
from scipy.signal import find_peaks
|
|
62
43
|
from scipy.sparse import issparse
|
|
63
44
|
|
|
45
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM peak plots")
|
|
46
|
+
|
|
64
47
|
if not inplace:
|
|
65
48
|
adata = adata.copy()
|
|
66
49
|
|
|
67
|
-
|
|
50
|
+
if ref_column not in adata.obs:
|
|
51
|
+
raise KeyError(f"obs column '{ref_column}' not found")
|
|
52
|
+
|
|
53
|
+
# Ensure categorical for predictable ref iteration
|
|
68
54
|
if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
|
|
69
55
|
adata.obs[ref_column] = adata.obs[ref_column].astype("category")
|
|
70
56
|
|
|
71
|
-
#
|
|
57
|
+
# Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
|
|
58
|
+
if getattr(adata.obs.columns, "duplicated", None) is not None:
|
|
59
|
+
if adata.obs.columns.duplicated().any():
|
|
60
|
+
adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
|
|
61
|
+
|
|
62
|
+
# Fallback coordinates from var_names
|
|
72
63
|
try:
|
|
73
64
|
base_coordinates = adata.var_names.astype(int).values
|
|
74
65
|
except Exception:
|
|
75
66
|
base_coordinates = np.arange(adata.n_vars, dtype=int)
|
|
76
67
|
|
|
68
|
+
# Output dir
|
|
77
69
|
if output_dir is not None:
|
|
78
70
|
output_dir = Path(output_dir)
|
|
79
71
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
80
72
|
|
|
81
|
-
# HMM layers
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
73
|
+
# Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
|
|
74
|
+
declared = list(adata.uns.get("hmm_appended_layers", []) or [])
|
|
75
|
+
search_pool = [
|
|
76
|
+
layer
|
|
77
|
+
for layer in declared
|
|
78
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
79
|
+
]
|
|
88
80
|
|
|
89
81
|
all_peak_var_cols = []
|
|
90
82
|
|
|
91
|
-
# Iterate
|
|
83
|
+
# Iterate per reference
|
|
92
84
|
for ref in adata.obs[ref_column].cat.categories:
|
|
93
85
|
ref_mask = (adata.obs[ref_column] == ref).values
|
|
94
86
|
if not ref_mask.any():
|
|
95
87
|
continue
|
|
96
88
|
|
|
97
|
-
# Per-ref
|
|
89
|
+
# Per-ref coordinate system
|
|
98
90
|
if index_col_suffix is not None:
|
|
99
91
|
coord_col = f"{ref}_{index_col_suffix}"
|
|
100
92
|
if coord_col not in adata.var:
|
|
101
93
|
raise KeyError(
|
|
102
|
-
f"index_col_suffix='{index_col_suffix}' requested, "
|
|
103
|
-
f"but var column '{coord_col}' is missing for ref '{ref}'."
|
|
94
|
+
f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
|
|
104
95
|
)
|
|
105
96
|
coord_vals = adata.var[coord_col].values
|
|
106
|
-
# Try to coerce to numeric
|
|
107
97
|
try:
|
|
108
98
|
coordinates = coord_vals.astype(int)
|
|
109
99
|
except Exception:
|
|
@@ -111,184 +101,159 @@ def call_hmm_peaks(
|
|
|
111
101
|
else:
|
|
112
102
|
coordinates = base_coordinates
|
|
113
103
|
|
|
114
|
-
|
|
104
|
+
if coordinates.shape[0] != adata.n_vars:
|
|
105
|
+
raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
|
|
106
|
+
|
|
107
|
+
# Feature keys to consider
|
|
115
108
|
for feature_key, config in feature_configs.items():
|
|
116
|
-
#
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
if not candidate_layers:
|
|
129
|
-
for lname in search_layers:
|
|
130
|
-
if lname.endswith(feature_key):
|
|
131
|
-
candidate_layers.append(lname)
|
|
132
|
-
|
|
133
|
-
# Third: if user passed a full layer name that wasn't in hmm_layers,
|
|
134
|
-
# but does exist in adata.layers, allow it.
|
|
135
|
-
if not candidate_layers and feature_key in adata.layers:
|
|
136
|
-
candidate_layers.append(feature_key)
|
|
137
|
-
|
|
138
|
-
if not candidate_layers:
|
|
139
|
-
print(
|
|
140
|
-
f"[call_hmm_peaks] WARNING: no layers found matching feature key "
|
|
141
|
-
f"'{feature_key}' in ref '{ref}'. Skipping."
|
|
109
|
+
# Resolve candidate layers: exact → suffix → direct present
|
|
110
|
+
candidates = [ln for ln in search_pool if ln == feature_key]
|
|
111
|
+
if not candidates:
|
|
112
|
+
candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
|
|
113
|
+
if not candidates and feature_key in adata.layers:
|
|
114
|
+
candidates = [feature_key]
|
|
115
|
+
|
|
116
|
+
if not candidates:
|
|
117
|
+
logger.warning(
|
|
118
|
+
"[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
|
|
119
|
+
feature_key,
|
|
120
|
+
ref,
|
|
142
121
|
)
|
|
143
122
|
continue
|
|
144
123
|
|
|
145
|
-
#
|
|
146
|
-
|
|
124
|
+
# Hyperparams (sanitized)
|
|
125
|
+
min_distance = max(1, int(config.get("min_distance", 200)))
|
|
126
|
+
peak_width = max(1, int(config.get("peak_width", 200)))
|
|
127
|
+
peak_prom = float(config.get("peak_prominence", 0.2))
|
|
128
|
+
peak_threshold = float(config.get("peak_threshold", 0.8))
|
|
129
|
+
rolling_window = max(1, int(config.get("rolling_window", 1)))
|
|
130
|
+
|
|
131
|
+
for layer_name in candidates:
|
|
147
132
|
if layer_name not in adata.layers:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
133
|
+
logger.warning(
|
|
134
|
+
"[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
|
|
135
|
+
layer_name,
|
|
151
136
|
)
|
|
152
137
|
continue
|
|
153
138
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
139
|
+
# Dense layer data
|
|
140
|
+
L = adata.layers[layer_name]
|
|
141
|
+
L = L.toarray() if issparse(L) else np.asarray(L)
|
|
142
|
+
if L.shape != (adata.n_obs, adata.n_vars):
|
|
143
|
+
logger.warning(
|
|
144
|
+
"[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
|
|
145
|
+
layer_name,
|
|
146
|
+
L.shape,
|
|
147
|
+
adata.n_obs,
|
|
148
|
+
adata.n_vars,
|
|
149
|
+
)
|
|
150
|
+
continue
|
|
164
151
|
|
|
165
|
-
#
|
|
166
|
-
matrix =
|
|
167
|
-
if matrix.shape[0] == 0:
|
|
152
|
+
# Ref subset
|
|
153
|
+
matrix = L[ref_mask, :]
|
|
154
|
+
if matrix.size == 0 or matrix.shape[0] == 0:
|
|
168
155
|
continue
|
|
169
156
|
|
|
170
|
-
# Mean signal along positions (within this ref only)
|
|
171
157
|
means = np.nanmean(matrix, axis=0)
|
|
158
|
+
means = np.nan_to_num(means, nan=0.0)
|
|
172
159
|
|
|
173
|
-
# Optional rolling-mean smoothing before peak detection
|
|
174
|
-
rolling_window = int(config.get("rolling_window", 1))
|
|
175
160
|
if rolling_window > 1:
|
|
176
|
-
# Simple centered rolling mean via convolution
|
|
177
161
|
kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
|
|
178
|
-
|
|
179
|
-
peak_metric = smoothed
|
|
162
|
+
peak_metric = np.convolve(means, kernel, mode="same")
|
|
180
163
|
else:
|
|
181
164
|
peak_metric = means
|
|
182
165
|
|
|
183
166
|
# Peak detection
|
|
184
167
|
peak_indices, _ = find_peaks(
|
|
185
|
-
peak_metric, prominence=
|
|
168
|
+
peak_metric, prominence=peak_prom, distance=min_distance
|
|
186
169
|
)
|
|
187
170
|
if peak_indices.size == 0:
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
171
|
+
logger.info(
|
|
172
|
+
"[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
|
|
173
|
+
layer_name,
|
|
174
|
+
ref,
|
|
191
175
|
)
|
|
192
176
|
continue
|
|
193
177
|
|
|
194
178
|
peak_centers = coordinates[peak_indices]
|
|
195
|
-
# Store per-ref peak centers
|
|
196
179
|
adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
|
|
197
180
|
|
|
198
|
-
#
|
|
199
|
-
plt.
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
181
|
+
# Plot once per layer/ref
|
|
182
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
183
|
+
ax.plot(coordinates, peak_metric, linewidth=1)
|
|
184
|
+
ax.set_title(f"{layer_name} peaks in {ref}")
|
|
185
|
+
ax.set_xlabel("Coordinate")
|
|
186
|
+
ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
|
|
205
187
|
for i, center in enumerate(peak_centers):
|
|
206
188
|
start = center - peak_width // 2
|
|
207
189
|
end = center + peak_width // 2
|
|
208
190
|
height = peak_metric[peak_indices[i]]
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
else:
|
|
217
|
-
x_text, ha = end, "left"
|
|
218
|
-
else:
|
|
219
|
-
x_text, ha = start, "right"
|
|
220
|
-
|
|
221
|
-
plt.text(
|
|
222
|
-
x_text,
|
|
223
|
-
height * 0.8,
|
|
224
|
-
f"Peak {i}\n{center}",
|
|
225
|
-
color="red",
|
|
226
|
-
ha=ha,
|
|
227
|
-
va="bottom",
|
|
228
|
-
fontsize=8,
|
|
191
|
+
ax.axvspan(start, end, alpha=0.2)
|
|
192
|
+
ax.axvline(center, linestyle="--", linewidth=0.8)
|
|
193
|
+
x_text, ha = (
|
|
194
|
+
(start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
|
|
195
|
+
)
|
|
196
|
+
ax.text(
|
|
197
|
+
x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
|
|
229
198
|
)
|
|
230
199
|
|
|
231
200
|
if save_plot and output_dir is not None:
|
|
232
201
|
tag = date_tag or "output"
|
|
233
|
-
# include ref in filename
|
|
234
202
|
safe_ref = str(ref).replace("/", "_")
|
|
235
203
|
safe_layer = str(layer_name).replace("/", "_")
|
|
236
204
|
fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
plt.close()
|
|
205
|
+
fig.savefig(fname, bbox_inches="tight", dpi=200)
|
|
206
|
+
logger.info("[call_hmm_peaks] Saved plot to %s", fname)
|
|
207
|
+
plt.close(fig)
|
|
240
208
|
else:
|
|
241
|
-
|
|
209
|
+
fig.tight_layout()
|
|
242
210
|
plt.show()
|
|
243
211
|
|
|
212
|
+
# Collect new obs columns; assign once per layer/ref
|
|
213
|
+
new_obs_cols: Dict[str, np.ndarray] = {}
|
|
244
214
|
feature_peak_cols = []
|
|
245
215
|
|
|
246
|
-
|
|
247
|
-
for center in peak_centers:
|
|
216
|
+
for center in np.asarray(peak_centers).tolist():
|
|
248
217
|
start = center - peak_width // 2
|
|
249
218
|
end = center + peak_width // 2
|
|
250
219
|
|
|
251
|
-
#
|
|
220
|
+
# var window mask
|
|
252
221
|
colname = f"{layer_name}_{ref}_peak_{center}"
|
|
253
222
|
feature_peak_cols.append(colname)
|
|
254
223
|
all_peak_var_cols.append(colname)
|
|
255
|
-
|
|
256
|
-
# Var-level mask: is this position in the window?
|
|
257
224
|
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
258
225
|
adata.var[colname] = peak_mask
|
|
259
226
|
|
|
260
|
-
#
|
|
261
|
-
region = matrix[:, peak_mask] # (
|
|
227
|
+
# feature-layer summaries for reads in this ref
|
|
228
|
+
region = matrix[:, peak_mask] # (n_ref, n_window)
|
|
262
229
|
|
|
263
|
-
# Per-read summary in this window for the feature layer itself
|
|
264
230
|
mean_col = f"mean_{layer_name}_{ref}_around_{center}"
|
|
265
231
|
sum_col = f"sum_{layer_name}_{ref}_around_{center}"
|
|
266
232
|
present_col = f"{layer_name}_{ref}_present_at_{center}"
|
|
267
233
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
234
|
+
for nm, default, dt in (
|
|
235
|
+
(mean_col, np.nan, float),
|
|
236
|
+
(sum_col, 0.0, float),
|
|
237
|
+
(present_col, False, bool),
|
|
238
|
+
):
|
|
239
|
+
if nm not in new_obs_cols:
|
|
240
|
+
new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
|
|
241
|
+
|
|
242
|
+
if region.shape[1] > 0:
|
|
243
|
+
means_per_read = np.nanmean(region, axis=1)
|
|
244
|
+
sums_per_read = np.nansum(region, axis=1)
|
|
245
|
+
else:
|
|
246
|
+
means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
|
|
247
|
+
sums_per_read = np.zeros(matrix.shape[0], dtype=float)
|
|
248
|
+
|
|
249
|
+
new_obs_cols[mean_col][ref_mask] = means_per_read
|
|
250
|
+
new_obs_cols[sum_col][ref_mask] = sums_per_read
|
|
251
|
+
new_obs_cols[present_col][ref_mask] = (
|
|
252
|
+
np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
|
|
280
253
|
)
|
|
281
254
|
|
|
282
|
-
#
|
|
283
|
-
|
|
284
|
-
sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
|
|
285
|
-
mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
|
|
286
|
-
if sum_site_col not in adata.obs:
|
|
287
|
-
adata.obs[sum_site_col] = 0.0
|
|
288
|
-
if mean_site_col not in adata.obs:
|
|
289
|
-
adata.obs[mean_site_col] = np.nan
|
|
290
|
-
|
|
291
|
-
# Per-site-type summaries for this ref
|
|
255
|
+
# site-type summaries from adata.X, not an AnnData view
|
|
256
|
+
Xmat = adata.X
|
|
292
257
|
for site_type in site_types:
|
|
293
258
|
mask_key = f"{ref}_{site_type}_site"
|
|
294
259
|
if mask_key not in adata.var:
|
|
@@ -299,35 +264,53 @@ def call_hmm_peaks(
|
|
|
299
264
|
continue
|
|
300
265
|
|
|
301
266
|
site_coords = coordinates[site_mask]
|
|
302
|
-
|
|
303
|
-
|
|
267
|
+
site_region_mask = (site_coords >= start) & (site_coords <= end)
|
|
268
|
+
sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
269
|
+
mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
270
|
+
|
|
271
|
+
if sum_site_col not in new_obs_cols:
|
|
272
|
+
new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
|
|
273
|
+
if mean_site_col not in new_obs_cols:
|
|
274
|
+
new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
|
|
275
|
+
|
|
276
|
+
if not site_region_mask.any():
|
|
304
277
|
continue
|
|
305
278
|
|
|
306
279
|
full_mask = np.zeros_like(site_mask, dtype=bool)
|
|
307
|
-
full_mask[site_mask] =
|
|
308
|
-
|
|
309
|
-
site_region = adata[ref_mask, full_mask].X
|
|
310
|
-
if hasattr(site_region, "A"):
|
|
311
|
-
site_region = site_region.A # sparse -> dense
|
|
280
|
+
full_mask[site_mask] = site_region_mask
|
|
312
281
|
|
|
313
|
-
if
|
|
314
|
-
|
|
282
|
+
if issparse(Xmat):
|
|
283
|
+
site_region = Xmat[ref_mask][:, full_mask]
|
|
284
|
+
site_region = site_region.toarray()
|
|
285
|
+
else:
|
|
286
|
+
Xnp = np.asarray(Xmat)
|
|
287
|
+
site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
|
|
315
288
|
|
|
316
|
-
|
|
317
|
-
|
|
289
|
+
if site_region.shape[1] > 0:
|
|
290
|
+
new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
|
|
291
|
+
new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
|
|
318
292
|
|
|
319
|
-
|
|
320
|
-
|
|
293
|
+
# one-shot assignment to avoid fragmentation
|
|
294
|
+
if new_obs_cols:
|
|
295
|
+
adata.obs = adata.obs.assign(
|
|
296
|
+
**{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
|
|
297
|
+
)
|
|
321
298
|
|
|
322
|
-
#
|
|
299
|
+
# per (layer, ref) any-peak
|
|
323
300
|
any_col = f"is_in_any_{layer_name}_peak_{ref}"
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
301
|
+
if feature_peak_cols:
|
|
302
|
+
adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
|
|
303
|
+
else:
|
|
304
|
+
adata.var[any_col] = False
|
|
305
|
+
|
|
306
|
+
logger.info(
|
|
307
|
+
"[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
|
|
308
|
+
len(peak_centers),
|
|
309
|
+
layer_name,
|
|
310
|
+
ref,
|
|
328
311
|
)
|
|
329
312
|
|
|
330
|
-
#
|
|
313
|
+
# global any-peak across all layers/refs
|
|
331
314
|
if all_peak_var_cols:
|
|
332
315
|
adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
|
|
333
316
|
|
smftools/hmm/display_hmm.py
CHANGED
|
@@ -1,18 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.logging_utils import get_logger
|
|
4
|
+
from smftools.optional_imports import require
|
|
5
|
+
|
|
6
|
+
logger = get_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
1
9
|
def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
10
|
+
"""Log a summary of HMM transition and emission parameters.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
hmm: HMM object with edges and distributions.
|
|
14
|
+
state_labels: Optional labels for states.
|
|
15
|
+
obs_labels: Optional labels for observations.
|
|
16
|
+
"""
|
|
17
|
+
torch = require("torch", extra="torch", purpose="HMM display")
|
|
18
|
+
|
|
19
|
+
logger.info("**HMM Model Overview**")
|
|
20
|
+
logger.info("%s", hmm)
|
|
5
21
|
|
|
6
|
-
|
|
22
|
+
logger.info("**Transition Matrix**")
|
|
7
23
|
transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
|
|
8
24
|
for i, row in enumerate(transition_matrix):
|
|
9
25
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
10
26
|
formatted_row = ", ".join(f"{p:.6f}" for p in row)
|
|
11
|
-
|
|
27
|
+
logger.info("%s: [%s]", label, formatted_row)
|
|
12
28
|
|
|
13
|
-
|
|
29
|
+
logger.info("**Emission Probabilities**")
|
|
14
30
|
for i, dist in enumerate(hmm.distributions):
|
|
15
31
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
16
32
|
probs = dist.probs.detach().cpu().numpy()
|
|
17
33
|
formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
|
|
18
|
-
|
|
34
|
+
logger.info("%s: %s", label, formatted_emissions)
|