smftools 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +48 -0
- smftools/cli/hmm_adata.py +168 -145
- smftools/cli/load_adata.py +155 -95
- smftools/cli/preprocess_adata.py +222 -130
- smftools/cli/spatial_adata.py +441 -308
- smftools/cli_entry.py +4 -5
- smftools/config/conversion.yaml +12 -5
- smftools/config/deaminase.yaml +11 -9
- smftools/config/default.yaml +123 -19
- smftools/config/direct.yaml +3 -0
- smftools/config/experiment_config.py +120 -19
- smftools/hmm/HMM.py +12 -1
- smftools/hmm/__init__.py +0 -6
- smftools/hmm/archived/call_hmm_peaks.py +106 -0
- smftools/hmm/call_hmm_peaks.py +318 -90
- smftools/informatics/bam_functions.py +28 -29
- smftools/informatics/h5ad_functions.py +1 -1
- smftools/plotting/general_plotting.py +97 -51
- smftools/plotting/position_stats.py +3 -3
- smftools/preprocessing/__init__.py +2 -4
- smftools/preprocessing/append_base_context.py +34 -25
- smftools/preprocessing/append_binary_layer_by_base_context.py +2 -2
- smftools/preprocessing/binarize_on_Youden.py +10 -8
- smftools/preprocessing/calculate_complexity_II.py +1 -1
- smftools/preprocessing/calculate_coverage.py +16 -13
- smftools/preprocessing/calculate_position_Youden.py +41 -25
- smftools/preprocessing/calculate_read_modification_stats.py +1 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +1 -1
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +1 -1
- smftools/preprocessing/flag_duplicate_reads.py +1 -1
- smftools/preprocessing/invert_adata.py +1 -1
- smftools/preprocessing/load_sample_sheet.py +1 -1
- smftools/preprocessing/reindex_references_adata.py +37 -0
- smftools/readwrite.py +94 -0
- {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/METADATA +18 -12
- {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/RECORD +46 -43
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archives/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{calculate_complexity.py → archives/calculate_complexity.py} +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import List, Optional, Tuple, Union, Any, Dict
|
|
2
|
+
from typing import List, Optional, Tuple, Union, Any, Dict, Sequence
|
|
3
3
|
import ast
|
|
4
4
|
import json
|
|
5
5
|
|
|
@@ -772,6 +772,8 @@ class HMM(nn.Module):
|
|
|
772
772
|
verbose: bool = True,
|
|
773
773
|
uns_key: str = "hmm_appended_layers",
|
|
774
774
|
config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
|
|
775
|
+
uns_flag: str = "hmm_annotated",
|
|
776
|
+
force_redo: bool = False
|
|
775
777
|
):
|
|
776
778
|
"""
|
|
777
779
|
Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
|
|
@@ -793,6 +795,12 @@ class HMM(nn.Module):
|
|
|
793
795
|
import torch as _torch
|
|
794
796
|
from tqdm import trange, tqdm as _tqdm
|
|
795
797
|
|
|
798
|
+
# Only run if not already performed
|
|
799
|
+
already = bool(adata.uns.get(uns_flag, False))
|
|
800
|
+
if (already and not force_redo):
|
|
801
|
+
# QC already performed; nothing to do
|
|
802
|
+
return None if in_place else adata
|
|
803
|
+
|
|
796
804
|
# small helpers
|
|
797
805
|
def _try_json_or_literal(s):
|
|
798
806
|
if s is None:
|
|
@@ -1298,6 +1306,9 @@ class HMM(nn.Module):
|
|
|
1298
1306
|
new_list = existing + [l for l in appended_layers if l not in existing]
|
|
1299
1307
|
adata.uns[uns_key] = new_list
|
|
1300
1308
|
|
|
1309
|
+
# Mark that the annotation has been completed
|
|
1310
|
+
adata.uns[uns_flag] = True
|
|
1311
|
+
|
|
1301
1312
|
return None if in_place else adata
|
|
1302
1313
|
|
|
1303
1314
|
def merge_intervals_in_layer(
|
smftools/hmm/__init__.py
CHANGED
|
@@ -1,20 +1,14 @@
|
|
|
1
|
-
from .apply_hmm_batched import apply_hmm_batched
|
|
2
|
-
from .calculate_distances import calculate_distances
|
|
3
1
|
from .call_hmm_peaks import call_hmm_peaks
|
|
4
2
|
from .display_hmm import display_hmm
|
|
5
3
|
from .hmm_readwrite import load_hmm, save_hmm
|
|
6
4
|
from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
|
|
7
|
-
from .train_hmm import train_hmm
|
|
8
5
|
|
|
9
6
|
|
|
10
7
|
__all__ = [
|
|
11
|
-
"apply_hmm_batched",
|
|
12
|
-
"calculate_distances",
|
|
13
8
|
"call_hmm_peaks",
|
|
14
9
|
"display_hmm",
|
|
15
10
|
"load_hmm",
|
|
16
11
|
"refine_nucleosome_calls",
|
|
17
12
|
"infer_nucleosomes_in_large_bound",
|
|
18
13
|
"save_hmm",
|
|
19
|
-
"train_hmm"
|
|
20
14
|
]
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
def call_hmm_peaks(
|
|
2
|
+
adata,
|
|
3
|
+
feature_configs,
|
|
4
|
+
obs_column='Reference_strand',
|
|
5
|
+
site_types=['GpC_site', 'CpG_site'],
|
|
6
|
+
save_plot=False,
|
|
7
|
+
output_dir=None,
|
|
8
|
+
date_tag=None,
|
|
9
|
+
inplace=False
|
|
10
|
+
):
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from scipy.signal import find_peaks
|
|
15
|
+
|
|
16
|
+
if not inplace:
|
|
17
|
+
adata = adata.copy()
|
|
18
|
+
|
|
19
|
+
# Ensure obs_column is categorical
|
|
20
|
+
if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
|
|
21
|
+
adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
|
|
22
|
+
|
|
23
|
+
coordinates = adata.var_names.astype(int).values
|
|
24
|
+
peak_columns = []
|
|
25
|
+
|
|
26
|
+
obs_updates = {}
|
|
27
|
+
|
|
28
|
+
for feature_layer, config in feature_configs.items():
|
|
29
|
+
min_distance = config.get('min_distance', 200)
|
|
30
|
+
peak_width = config.get('peak_width', 200)
|
|
31
|
+
peak_prominence = config.get('peak_prominence', 0.2)
|
|
32
|
+
peak_threshold = config.get('peak_threshold', 0.8)
|
|
33
|
+
|
|
34
|
+
matrix = adata.layers[feature_layer]
|
|
35
|
+
means = np.mean(matrix, axis=0)
|
|
36
|
+
peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
|
|
37
|
+
peak_centers = coordinates[peak_indices]
|
|
38
|
+
adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
|
|
39
|
+
|
|
40
|
+
# Plot
|
|
41
|
+
plt.figure(figsize=(6, 3))
|
|
42
|
+
plt.plot(coordinates, means)
|
|
43
|
+
plt.title(f"{feature_layer} with peak calls")
|
|
44
|
+
plt.xlabel("Genomic position")
|
|
45
|
+
plt.ylabel("Mean intensity")
|
|
46
|
+
for i, center in enumerate(peak_centers):
|
|
47
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
48
|
+
plt.axvspan(start, end, color='purple', alpha=0.2)
|
|
49
|
+
plt.axvline(center, color='red', linestyle='--')
|
|
50
|
+
aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
|
|
51
|
+
plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
|
|
52
|
+
if save_plot and output_dir:
|
|
53
|
+
filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
|
|
54
|
+
plt.savefig(filename, bbox_inches='tight')
|
|
55
|
+
print(f"Saved plot to {filename}")
|
|
56
|
+
else:
|
|
57
|
+
plt.show()
|
|
58
|
+
|
|
59
|
+
feature_peak_columns = []
|
|
60
|
+
for center in peak_centers:
|
|
61
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
62
|
+
colname = f'{feature_layer}_peak_{center}'
|
|
63
|
+
peak_columns.append(colname)
|
|
64
|
+
feature_peak_columns.append(colname)
|
|
65
|
+
|
|
66
|
+
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
67
|
+
adata.var[colname] = peak_mask
|
|
68
|
+
|
|
69
|
+
region = matrix[:, peak_mask]
|
|
70
|
+
obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
|
|
71
|
+
obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
|
|
72
|
+
obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
|
|
73
|
+
|
|
74
|
+
for site_type in site_types:
|
|
75
|
+
adata.obs[f'{site_type}_sum_around_{center}'] = 0
|
|
76
|
+
adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
|
|
77
|
+
|
|
78
|
+
for ref in adata.obs[obs_column].cat.categories:
|
|
79
|
+
ref_idx = adata.obs[obs_column] == ref
|
|
80
|
+
mask_key = f"{ref}_{site_type}"
|
|
81
|
+
for site_type in site_types:
|
|
82
|
+
if mask_key not in adata.var:
|
|
83
|
+
continue
|
|
84
|
+
site_mask = adata.var[mask_key].values
|
|
85
|
+
site_coords = coordinates[site_mask]
|
|
86
|
+
region_mask = (site_coords >= start) & (site_coords <= end)
|
|
87
|
+
if not region_mask.any():
|
|
88
|
+
continue
|
|
89
|
+
full_mask = site_mask.copy()
|
|
90
|
+
full_mask[site_mask] = region_mask
|
|
91
|
+
site_region = adata[ref_idx, full_mask].X
|
|
92
|
+
if hasattr(site_region, "A"):
|
|
93
|
+
site_region = site_region.A
|
|
94
|
+
if site_region.shape[1] > 0:
|
|
95
|
+
adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
|
|
96
|
+
adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
|
|
97
|
+
else:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
|
|
101
|
+
print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
|
|
102
|
+
|
|
103
|
+
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
|
104
|
+
adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
|
|
105
|
+
|
|
106
|
+
return adata if not inplace else None
|
smftools/hmm/call_hmm_peaks.py
CHANGED
|
@@ -1,106 +1,334 @@
|
|
|
1
|
+
from typing import Dict, Optional, Any, Union, Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
1
4
|
def call_hmm_peaks(
|
|
2
5
|
adata,
|
|
3
|
-
feature_configs,
|
|
4
|
-
|
|
5
|
-
site_types=
|
|
6
|
-
save_plot=False,
|
|
7
|
-
output_dir=None,
|
|
8
|
-
date_tag=None,
|
|
9
|
-
inplace=
|
|
6
|
+
feature_configs: Dict[str, Dict[str, Any]],
|
|
7
|
+
ref_column: str = "Reference_strand",
|
|
8
|
+
site_types: Sequence[str] = ("GpC", "CpG"),
|
|
9
|
+
save_plot: bool = False,
|
|
10
|
+
output_dir: Optional[Union[str, "Path"]] = None,
|
|
11
|
+
date_tag: Optional[str] = None,
|
|
12
|
+
inplace: bool = True,
|
|
13
|
+
index_col_suffix: Optional[str] = None,
|
|
14
|
+
alternate_labels: bool = False,
|
|
10
15
|
):
|
|
16
|
+
"""
|
|
17
|
+
Call peaks on one or more HMM-derived (or other) layers and annotate adata.var / adata.obs,
|
|
18
|
+
doing peak calling *within each reference subset*.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
adata : AnnData
|
|
23
|
+
Input AnnData with layers already containing feature tracks (e.g. HMM-derived masks).
|
|
24
|
+
feature_configs : dict
|
|
25
|
+
Mapping: feature_type_or_layer_suffix -> {
|
|
26
|
+
"min_distance": int (default 200),
|
|
27
|
+
"peak_width": int (default 200),
|
|
28
|
+
"peak_prominence": float (default 0.2),
|
|
29
|
+
"peak_threshold": float (default 0.8),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
Keys are usually *feature types* like "all_accessible_features" or
|
|
33
|
+
"small_bound_stretch". These are matched against existing HMM layers
|
|
34
|
+
(e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
|
|
35
|
+
using a suffix match. You can also pass full layer names if you wish.
|
|
36
|
+
ref_column : str
|
|
37
|
+
Column in adata.obs defining reference groups (e.g. "Reference_strand").
|
|
38
|
+
site_types : sequence of str
|
|
39
|
+
Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
|
|
40
|
+
e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
|
|
41
|
+
save_plot : bool
|
|
42
|
+
If True, save peak diagnostic plots instead of just showing them.
|
|
43
|
+
output_dir : path-like or None
|
|
44
|
+
Directory for saved plots (created if needed).
|
|
45
|
+
date_tag : str or None
|
|
46
|
+
Optional tag to prefix plot filenames.
|
|
47
|
+
inplace : bool
|
|
48
|
+
If False, operate on a copy and return it. If True, modify adata and return None.
|
|
49
|
+
index_col_suffix : str or None
|
|
50
|
+
If None, coordinates come from adata.var_names (cast to int when possible).
|
|
51
|
+
If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
|
|
52
|
+
coordinate system (e.g. a reindexed coordinate).
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
None or AnnData
|
|
57
|
+
"""
|
|
11
58
|
import numpy as np
|
|
12
59
|
import pandas as pd
|
|
13
60
|
import matplotlib.pyplot as plt
|
|
14
61
|
from scipy.signal import find_peaks
|
|
62
|
+
from scipy.sparse import issparse
|
|
15
63
|
|
|
16
64
|
if not inplace:
|
|
17
65
|
adata = adata.copy()
|
|
18
66
|
|
|
19
|
-
# Ensure
|
|
20
|
-
if not
|
|
21
|
-
adata.obs[
|
|
22
|
-
|
|
23
|
-
coordinates
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
67
|
+
# Ensure ref_column is categorical
|
|
68
|
+
if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
|
|
69
|
+
adata.obs[ref_column] = adata.obs[ref_column].astype("category")
|
|
70
|
+
|
|
71
|
+
# Base coordinates (fallback)
|
|
72
|
+
try:
|
|
73
|
+
base_coordinates = adata.var_names.astype(int).values
|
|
74
|
+
except Exception:
|
|
75
|
+
base_coordinates = np.arange(adata.n_vars, dtype=int)
|
|
76
|
+
|
|
77
|
+
if output_dir is not None:
|
|
78
|
+
output_dir = Path(output_dir)
|
|
79
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
|
|
81
|
+
# HMM layers known to the object (if present)
|
|
82
|
+
hmm_layers = list(adata.uns.get("hmm_appended_layers", [])) or []
|
|
83
|
+
# keep only the binary masks, not *_lengths
|
|
84
|
+
hmm_layers = [layer for layer in hmm_layers if not layer.endswith("_lengths")]
|
|
85
|
+
|
|
86
|
+
# Fallback: use all layer names if hmm_appended_layers is empty/missing
|
|
87
|
+
all_layer_names = list(adata.layers.keys())
|
|
88
|
+
|
|
89
|
+
all_peak_var_cols = []
|
|
90
|
+
|
|
91
|
+
# Iterate over each reference separately
|
|
92
|
+
for ref in adata.obs[ref_column].cat.categories:
|
|
93
|
+
ref_mask = (adata.obs[ref_column] == ref).values
|
|
94
|
+
if not ref_mask.any():
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
# Per-ref coordinates: either from a reindexed column or global fallback
|
|
98
|
+
if index_col_suffix is not None:
|
|
99
|
+
coord_col = f"{ref}_{index_col_suffix}"
|
|
100
|
+
if coord_col not in adata.var:
|
|
101
|
+
raise KeyError(
|
|
102
|
+
f"index_col_suffix='{index_col_suffix}' requested, "
|
|
103
|
+
f"but var column '{coord_col}' is missing for ref '{ref}'."
|
|
104
|
+
)
|
|
105
|
+
coord_vals = adata.var[coord_col].values
|
|
106
|
+
# Try to coerce to numeric
|
|
107
|
+
try:
|
|
108
|
+
coordinates = coord_vals.astype(int)
|
|
109
|
+
except Exception:
|
|
110
|
+
coordinates = np.asarray(coord_vals, dtype=float)
|
|
56
111
|
else:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
for
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
112
|
+
coordinates = base_coordinates
|
|
113
|
+
|
|
114
|
+
# Resolve each feature_config key to one or more actual layer names
|
|
115
|
+
for feature_key, config in feature_configs.items():
|
|
116
|
+
# Candidate search space: HMM layers if present, else all layers
|
|
117
|
+
search_layers = hmm_layers if hmm_layers else all_layer_names
|
|
118
|
+
|
|
119
|
+
candidate_layers = []
|
|
120
|
+
|
|
121
|
+
# First: exact match
|
|
122
|
+
for lname in search_layers:
|
|
123
|
+
if lname == feature_key:
|
|
124
|
+
candidate_layers.append(lname)
|
|
125
|
+
|
|
126
|
+
# Second: suffix match (e.g. "all_accessible_features" ->
|
|
127
|
+
# "GpC_all_accessible_features", "Combined_all_accessible_features", etc.)
|
|
128
|
+
if not candidate_layers:
|
|
129
|
+
for lname in search_layers:
|
|
130
|
+
if lname.endswith(feature_key):
|
|
131
|
+
candidate_layers.append(lname)
|
|
132
|
+
|
|
133
|
+
# Third: if user passed a full layer name that wasn't in hmm_layers,
|
|
134
|
+
# but does exist in adata.layers, allow it.
|
|
135
|
+
if not candidate_layers and feature_key in adata.layers:
|
|
136
|
+
candidate_layers.append(feature_key)
|
|
137
|
+
|
|
138
|
+
if not candidate_layers:
|
|
139
|
+
print(
|
|
140
|
+
f"[call_hmm_peaks] WARNING: no layers found matching feature key "
|
|
141
|
+
f"'{feature_key}' in ref '{ref}'. Skipping."
|
|
142
|
+
)
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
# Run peak calling on each resolved layer for this ref
|
|
146
|
+
for layer_name in candidate_layers:
|
|
147
|
+
if layer_name not in adata.layers:
|
|
148
|
+
print(
|
|
149
|
+
f"[call_hmm_peaks] WARNING: resolved layer '{layer_name}' "
|
|
150
|
+
f"not found in adata.layers; skipping."
|
|
151
|
+
)
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
min_distance = int(config.get("min_distance", 200))
|
|
155
|
+
peak_width = int(config.get("peak_width", 200))
|
|
156
|
+
peak_prominence = float(config.get("peak_prominence", 0.2))
|
|
157
|
+
peak_threshold = float(config.get("peak_threshold", 0.8))
|
|
158
|
+
|
|
159
|
+
layer_data = adata.layers[layer_name]
|
|
160
|
+
if issparse(layer_data):
|
|
161
|
+
layer_data = layer_data.toarray()
|
|
162
|
+
else:
|
|
163
|
+
layer_data = np.asarray(layer_data)
|
|
164
|
+
|
|
165
|
+
# Subset rows for this ref
|
|
166
|
+
matrix = layer_data[ref_mask, :] # (n_ref_reads, n_vars)
|
|
167
|
+
if matrix.shape[0] == 0:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
# Mean signal along positions (within this ref only)
|
|
171
|
+
means = np.nanmean(matrix, axis=0)
|
|
172
|
+
|
|
173
|
+
# Optional rolling-mean smoothing before peak detection
|
|
174
|
+
rolling_window = int(config.get("rolling_window", 1))
|
|
175
|
+
if rolling_window > 1:
|
|
176
|
+
# Simple centered rolling mean via convolution
|
|
177
|
+
kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
|
|
178
|
+
smoothed = np.convolve(means, kernel, mode="same")
|
|
179
|
+
peak_metric = smoothed
|
|
180
|
+
else:
|
|
181
|
+
peak_metric = means
|
|
182
|
+
|
|
183
|
+
# Peak detection
|
|
184
|
+
peak_indices, _ = find_peaks(
|
|
185
|
+
peak_metric, prominence=peak_prominence, distance=min_distance
|
|
186
|
+
)
|
|
187
|
+
if peak_indices.size == 0:
|
|
188
|
+
print(
|
|
189
|
+
f"[call_hmm_peaks] No peaks found for layer '{layer_name}' "
|
|
190
|
+
f"in ref '{ref}'."
|
|
191
|
+
)
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
peak_centers = coordinates[peak_indices]
|
|
195
|
+
# Store per-ref peak centers
|
|
196
|
+
adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
|
|
197
|
+
|
|
198
|
+
# ---- Plot ----
|
|
199
|
+
plt.figure(figsize=(6, 3))
|
|
200
|
+
plt.plot(coordinates, peak_metric, linewidth=1)
|
|
201
|
+
plt.title(f"{layer_name} peaks in {ref}")
|
|
202
|
+
plt.xlabel("Coordinate")
|
|
203
|
+
plt.ylabel(f"Rolling Mean - roll size {rolling_window}")
|
|
204
|
+
|
|
205
|
+
for i, center in enumerate(peak_centers):
|
|
206
|
+
start = center - peak_width // 2
|
|
207
|
+
end = center + peak_width // 2
|
|
208
|
+
height = peak_metric[peak_indices[i]]
|
|
209
|
+
plt.axvspan(start, end, color="purple", alpha=0.2)
|
|
210
|
+
plt.axvline(center, color="red", linestyle="--", linewidth=0.8)
|
|
211
|
+
|
|
212
|
+
# alternate label placement a bit left/right
|
|
213
|
+
if alternate_labels:
|
|
214
|
+
if i % 2 == 0:
|
|
215
|
+
x_text, ha = start, "right"
|
|
216
|
+
else:
|
|
217
|
+
x_text, ha = end, "left"
|
|
97
218
|
else:
|
|
98
|
-
|
|
219
|
+
x_text, ha = start, "right"
|
|
220
|
+
|
|
221
|
+
plt.text(
|
|
222
|
+
x_text,
|
|
223
|
+
height * 0.8,
|
|
224
|
+
f"Peak {i}\n{center}",
|
|
225
|
+
color="red",
|
|
226
|
+
ha=ha,
|
|
227
|
+
va="bottom",
|
|
228
|
+
fontsize=8,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if save_plot and output_dir is not None:
|
|
232
|
+
tag = date_tag or "output"
|
|
233
|
+
# include ref in filename
|
|
234
|
+
safe_ref = str(ref).replace("/", "_")
|
|
235
|
+
safe_layer = str(layer_name).replace("/", "_")
|
|
236
|
+
fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
|
|
237
|
+
plt.savefig(fname, bbox_inches="tight", dpi=200)
|
|
238
|
+
print(f"[call_hmm_peaks] Saved plot to {fname}")
|
|
239
|
+
plt.close()
|
|
240
|
+
else:
|
|
241
|
+
plt.tight_layout()
|
|
242
|
+
plt.show()
|
|
243
|
+
|
|
244
|
+
feature_peak_cols = []
|
|
245
|
+
|
|
246
|
+
# ---- Per-peak annotations (within this ref) ----
|
|
247
|
+
for center in peak_centers:
|
|
248
|
+
start = center - peak_width // 2
|
|
249
|
+
end = center + peak_width // 2
|
|
250
|
+
|
|
251
|
+
# Make column names ref- and layer-specific so they don't collide
|
|
252
|
+
colname = f"{layer_name}_{ref}_peak_{center}"
|
|
253
|
+
feature_peak_cols.append(colname)
|
|
254
|
+
all_peak_var_cols.append(colname)
|
|
255
|
+
|
|
256
|
+
# Var-level mask: is this position in the window?
|
|
257
|
+
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
258
|
+
adata.var[colname] = peak_mask
|
|
259
|
+
|
|
260
|
+
# Extract signal in that window from the *ref subset* matrix
|
|
261
|
+
region = matrix[:, peak_mask] # (n_ref_reads, n_positions_in_window)
|
|
262
|
+
|
|
263
|
+
# Per-read summary in this window for the feature layer itself
|
|
264
|
+
mean_col = f"mean_{layer_name}_{ref}_around_{center}"
|
|
265
|
+
sum_col = f"sum_{layer_name}_{ref}_around_{center}"
|
|
266
|
+
present_col = f"{layer_name}_{ref}_present_at_{center}"
|
|
267
|
+
|
|
268
|
+
# Create columns if missing, then fill only the ref rows
|
|
269
|
+
if mean_col not in adata.obs:
|
|
270
|
+
adata.obs[mean_col] = np.nan
|
|
271
|
+
if sum_col not in adata.obs:
|
|
272
|
+
adata.obs[sum_col] = 0.0
|
|
273
|
+
if present_col not in adata.obs:
|
|
274
|
+
adata.obs[present_col] = False
|
|
275
|
+
|
|
276
|
+
adata.obs.loc[ref_mask, mean_col] = np.nanmean(region, axis=1)
|
|
277
|
+
adata.obs.loc[ref_mask, sum_col] = np.nansum(region, axis=1)
|
|
278
|
+
adata.obs.loc[ref_mask, present_col] = (
|
|
279
|
+
adata.obs.loc[ref_mask, mean_col].values > peak_threshold
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Initialize site-type summaries (global columns; filled per ref)
|
|
283
|
+
for site_type in site_types:
|
|
284
|
+
sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
|
|
285
|
+
mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
|
|
286
|
+
if sum_site_col not in adata.obs:
|
|
287
|
+
adata.obs[sum_site_col] = 0.0
|
|
288
|
+
if mean_site_col not in adata.obs:
|
|
289
|
+
adata.obs[mean_site_col] = np.nan
|
|
290
|
+
|
|
291
|
+
# Per-site-type summaries for this ref
|
|
292
|
+
for site_type in site_types:
|
|
293
|
+
mask_key = f"{ref}_{site_type}_site"
|
|
294
|
+
if mask_key not in adata.var:
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
site_mask = adata.var[mask_key].values.astype(bool)
|
|
298
|
+
if not site_mask.any():
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
site_coords = coordinates[site_mask]
|
|
302
|
+
region_mask = (site_coords >= start) & (site_coords <= end)
|
|
303
|
+
if not region_mask.any():
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
full_mask = np.zeros_like(site_mask, dtype=bool)
|
|
307
|
+
full_mask[site_mask] = region_mask
|
|
308
|
+
|
|
309
|
+
site_region = adata[ref_mask, full_mask].X
|
|
310
|
+
if hasattr(site_region, "A"):
|
|
311
|
+
site_region = site_region.A # sparse -> dense
|
|
312
|
+
|
|
313
|
+
if site_region.shape[1] == 0:
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
|
|
317
|
+
mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
|
|
318
|
+
|
|
319
|
+
adata.obs.loc[ref_mask, sum_site_col] = np.nansum(site_region, axis=1)
|
|
320
|
+
adata.obs.loc[ref_mask, mean_site_col] = np.nanmean(site_region, axis=1)
|
|
99
321
|
|
|
100
|
-
|
|
101
|
-
|
|
322
|
+
# Mark "any peak" for this (layer, ref)
|
|
323
|
+
any_col = f"is_in_any_{layer_name}_peak_{ref}"
|
|
324
|
+
adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
|
|
325
|
+
print(
|
|
326
|
+
f"[call_hmm_peaks] Annotated {len(peak_centers)} peaks "
|
|
327
|
+
f"for layer '{layer_name}' in ref '{ref}'."
|
|
328
|
+
)
|
|
102
329
|
|
|
103
|
-
|
|
104
|
-
|
|
330
|
+
# Global any-peak flag across all feature layers and references
|
|
331
|
+
if all_peak_var_cols:
|
|
332
|
+
adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
|
|
105
333
|
|
|
106
|
-
return
|
|
334
|
+
return None if inplace else adata
|