smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from .call_hmm_peaks import call_hmm_peaks
2
2
  from .display_hmm import display_hmm
3
3
  from .hmm_readwrite import load_hmm, save_hmm
4
- from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
5
-
4
+ from .nucleosome_hmm_refinement import infer_nucleosomes_in_large_bound, refine_nucleosome_calls
6
5
 
7
6
  __all__ = [
8
7
  "call_hmm_peaks",
@@ -11,4 +10,4 @@ __all__ = [
11
10
  "refine_nucleosome_calls",
12
11
  "infer_nucleosomes_in_large_bound",
13
12
  "save_hmm",
14
- ]
13
+ ]
@@ -8,6 +8,21 @@ def call_hmm_peaks(
8
8
  date_tag=None,
9
9
  inplace=False
10
10
  ):
11
+ """Call peaks from HMM feature layers and annotate AnnData.
12
+
13
+ Args:
14
+ adata: AnnData containing feature layers.
15
+ feature_configs: Mapping of layer name to peak config.
16
+ obs_column: Obs column for reference categories.
17
+ site_types: Site types to summarize around peaks.
18
+ save_plot: Whether to save peak plots.
19
+ output_dir: Output directory for plots.
20
+ date_tag: Optional tag for plot filenames.
21
+ inplace: Whether to modify AnnData in place.
22
+
23
+ Returns:
24
+ Annotated AnnData with peak masks and summary columns.
25
+ """
11
26
  import numpy as np
12
27
  import pandas as pd
13
28
  import matplotlib.pyplot as plt
@@ -103,4 +118,4 @@ def call_hmm_peaks(
103
118
  adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
119
  adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
105
120
 
106
- return adata if not inplace else None
121
+ return adata if not inplace else None
@@ -1,5 +1,12 @@
1
- from typing import Dict, Optional, Any, Union, Sequence
1
+ # FILE: smftools/hmm/call_hmm_peaks.py
2
+
2
3
  from pathlib import Path
4
+ from typing import Any, Dict, Optional, Sequence, Union
5
+
6
+ from smftools.logging_utils import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
3
10
 
4
11
  def call_hmm_peaks(
5
12
  adata,
@@ -14,96 +21,76 @@ def call_hmm_peaks(
14
21
  alternate_labels: bool = False,
15
22
  ):
16
23
  """
17
- Call peaks on one or more HMM-derived (or other) layers and annotate adata.var / adata.obs,
18
- doing peak calling *within each reference subset*.
19
-
20
- Parameters
21
- ----------
22
- adata : AnnData
23
- Input AnnData with layers already containing feature tracks (e.g. HMM-derived masks).
24
- feature_configs : dict
25
- Mapping: feature_type_or_layer_suffix -> {
26
- "min_distance": int (default 200),
27
- "peak_width": int (default 200),
28
- "peak_prominence": float (default 0.2),
29
- "peak_threshold": float (default 0.8),
30
- }
31
-
32
- Keys are usually *feature types* like "all_accessible_features" or
33
- "small_bound_stretch". These are matched against existing HMM layers
34
- (e.g. "GpC_all_accessible_features", "Combined_small_bound_stretch")
35
- using a suffix match. You can also pass full layer names if you wish.
36
- ref_column : str
37
- Column in adata.obs defining reference groups (e.g. "Reference_strand").
38
- site_types : sequence of str
39
- Site types (without "_site"); expects var columns like f"{ref}_{site_type}_site".
40
- e.g. ("GpC", "CpG") -> "6B6_top_GpC_site", etc.
41
- save_plot : bool
42
- If True, save peak diagnostic plots instead of just showing them.
43
- output_dir : path-like or None
44
- Directory for saved plots (created if needed).
45
- date_tag : str or None
46
- Optional tag to prefix plot filenames.
47
- inplace : bool
48
- If False, operate on a copy and return it. If True, modify adata and return None.
49
- index_col_suffix : str or None
50
- If None, coordinates come from adata.var_names (cast to int when possible).
51
- If set, for each ref we use adata.var[f"{ref}_{index_col_suffix}"] as the
52
- coordinate system (e.g. a reindexed coordinate).
53
-
54
- Returns
55
- -------
56
- None or AnnData
24
+ Peak calling over HMM (or other) layers, per reference group and per layer.
25
+ Writes:
26
+ - adata.uns["{layer}_{ref}_peak_centers"] = list of centers
27
+ - adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
28
+ - adata.obs per-read summaries for each peak window:
29
+ mean_{layer}_{ref}_around_{center}
30
+ sum_{layer}_{ref}_around_{center}
31
+ {layer}_{ref}_present_at_{center} (bool)
32
+ and per site-type:
33
+ sum_{layer}_{site}_{ref}_around_{center}
34
+ mean_{layer}_{site}_{ref}_around_{center}
35
+ - adata.var["is_in_any_{layer}_peak_{ref}"]
36
+ - adata.var["is_in_any_peak"] (global)
57
37
  """
38
+ import matplotlib.pyplot as plt
58
39
  import numpy as np
59
40
  import pandas as pd
60
- import matplotlib.pyplot as plt
61
41
  from scipy.signal import find_peaks
62
42
  from scipy.sparse import issparse
63
43
 
64
44
  if not inplace:
65
45
  adata = adata.copy()
66
46
 
67
- # Ensure ref_column is categorical
47
+ if ref_column not in adata.obs:
48
+ raise KeyError(f"obs column '{ref_column}' not found")
49
+
50
+ # Ensure categorical for predictable ref iteration
68
51
  if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
69
52
  adata.obs[ref_column] = adata.obs[ref_column].astype("category")
70
53
 
71
- # Base coordinates (fallback)
54
+ # Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
55
+ if getattr(adata.obs.columns, "duplicated", None) is not None:
56
+ if adata.obs.columns.duplicated().any():
57
+ adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
58
+
59
+ # Fallback coordinates from var_names
72
60
  try:
73
61
  base_coordinates = adata.var_names.astype(int).values
74
62
  except Exception:
75
63
  base_coordinates = np.arange(adata.n_vars, dtype=int)
76
64
 
65
+ # Output dir
77
66
  if output_dir is not None:
78
67
  output_dir = Path(output_dir)
79
68
  output_dir.mkdir(parents=True, exist_ok=True)
80
69
 
81
- # HMM layers known to the object (if present)
82
- hmm_layers = list(adata.uns.get("hmm_appended_layers", [])) or []
83
- # keep only the binary masks, not *_lengths
84
- hmm_layers = [layer for layer in hmm_layers if not layer.endswith("_lengths")]
85
-
86
- # Fallback: use all layer names if hmm_appended_layers is empty/missing
87
- all_layer_names = list(adata.layers.keys())
70
+ # Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
71
+ declared = list(adata.uns.get("hmm_appended_layers", []) or [])
72
+ search_pool = [
73
+ layer
74
+ for layer in declared
75
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
76
+ ]
88
77
 
89
78
  all_peak_var_cols = []
90
79
 
91
- # Iterate over each reference separately
80
+ # Iterate per reference
92
81
  for ref in adata.obs[ref_column].cat.categories:
93
82
  ref_mask = (adata.obs[ref_column] == ref).values
94
83
  if not ref_mask.any():
95
84
  continue
96
85
 
97
- # Per-ref coordinates: either from a reindexed column or global fallback
86
+ # Per-ref coordinate system
98
87
  if index_col_suffix is not None:
99
88
  coord_col = f"{ref}_{index_col_suffix}"
100
89
  if coord_col not in adata.var:
101
90
  raise KeyError(
102
- f"index_col_suffix='{index_col_suffix}' requested, "
103
- f"but var column '{coord_col}' is missing for ref '{ref}'."
91
+ f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
104
92
  )
105
93
  coord_vals = adata.var[coord_col].values
106
- # Try to coerce to numeric
107
94
  try:
108
95
  coordinates = coord_vals.astype(int)
109
96
  except Exception:
@@ -111,184 +98,159 @@ def call_hmm_peaks(
111
98
  else:
112
99
  coordinates = base_coordinates
113
100
 
114
- # Resolve each feature_config key to one or more actual layer names
101
+ if coordinates.shape[0] != adata.n_vars:
102
+ raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
103
+
104
+ # Feature keys to consider
115
105
  for feature_key, config in feature_configs.items():
116
- # Candidate search space: HMM layers if present, else all layers
117
- search_layers = hmm_layers if hmm_layers else all_layer_names
118
-
119
- candidate_layers = []
120
-
121
- # First: exact match
122
- for lname in search_layers:
123
- if lname == feature_key:
124
- candidate_layers.append(lname)
125
-
126
- # Second: suffix match (e.g. "all_accessible_features" ->
127
- # "GpC_all_accessible_features", "Combined_all_accessible_features", etc.)
128
- if not candidate_layers:
129
- for lname in search_layers:
130
- if lname.endswith(feature_key):
131
- candidate_layers.append(lname)
132
-
133
- # Third: if user passed a full layer name that wasn't in hmm_layers,
134
- # but does exist in adata.layers, allow it.
135
- if not candidate_layers and feature_key in adata.layers:
136
- candidate_layers.append(feature_key)
137
-
138
- if not candidate_layers:
139
- print(
140
- f"[call_hmm_peaks] WARNING: no layers found matching feature key "
141
- f"'{feature_key}' in ref '{ref}'. Skipping."
106
+ # Resolve candidate layers: exact suffix direct present
107
+ candidates = [ln for ln in search_pool if ln == feature_key]
108
+ if not candidates:
109
+ candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
110
+ if not candidates and feature_key in adata.layers:
111
+ candidates = [feature_key]
112
+
113
+ if not candidates:
114
+ logger.warning(
115
+ "[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
116
+ feature_key,
117
+ ref,
142
118
  )
143
119
  continue
144
120
 
145
- # Run peak calling on each resolved layer for this ref
146
- for layer_name in candidate_layers:
121
+ # Hyperparams (sanitized)
122
+ min_distance = max(1, int(config.get("min_distance", 200)))
123
+ peak_width = max(1, int(config.get("peak_width", 200)))
124
+ peak_prom = float(config.get("peak_prominence", 0.2))
125
+ peak_threshold = float(config.get("peak_threshold", 0.8))
126
+ rolling_window = max(1, int(config.get("rolling_window", 1)))
127
+
128
+ for layer_name in candidates:
147
129
  if layer_name not in adata.layers:
148
- print(
149
- f"[call_hmm_peaks] WARNING: resolved layer '{layer_name}' "
150
- f"not found in adata.layers; skipping."
130
+ logger.warning(
131
+ "[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
132
+ layer_name,
151
133
  )
152
134
  continue
153
135
 
154
- min_distance = int(config.get("min_distance", 200))
155
- peak_width = int(config.get("peak_width", 200))
156
- peak_prominence = float(config.get("peak_prominence", 0.2))
157
- peak_threshold = float(config.get("peak_threshold", 0.8))
158
-
159
- layer_data = adata.layers[layer_name]
160
- if issparse(layer_data):
161
- layer_data = layer_data.toarray()
162
- else:
163
- layer_data = np.asarray(layer_data)
136
+ # Dense layer data
137
+ L = adata.layers[layer_name]
138
+ L = L.toarray() if issparse(L) else np.asarray(L)
139
+ if L.shape != (adata.n_obs, adata.n_vars):
140
+ logger.warning(
141
+ "[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
142
+ layer_name,
143
+ L.shape,
144
+ adata.n_obs,
145
+ adata.n_vars,
146
+ )
147
+ continue
164
148
 
165
- # Subset rows for this ref
166
- matrix = layer_data[ref_mask, :] # (n_ref_reads, n_vars)
167
- if matrix.shape[0] == 0:
149
+ # Ref subset
150
+ matrix = L[ref_mask, :]
151
+ if matrix.size == 0 or matrix.shape[0] == 0:
168
152
  continue
169
153
 
170
- # Mean signal along positions (within this ref only)
171
154
  means = np.nanmean(matrix, axis=0)
155
+ means = np.nan_to_num(means, nan=0.0)
172
156
 
173
- # Optional rolling-mean smoothing before peak detection
174
- rolling_window = int(config.get("rolling_window", 1))
175
157
  if rolling_window > 1:
176
- # Simple centered rolling mean via convolution
177
158
  kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
178
- smoothed = np.convolve(means, kernel, mode="same")
179
- peak_metric = smoothed
159
+ peak_metric = np.convolve(means, kernel, mode="same")
180
160
  else:
181
161
  peak_metric = means
182
162
 
183
163
  # Peak detection
184
164
  peak_indices, _ = find_peaks(
185
- peak_metric, prominence=peak_prominence, distance=min_distance
165
+ peak_metric, prominence=peak_prom, distance=min_distance
186
166
  )
187
167
  if peak_indices.size == 0:
188
- print(
189
- f"[call_hmm_peaks] No peaks found for layer '{layer_name}' "
190
- f"in ref '{ref}'."
168
+ logger.info(
169
+ "[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
170
+ layer_name,
171
+ ref,
191
172
  )
192
173
  continue
193
174
 
194
175
  peak_centers = coordinates[peak_indices]
195
- # Store per-ref peak centers
196
176
  adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
197
177
 
198
- # ---- Plot ----
199
- plt.figure(figsize=(6, 3))
200
- plt.plot(coordinates, peak_metric, linewidth=1)
201
- plt.title(f"{layer_name} peaks in {ref}")
202
- plt.xlabel("Coordinate")
203
- plt.ylabel(f"Rolling Mean - roll size {rolling_window}")
204
-
178
+ # Plot once per layer/ref
179
+ fig, ax = plt.subplots(figsize=(6, 3))
180
+ ax.plot(coordinates, peak_metric, linewidth=1)
181
+ ax.set_title(f"{layer_name} peaks in {ref}")
182
+ ax.set_xlabel("Coordinate")
183
+ ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
205
184
  for i, center in enumerate(peak_centers):
206
185
  start = center - peak_width // 2
207
186
  end = center + peak_width // 2
208
187
  height = peak_metric[peak_indices[i]]
209
- plt.axvspan(start, end, color="purple", alpha=0.2)
210
- plt.axvline(center, color="red", linestyle="--", linewidth=0.8)
211
-
212
- # alternate label placement a bit left/right
213
- if alternate_labels:
214
- if i % 2 == 0:
215
- x_text, ha = start, "right"
216
- else:
217
- x_text, ha = end, "left"
218
- else:
219
- x_text, ha = start, "right"
220
-
221
- plt.text(
222
- x_text,
223
- height * 0.8,
224
- f"Peak {i}\n{center}",
225
- color="red",
226
- ha=ha,
227
- va="bottom",
228
- fontsize=8,
188
+ ax.axvspan(start, end, alpha=0.2)
189
+ ax.axvline(center, linestyle="--", linewidth=0.8)
190
+ x_text, ha = (
191
+ (start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
192
+ )
193
+ ax.text(
194
+ x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
229
195
  )
230
196
 
231
197
  if save_plot and output_dir is not None:
232
198
  tag = date_tag or "output"
233
- # include ref in filename
234
199
  safe_ref = str(ref).replace("/", "_")
235
200
  safe_layer = str(layer_name).replace("/", "_")
236
201
  fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
237
- plt.savefig(fname, bbox_inches="tight", dpi=200)
238
- print(f"[call_hmm_peaks] Saved plot to {fname}")
239
- plt.close()
202
+ fig.savefig(fname, bbox_inches="tight", dpi=200)
203
+ logger.info("[call_hmm_peaks] Saved plot to %s", fname)
204
+ plt.close(fig)
240
205
  else:
241
- plt.tight_layout()
206
+ fig.tight_layout()
242
207
  plt.show()
243
208
 
209
+ # Collect new obs columns; assign once per layer/ref
210
+ new_obs_cols: Dict[str, np.ndarray] = {}
244
211
  feature_peak_cols = []
245
212
 
246
- # ---- Per-peak annotations (within this ref) ----
247
- for center in peak_centers:
213
+ for center in np.asarray(peak_centers).tolist():
248
214
  start = center - peak_width // 2
249
215
  end = center + peak_width // 2
250
216
 
251
- # Make column names ref- and layer-specific so they don't collide
217
+ # var window mask
252
218
  colname = f"{layer_name}_{ref}_peak_{center}"
253
219
  feature_peak_cols.append(colname)
254
220
  all_peak_var_cols.append(colname)
255
-
256
- # Var-level mask: is this position in the window?
257
221
  peak_mask = (coordinates >= start) & (coordinates <= end)
258
222
  adata.var[colname] = peak_mask
259
223
 
260
- # Extract signal in that window from the *ref subset* matrix
261
- region = matrix[:, peak_mask] # (n_ref_reads, n_positions_in_window)
224
+ # feature-layer summaries for reads in this ref
225
+ region = matrix[:, peak_mask] # (n_ref, n_window)
262
226
 
263
- # Per-read summary in this window for the feature layer itself
264
227
  mean_col = f"mean_{layer_name}_{ref}_around_{center}"
265
228
  sum_col = f"sum_{layer_name}_{ref}_around_{center}"
266
229
  present_col = f"{layer_name}_{ref}_present_at_{center}"
267
230
 
268
- # Create columns if missing, then fill only the ref rows
269
- if mean_col not in adata.obs:
270
- adata.obs[mean_col] = np.nan
271
- if sum_col not in adata.obs:
272
- adata.obs[sum_col] = 0.0
273
- if present_col not in adata.obs:
274
- adata.obs[present_col] = False
275
-
276
- adata.obs.loc[ref_mask, mean_col] = np.nanmean(region, axis=1)
277
- adata.obs.loc[ref_mask, sum_col] = np.nansum(region, axis=1)
278
- adata.obs.loc[ref_mask, present_col] = (
279
- adata.obs.loc[ref_mask, mean_col].values > peak_threshold
231
+ for nm, default, dt in (
232
+ (mean_col, np.nan, float),
233
+ (sum_col, 0.0, float),
234
+ (present_col, False, bool),
235
+ ):
236
+ if nm not in new_obs_cols:
237
+ new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
238
+
239
+ if region.shape[1] > 0:
240
+ means_per_read = np.nanmean(region, axis=1)
241
+ sums_per_read = np.nansum(region, axis=1)
242
+ else:
243
+ means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
244
+ sums_per_read = np.zeros(matrix.shape[0], dtype=float)
245
+
246
+ new_obs_cols[mean_col][ref_mask] = means_per_read
247
+ new_obs_cols[sum_col][ref_mask] = sums_per_read
248
+ new_obs_cols[present_col][ref_mask] = (
249
+ np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
280
250
  )
281
251
 
282
- # Initialize site-type summaries (global columns; filled per ref)
283
- for site_type in site_types:
284
- sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
285
- mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
286
- if sum_site_col not in adata.obs:
287
- adata.obs[sum_site_col] = 0.0
288
- if mean_site_col not in adata.obs:
289
- adata.obs[mean_site_col] = np.nan
290
-
291
- # Per-site-type summaries for this ref
252
+ # site-type summaries from adata.X, not an AnnData view
253
+ Xmat = adata.X
292
254
  for site_type in site_types:
293
255
  mask_key = f"{ref}_{site_type}_site"
294
256
  if mask_key not in adata.var:
@@ -299,35 +261,53 @@ def call_hmm_peaks(
299
261
  continue
300
262
 
301
263
  site_coords = coordinates[site_mask]
302
- region_mask = (site_coords >= start) & (site_coords <= end)
303
- if not region_mask.any():
264
+ site_region_mask = (site_coords >= start) & (site_coords <= end)
265
+ sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
266
+ mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
267
+
268
+ if sum_site_col not in new_obs_cols:
269
+ new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
270
+ if mean_site_col not in new_obs_cols:
271
+ new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
272
+
273
+ if not site_region_mask.any():
304
274
  continue
305
275
 
306
276
  full_mask = np.zeros_like(site_mask, dtype=bool)
307
- full_mask[site_mask] = region_mask
277
+ full_mask[site_mask] = site_region_mask
308
278
 
309
- site_region = adata[ref_mask, full_mask].X
310
- if hasattr(site_region, "A"):
311
- site_region = site_region.A # sparse -> dense
312
-
313
- if site_region.shape[1] == 0:
314
- continue
279
+ if issparse(Xmat):
280
+ site_region = Xmat[ref_mask][:, full_mask]
281
+ site_region = site_region.toarray()
282
+ else:
283
+ Xnp = np.asarray(Xmat)
284
+ site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
315
285
 
316
- sum_site_col = f"{site_type}_{ref}_sum_around_{center}"
317
- mean_site_col = f"{site_type}_{ref}_mean_around_{center}"
286
+ if site_region.shape[1] > 0:
287
+ new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
288
+ new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
318
289
 
319
- adata.obs.loc[ref_mask, sum_site_col] = np.nansum(site_region, axis=1)
320
- adata.obs.loc[ref_mask, mean_site_col] = np.nanmean(site_region, axis=1)
290
+ # one-shot assignment to avoid fragmentation
291
+ if new_obs_cols:
292
+ adata.obs = adata.obs.assign(
293
+ **{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
294
+ )
321
295
 
322
- # Mark "any peak" for this (layer, ref)
296
+ # per (layer, ref) any-peak
323
297
  any_col = f"is_in_any_{layer_name}_peak_{ref}"
324
- adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
325
- print(
326
- f"[call_hmm_peaks] Annotated {len(peak_centers)} peaks "
327
- f"for layer '{layer_name}' in ref '{ref}'."
298
+ if feature_peak_cols:
299
+ adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
300
+ else:
301
+ adata.var[any_col] = False
302
+
303
+ logger.info(
304
+ "[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
305
+ len(peak_centers),
306
+ layer_name,
307
+ ref,
328
308
  )
329
309
 
330
- # Global any-peak flag across all feature layers and references
310
+ # global any-peak across all layers/refs
331
311
  if all_peak_var_cols:
332
312
  adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
333
313
 
@@ -1,18 +1,31 @@
1
+ from smftools.logging_utils import get_logger
2
+
3
+ logger = get_logger(__name__)
4
+
5
+
1
6
  def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
7
+ """Log a summary of HMM transition and emission parameters.
8
+
9
+ Args:
10
+ hmm: HMM object with edges and distributions.
11
+ state_labels: Optional labels for states.
12
+ obs_labels: Optional labels for observations.
13
+ """
2
14
  import torch
3
- print("\n**HMM Model Overview**")
4
- print(hmm)
5
15
 
6
- print("\n**Transition Matrix**")
16
+ logger.info("**HMM Model Overview**")
17
+ logger.info("%s", hmm)
18
+
19
+ logger.info("**Transition Matrix**")
7
20
  transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
8
21
  for i, row in enumerate(transition_matrix):
9
22
  label = state_labels[i] if state_labels else f"State {i}"
10
23
  formatted_row = ", ".join(f"{p:.6f}" for p in row)
11
- print(f"{label}: [{formatted_row}]")
24
+ logger.info("%s: [%s]", label, formatted_row)
12
25
 
13
- print("\n**Emission Probabilities**")
26
+ logger.info("**Emission Probabilities**")
14
27
  for i, dist in enumerate(hmm.distributions):
15
28
  label = state_labels[i] if state_labels else f"State {i}"
16
29
  probs = dist.probs.detach().cpu().numpy()
17
30
  formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
18
- print(f"{label}: {formatted_emissions}")
31
+ logger.info("%s: %s", label, formatted_emissions)
@@ -1,16 +1,25 @@
1
- def load_hmm(model_path, device='cpu'):
1
+ def load_hmm(model_path, device="cpu"):
2
2
  """
3
3
  Reads in a pretrained HMM.
4
-
4
+
5
5
  Parameters:
6
6
  model_path (str): Path to a pretrained HMM
7
7
  """
8
8
  import torch
9
+
9
10
  # Load model using PyTorch
10
11
  hmm = torch.load(model_path)
11
- hmm.to(device)
12
+ hmm.to(device)
12
13
  return hmm
13
14
 
15
+
14
16
  def save_hmm(model, model_path):
17
+ """Save a pretrained HMM to disk.
18
+
19
+ Args:
20
+ model: HMM model instance.
21
+ model_path: Output path for the model.
22
+ """
15
23
  import torch
16
- torch.save(model, model_path)
24
+
25
+ torch.save(model, model_path)