smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py
CHANGED
|
@@ -1,20 +1,13 @@
|
|
|
1
|
-
from .apply_hmm_batched import apply_hmm_batched
|
|
2
|
-
from .calculate_distances import calculate_distances
|
|
3
1
|
from .call_hmm_peaks import call_hmm_peaks
|
|
4
2
|
from .display_hmm import display_hmm
|
|
5
3
|
from .hmm_readwrite import load_hmm, save_hmm
|
|
6
|
-
from .nucleosome_hmm_refinement import
|
|
7
|
-
from .train_hmm import train_hmm
|
|
8
|
-
|
|
4
|
+
from .nucleosome_hmm_refinement import infer_nucleosomes_in_large_bound, refine_nucleosome_calls
|
|
9
5
|
|
|
10
6
|
__all__ = [
|
|
11
|
-
"apply_hmm_batched",
|
|
12
|
-
"calculate_distances",
|
|
13
7
|
"call_hmm_peaks",
|
|
14
8
|
"display_hmm",
|
|
15
9
|
"load_hmm",
|
|
16
10
|
"refine_nucleosome_calls",
|
|
17
11
|
"infer_nucleosomes_in_large_bound",
|
|
18
12
|
"save_hmm",
|
|
19
|
-
|
|
20
|
-
]
|
|
13
|
+
]
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
def call_hmm_peaks(
|
|
2
|
+
adata,
|
|
3
|
+
feature_configs,
|
|
4
|
+
obs_column='Reference_strand',
|
|
5
|
+
site_types=['GpC_site', 'CpG_site'],
|
|
6
|
+
save_plot=False,
|
|
7
|
+
output_dir=None,
|
|
8
|
+
date_tag=None,
|
|
9
|
+
inplace=False
|
|
10
|
+
):
|
|
11
|
+
"""Call peaks from HMM feature layers and annotate AnnData.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
adata: AnnData containing feature layers.
|
|
15
|
+
feature_configs: Mapping of layer name to peak config.
|
|
16
|
+
obs_column: Obs column for reference categories.
|
|
17
|
+
site_types: Site types to summarize around peaks.
|
|
18
|
+
save_plot: Whether to save peak plots.
|
|
19
|
+
output_dir: Output directory for plots.
|
|
20
|
+
date_tag: Optional tag for plot filenames.
|
|
21
|
+
inplace: Whether to modify AnnData in place.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Annotated AnnData with peak masks and summary columns.
|
|
25
|
+
"""
|
|
26
|
+
import numpy as np
|
|
27
|
+
import pandas as pd
|
|
28
|
+
import matplotlib.pyplot as plt
|
|
29
|
+
from scipy.signal import find_peaks
|
|
30
|
+
|
|
31
|
+
if not inplace:
|
|
32
|
+
adata = adata.copy()
|
|
33
|
+
|
|
34
|
+
# Ensure obs_column is categorical
|
|
35
|
+
if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
|
|
36
|
+
adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
|
|
37
|
+
|
|
38
|
+
coordinates = adata.var_names.astype(int).values
|
|
39
|
+
peak_columns = []
|
|
40
|
+
|
|
41
|
+
obs_updates = {}
|
|
42
|
+
|
|
43
|
+
for feature_layer, config in feature_configs.items():
|
|
44
|
+
min_distance = config.get('min_distance', 200)
|
|
45
|
+
peak_width = config.get('peak_width', 200)
|
|
46
|
+
peak_prominence = config.get('peak_prominence', 0.2)
|
|
47
|
+
peak_threshold = config.get('peak_threshold', 0.8)
|
|
48
|
+
|
|
49
|
+
matrix = adata.layers[feature_layer]
|
|
50
|
+
means = np.mean(matrix, axis=0)
|
|
51
|
+
peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
|
|
52
|
+
peak_centers = coordinates[peak_indices]
|
|
53
|
+
adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
|
|
54
|
+
|
|
55
|
+
# Plot
|
|
56
|
+
plt.figure(figsize=(6, 3))
|
|
57
|
+
plt.plot(coordinates, means)
|
|
58
|
+
plt.title(f"{feature_layer} with peak calls")
|
|
59
|
+
plt.xlabel("Genomic position")
|
|
60
|
+
plt.ylabel("Mean intensity")
|
|
61
|
+
for i, center in enumerate(peak_centers):
|
|
62
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
63
|
+
plt.axvspan(start, end, color='purple', alpha=0.2)
|
|
64
|
+
plt.axvline(center, color='red', linestyle='--')
|
|
65
|
+
aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
|
|
66
|
+
plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
|
|
67
|
+
if save_plot and output_dir:
|
|
68
|
+
filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
|
|
69
|
+
plt.savefig(filename, bbox_inches='tight')
|
|
70
|
+
print(f"Saved plot to {filename}")
|
|
71
|
+
else:
|
|
72
|
+
plt.show()
|
|
73
|
+
|
|
74
|
+
feature_peak_columns = []
|
|
75
|
+
for center in peak_centers:
|
|
76
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
77
|
+
colname = f'{feature_layer}_peak_{center}'
|
|
78
|
+
peak_columns.append(colname)
|
|
79
|
+
feature_peak_columns.append(colname)
|
|
80
|
+
|
|
81
|
+
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
82
|
+
adata.var[colname] = peak_mask
|
|
83
|
+
|
|
84
|
+
region = matrix[:, peak_mask]
|
|
85
|
+
obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
|
|
86
|
+
obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
|
|
87
|
+
obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
|
|
88
|
+
|
|
89
|
+
for site_type in site_types:
|
|
90
|
+
adata.obs[f'{site_type}_sum_around_{center}'] = 0
|
|
91
|
+
adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
|
|
92
|
+
|
|
93
|
+
for ref in adata.obs[obs_column].cat.categories:
|
|
94
|
+
ref_idx = adata.obs[obs_column] == ref
|
|
95
|
+
mask_key = f"{ref}_{site_type}"
|
|
96
|
+
for site_type in site_types:
|
|
97
|
+
if mask_key not in adata.var:
|
|
98
|
+
continue
|
|
99
|
+
site_mask = adata.var[mask_key].values
|
|
100
|
+
site_coords = coordinates[site_mask]
|
|
101
|
+
region_mask = (site_coords >= start) & (site_coords <= end)
|
|
102
|
+
if not region_mask.any():
|
|
103
|
+
continue
|
|
104
|
+
full_mask = site_mask.copy()
|
|
105
|
+
full_mask[site_mask] = region_mask
|
|
106
|
+
site_region = adata[ref_idx, full_mask].X
|
|
107
|
+
if hasattr(site_region, "A"):
|
|
108
|
+
site_region = site_region.A
|
|
109
|
+
if site_region.shape[1] > 0:
|
|
110
|
+
adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
|
|
111
|
+
adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
|
|
112
|
+
else:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
|
|
116
|
+
print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
|
|
117
|
+
|
|
118
|
+
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
|
119
|
+
adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
|
|
120
|
+
|
|
121
|
+
return adata if not inplace else None
|
smftools/hmm/call_hmm_peaks.py
CHANGED
|
@@ -1,106 +1,314 @@
|
|
|
1
|
+
# FILE: smftools/hmm/call_hmm_peaks.py
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from smftools.logging_utils import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
1
11
|
def call_hmm_peaks(
|
|
2
12
|
adata,
|
|
3
|
-
feature_configs,
|
|
4
|
-
|
|
5
|
-
site_types=
|
|
6
|
-
save_plot=False,
|
|
7
|
-
output_dir=None,
|
|
8
|
-
date_tag=None,
|
|
9
|
-
inplace=
|
|
13
|
+
feature_configs: Dict[str, Dict[str, Any]],
|
|
14
|
+
ref_column: str = "Reference_strand",
|
|
15
|
+
site_types: Sequence[str] = ("GpC", "CpG"),
|
|
16
|
+
save_plot: bool = False,
|
|
17
|
+
output_dir: Optional[Union[str, "Path"]] = None,
|
|
18
|
+
date_tag: Optional[str] = None,
|
|
19
|
+
inplace: bool = True,
|
|
20
|
+
index_col_suffix: Optional[str] = None,
|
|
21
|
+
alternate_labels: bool = False,
|
|
10
22
|
):
|
|
23
|
+
"""
|
|
24
|
+
Peak calling over HMM (or other) layers, per reference group and per layer.
|
|
25
|
+
Writes:
|
|
26
|
+
- adata.uns["{layer}_{ref}_peak_centers"] = list of centers
|
|
27
|
+
- adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
|
|
28
|
+
- adata.obs per-read summaries for each peak window:
|
|
29
|
+
mean_{layer}_{ref}_around_{center}
|
|
30
|
+
sum_{layer}_{ref}_around_{center}
|
|
31
|
+
{layer}_{ref}_present_at_{center} (bool)
|
|
32
|
+
and per site-type:
|
|
33
|
+
sum_{layer}_{site}_{ref}_around_{center}
|
|
34
|
+
mean_{layer}_{site}_{ref}_around_{center}
|
|
35
|
+
- adata.var["is_in_any_{layer}_peak_{ref}"]
|
|
36
|
+
- adata.var["is_in_any_peak"] (global)
|
|
37
|
+
"""
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
11
39
|
import numpy as np
|
|
12
40
|
import pandas as pd
|
|
13
|
-
import matplotlib.pyplot as plt
|
|
14
41
|
from scipy.signal import find_peaks
|
|
42
|
+
from scipy.sparse import issparse
|
|
15
43
|
|
|
16
44
|
if not inplace:
|
|
17
45
|
adata = adata.copy()
|
|
18
46
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
for
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
if ref_column not in adata.obs:
|
|
48
|
+
raise KeyError(f"obs column '{ref_column}' not found")
|
|
49
|
+
|
|
50
|
+
# Ensure categorical for predictable ref iteration
|
|
51
|
+
if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
|
|
52
|
+
adata.obs[ref_column] = adata.obs[ref_column].astype("category")
|
|
53
|
+
|
|
54
|
+
# Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
|
|
55
|
+
if getattr(adata.obs.columns, "duplicated", None) is not None:
|
|
56
|
+
if adata.obs.columns.duplicated().any():
|
|
57
|
+
adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
|
|
58
|
+
|
|
59
|
+
# Fallback coordinates from var_names
|
|
60
|
+
try:
|
|
61
|
+
base_coordinates = adata.var_names.astype(int).values
|
|
62
|
+
except Exception:
|
|
63
|
+
base_coordinates = np.arange(adata.n_vars, dtype=int)
|
|
64
|
+
|
|
65
|
+
# Output dir
|
|
66
|
+
if output_dir is not None:
|
|
67
|
+
output_dir = Path(output_dir)
|
|
68
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
# Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
|
|
71
|
+
declared = list(adata.uns.get("hmm_appended_layers", []) or [])
|
|
72
|
+
search_pool = [
|
|
73
|
+
layer
|
|
74
|
+
for layer in declared
|
|
75
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
all_peak_var_cols = []
|
|
79
|
+
|
|
80
|
+
# Iterate per reference
|
|
81
|
+
for ref in adata.obs[ref_column].cat.categories:
|
|
82
|
+
ref_mask = (adata.obs[ref_column] == ref).values
|
|
83
|
+
if not ref_mask.any():
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
# Per-ref coordinate system
|
|
87
|
+
if index_col_suffix is not None:
|
|
88
|
+
coord_col = f"{ref}_{index_col_suffix}"
|
|
89
|
+
if coord_col not in adata.var:
|
|
90
|
+
raise KeyError(
|
|
91
|
+
f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
|
|
92
|
+
)
|
|
93
|
+
coord_vals = adata.var[coord_col].values
|
|
94
|
+
try:
|
|
95
|
+
coordinates = coord_vals.astype(int)
|
|
96
|
+
except Exception:
|
|
97
|
+
coordinates = np.asarray(coord_vals, dtype=float)
|
|
56
98
|
else:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
99
|
+
coordinates = base_coordinates
|
|
100
|
+
|
|
101
|
+
if coordinates.shape[0] != adata.n_vars:
|
|
102
|
+
raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
|
|
103
|
+
|
|
104
|
+
# Feature keys to consider
|
|
105
|
+
for feature_key, config in feature_configs.items():
|
|
106
|
+
# Resolve candidate layers: exact → suffix → direct present
|
|
107
|
+
candidates = [ln for ln in search_pool if ln == feature_key]
|
|
108
|
+
if not candidates:
|
|
109
|
+
candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
|
|
110
|
+
if not candidates and feature_key in adata.layers:
|
|
111
|
+
candidates = [feature_key]
|
|
112
|
+
|
|
113
|
+
if not candidates:
|
|
114
|
+
logger.warning(
|
|
115
|
+
"[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
|
|
116
|
+
feature_key,
|
|
117
|
+
ref,
|
|
118
|
+
)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# Hyperparams (sanitized)
|
|
122
|
+
min_distance = max(1, int(config.get("min_distance", 200)))
|
|
123
|
+
peak_width = max(1, int(config.get("peak_width", 200)))
|
|
124
|
+
peak_prom = float(config.get("peak_prominence", 0.2))
|
|
125
|
+
peak_threshold = float(config.get("peak_threshold", 0.8))
|
|
126
|
+
rolling_window = max(1, int(config.get("rolling_window", 1)))
|
|
127
|
+
|
|
128
|
+
for layer_name in candidates:
|
|
129
|
+
if layer_name not in adata.layers:
|
|
130
|
+
logger.warning(
|
|
131
|
+
"[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
|
|
132
|
+
layer_name,
|
|
133
|
+
)
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
# Dense layer data
|
|
137
|
+
L = adata.layers[layer_name]
|
|
138
|
+
L = L.toarray() if issparse(L) else np.asarray(L)
|
|
139
|
+
if L.shape != (adata.n_obs, adata.n_vars):
|
|
140
|
+
logger.warning(
|
|
141
|
+
"[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
|
|
142
|
+
layer_name,
|
|
143
|
+
L.shape,
|
|
144
|
+
adata.n_obs,
|
|
145
|
+
adata.n_vars,
|
|
146
|
+
)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
# Ref subset
|
|
150
|
+
matrix = L[ref_mask, :]
|
|
151
|
+
if matrix.size == 0 or matrix.shape[0] == 0:
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
means = np.nanmean(matrix, axis=0)
|
|
155
|
+
means = np.nan_to_num(means, nan=0.0)
|
|
156
|
+
|
|
157
|
+
if rolling_window > 1:
|
|
158
|
+
kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
|
|
159
|
+
peak_metric = np.convolve(means, kernel, mode="same")
|
|
160
|
+
else:
|
|
161
|
+
peak_metric = means
|
|
162
|
+
|
|
163
|
+
# Peak detection
|
|
164
|
+
peak_indices, _ = find_peaks(
|
|
165
|
+
peak_metric, prominence=peak_prom, distance=min_distance
|
|
166
|
+
)
|
|
167
|
+
if peak_indices.size == 0:
|
|
168
|
+
logger.info(
|
|
169
|
+
"[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
|
|
170
|
+
layer_name,
|
|
171
|
+
ref,
|
|
172
|
+
)
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
peak_centers = coordinates[peak_indices]
|
|
176
|
+
adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
|
|
177
|
+
|
|
178
|
+
# Plot once per layer/ref
|
|
179
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
180
|
+
ax.plot(coordinates, peak_metric, linewidth=1)
|
|
181
|
+
ax.set_title(f"{layer_name} peaks in {ref}")
|
|
182
|
+
ax.set_xlabel("Coordinate")
|
|
183
|
+
ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
|
|
184
|
+
for i, center in enumerate(peak_centers):
|
|
185
|
+
start = center - peak_width // 2
|
|
186
|
+
end = center + peak_width // 2
|
|
187
|
+
height = peak_metric[peak_indices[i]]
|
|
188
|
+
ax.axvspan(start, end, alpha=0.2)
|
|
189
|
+
ax.axvline(center, linestyle="--", linewidth=0.8)
|
|
190
|
+
x_text, ha = (
|
|
191
|
+
(start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
|
|
192
|
+
)
|
|
193
|
+
ax.text(
|
|
194
|
+
x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if save_plot and output_dir is not None:
|
|
198
|
+
tag = date_tag or "output"
|
|
199
|
+
safe_ref = str(ref).replace("/", "_")
|
|
200
|
+
safe_layer = str(layer_name).replace("/", "_")
|
|
201
|
+
fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
|
|
202
|
+
fig.savefig(fname, bbox_inches="tight", dpi=200)
|
|
203
|
+
logger.info("[call_hmm_peaks] Saved plot to %s", fname)
|
|
204
|
+
plt.close(fig)
|
|
205
|
+
else:
|
|
206
|
+
fig.tight_layout()
|
|
207
|
+
plt.show()
|
|
208
|
+
|
|
209
|
+
# Collect new obs columns; assign once per layer/ref
|
|
210
|
+
new_obs_cols: Dict[str, np.ndarray] = {}
|
|
211
|
+
feature_peak_cols = []
|
|
212
|
+
|
|
213
|
+
for center in np.asarray(peak_centers).tolist():
|
|
214
|
+
start = center - peak_width // 2
|
|
215
|
+
end = center + peak_width // 2
|
|
216
|
+
|
|
217
|
+
# var window mask
|
|
218
|
+
colname = f"{layer_name}_{ref}_peak_{center}"
|
|
219
|
+
feature_peak_cols.append(colname)
|
|
220
|
+
all_peak_var_cols.append(colname)
|
|
221
|
+
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
222
|
+
adata.var[colname] = peak_mask
|
|
223
|
+
|
|
224
|
+
# feature-layer summaries for reads in this ref
|
|
225
|
+
region = matrix[:, peak_mask] # (n_ref, n_window)
|
|
226
|
+
|
|
227
|
+
mean_col = f"mean_{layer_name}_{ref}_around_{center}"
|
|
228
|
+
sum_col = f"sum_{layer_name}_{ref}_around_{center}"
|
|
229
|
+
present_col = f"{layer_name}_{ref}_present_at_{center}"
|
|
230
|
+
|
|
231
|
+
for nm, default, dt in (
|
|
232
|
+
(mean_col, np.nan, float),
|
|
233
|
+
(sum_col, 0.0, float),
|
|
234
|
+
(present_col, False, bool),
|
|
235
|
+
):
|
|
236
|
+
if nm not in new_obs_cols:
|
|
237
|
+
new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
|
|
238
|
+
|
|
239
|
+
if region.shape[1] > 0:
|
|
240
|
+
means_per_read = np.nanmean(region, axis=1)
|
|
241
|
+
sums_per_read = np.nansum(region, axis=1)
|
|
97
242
|
else:
|
|
98
|
-
|
|
243
|
+
means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
|
|
244
|
+
sums_per_read = np.zeros(matrix.shape[0], dtype=float)
|
|
245
|
+
|
|
246
|
+
new_obs_cols[mean_col][ref_mask] = means_per_read
|
|
247
|
+
new_obs_cols[sum_col][ref_mask] = sums_per_read
|
|
248
|
+
new_obs_cols[present_col][ref_mask] = (
|
|
249
|
+
np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# site-type summaries from adata.X, not an AnnData view
|
|
253
|
+
Xmat = adata.X
|
|
254
|
+
for site_type in site_types:
|
|
255
|
+
mask_key = f"{ref}_{site_type}_site"
|
|
256
|
+
if mask_key not in adata.var:
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
site_mask = adata.var[mask_key].values.astype(bool)
|
|
260
|
+
if not site_mask.any():
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
site_coords = coordinates[site_mask]
|
|
264
|
+
site_region_mask = (site_coords >= start) & (site_coords <= end)
|
|
265
|
+
sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
266
|
+
mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
|
|
267
|
+
|
|
268
|
+
if sum_site_col not in new_obs_cols:
|
|
269
|
+
new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
|
|
270
|
+
if mean_site_col not in new_obs_cols:
|
|
271
|
+
new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
|
|
272
|
+
|
|
273
|
+
if not site_region_mask.any():
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
full_mask = np.zeros_like(site_mask, dtype=bool)
|
|
277
|
+
full_mask[site_mask] = site_region_mask
|
|
278
|
+
|
|
279
|
+
if issparse(Xmat):
|
|
280
|
+
site_region = Xmat[ref_mask][:, full_mask]
|
|
281
|
+
site_region = site_region.toarray()
|
|
282
|
+
else:
|
|
283
|
+
Xnp = np.asarray(Xmat)
|
|
284
|
+
site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
|
|
285
|
+
|
|
286
|
+
if site_region.shape[1] > 0:
|
|
287
|
+
new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
|
|
288
|
+
new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
|
|
289
|
+
|
|
290
|
+
# one-shot assignment to avoid fragmentation
|
|
291
|
+
if new_obs_cols:
|
|
292
|
+
adata.obs = adata.obs.assign(
|
|
293
|
+
**{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# per (layer, ref) any-peak
|
|
297
|
+
any_col = f"is_in_any_{layer_name}_peak_{ref}"
|
|
298
|
+
if feature_peak_cols:
|
|
299
|
+
adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
|
|
300
|
+
else:
|
|
301
|
+
adata.var[any_col] = False
|
|
99
302
|
|
|
100
|
-
|
|
101
|
-
|
|
303
|
+
logger.info(
|
|
304
|
+
"[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
|
|
305
|
+
len(peak_centers),
|
|
306
|
+
layer_name,
|
|
307
|
+
ref,
|
|
308
|
+
)
|
|
102
309
|
|
|
103
|
-
|
|
104
|
-
|
|
310
|
+
# global any-peak across all layers/refs
|
|
311
|
+
if all_peak_var_cols:
|
|
312
|
+
adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
|
|
105
313
|
|
|
106
|
-
return
|
|
314
|
+
return None if inplace else adata
|
smftools/hmm/display_hmm.py
CHANGED
|
@@ -1,18 +1,31 @@
|
|
|
1
|
+
from smftools.logging_utils import get_logger
|
|
2
|
+
|
|
3
|
+
logger = get_logger(__name__)
|
|
4
|
+
|
|
5
|
+
|
|
1
6
|
def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
|
|
7
|
+
"""Log a summary of HMM transition and emission parameters.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
hmm: HMM object with edges and distributions.
|
|
11
|
+
state_labels: Optional labels for states.
|
|
12
|
+
obs_labels: Optional labels for observations.
|
|
13
|
+
"""
|
|
2
14
|
import torch
|
|
3
|
-
print("\n**HMM Model Overview**")
|
|
4
|
-
print(hmm)
|
|
5
15
|
|
|
6
|
-
|
|
16
|
+
logger.info("**HMM Model Overview**")
|
|
17
|
+
logger.info("%s", hmm)
|
|
18
|
+
|
|
19
|
+
logger.info("**Transition Matrix**")
|
|
7
20
|
transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
|
|
8
21
|
for i, row in enumerate(transition_matrix):
|
|
9
22
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
10
23
|
formatted_row = ", ".join(f"{p:.6f}" for p in row)
|
|
11
|
-
|
|
24
|
+
logger.info("%s: [%s]", label, formatted_row)
|
|
12
25
|
|
|
13
|
-
|
|
26
|
+
logger.info("**Emission Probabilities**")
|
|
14
27
|
for i, dist in enumerate(hmm.distributions):
|
|
15
28
|
label = state_labels[i] if state_labels else f"State {i}"
|
|
16
29
|
probs = dist.probs.detach().cpu().numpy()
|
|
17
30
|
formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
|
|
18
|
-
|
|
31
|
+
logger.info("%s: %s", label, formatted_emissions)
|
smftools/hmm/hmm_readwrite.py
CHANGED
|
@@ -1,16 +1,25 @@
|
|
|
1
|
-
def load_hmm(model_path, device=
|
|
1
|
+
def load_hmm(model_path, device="cpu"):
|
|
2
2
|
"""
|
|
3
3
|
Reads in a pretrained HMM.
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
Parameters:
|
|
6
6
|
model_path (str): Path to a pretrained HMM
|
|
7
7
|
"""
|
|
8
8
|
import torch
|
|
9
|
+
|
|
9
10
|
# Load model using PyTorch
|
|
10
11
|
hmm = torch.load(model_path)
|
|
11
|
-
hmm.to(device)
|
|
12
|
+
hmm.to(device)
|
|
12
13
|
return hmm
|
|
13
14
|
|
|
15
|
+
|
|
14
16
|
def save_hmm(model, model_path):
|
|
17
|
+
"""Save a pretrained HMM to disk.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model: HMM model instance.
|
|
21
|
+
model_path: Output path for the model.
|
|
22
|
+
"""
|
|
15
23
|
import torch
|
|
16
|
-
|
|
24
|
+
|
|
25
|
+
torch.save(model, model_path)
|