smftools 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +48 -0
  3. smftools/cli/hmm_adata.py +168 -145
  4. smftools/cli/load_adata.py +155 -95
  5. smftools/cli/preprocess_adata.py +222 -130
  6. smftools/cli/spatial_adata.py +441 -308
  7. smftools/cli_entry.py +4 -5
  8. smftools/config/conversion.yaml +12 -5
  9. smftools/config/deaminase.yaml +11 -9
  10. smftools/config/default.yaml +123 -19
  11. smftools/config/direct.yaml +3 -0
  12. smftools/config/experiment_config.py +120 -19
  13. smftools/hmm/HMM.py +12 -1
  14. smftools/hmm/__init__.py +0 -6
  15. smftools/hmm/archived/call_hmm_peaks.py +106 -0
  16. smftools/hmm/call_hmm_peaks.py +318 -90
  17. smftools/informatics/bam_functions.py +28 -29
  18. smftools/informatics/h5ad_functions.py +1 -1
  19. smftools/plotting/general_plotting.py +97 -51
  20. smftools/plotting/position_stats.py +3 -3
  21. smftools/preprocessing/__init__.py +2 -4
  22. smftools/preprocessing/append_base_context.py +34 -25
  23. smftools/preprocessing/append_binary_layer_by_base_context.py +2 -2
  24. smftools/preprocessing/binarize_on_Youden.py +10 -8
  25. smftools/preprocessing/calculate_complexity_II.py +1 -1
  26. smftools/preprocessing/calculate_coverage.py +16 -13
  27. smftools/preprocessing/calculate_position_Youden.py +41 -25
  28. smftools/preprocessing/calculate_read_modification_stats.py +1 -1
  29. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +1 -1
  30. smftools/preprocessing/filter_reads_on_modification_thresholds.py +1 -1
  31. smftools/preprocessing/flag_duplicate_reads.py +1 -1
  32. smftools/preprocessing/invert_adata.py +1 -1
  33. smftools/preprocessing/load_sample_sheet.py +1 -1
  34. smftools/preprocessing/reindex_references_adata.py +37 -0
  35. smftools/readwrite.py +94 -0
  36. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/METADATA +18 -12
  37. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/RECORD +46 -43
  38. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  39. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  40. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  41. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  42. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archives/add_read_length_and_mapping_qc.py} +0 -0
  43. /smftools/preprocessing/{calculate_complexity.py → archives/calculate_complexity.py} +0 -0
  44. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/WHEEL +0 -0
  45. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/entry_points.txt +0 -0
  46. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import List, Optional, Tuple, Union, Any, Dict
2
+ from typing import List, Optional, Tuple, Union, Any, Dict, Sequence
3
3
  import ast
4
4
  import json
5
5
 
@@ -772,6 +772,8 @@ class HMM(nn.Module):
772
772
  verbose: bool = True,
773
773
  uns_key: str = "hmm_appended_layers",
774
774
  config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
775
+ uns_flag: str = "hmm_annotated",
776
+ force_redo: bool = False
775
777
  ):
776
778
  """
777
779
  Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
@@ -793,6 +795,12 @@ class HMM(nn.Module):
793
795
  import torch as _torch
794
796
  from tqdm import trange, tqdm as _tqdm
795
797
 
798
+ # Only run if not already performed
799
+ already = bool(adata.uns.get(uns_flag, False))
800
+ if (already and not force_redo):
801
+ # QC already performed; nothing to do
802
+ return None if in_place else adata
803
+
796
804
  # small helpers
797
805
  def _try_json_or_literal(s):
798
806
  if s is None:
@@ -1298,6 +1306,9 @@ class HMM(nn.Module):
1298
1306
  new_list = existing + [l for l in appended_layers if l not in existing]
1299
1307
  adata.uns[uns_key] = new_list
1300
1308
 
1309
+ # Mark that the annotation has been completed
1310
+ adata.uns[uns_flag] = True
1311
+
1301
1312
  return None if in_place else adata
1302
1313
 
1303
1314
  def merge_intervals_in_layer(
smftools/hmm/__init__.py CHANGED
@@ -1,20 +1,14 @@
1
- from .apply_hmm_batched import apply_hmm_batched
2
- from .calculate_distances import calculate_distances
3
1
  from .call_hmm_peaks import call_hmm_peaks
4
2
  from .display_hmm import display_hmm
5
3
  from .hmm_readwrite import load_hmm, save_hmm
6
4
  from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
7
- from .train_hmm import train_hmm
8
5
 
9
6
 
10
7
  __all__ = [
11
- "apply_hmm_batched",
12
- "calculate_distances",
13
8
  "call_hmm_peaks",
14
9
  "display_hmm",
15
10
  "load_hmm",
16
11
  "refine_nucleosome_calls",
17
12
  "infer_nucleosomes_in_large_bound",
18
13
  "save_hmm",
19
- "train_hmm"
20
14
  ]
@@ -0,0 +1,106 @@
1
+ def call_hmm_peaks(
2
+ adata,
3
+ feature_configs,
4
+ obs_column='Reference_strand',
5
+ site_types=['GpC_site', 'CpG_site'],
6
+ save_plot=False,
7
+ output_dir=None,
8
+ date_tag=None,
9
+ inplace=False
10
+ ):
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from scipy.signal import find_peaks
15
+
16
+ if not inplace:
17
+ adata = adata.copy()
18
+
19
+ # Ensure obs_column is categorical
20
+ if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
21
+ adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
22
+
23
+ coordinates = adata.var_names.astype(int).values
24
+ peak_columns = []
25
+
26
+ obs_updates = {}
27
+
28
+ for feature_layer, config in feature_configs.items():
29
+ min_distance = config.get('min_distance', 200)
30
+ peak_width = config.get('peak_width', 200)
31
+ peak_prominence = config.get('peak_prominence', 0.2)
32
+ peak_threshold = config.get('peak_threshold', 0.8)
33
+
34
+ matrix = adata.layers[feature_layer]
35
+ means = np.mean(matrix, axis=0)
36
+ peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
+ peak_centers = coordinates[peak_indices]
38
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
39
+
40
+ # Plot
41
+ plt.figure(figsize=(6, 3))
42
+ plt.plot(coordinates, means)
43
+ plt.title(f"{feature_layer} with peak calls")
44
+ plt.xlabel("Genomic position")
45
+ plt.ylabel("Mean intensity")
46
+ for i, center in enumerate(peak_centers):
47
+ start, end = center - peak_width // 2, center + peak_width // 2
48
+ plt.axvspan(start, end, color='purple', alpha=0.2)
49
+ plt.axvline(center, color='red', linestyle='--')
50
+ aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
51
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
52
+ if save_plot and output_dir:
53
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
54
+ plt.savefig(filename, bbox_inches='tight')
55
+ print(f"Saved plot to {filename}")
56
+ else:
57
+ plt.show()
58
+
59
+ feature_peak_columns = []
60
+ for center in peak_centers:
61
+ start, end = center - peak_width // 2, center + peak_width // 2
62
+ colname = f'{feature_layer}_peak_{center}'
63
+ peak_columns.append(colname)
64
+ feature_peak_columns.append(colname)
65
+
66
+ peak_mask = (coordinates >= start) & (coordinates <= end)
67
+ adata.var[colname] = peak_mask
68
+
69
+ region = matrix[:, peak_mask]
70
+ obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
71
+ obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
72
+ obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
73
+
74
+ for site_type in site_types:
75
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
76
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
77
+
78
+ for ref in adata.obs[obs_column].cat.categories:
79
+ ref_idx = adata.obs[obs_column] == ref
80
+ mask_key = f"{ref}_{site_type}"
81
+ for site_type in site_types:
82
+ if mask_key not in adata.var:
83
+ continue
84
+ site_mask = adata.var[mask_key].values
85
+ site_coords = coordinates[site_mask]
86
+ region_mask = (site_coords >= start) & (site_coords <= end)
87
+ if not region_mask.any():
88
+ continue
89
+ full_mask = site_mask.copy()
90
+ full_mask[site_mask] = region_mask
91
+ site_region = adata[ref_idx, full_mask].X
92
+ if hasattr(site_region, "A"):
93
+ site_region = site_region.A
94
+ if site_region.shape[1] > 0:
95
+ adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
96
+ adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
97
+ else:
98
+ pass
99
+
100
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
101
+ print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
102
+
103
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
+ adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
105
+
106
+ return adata if not inplace else None
@@ -1,106 +1,334 @@
1
+ from typing import Dict, Optional, Any, Union, Sequence
2
+ from pathlib import Path
3
+
1
4
  def call_hmm_peaks(
2
5
  adata,
3
- feature_configs,
4
- obs_column='Reference_strand',
5
- site_types=['GpC_site', 'CpG_site'],
6
- save_plot=False,
7
- output_dir=None,
8
- date_tag=None,
9
- inplace=False
6
+ feature_configs: Dict[str, Dict[str, Any]],
7
+ ref_column: str = "Reference_strand",
8
+ site_types: Sequence[str] = ("GpC", "CpG"),
9
+ save_plot: bool = False,
10
+ output_dir: Optional[Union[str, "Path"]] = None,
11
+ date_tag: Optional[str] = None,
12
+ inplace: bool = True,
13
+ index_col_suffix: Optional[str] = None,
14
+ alternate_labels: bool = False,
10
15
  ):
16
+ """
17
+ Call peaks on one or more HMM-derived (or other) layers and annotate adata.var / adata.obs,
18
+ doing peak calling *within each reference subset*.
19
+
20
+ Parameters
21
+ ----------
22
+ adata : AnnData
23
+ Input AnnData with layers already containing feature tracks (e.g. HMM-derived masks).
24
+ feature_configs : dict
25
+ Mapping: feature_type_or_layer_suffix -> {
26
+ "min_distance": int (default 200),
27
+ "peak_width": int (default 200),
28
+ "peak_prominence": float (default 0.2),
29
+ "peak_threshold": float (default 0.8),
30
+ }
31
+
32
+ Keys are usually *feature types* like "all_accessible_features" or
33
+ "small_bound_stretch". These are matched against existing HMM layers
34
+ (e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
35
+ using a suffix match. You can also pass full layer names if you wish.
36
+ ref_column : str
37
+ Column in adata.obs defining reference groups (e.g. "Reference_strand").
38
+ site_types : sequence of str
39
+ Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
40
+ e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
41
+ save_plot : bool
42
+ If True, save peak diagnostic plots instead of just showing them.
43
+ output_dir : path-like or None
44
+ Directory for saved plots (created if needed).
45
+ date_tag : str or None
46
+ Optional tag to prefix plot filenames.
47
+ inplace : bool
48
+ If False, operate on a copy and return it. If True, modify adata and return None.
49
+ index_col_suffix : str or None
50
+ If None, coordinates come from adata.var_names (cast to int when possible).
51
+ If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
52
+ coordinate system (e.g. a reindexed coordinate).
53
+
54
+ Returns
55
+ -------
56
+ None or AnnData
57
+ """
11
58
  import numpy as np
12
59
  import pandas as pd
13
60
  import matplotlib.pyplot as plt
14
61
  from scipy.signal import find_peaks
62
+ from scipy.sparse import issparse
15
63
 
16
64
  if not inplace:
17
65
  adata = adata.copy()
18
66
 
19
- # Ensure obs_column is categorical
20
- if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
21
- adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
22
-
23
- coordinates = adata.var_names.astype(int).values
24
- peak_columns = []
25
-
26
- obs_updates = {}
27
-
28
- for feature_layer, config in feature_configs.items():
29
- min_distance = config.get('min_distance', 200)
30
- peak_width = config.get('peak_width', 200)
31
- peak_prominence = config.get('peak_prominence', 0.2)
32
- peak_threshold = config.get('peak_threshold', 0.8)
33
-
34
- matrix = adata.layers[feature_layer]
35
- means = np.mean(matrix, axis=0)
36
- peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
- peak_centers = coordinates[peak_indices]
38
- adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
39
-
40
- # Plot
41
- plt.figure(figsize=(6, 3))
42
- plt.plot(coordinates, means)
43
- plt.title(f"{feature_layer} with peak calls")
44
- plt.xlabel("Genomic position")
45
- plt.ylabel("Mean intensity")
46
- for i, center in enumerate(peak_centers):
47
- start, end = center - peak_width // 2, center + peak_width // 2
48
- plt.axvspan(start, end, color='purple', alpha=0.2)
49
- plt.axvline(center, color='red', linestyle='--')
50
- aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
51
- plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
52
- if save_plot and output_dir:
53
- filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
54
- plt.savefig(filename, bbox_inches='tight')
55
- print(f"Saved plot to {filename}")
67
+ # Ensure ref_column is categorical
68
+ if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
69
+ adata.obs[ref_column] = adata.obs[ref_column].astype("category")
70
+
71
+ # Base coordinates (fallback)
72
+ try:
73
+ base_coordinates = adata.var_names.astype(int).values
74
+ except Exception:
75
+ base_coordinates = np.arange(adata.n_vars, dtype=int)
76
+
77
+ if output_dir is not None:
78
+ output_dir = Path(output_dir)
79
+ output_dir.mkdir(parents=True, exist_ok=True)
80
+
81
+ # HMM layers known to the object (if present)
82
+ hmm_layers = list(adata.uns.get("hmm_appended_layers", [])) or []
83
+ # keep only the binary masks, not *_lengths
84
+ hmm_layers = [layer for layer in hmm_layers if not layer.endswith("_lengths")]
85
+
86
+ # Fallback: use all layer names if hmm_appended_layers is empty/missing
87
+ all_layer_names = list(adata.layers.keys())
88
+
89
+ all_peak_var_cols = []
90
+
91
+ # Iterate over each reference separately
92
+ for ref in adata.obs[ref_column].cat.categories:
93
+ ref_mask = (adata.obs[ref_column] == ref).values
94
+ if not ref_mask.any():
95
+ continue
96
+
97
+ # Per-ref coordinates: either from a reindexed column or global fallback
98
+ if index_col_suffix is not None:
99
+ coord_col = f"{ref}_{index_col_suffix}"
100
+ if coord_col not in adata.var:
101
+ raise KeyError(
102
+ f"index_col_suffix='{index_col_suffix}' requested, "
103
+ f"but var column '{coord_col}' is missing for ref '{ref}'."
104
+ )
105
+ coord_vals = adata.var[coord_col].values
106
+ # Try to coerce to numeric
107
+ try:
108
+ coordinates = coord_vals.astype(int)
109
+ except Exception:
110
+ coordinates = np.asarray(coord_vals, dtype=float)
56
111
  else:
57
- plt.show()
58
-
59
- feature_peak_columns = []
60
- for center in peak_centers:
61
- start, end = center - peak_width // 2, center + peak_width // 2
62
- colname = f'{feature_layer}_peak_{center}'
63
- peak_columns.append(colname)
64
- feature_peak_columns.append(colname)
65
-
66
- peak_mask = (coordinates >= start) & (coordinates <= end)
67
- adata.var[colname] = peak_mask
68
-
69
- region = matrix[:, peak_mask]
70
- obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
71
- obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
72
- obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
73
-
74
- for site_type in site_types:
75
- adata.obs[f'{site_type}_sum_around_{center}'] = 0
76
- adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
77
-
78
- for ref in adata.obs[obs_column].cat.categories:
79
- ref_idx = adata.obs[obs_column] == ref
80
- mask_key = f"{ref}_{site_type}"
81
- for site_type in site_types:
82
- if mask_key not in adata.var:
83
- continue
84
- site_mask = adata.var[mask_key].values
85
- site_coords = coordinates[site_mask]
86
- region_mask = (site_coords >= start) & (site_coords <= end)
87
- if not region_mask.any():
88
- continue
89
- full_mask = site_mask.copy()
90
- full_mask[site_mask] = region_mask
91
- site_region = adata[ref_idx, full_mask].X
92
- if hasattr(site_region, "A"):
93
- site_region = site_region.A
94
- if site_region.shape[1] > 0:
95
- adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
96
- adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
112
+ coordinates = base_coordinates
113
+
114
+ # Resolve each feature_config key to one or more actual layer names
115
+ for feature_key, config in feature_configs.items():
116
+ # Candidate search space: HMM layers if present, else all layers
117
+ search_layers = hmm_layers if hmm_layers else all_layer_names
118
+
119
+ candidate_layers = []
120
+
121
+ # First: exact match
122
+ for lname in search_layers:
123
+ if lname == feature_key:
124
+ candidate_layers.append(lname)
125
+
126
+ # Second: suffix match (e.g. "all_accessible_features" ->
127
+ # "GpC_all_accessible_features", "Combined_all_accessible_features", etc.)
128
+ if not candidate_layers:
129
+ for lname in search_layers:
130
+ if lname.endswith(feature_key):
131
+ candidate_layers.append(lname)
132
+
133
+ # Third: if user passed a full layer name that wasn't in hmm_layers,
134
+ # but does exist in adata.layers, allow it.
135
+ if not candidate_layers and feature_key in adata.layers:
136
+ candidate_layers.append(feature_key)
137
+
138
+ if not candidate_layers:
139
+ print(
140
+ f"[call_hmm_peaks] WARNING: no layers found matching feature key "
141
+ f"'{feature_key}' in ref '{ref}'. Skipping."
142
+ )
143
+ continue
144
+
145
+ # Run peak calling on each resolved layer for this ref
146
+ for layer_name in candidate_layers:
147
+ if layer_name not in adata.layers:
148
+ print(
149
+ f"[call_hmm_peaks] WARNING: resolved layer '{layer_name}' "
150
+ f"not found in adata.layers; skipping."
151
+ )
152
+ continue
153
+
154
+ min_distance = int(config.get("min_distance", 200))
155
+ peak_width = int(config.get("peak_width", 200))
156
+ peak_prominence = float(config.get("peak_prominence", 0.2))
157
+ peak_threshold = float(config.get("peak_threshold", 0.8))
158
+
159
+ layer_data = adata.layers[layer_name]
160
+ if issparse(layer_data):
161
+ layer_data = layer_data.toarray()
162
+ else:
163
+ layer_data = np.asarray(layer_data)
164
+
165
+ # Subset rows for this ref
166
+ matrix = layer_data[ref_mask, :] # (n_ref_reads, n_vars)
167
+ if matrix.shape[0] == 0:
168
+ continue
169
+
170
+ # Mean signal along positions (within this ref only)
171
+ means = np.nanmean(matrix, axis=0)
172
+
173
+ # Optional rolling-mean smoothing before peak detection
174
+ rolling_window = int(config.get("rolling_window", 1))
175
+ if rolling_window > 1:
176
+ # Simple centered rolling mean via convolution
177
+ kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
178
+ smoothed = np.convolve(means, kernel, mode="same")
179
+ peak_metric = smoothed
180
+ else:
181
+ peak_metric = means
182
+
183
+ # Peak detection
184
+ peak_indices, _ = find_peaks(
185
+ peak_metric, prominence=peak_prominence, distance=min_distance
186
+ )
187
+ if peak_indices.size == 0:
188
+ print(
189
+ f"[call_hmm_peaks] No peaks found for layer '{layer_name}' "
190
+ f"in ref '{ref}'."
191
+ )
192
+ continue
193
+
194
+ peak_centers = coordinates[peak_indices]
195
+ # Store per-ref peak centers
196
+ adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
197
+
198
+ # ---- Plot ----
199
+ plt.figure(figsize=(6, 3))
200
+ plt.plot(coordinates, peak_metric, linewidth=1)
201
+ plt.title(f"{layer_name} peaks in {ref}")
202
+ plt.xlabel("Coordinate")
203
+ plt.ylabel(f"Rolling Mean - roll size {rolling_window}")
204
+
205
+ for i, center in enumerate(peak_centers):
206
+ start = center - peak_width // 2
207
+ end = center + peak_width // 2
208
+ height = peak_metric[peak_indices[i]]
209
+ plt.axvspan(start, end, color="purple", alpha=0.2)
210
+ plt.axvline(center, color="red", linestyle="--", linewidth=0.8)
211
+
212
+ # alternate label placement a bit left/right
213
+ if alternate_labels:
214
+ if i % 2 == 0:
215
+ x_text, ha = start, "right"
216
+ else:
217
+ x_text, ha = end, "left"
97
218
  else:
98
- pass
219
+ x_text, ha = start, "right"
220
+
221
+ plt.text(
222
+ x_text,
223
+ height * 0.8,
224
+ f"Peak {i}\n{center}",
225
+ color="red",
226
+ ha=ha,
227
+ va="bottom",
228
+ fontsize=8,
229
+ )
230
+
231
+ if save_plot and output_dir is not None:
232
+ tag = date_tag or "output"
233
+ # include ref in filename
234
+ safe_ref = str(ref).replace("/", "_")
235
+ safe_layer = str(layer_name).replace("/", "_")
236
+ fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
237
+ plt.savefig(fname, bbox_inches="tight", dpi=200)
238
+ print(f"[call_hmm_peaks] Saved plot to {fname}")
239
+ plt.close()
240
+ else:
241
+ plt.tight_layout()
242
+ plt.show()
243
+
244
+ feature_peak_cols = []
245
+
246
+ # ---- Per-peak annotations (within this ref) ----
247
+ for center in peak_centers:
248
+ start = center - peak_width // 2
249
+ end = center + peak_width // 2
250
+
251
+ # Make column names ref- and layer-specific so they don't collide
252
+ colname = f"{layer_name}_{ref}_peak_{center}"
253
+ feature_peak_cols.append(colname)
254
+ all_peak_var_cols.append(colname)
255
+
256
+ # Var-level mask: is this position in the window?
257
+ peak_mask = (coordinates >= start) & (coordinates <= end)
258
+ adata.var[colname] = peak_mask
259
+
260
+ # Extract signal in that window from the *ref subset* matrix
261
+ region = matrix[:, peak_mask] # (n_ref_reads, n_positions_in_window)
262
+
263
+ # Per-read summary in this window for the feature layer itself
264
+ mean_col = f"mean_{layer_name}_{ref}_around_{center}"
265
+ sum_col = f"sum_{layer_name}_{ref}_around_{center}"
266
+ present_col = f"{layer_name}_{ref}_present_at_{center}"
267
+
268
+ # Create columns if missing, then fill only the ref rows
269
+ if mean_col not in adata.obs:
270
+ adata.obs[mean_col] = np.nan
271
+ if sum_col not in adata.obs:
272
+ adata.obs[sum_col] = 0.0
273
+ if present_col not in adata.obs:
274
+ adata.obs[present_col] = False
275
+
276
+ adata.obs.loc[ref_mask, mean_col] = np.nanmean(region, axis=1)
277
+ adata.obs.loc[ref_mask, sum_col] = np.nansum(region, axis=1)
278
+ adata.obs.loc[ref_mask, present_col] = (
279
+ adata.obs.loc[ref_mask, mean_col].values > peak_threshold
280
+ )
281
+
282
+ # Initialize site-type summaries (global columns; filled per ref)
283
+ for site_type in site_types:
284
+ sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
285
+ mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
286
+ if sum_site_col not in adata.obs:
287
+ adata.obs[sum_site_col] = 0.0
288
+ if mean_site_col not in adata.obs:
289
+ adata.obs[mean_site_col] = np.nan
290
+
291
+ # Per-site-type summaries for this ref
292
+ for site_type in site_types:
293
+ mask_key = f"{ref}_{site_type}_site"
294
+ if mask_key not in adata.var:
295
+ continue
296
+
297
+ site_mask = adata.var[mask_key].values.astype(bool)
298
+ if not site_mask.any():
299
+ continue
300
+
301
+ site_coords = coordinates[site_mask]
302
+ region_mask = (site_coords >= start) & (site_coords <= end)
303
+ if not region_mask.any():
304
+ continue
305
+
306
+ full_mask = np.zeros_like(site_mask, dtype=bool)
307
+ full_mask[site_mask] = region_mask
308
+
309
+ site_region = adata[ref_mask, full_mask].X
310
+ if hasattr(site_region, "A"):
311
+ site_region = site_region.A # sparse -> dense
312
+
313
+ if site_region.shape[1] == 0:
314
+ continue
315
+
316
+ sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
317
+ mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
318
+
319
+ adata.obs.loc[ref_mask, sum_site_col] = np.nansum(site_region, axis=1)
320
+ adata.obs.loc[ref_mask, mean_site_col] = np.nanmean(site_region, axis=1)
99
321
 
100
- adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
101
- print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
322
+ # Mark "any peak" for this (layer, ref)
323
+ any_col = f"is_in_any_{layer_name}_peak_{ref}"
324
+ adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
325
+ print(
326
+ f"[call_hmm_peaks] Annotated {len(peak_centers)} peaks "
327
+ f"for layer '{layer_name}' in ref '{ref}'."
328
+ )
102
329
 
103
- adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
- adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
330
+ # Global any-peak flag across all feature layers and references
331
+ if all_peak_var_cols:
332
+ adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
105
333
 
106
- return adata if not inplace else None
334
+ return None if inplace else adata