smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py CHANGED
@@ -1,14 +1,24 @@
1
- from .call_hmm_peaks import call_hmm_peaks
2
- from .display_hmm import display_hmm
3
- from .hmm_readwrite import load_hmm, save_hmm
4
- from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
5
-
6
-
7
- __all__ = [
8
- "call_hmm_peaks",
9
- "display_hmm",
10
- "load_hmm",
11
- "refine_nucleosome_calls",
12
- "infer_nucleosomes_in_large_bound",
13
- "save_hmm",
14
- ]
1
+ from __future__ import annotations
2
+
3
+ from importlib import import_module
4
+
5
+ _LAZY_ATTRS = {
6
+ "call_hmm_peaks": "smftools.hmm.call_hmm_peaks",
7
+ "display_hmm": "smftools.hmm.display_hmm",
8
+ "load_hmm": "smftools.hmm.hmm_readwrite",
9
+ "save_hmm": "smftools.hmm.hmm_readwrite",
10
+ "infer_nucleosomes_in_large_bound": "smftools.hmm.nucleosome_hmm_refinement",
11
+ "refine_nucleosome_calls": "smftools.hmm.nucleosome_hmm_refinement",
12
+ }
13
+
14
+
15
+ def __getattr__(name: str):
16
+ if name in _LAZY_ATTRS:
17
+ module = import_module(_LAZY_ATTRS[name])
18
+ attr = getattr(module, name)
19
+ globals()[name] = attr
20
+ return attr
21
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
22
+
23
+
24
+ __all__ = list(_LAZY_ATTRS.keys())
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import torch
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # calculate_distances
2
4
 
3
5
  def calculate_distances(intervals, threshold=0.9):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  def call_hmm_peaks(
2
4
  adata,
3
5
  feature_configs,
@@ -8,6 +10,21 @@ def call_hmm_peaks(
8
10
  date_tag=None,
9
11
  inplace=False
10
12
  ):
13
+ """Call peaks from HMM feature layers and annotate AnnData.
14
+
15
+ Args:
16
+ adata: AnnData containing feature layers.
17
+ feature_configs: Mapping of layer name to peak config.
18
+ obs_column: Obs column for reference categories.
19
+ site_types: Site types to summarize around peaks.
20
+ save_plot: Whether to save peak plots.
21
+ output_dir: Output directory for plots.
22
+ date_tag: Optional tag for plot filenames.
23
+ inplace: Whether to modify AnnData in place.
24
+
25
+ Returns:
26
+ Annotated AnnData with peak masks and summary columns.
27
+ """
11
28
  import numpy as np
12
29
  import pandas as pd
13
30
  import matplotlib.pyplot as plt
@@ -103,4 +120,4 @@ def call_hmm_peaks(
103
120
  adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
121
  adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
105
122
 
106
- return adata if not inplace else None
123
+ return adata if not inplace else None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  def train_hmm(
2
4
  data,
3
5
  emission_probs=[[0.8, 0.2], [0.2, 0.8]],
@@ -1,5 +1,14 @@
1
- from typing import Dict, Optional, Any, Union, Sequence
1
+ from __future__ import annotations
2
+
3
+ # FILE: smftools/hmm/call_hmm_peaks.py
2
4
  from pathlib import Path
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
+
7
+ from smftools.logging_utils import get_logger
8
+ from smftools.optional_imports import require
9
+
10
+ logger = get_logger(__name__)
11
+
3
12
 
4
13
  def call_hmm_peaks(
5
14
  adata,
@@ -14,96 +23,77 @@ def call_hmm_peaks(
14
23
  alternate_labels: bool = False,
15
24
  ):
16
25
  """
17
- Call peaks on one or more HMM-derived (or other) layers and annotate adata.var / adata.obs,
18
- doing peak calling *within each reference subset*.
19
-
20
- Parameters
21
- ----------
22
- adata : AnnData
23
- Input AnnData with layers already containing feature tracks (e.g. HMM-derived masks).
24
- feature_configs : dict
25
- Mapping: feature_type_or_layer_suffix -> {
26
- "min_distance": int (default 200),
27
- "peak_width": int (default 200),
28
- "peak_prominence": float (default 0.2),
29
- "peak_threshold": float (default 0.8),
30
- }
31
-
32
- Keys are usually *feature types* like "all_accessible_features" or
33
- "small_bound_stretch". These are matched against existing HMM layers
34
- (e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
35
- using a suffix match. You can also pass full layer names if you wish.
36
- ref_column : str
37
- Column in adata.obs defining reference groups (e.g. "Reference_strand").
38
- site_types : sequence of str
39
- Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
40
- e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
41
- save_plot : bool
42
- If True, save peak diagnostic plots instead of just showing them.
43
- output_dir : path-like or None
44
- Directory for saved plots (created if needed).
45
- date_tag : str or None
46
- Optional tag to prefix plot filenames.
47
- inplace : bool
48
- If False, operate on a copy and return it. If True, modify adata and return None.
49
- index_col_suffix : str or None
50
- If None, coordinates come from adata.var_names (cast to int when possible).
51
- If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
52
- coordinate system (e.g. a reindexed coordinate).
53
-
54
- Returns
55
- -------
56
- None or AnnData
26
+ Peak calling over HMM (or other) layers, per reference group and per layer.
27
+ Writes:
28
+ - adata.uns["{layer}_{ref}_peak_centers"] = list of centers
29
+ - adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
30
+ - adata.obs per-read summaries for each peak window:
31
+ mean_{layer}_{ref}_around_{center}
32
+ sum_{layer}_{ref}_around_{center}
33
+ {layer}_{ref}_present_at_{center} (bool)
34
+ and per site-type:
35
+ sum_{layer}_{site}_{ref}_around_{center}
36
+ mean_{layer}_{site}_{ref}_around_{center}
37
+ - adata.var["is_in_any_{layer}_peak_{ref}"]
38
+ - adata.var["is_in_any_peak"] (global)
57
39
  """
58
40
  import numpy as np
59
41
  import pandas as pd
60
- import matplotlib.pyplot as plt
61
42
  from scipy.signal import find_peaks
62
43
  from scipy.sparse import issparse
63
44
 
45
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM peak plots")
46
+
64
47
  if not inplace:
65
48
  adata = adata.copy()
66
49
 
67
- # Ensure ref_column is categorical
50
+ if ref_column not in adata.obs:
51
+ raise KeyError(f"obs column '{ref_column}' not found")
52
+
53
+ # Ensure categorical for predictable ref iteration
68
54
  if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
69
55
  adata.obs[ref_column] = adata.obs[ref_column].astype("category")
70
56
 
71
- # Base coordinates (fallback)
57
+ # Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
58
+ if getattr(adata.obs.columns, "duplicated", None) is not None:
59
+ if adata.obs.columns.duplicated().any():
60
+ adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
61
+
62
+ # Fallback coordinates from var_names
72
63
  try:
73
64
  base_coordinates = adata.var_names.astype(int).values
74
65
  except Exception:
75
66
  base_coordinates = np.arange(adata.n_vars, dtype=int)
76
67
 
68
+ # Output dir
77
69
  if output_dir is not None:
78
70
  output_dir = Path(output_dir)
79
71
  output_dir.mkdir(parents=True, exist_ok=True)
80
72
 
81
- # HMM layers known to the object (if present)
82
- hmm_layers = list(adata.uns.get("hmm_appended_layers", [])) or []
83
- # keep only the binary masks, not *_lengths
84
- hmm_layers = [layer for layer in hmm_layers if not layer.endswith("_lengths")]
85
-
86
- # Fallback: use all layer names if hmm_appended_layers is empty/missing
87
- all_layer_names = list(adata.layers.keys())
73
+ # Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
74
+ declared = list(adata.uns.get("hmm_appended_layers", []) or [])
75
+ search_pool = [
76
+ layer
77
+ for layer in declared
78
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
79
+ ]
88
80
 
89
81
  all_peak_var_cols = []
90
82
 
91
- # Iterate over each reference separately
83
+ # Iterate per reference
92
84
  for ref in adata.obs[ref_column].cat.categories:
93
85
  ref_mask = (adata.obs[ref_column] == ref).values
94
86
  if not ref_mask.any():
95
87
  continue
96
88
 
97
- # Per-ref coordinates: either from a reindexed column or global fallback
89
+ # Per-ref coordinate system
98
90
  if index_col_suffix is not None:
99
91
  coord_col = f"{ref}_{index_col_suffix}"
100
92
  if coord_col not in adata.var:
101
93
  raise KeyError(
102
- f"index_col_suffix='{index_col_suffix}' requested, "
103
- f"but var column '{coord_col}' is missing for ref '{ref}'."
94
+ f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
104
95
  )
105
96
  coord_vals = adata.var[coord_col].values
106
- # Try to coerce to numeric
107
97
  try:
108
98
  coordinates = coord_vals.astype(int)
109
99
  except Exception:
@@ -111,184 +101,159 @@ def call_hmm_peaks(
111
101
  else:
112
102
  coordinates = base_coordinates
113
103
 
114
- # Resolve each feature_config key to one or more actual layer names
104
+ if coordinates.shape[0] != adata.n_vars:
105
+ raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
106
+
107
+ # Feature keys to consider
115
108
  for feature_key, config in feature_configs.items():
116
- # Candidate search space: HMM layers if present, else all layers
117
- search_layers = hmm_layers if hmm_layers else all_layer_names
118
-
119
- candidate_layers = []
120
-
121
- # First: exact match
122
- for lname in search_layers:
123
- if lname == feature_key:
124
- candidate_layers.append(lname)
125
-
126
- # Second: suffix match (e.g. "all_accessible_features" ->
127
- # "GpC_all_accessible_features", "Combined_all_accessible_features", etc.)
128
- if not candidate_layers:
129
- for lname in search_layers:
130
- if lname.endswith(feature_key):
131
- candidate_layers.append(lname)
132
-
133
- # Third: if user passed a full layer name that wasn't in hmm_layers,
134
- # but does exist in adata.layers, allow it.
135
- if not candidate_layers and feature_key in adata.layers:
136
- candidate_layers.append(feature_key)
137
-
138
- if not candidate_layers:
139
- print(
140
- f"[call_hmm_peaks] WARNING: no layers found matching feature key "
141
- f"'{feature_key}' in ref '{ref}'. Skipping."
109
+ # Resolve candidate layers: exact suffix direct present
110
+ candidates = [ln for ln in search_pool if ln == feature_key]
111
+ if not candidates:
112
+ candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
113
+ if not candidates and feature_key in adata.layers:
114
+ candidates = [feature_key]
115
+
116
+ if not candidates:
117
+ logger.warning(
118
+ "[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
119
+ feature_key,
120
+ ref,
142
121
  )
143
122
  continue
144
123
 
145
- # Run peak calling on each resolved layer for this ref
146
- for layer_name in candidate_layers:
124
+ # Hyperparams (sanitized)
125
+ min_distance = max(1, int(config.get("min_distance", 200)))
126
+ peak_width = max(1, int(config.get("peak_width", 200)))
127
+ peak_prom = float(config.get("peak_prominence", 0.2))
128
+ peak_threshold = float(config.get("peak_threshold", 0.8))
129
+ rolling_window = max(1, int(config.get("rolling_window", 1)))
130
+
131
+ for layer_name in candidates:
147
132
  if layer_name not in adata.layers:
148
- print(
149
- f"[call_hmm_peaks] WARNING: resolved layer '{layer_name}' "
150
- f"not found in adata.layers; skipping."
133
+ logger.warning(
134
+ "[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
135
+ layer_name,
151
136
  )
152
137
  continue
153
138
 
154
- min_distance = int(config.get("min_distance", 200))
155
- peak_width = int(config.get("peak_width", 200))
156
- peak_prominence = float(config.get("peak_prominence", 0.2))
157
- peak_threshold = float(config.get("peak_threshold", 0.8))
158
-
159
- layer_data = adata.layers[layer_name]
160
- if issparse(layer_data):
161
- layer_data = layer_data.toarray()
162
- else:
163
- layer_data = np.asarray(layer_data)
139
+ # Dense layer data
140
+ L = adata.layers[layer_name]
141
+ L = L.toarray() if issparse(L) else np.asarray(L)
142
+ if L.shape != (adata.n_obs, adata.n_vars):
143
+ logger.warning(
144
+ "[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
145
+ layer_name,
146
+ L.shape,
147
+ adata.n_obs,
148
+ adata.n_vars,
149
+ )
150
+ continue
164
151
 
165
- # Subset rows for this ref
166
- matrix = layer_data[ref_mask, :] # (n_ref_reads, n_vars)
167
- if matrix.shape[0] == 0:
152
+ # Ref subset
153
+ matrix = L[ref_mask, :]
154
+ if matrix.size == 0 or matrix.shape[0] == 0:
168
155
  continue
169
156
 
170
- # Mean signal along positions (within this ref only)
171
157
  means = np.nanmean(matrix, axis=0)
158
+ means = np.nan_to_num(means, nan=0.0)
172
159
 
173
- # Optional rolling-mean smoothing before peak detection
174
- rolling_window = int(config.get("rolling_window", 1))
175
160
  if rolling_window > 1:
176
- # Simple centered rolling mean via convolution
177
161
  kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
178
- smoothed = np.convolve(means, kernel, mode="same")
179
- peak_metric = smoothed
162
+ peak_metric = np.convolve(means, kernel, mode="same")
180
163
  else:
181
164
  peak_metric = means
182
165
 
183
166
  # Peak detection
184
167
  peak_indices, _ = find_peaks(
185
- peak_metric, prominence=peak_prominence, distance=min_distance
168
+ peak_metric, prominence=peak_prom, distance=min_distance
186
169
  )
187
170
  if peak_indices.size == 0:
188
- print(
189
- f"[call_hmm_peaks] No peaks found for layer '{layer_name}' "
190
- f"in ref '{ref}'."
171
+ logger.info(
172
+ "[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
173
+ layer_name,
174
+ ref,
191
175
  )
192
176
  continue
193
177
 
194
178
  peak_centers = coordinates[peak_indices]
195
- # Store per-ref peak centers
196
179
  adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
197
180
 
198
- # ---- Plot ----
199
- plt.figure(figsize=(6, 3))
200
- plt.plot(coordinates, peak_metric, linewidth=1)
201
- plt.title(f"{layer_name} peaks in {ref}")
202
- plt.xlabel("Coordinate")
203
- plt.ylabel(f"Rolling Mean - roll size {rolling_window}")
204
-
181
+ # Plot once per layer/ref
182
+ fig, ax = plt.subplots(figsize=(6, 3))
183
+ ax.plot(coordinates, peak_metric, linewidth=1)
184
+ ax.set_title(f"{layer_name} peaks in {ref}")
185
+ ax.set_xlabel("Coordinate")
186
+ ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
205
187
  for i, center in enumerate(peak_centers):
206
188
  start = center - peak_width // 2
207
189
  end = center + peak_width // 2
208
190
  height = peak_metric[peak_indices[i]]
209
- plt.axvspan(start, end, color="purple", alpha=0.2)
210
- plt.axvline(center, color="red", linestyle="--", linewidth=0.8)
211
-
212
- # alternate label placement a bit left/right
213
- if alternate_labels:
214
- if i % 2 == 0:
215
- x_text, ha = start, "right"
216
- else:
217
- x_text, ha = end, "left"
218
- else:
219
- x_text, ha = start, "right"
220
-
221
- plt.text(
222
- x_text,
223
- height * 0.8,
224
- f"Peak {i}\n{center}",
225
- color="red",
226
- ha=ha,
227
- va="bottom",
228
- fontsize=8,
191
+ ax.axvspan(start, end, alpha=0.2)
192
+ ax.axvline(center, linestyle="--", linewidth=0.8)
193
+ x_text, ha = (
194
+ (start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
195
+ )
196
+ ax.text(
197
+ x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
229
198
  )
230
199
 
231
200
  if save_plot and output_dir is not None:
232
201
  tag = date_tag or "output"
233
- # include ref in filename
234
202
  safe_ref = str(ref).replace("/", "_")
235
203
  safe_layer = str(layer_name).replace("/", "_")
236
204
  fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
237
- plt.savefig(fname, bbox_inches="tight", dpi=200)
238
- print(f"[call_hmm_peaks] Saved plot to {fname}")
239
- plt.close()
205
+ fig.savefig(fname, bbox_inches="tight", dpi=200)
206
+ logger.info("[call_hmm_peaks] Saved plot to %s", fname)
207
+ plt.close(fig)
240
208
  else:
241
- plt.tight_layout()
209
+ fig.tight_layout()
242
210
  plt.show()
243
211
 
212
+ # Collect new obs columns; assign once per layer/ref
213
+ new_obs_cols: Dict[str, np.ndarray] = {}
244
214
  feature_peak_cols = []
245
215
 
246
- # ---- Per-peak annotations (within this ref) ----
247
- for center in peak_centers:
216
+ for center in np.asarray(peak_centers).tolist():
248
217
  start = center - peak_width // 2
249
218
  end = center + peak_width // 2
250
219
 
251
- # Make column names ref- and layer-specific so they don't collide
220
+ # var window mask
252
221
  colname = f"{layer_name}_{ref}_peak_{center}"
253
222
  feature_peak_cols.append(colname)
254
223
  all_peak_var_cols.append(colname)
255
-
256
- # Var-level mask: is this position in the window?
257
224
  peak_mask = (coordinates >= start) & (coordinates <= end)
258
225
  adata.var[colname] = peak_mask
259
226
 
260
- # Extract signal in that window from the *ref subset* matrix
261
- region = matrix[:, peak_mask] # (n_ref_reads, n_positions_in_window)
227
+ # feature-layer summaries for reads in this ref
228
+ region = matrix[:, peak_mask] # (n_ref, n_window)
262
229
 
263
- # Per-read summary in this window for the feature layer itself
264
230
  mean_col = f"mean_{layer_name}_{ref}_around_{center}"
265
231
  sum_col = f"sum_{layer_name}_{ref}_around_{center}"
266
232
  present_col = f"{layer_name}_{ref}_present_at_{center}"
267
233
 
268
- # Create columns if missing, then fill only the ref rows
269
- if mean_col not in adata.obs:
270
- adata.obs[mean_col] = np.nan
271
- if sum_col not in adata.obs:
272
- adata.obs[sum_col] = 0.0
273
- if present_col not in adata.obs:
274
- adata.obs[present_col] = False
275
-
276
- adata.obs.loc[ref_mask, mean_col] = np.nanmean(region, axis=1)
277
- adata.obs.loc[ref_mask, sum_col] = np.nansum(region, axis=1)
278
- adata.obs.loc[ref_mask, present_col] = (
279
- adata.obs.loc[ref_mask, mean_col].values > peak_threshold
234
+ for nm, default, dt in (
235
+ (mean_col, np.nan, float),
236
+ (sum_col, 0.0, float),
237
+ (present_col, False, bool),
238
+ ):
239
+ if nm not in new_obs_cols:
240
+ new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
241
+
242
+ if region.shape[1] > 0:
243
+ means_per_read = np.nanmean(region, axis=1)
244
+ sums_per_read = np.nansum(region, axis=1)
245
+ else:
246
+ means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
247
+ sums_per_read = np.zeros(matrix.shape[0], dtype=float)
248
+
249
+ new_obs_cols[mean_col][ref_mask] = means_per_read
250
+ new_obs_cols[sum_col][ref_mask] = sums_per_read
251
+ new_obs_cols[present_col][ref_mask] = (
252
+ np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
280
253
  )
281
254
 
282
- # Initialize site-type summaries (global columns; filled per ref)
283
- for site_type in site_types:
284
- sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
285
- mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
286
- if sum_site_col not in adata.obs:
287
- adata.obs[sum_site_col] = 0.0
288
- if mean_site_col not in adata.obs:
289
- adata.obs[mean_site_col] = np.nan
290
-
291
- # Per-site-type summaries for this ref
255
+ # site-type summaries from adata.X, not an AnnData view
256
+ Xmat = adata.X
292
257
  for site_type in site_types:
293
258
  mask_key = f"{ref}_{site_type}_site"
294
259
  if mask_key not in adata.var:
@@ -299,35 +264,53 @@ def call_hmm_peaks(
299
264
  continue
300
265
 
301
266
  site_coords = coordinates[site_mask]
302
- region_mask = (site_coords >= start) & (site_coords <= end)
303
- if not region_mask.any():
267
+ site_region_mask = (site_coords >= start) & (site_coords <= end)
268
+ sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
269
+ mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
270
+
271
+ if sum_site_col not in new_obs_cols:
272
+ new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
273
+ if mean_site_col not in new_obs_cols:
274
+ new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
275
+
276
+ if not site_region_mask.any():
304
277
  continue
305
278
 
306
279
  full_mask = np.zeros_like(site_mask, dtype=bool)
307
- full_mask[site_mask] = region_mask
308
-
309
- site_region = adata[ref_mask, full_mask].X
310
- if hasattr(site_region, "A"):
311
- site_region = site_region.A # sparse -> dense
280
+ full_mask[site_mask] = site_region_mask
312
281
 
313
- if site_region.shape[1] == 0:
314
- continue
282
+ if issparse(Xmat):
283
+ site_region = Xmat[ref_mask][:, full_mask]
284
+ site_region = site_region.toarray()
285
+ else:
286
+ Xnp = np.asarray(Xmat)
287
+ site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
315
288
 
316
- sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
317
- mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
289
+ if site_region.shape[1] > 0:
290
+ new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
291
+ new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
318
292
 
319
- adata.obs.loc[ref_mask, sum_site_col] = np.nansum(site_region, axis=1)
320
- adata.obs.loc[ref_mask, mean_site_col] = np.nanmean(site_region, axis=1)
293
+ # one-shot assignment to avoid fragmentation
294
+ if new_obs_cols:
295
+ adata.obs = adata.obs.assign(
296
+ **{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
297
+ )
321
298
 
322
- # Mark "any peak" for this (layer, ref)
299
+ # per (layer, ref) any-peak
323
300
  any_col = f"is_in_any_{layer_name}_peak_{ref}"
324
- adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
325
- print(
326
- f"[call_hmm_peaks] Annotated {len(peak_centers)} peaks "
327
- f"for layer '{layer_name}' in ref '{ref}'."
301
+ if feature_peak_cols:
302
+ adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
303
+ else:
304
+ adata.var[any_col] = False
305
+
306
+ logger.info(
307
+ "[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
308
+ len(peak_centers),
309
+ layer_name,
310
+ ref,
328
311
  )
329
312
 
330
- # Global any-peak flag across all feature layers and references
313
+ # global any-peak across all layers/refs
331
314
  if all_peak_var_cols:
332
315
  adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
333
316
 
@@ -1,18 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from smftools.logging_utils import get_logger
4
+ from smftools.optional_imports import require
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
1
9
  def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
2
- import torch
3
- print("\n**HMM Model Overview**")
4
- print(hmm)
10
+ """Log a summary of HMM transition and emission parameters.
11
+
12
+ Args:
13
+ hmm: HMM object with edges and distributions.
14
+ state_labels: Optional labels for states.
15
+ obs_labels: Optional labels for observations.
16
+ """
17
+ torch = require("torch", extra="torch", purpose="HMM display")
18
+
19
+ logger.info("**HMM Model Overview**")
20
+ logger.info("%s", hmm)
5
21
 
6
- print("\n**Transition Matrix**")
22
+ logger.info("**Transition Matrix**")
7
23
  transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
8
24
  for i, row in enumerate(transition_matrix):
9
25
  label = state_labels[i] if state_labels else f"State {i}"
10
26
  formatted_row = ", ".join(f"{p:.6f}" for p in row)
11
- print(f"{label}: [{formatted_row}]")
27
+ logger.info("%s: [%s]", label, formatted_row)
12
28
 
13
- print("\n**Emission Probabilities**")
29
+ logger.info("**Emission Probabilities**")
14
30
  for i, dist in enumerate(hmm.distributions):
15
31
  label = state_labels[i] if state_labels else f"State {i}"
16
32
  probs = dist.probs.detach().cpu().numpy()
17
33
  formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
18
- print(f"{label}: {formatted_emissions}")
34
+ logger.info("%s: %s", label, formatted_emissions)