smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/cli/hmm_adata.py CHANGED
@@ -1,223 +1,855 @@
1
- def hmm_adata(config_path):
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, List, Optional, Sequence, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from smftools.logging_utils import get_logger
12
+
13
+ # FIX: import _to_dense_np to avoid NameError
14
+ from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ # =============================================================================
19
+ # Helpers: extracting training arrays
20
+ # =============================================================================
21
+
22
+
23
+ def _get_training_matrix(
24
+ subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
25
+ ) -> Tuple[np.ndarray, Optional[str]]:
26
+ """
27
+ Matches your existing behavior:
28
+ - direct -> uses cfg.output_binary_layer_name in .layers
29
+ - else -> uses .X
30
+ Returns (X, layer_name_or_None) where X is dense float array.
31
+ """
32
+ sub = subset[:, cols_mask]
33
+
34
+ if smf_modality == "direct":
35
+ hmm_layer = getattr(cfg, "output_binary_layer_name", None)
36
+ if hmm_layer is None or hmm_layer not in sub.layers:
37
+ raise KeyError(f"Missing HMM training layer '{hmm_layer}' in subset.")
38
+
39
+ logger.debug("Using direct modality HMM training layer: %s", hmm_layer)
40
+ mat = sub.layers[hmm_layer]
41
+ else:
42
+ logger.debug("Using .X for HMM training matrix")
43
+ hmm_layer = None
44
+ mat = sub.X
45
+
46
+ X = _to_dense_np(mat).astype(float)
47
+ if X.ndim != 2:
48
+ raise ValueError(f"Expected 2D training matrix; got {X.shape}")
49
+ return X, hmm_layer
50
+
51
+
52
+ def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
2
53
  """
3
- High-level function to call for hmm analysis of an adata object.
4
- Command line accesses this through smftools hmm <config_path>
54
+ Reproduces your mask resolution, with compatibility for both *_any_C_site and *_C_site.
55
+ """
56
+ key = str(methbase).strip().lower()
57
+
58
+ logger.debug("Resolving position mask for methbase=%s on ref=%s", key, ref)
59
+
60
+ if key in ("a",):
61
+ col = f"{ref}_A_site"
62
+ if col not in subset.var:
63
+ return None
64
+ logger.debug("Using positions with A calls from column: %s", col)
65
+ return np.asarray(subset.var[col])
66
+
67
+ if key in ("c", "any_c", "anyc", "any-c"):
68
+ for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
69
+ if col in subset.var:
70
+ logger.debug("Using positions with C calls from column: %s", col)
71
+ return np.asarray(subset.var[col])
72
+ return None
73
+
74
+ if key in ("gpc", "gpc_site", "gpc-site"):
75
+ col = f"{ref}_GpC_site"
76
+ if col not in subset.var:
77
+ return None
78
+ logger.debug("Using positions with GpC calls from column: %s", col)
79
+ return np.asarray(subset.var[col])
80
+
81
+ if key in ("cpg", "cpg_site", "cpg-site"):
82
+ col = f"{ref}_CpG_site"
83
+ if col not in subset.var:
84
+ return None
85
+ logger.debug("Using positions with CpG calls from column: %s", col)
86
+ return np.asarray(subset.var[col])
87
+
88
+ alt = f"{ref}_{methbase}_site"
89
+ if alt not in subset.var:
90
+ return None
91
+
92
+ logger.debug("Using positions from column: %s", alt)
93
+ return np.asarray(subset.var[alt])
94
+
95
+
96
+ def build_single_channel(
97
+ subset, ref: str, methbase: str, smf_modality: Optional[str], cfg
98
+ ) -> Tuple[np.ndarray, np.ndarray]:
99
+ """
100
+ Returns:
101
+ X (N, Lmb) float with NaNs allowed
102
+ coords (Lmb,) int coords from var_names
103
+ """
104
+ pm = _resolve_pos_mask_for_methbase(subset, ref, methbase)
105
+ logger.debug(
106
+ "Position mask for methbase=%s on ref=%s has %d sites",
107
+ methbase,
108
+ ref,
109
+ int(np.sum(pm)) if pm is not None else 0,
110
+ )
111
+
112
+ if pm is None or int(np.sum(pm)) == 0:
113
+ raise ValueError(f"No columns for methbase={methbase} on ref={ref}")
114
+
115
+ X, _ = _get_training_matrix(subset, pm, smf_modality, cfg)
116
+ logger.debug("Training matrix for methbase=%s on ref=%s has shape %s", methbase, ref, X.shape)
117
+
118
+ coords, _ = _safe_int_coords(subset[:, pm].var_names)
119
+ logger.debug(
120
+ "Coordinates for methbase=%s on ref=%s have length %d", methbase, ref, coords.shape[0]
121
+ )
122
+
123
+ return X, coords
5
124
 
6
- Parameters:
7
- config_path (str): A string representing the file path to the experiment configuration csv file.
125
+
126
+ def build_multi_channel_union(
127
+ subset, ref: str, methbases: Sequence[str], smf_modality: Optional[str], cfg
128
+ ) -> Tuple[np.ndarray, np.ndarray, List[str]]:
129
+ """
130
+ Build (N, Lunion, C) on union coordinate grid across methbases.
8
131
 
9
132
  Returns:
10
- (pp_dedup_spatial_hmm_adata, pp_dedup_spatial_hmm_adata_path)
133
+ X3d: (N, Lunion, C) float with NaN where methbase has no site
134
+ coords: (Lunion,) int union coords
135
+ used_methbases: list of methbases actually included (>=2)
136
+ """
137
+ per: List[Tuple[str, np.ndarray, np.ndarray, np.ndarray]] = [] # (mb, X, coords, pm)
138
+
139
+ for mb in methbases:
140
+ pm = _resolve_pos_mask_for_methbase(subset, ref, mb)
141
+ if pm is None or int(np.sum(pm)) == 0:
142
+ continue
143
+ Xmb, _ = _get_training_matrix(subset, pm, smf_modality, cfg) # (N,Lmb)
144
+ cmb, _ = _safe_int_coords(subset[:, pm].var_names)
145
+ per.append((mb, Xmb.astype(float), cmb.astype(int), pm))
146
+
147
+ if len(per) < 2:
148
+ raise ValueError(f"Need >=2 methbases with columns for union multi-channel on ref={ref}")
149
+
150
+ # union coordinates
151
+ coords = np.unique(np.concatenate([c for _, _, c, _ in per], axis=0)).astype(int)
152
+ idx = {int(v): i for i, v in enumerate(coords.tolist())}
153
+
154
+ N = per[0][1].shape[0]
155
+ L = coords.shape[0]
156
+ C = len(per)
157
+ X3 = np.full((N, L, C), np.nan, dtype=float)
158
+
159
+ for ci, (mb, Xmb, cmb, _) in enumerate(per):
160
+ cols = np.array([idx[int(v)] for v in cmb.tolist()], dtype=int)
161
+ X3[:, cols, ci] = Xmb
162
+
163
+ used = [mb for (mb, _, _, _) in per]
164
+ return X3, coords, used
165
+
166
+
167
+ @dataclass
168
+ class HMMTask:
169
+ name: str
170
+ signals: List[str] # e.g. ["GpC"] or ["GpC","CpG"] or ["CpG"]
171
+ feature_groups: List[str] # e.g. ["footprint","accessible"] or ["cpg"]
172
+ output_prefix: Optional[str] = None # force prefix (CpG task uses "CpG")
173
+
174
+
175
+ def build_hmm_tasks(cfg: Union[dict, Any]) -> List[HMMTask]:
176
+ """
177
+ Accessibility signals come from cfg['hmm_methbases'].
178
+ CpG task is enabled by cfg['cpg']==True, independent of hmm_methbases.
179
+ """
180
+ if not isinstance(cfg, dict):
181
+ # best effort conversion
182
+ cfg = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith("_")}
183
+
184
+ tasks: List[HMMTask] = []
185
+
186
+ # accessibility task
187
+ methbases = list(cfg.get("hmm_methbases", []) or [])
188
+ if len(methbases) > 0:
189
+ tasks.append(
190
+ HMMTask(
191
+ name="accessibility",
192
+ signals=methbases,
193
+ feature_groups=["footprint", "accessible"],
194
+ output_prefix=None,
195
+ )
196
+ )
197
+
198
+ # CpG task (special case)
199
+ if bool(cfg.get("cpg", False)):
200
+ tasks.append(
201
+ HMMTask(
202
+ name="cpg",
203
+ signals=["CpG"],
204
+ feature_groups=["cpg"],
205
+ output_prefix="CpG",
206
+ )
207
+ )
208
+
209
+ return tasks
210
+
211
+
212
+ def select_hmm_arch(cfg: dict, signals: Sequence[str]) -> str:
213
+ """
214
+ Simple, explicit model selection:
215
+ - distance-aware => 'single_distance_binned' (only meaningful for single-channel)
216
+ - multi-signal => 'multi'
217
+ - else => 'single'
218
+ """
219
+ if bool(cfg.get("hmm_distance_aware", False)) and len(signals) == 1:
220
+ return "single_distance_binned"
221
+ if len(signals) > 1:
222
+ return "multi"
223
+ return "single"
224
+
225
+
226
+ def resolve_input_layer(adata, cfg: dict, layer_override: Optional[str]) -> Optional[str]:
227
+ """
228
+ If direct modality, prefer cfg.output_binary_layer_name.
229
+ Else use layer_override or None (meaning use .X).
230
+ """
231
+ smf_modality = cfg.get("smf_modality", None)
232
+ if smf_modality == "direct":
233
+ nm = cfg.get("output_binary_layer_name", None)
234
+ if nm is None:
235
+ raise KeyError("cfg.output_binary_layer_name missing for smf_modality='direct'")
236
+ if nm not in adata.layers:
237
+ raise KeyError(f"Direct modality expects layer '{nm}' in adata.layers")
238
+ return nm
239
+ return layer_override
240
+
241
+
242
+ def _ensure_layer_and_assign_rows(adata, layer_name: str, row_mask: np.ndarray, subset_layer):
243
+ """
244
+ Writes subset_layer (n_subset_obs, n_vars) into adata.layers[layer_name] for rows where row_mask==True.
245
+ """
246
+ row_mask = np.asarray(row_mask, dtype=bool)
247
+ if row_mask.ndim != 1 or row_mask.size != adata.n_obs:
248
+ raise ValueError("row_mask must be length adata.n_obs")
249
+
250
+ arr = _to_dense_np(subset_layer)
251
+ if arr.shape != (int(row_mask.sum()), adata.n_vars):
252
+ raise ValueError(
253
+ f"subset layer '{layer_name}' shape {arr.shape} != ({int(row_mask.sum())}, {adata.n_vars})"
254
+ )
255
+
256
+ if layer_name not in adata.layers:
257
+ adata.layers[layer_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=arr.dtype)
258
+
259
+ adata.layers[layer_name][row_mask, :] = arr
260
+
261
+
262
+ def resolve_torch_device(device_str: str | None) -> torch.device:
263
+ d = (device_str or "auto").lower()
264
+ if d == "auto":
265
+ if torch.cuda.is_available():
266
+ return torch.device("cuda")
267
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
268
+ return torch.device("mps")
269
+ return torch.device("cpu")
270
+ return torch.device(d)
271
+
272
+
273
+ # =============================================================================
274
+ # Model selection + fit strategy manager
275
+ # =============================================================================
276
+ @dataclass
277
+ class HMMTrainer:
278
+ cfg: Any
279
+ models_dir: Path
280
+
281
+ def __post_init__(self):
282
+ self.models_dir = Path(self.models_dir)
283
+ self.models_dir.mkdir(parents=True, exist_ok=True)
284
+
285
+ def choose_arch(self, *, multichannel: bool) -> str:
286
+ use_dist = bool(getattr(self.cfg, "hmm_distance_aware", False))
287
+ if multichannel:
288
+ return "multi"
289
+ return "single_distance_binned" if use_dist else "single"
290
+
291
+ def _fit_scope(self) -> str:
292
+ return str(getattr(self.cfg, "hmm_fit_scope", "per_sample")).lower()
293
+ # "per_sample" | "global" | "global_then_adapt"
294
+
295
+ def _path(self, kind: str, sample: str, ref: str, label: str) -> Path:
296
+ # kind: "GLOBAL" | "PER" | "ADAPT"
297
+ def safe(s):
298
+ str(s).replace("/", "_")
299
+
300
+ return self.models_dir / f"{kind}_{safe(sample)}_{safe(ref)}_{safe(label)}.pt"
301
+
302
+ def _save(self, model, path: Path):
303
+ override = {}
304
+ if getattr(model, "hmm_name", None) == "multi":
305
+ override["hmm_n_channels"] = int(getattr(model, "n_channels", 2))
306
+ if getattr(model, "hmm_name", None) == "single_distance_binned":
307
+ override["hmm_distance_bins"] = list(
308
+ getattr(model, "distance_bins", [1, 5, 10, 25, 50, 100])
309
+ )
310
+
311
+ payload = {
312
+ "state_dict": model.state_dict(),
313
+ "hmm_arch": getattr(model, "hmm_name", None) or getattr(self.cfg, "hmm_arch", None),
314
+ "override": override,
315
+ }
316
+ torch.save(payload, path)
317
+
318
+ def _load(self, path: Path, arch: str, device):
319
+ payload = torch.load(path, map_location="cpu")
320
+ override = payload.get("override", None)
321
+ m = create_hmm(self.cfg, arch=arch, override=override, device=device)
322
+ sd = payload["state_dict"]
323
+
324
+ target_dtype = next(m.parameters()).dtype
325
+ for k, v in list(sd.items()):
326
+ if isinstance(v, torch.Tensor) and v.dtype != target_dtype:
327
+ sd[k] = v.to(dtype=target_dtype)
328
+
329
+ m.load_state_dict(sd)
330
+ m.to(device)
331
+ m.eval()
332
+ return m
333
+
334
+ def fit_or_load(
335
+ self,
336
+ *,
337
+ sample: str,
338
+ ref: str,
339
+ label: str,
340
+ arch: str,
341
+ X,
342
+ coords: Optional[np.ndarray],
343
+ device,
344
+ ):
345
+ force_fit = bool(getattr(self.cfg, "force_redo_hmm_fit", False))
346
+ scope = self._fit_scope()
347
+
348
+ max_iter = int(getattr(self.cfg, "hmm_max_iter", 50))
349
+ tol = float(getattr(self.cfg, "hmm_tol", 1e-4))
350
+ verbose = bool(getattr(self.cfg, "hmm_verbose", False))
351
+
352
+ # ---- global then adapt ----
353
+ if scope == "global_then_adapt":
354
+ p_global = self._path("GLOBAL", "ALL", ref, label)
355
+ if p_global.exists() and not force_fit:
356
+ base = self._load(p_global, arch=arch, device=device)
357
+ else:
358
+ base = create_hmm(self.cfg, arch=arch).to(device)
359
+ if arch == "single_distance_binned":
360
+ base.fit(
361
+ X, device=device, coords=coords, max_iter=max_iter, tol=tol, verbose=verbose
362
+ )
363
+ else:
364
+ base.fit(X, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
365
+ self._save(base, p_global)
366
+
367
+ p_adapt = self._path("ADAPT", sample, ref, label)
368
+ if p_adapt.exists() and not force_fit:
369
+ return self._load(p_adapt, arch=arch, device=device)
370
+
371
+ # IMPORTANT: this assumes you added model.adapt_emissions(...)
372
+ adapted = copy.deepcopy(base).to(device)
373
+ if arch == "single_distance_binned":
374
+ adapted.adapt_emissions(
375
+ X,
376
+ coords,
377
+ device=device,
378
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
379
+ verbose=verbose,
380
+ )
381
+
382
+ else:
383
+ adapted.adapt_emissions(
384
+ X,
385
+ coords,
386
+ device=device,
387
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
388
+ verbose=verbose,
389
+ )
390
+
391
+ self._save(adapted, p_adapt)
392
+ return adapted
393
+
394
+ # ---- global only ----
395
+ if scope == "global":
396
+ p = self._path("GLOBAL", "ALL", ref, label)
397
+ if p.exists() and not force_fit:
398
+ return self._load(p, arch=arch, device=device)
399
+
400
+ # ---- per sample ----
401
+ else:
402
+ p = self._path("PER", sample, ref, label)
403
+ if p.exists() and not force_fit:
404
+ return self._load(p, arch=arch, device=device)
405
+
406
+ m = create_hmm(self.cfg, arch=arch, device=device)
407
+ if arch == "single_distance_binned":
408
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
409
+ else:
410
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
411
+ self._save(m, p)
412
+ return m
413
+
414
+
415
+ def _fully_qualified_merge_layers(cfg, prefix: str) -> List[Tuple[str, int]]:
416
+ """
417
+ cfg.hmm_merge_layer_features is assumed to be a list of (core_layer_name, merge_distance),
418
+ where core_layer_name is like "all_accessible_features" (NOT prefixed with methbase).
419
+ We expand to f"{prefix}_{core_layer_name}".
420
+ """
421
+ out = []
422
+ for core_layer, dist in getattr(cfg, "hmm_merge_layer_features", []) or []:
423
+ if not core_layer:
424
+ continue
425
+ out.append((f"{prefix}_{core_layer}", int(dist)))
426
+ return out
427
+
428
+
429
+ def hmm_adata(config_path: str):
430
+ """
431
+ CLI-facing wrapper for HMM analysis.
432
+
433
+ Command line entrypoint:
434
+ smftools hmm <config_path>
435
+
436
+ Responsibilities:
437
+ - Build cfg via load_adata()
438
+ - Ensure preprocess + spatial stages are run
439
+ - Decide which AnnData to start from (hmm > spatial > pp_dedup > pp > raw)
440
+ - Call hmm_adata_core(cfg, adata, paths)
11
441
  """
12
- from ..readwrite import safe_read_h5ad, safe_write_h5ad, make_dirs, add_or_update_column_in_csv
442
+ from ..readwrite import safe_read_h5ad
443
+ from .helpers import get_adata_paths
13
444
  from .load_adata import load_adata
14
445
  from .preprocess_adata import preprocess_adata
15
446
  from .spatial_adata import spatial_adata
16
447
 
17
- import numpy as np
18
- import pandas as pd
19
- import anndata as ad
20
- import scanpy as sc
21
-
22
- import os
23
- from importlib import resources
24
- from pathlib import Path
448
+ # 1) load cfg / stage paths
449
+ _, _, cfg = load_adata(config_path)
450
+ paths = get_adata_paths(cfg)
25
451
 
26
- from datetime import datetime
27
- date_str = datetime.today().strftime("%y%m%d")
452
+ # 2) make sure upstream stages are run (they have their own skipping logic)
453
+ preprocess_adata(config_path)
454
+ spatial_ad, spatial_path = spatial_adata(config_path)
28
455
 
29
- ############################################### smftools load start ###############################################
30
- adata, adata_path, cfg = load_adata(config_path)
31
- # General config variable init - Necessary user passed inputs
32
- smf_modality = cfg.smf_modality # needed for specifying if the data is conversion SMF or direct methylation detection SMF. Or deaminase smf Necessary.
33
- output_directory = Path(cfg.output_directory) # Path to the output directory to make for the analysis. Necessary.
456
+ # 3) choose starting AnnData
457
+ # Prefer:
458
+ # - existing HMM h5ad if not forcing redo
459
+ # - in-memory spatial_ad from wrapper call
460
+ # - saved spatial / pp_dedup / pp / raw on disk
461
+ if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
462
+ adata, _ = safe_read_h5ad(paths.hmm)
463
+ return adata, paths.hmm
34
464
 
35
- # Make initial output directory
36
- make_dirs([output_directory])
37
- ############################################### smftools load end ###############################################
38
-
39
- ############################################### smftools preprocess start ###############################################
40
- pp_adata, pp_adata_path, pp_dedup_adata, pp_dup_rem_adata_path = preprocess_adata(config_path)
41
- ############################################### smftools preprocess end ###############################################
42
-
43
- ############################################### smftools spatial start ###############################################
44
- spatial_ad, spatial_adata_path = spatial_adata(config_path)
45
- ############################################### smftools spatial end ###############################################
46
-
47
- ############################################### smftools hmm start ###############################################
48
- input_manager_df = pd.read_csv(cfg.summary_file)
49
- initial_adata_path = Path(input_manager_df['load_adata'][0])
50
- pp_adata_path = Path(input_manager_df['pp_adata'][0])
51
- pp_dup_rem_adata_path = Path(input_manager_df['pp_dedup_adata'][0])
52
- spatial_adata_path = Path(input_manager_df['spatial_adata'][0])
53
- hmm_adata_path = Path(input_manager_df['hmm_adata'][0])
54
-
55
- if spatial_ad:
56
- # This happens on first run of the pipeline
465
+ if spatial_ad is not None:
57
466
  adata = spatial_ad
467
+ source_path = spatial_path
468
+ elif paths.spatial.exists():
469
+ adata, _ = safe_read_h5ad(paths.spatial)
470
+ source_path = paths.spatial
471
+ elif paths.pp_dedup.exists():
472
+ adata, _ = safe_read_h5ad(paths.pp_dedup)
473
+ source_path = paths.pp_dedup
474
+ elif paths.pp.exists():
475
+ adata, _ = safe_read_h5ad(paths.pp)
476
+ source_path = paths.pp
477
+ elif paths.raw.exists():
478
+ adata, _ = safe_read_h5ad(paths.raw)
479
+ source_path = paths.raw
58
480
  else:
59
- # If an anndata is saved, check which stages of the anndata are available
60
- initial_version_available = initial_adata_path.exists()
61
- preprocessed_version_available = pp_adata_path.exists()
62
- preprocessed_dup_removed_version_available = pp_dup_rem_adata_path.exists()
63
- preprocessed_dedup_spatial_version_available = spatial_adata_path.exists()
64
- preprocessed_dedup_spatial_hmm_version_available = hmm_adata_path.exists()
65
-
66
- if cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply:
67
- print(f"Forcing redo of hmm analysis workflow.")
68
- if preprocessed_dedup_spatial_hmm_version_available:
69
- adata, load_report = safe_read_h5ad(hmm_adata_path)
70
- elif preprocessed_dedup_spatial_version_available:
71
- adata, load_report = safe_read_h5ad(spatial_adata_path)
72
- elif preprocessed_dup_removed_version_available:
73
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
74
- elif initial_version_available:
75
- adata, load_report = safe_read_h5ad(initial_adata_path)
76
- else:
77
- print(f"Can not redo duplicate detection when there is no compatible adata available: either raw or preprocessed are required")
78
- elif preprocessed_dedup_spatial_hmm_version_available:
79
- adata, load_report = safe_read_h5ad(hmm_adata_path)
80
- else:
81
- if preprocessed_dedup_spatial_version_available:
82
- adata, load_report = safe_read_h5ad(spatial_adata_path)
83
- elif preprocessed_dup_removed_version_available:
84
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
85
- elif initial_version_available:
86
- adata, load_report = safe_read_h5ad(initial_adata_path)
87
- else:
88
- print(f"No adata available.")
89
- return
90
- references = adata.obs[cfg.reference_column].cat.categories
91
- deaminase = smf_modality == 'deaminase'
92
- ############################################### HMM based feature annotations ###############################################
481
+ raise FileNotFoundError(
482
+ "No AnnData available for HMM: expected at least raw or preprocessed h5ad."
483
+ )
484
+
485
+ # 4) delegate to core
486
+ adata, hmm_adata_path = hmm_adata_core(
487
+ cfg,
488
+ adata,
489
+ paths,
490
+ source_adata_path=source_path,
491
+ config_path=config_path,
492
+ )
493
+ return adata, hmm_adata_path
494
+
495
+
496
+ def hmm_adata_core(
497
+ cfg,
498
+ adata,
499
+ paths,
500
+ source_adata_path: Path | None = None,
501
+ config_path: str | None = None,
502
+ ) -> Tuple["anndata.AnnData", Path]:
503
+ """
504
+ Core HMM analysis pipeline.
505
+
506
+ Assumes:
507
+ - cfg is an ExperimentConfig
508
+ - adata is the starting AnnData (typically spatial + dedup)
509
+ - paths is an AdataPaths object (with .raw/.pp/.pp_dedup/.spatial/.hmm)
510
+
511
+ Does NOT decide which h5ad to start from – that is the wrapper's job.
512
+ """
513
+
514
+ import numpy as np
515
+
516
+ from ..hmm import call_hmm_peaks
517
+ from ..metadata import record_smftools_metadata
518
+ from ..plotting import (
519
+ combined_hmm_raw_clustermap,
520
+ plot_hmm_layers_rolling_by_sample_ref,
521
+ plot_hmm_size_contours,
522
+ )
523
+ from ..readwrite import make_dirs
524
+ from .helpers import write_gz_h5ad
525
+
526
+ smf_modality = cfg.smf_modality
527
+ deaminase = smf_modality == "deaminase"
528
+
529
+ output_directory = Path(cfg.output_directory)
530
+ make_dirs([output_directory])
531
+
532
+ pp_dir = output_directory / "preprocessed" / "deduplicated"
533
+
534
+ # ---------------------------- HMM annotate stage ----------------------------
93
535
  if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
94
- from ..hmm.HMM import HMM
95
- from scipy.sparse import issparse, csr_matrix
96
- import warnings
536
+ hmm_models_dir = pp_dir / "10_hmm_models"
537
+ make_dirs([pp_dir, hmm_models_dir])
97
538
 
98
- pp_dir = output_directory / "preprocessed"
99
- pp_dir = pp_dir / "deduplicated"
100
- hmm_dir = pp_dir / "10_hmm_models"
539
+ # Standard bookkeeping
540
+ uns_key = "hmm_appended_layers"
541
+ if adata.uns.get(uns_key) is None:
542
+ adata.uns[uns_key] = []
543
+ global_appended = list(adata.uns.get(uns_key, []))
101
544
 
102
- if hmm_dir.is_dir():
103
- print(f'{hmm_dir} already exists.')
104
- else:
105
- make_dirs([pp_dir, hmm_dir])
545
+ # Prepare trainer + feature config
546
+ trainer = HMMTrainer(cfg=cfg, models_dir=hmm_models_dir)
547
+
548
+ feature_sets = normalize_hmm_feature_sets(getattr(cfg, "hmm_feature_sets", None))
549
+ prob_thr = float(getattr(cfg, "hmm_feature_prob_threshold", 0.5))
550
+ decode = str(getattr(cfg, "hmm_decode", "marginal"))
551
+ write_post = bool(getattr(cfg, "hmm_write_posterior", True))
552
+ post_state = getattr(cfg, "hmm_posterior_state", "Modified")
553
+ merged_suffix = str(getattr(cfg, "hmm_merged_suffix", "_merged"))
554
+ force_apply = bool(getattr(cfg, "force_redo_hmm_apply", False))
555
+ bypass_apply = bool(getattr(cfg, "bypass_hmm_apply", False))
556
+ bypass_fit = bool(getattr(cfg, "bypass_hmm_fit", False))
106
557
 
107
558
  samples = adata.obs[cfg.sample_name_col_for_plotting].cat.categories
108
559
  references = adata.obs[cfg.reference_column].cat.categories
109
- uns_key = "hmm_appended_layers"
560
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
110
561
 
111
- # ensure uns key exists (avoid KeyError later)
112
- if adata.uns.get(uns_key) is None:
113
- adata.uns[uns_key] = []
562
+ if not methbases:
563
+ raise ValueError("cfg.hmm_methbases is empty.")
114
564
 
115
- if adata.uns.get('hmm_annotated', False) and not cfg.force_redo_hmm_fit and not cfg.force_redo_hmm_apply:
565
+ # Top-level skip
566
+ already = bool(adata.uns.get("hmm_annotated", False))
567
+ if already and not (bool(getattr(cfg, "force_redo_hmm_fit", False)) or force_apply):
116
568
  pass
569
+
117
570
  else:
571
+ logger.info("Starting HMM annotation over samples and references")
118
572
  for sample in samples:
119
573
  for ref in references:
120
- mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (adata.obs[cfg.reference_column] == ref)
121
- subset = adata[mask].copy()
122
- if subset.shape[0] < 1:
574
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
575
+ adata.obs[cfg.reference_column] == ref
576
+ )
577
+ if int(np.sum(mask)) == 0:
123
578
  continue
124
579
 
125
- for mod_site in cfg.hmm_methbases:
126
- mod_label = {'C': 'C'}.get(mod_site, mod_site)
127
- hmm_path = hmm_dir / f"{sample}_{ref}_{mod_label}_hmm_model.pth"
128
-
129
- # ensure the input obsm exists
130
- obsm_key = f'{ref}_{mod_label}_site'
131
- if obsm_key not in subset.obsm:
132
- print(f"Skipping {sample} {ref} {mod_label}: missing obsm '{obsm_key}'")
580
+ subset = adata[mask].copy()
581
+ subset.uns[uns_key] = [] # isolate appended tracking per subset
582
+
583
+ # ---- Decide which tasks to run ----
584
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
585
+ run_multi = bool(getattr(cfg, "hmm_run_multichannel", True))
586
+ run_cpg = bool(getattr(cfg, "cpg", False))
587
+ device = resolve_torch_device(cfg.device)
588
+
589
+ logger.info("HMM processing sample=%s ref=%s", sample, ref)
590
+
591
+ # ---- split feature sets ----
592
+ feature_sets_all = normalize_hmm_feature_sets(
593
+ getattr(cfg, "hmm_feature_sets", None)
594
+ )
595
+ feature_sets_access = {
596
+ k: v
597
+ for k, v in feature_sets_all.items()
598
+ if k in ("footprint", "accessible")
599
+ }
600
+ feature_sets_cpg = (
601
+ {"cpg": feature_sets_all["cpg"]} if "cpg" in feature_sets_all else {}
602
+ )
603
+
604
+ # =========================
605
+ # 1) Single-channel accessibility (per methbase)
606
+ # =========================
607
+ for mb in methbases:
608
+ logger.info("HMM single-channel for methbase=%s", mb)
609
+
610
+ try:
611
+ X, coords = build_single_channel(
612
+ subset,
613
+ ref=str(ref),
614
+ methbase=str(mb),
615
+ smf_modality=smf_modality,
616
+ cfg=cfg,
617
+ )
618
+ except Exception:
619
+ logger.warning(
620
+ "Skipping HMM single-channel for methbase=%s due to data error", mb
621
+ )
133
622
  continue
134
623
 
135
- # Fit or load model
136
- if hmm_path.exists() and not cfg.force_redo_hmm_fit:
137
- hmm = HMM.load(hmm_path)
138
- hmm.print_params()
139
- else:
140
- print(f"Fitting HMM for {sample} {ref} {mod_label}")
141
- hmm = HMM.from_config(cfg)
142
- # fit expects a list-of-seqs or 2D ndarray in the obsm
143
- seqs = subset.obsm[obsm_key]
144
- hmm.fit(seqs)
145
- hmm.print_params()
146
- hmm.save(hmm_path)
147
-
148
- # Apply / annotate on the subset, then copy layers back to final_adata
149
- if cfg.bypass_hmm_apply:
150
- pass
151
- else:
152
- print(f"Applying HMM on subset for {sample} {ref} {mod_label}")
153
- # Use the new uns_key argument so subset will record appended layer names
154
- # (annotate_adata modifies subset.obs/layers in-place and should write subset.uns[uns_key])
155
- if smf_modality == "direct":
156
- hmm_layer = cfg.output_binary_layer_name
157
- else:
158
- hmm_layer = None
159
-
160
- hmm.annotate_adata(subset,
161
- obs_column=cfg.reference_column,
162
- layer=hmm_layer,
163
- config=cfg,
164
- force_redo=cfg.force_redo_hmm_apply
624
+ arch = trainer.choose_arch(multichannel=False)
625
+
626
+ logger.info("HMM fitting/loading for methbase=%s", mb)
627
+ hmm = trainer.fit_or_load(
628
+ sample=str(sample),
629
+ ref=str(ref),
630
+ label=str(mb),
631
+ arch=arch,
632
+ X=X,
633
+ coords=coords,
634
+ device=device,
635
+ )
636
+
637
+ if not bypass_apply:
638
+ logger.info("HMM applying for methbase=%s", mb)
639
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
640
+ hmm.annotate_adata(
641
+ subset,
642
+ prefix=str(mb),
643
+ X=X,
644
+ coords=coords,
645
+ var_mask=pm,
646
+ span_fill=True,
647
+ config=cfg,
648
+ decode=decode,
649
+ write_posterior=write_post,
650
+ posterior_state=post_state,
651
+ feature_sets=feature_sets_access, # <--- ONLY accessibility feature sets
652
+ prob_threshold=prob_thr,
653
+ uns_key=uns_key,
654
+ uns_flag=f"hmm_annotated_{mb}",
655
+ force_redo=force_apply,
656
+ )
657
+
658
+ # merges for this mb
659
+ for core_layer, dist in (
660
+ getattr(cfg, "hmm_merge_layer_features", []) or []
661
+ ):
662
+ base_layer = f"{mb}_{core_layer}"
663
+ logger.info("Merging intervals for layer=%s", base_layer)
664
+ if base_layer in subset.layers:
665
+ merged_base = hmm.merge_intervals_to_new_layer(
666
+ subset,
667
+ base_layer,
668
+ distance_threshold=int(dist),
669
+ suffix=merged_suffix,
670
+ overwrite=True,
671
+ )
672
+ # write merged size classes based on whichever group core_layer corresponds to
673
+ for group, fs in feature_sets_access.items():
674
+ fmap = fs.get("features", {}) or {}
675
+ if fmap:
676
+ hmm.write_size_class_layers_from_binary(
677
+ subset,
678
+ merged_base,
679
+ out_prefix=str(mb),
680
+ feature_ranges=fmap,
681
+ suffix=merged_suffix,
682
+ overwrite=True,
165
683
  )
166
-
167
- if adata.uns.get('hmm_annotated', False) and not cfg.force_redo_hmm_apply:
168
- pass
169
- else:
170
- to_merge = cfg.hmm_merge_layer_features
171
- for layer_to_merge, merge_distance in to_merge:
172
- if layer_to_merge:
173
- hmm.merge_intervals_in_layer(subset,
174
- layer=layer_to_merge,
175
- distance_threshold=merge_distance,
176
- overwrite=True
177
- )
178
- else:
179
- pass
180
-
181
- # collect appended layers from subset.uns
182
- appended = list(subset.uns.get(uns_key, []))
183
- print(appended)
184
- if len(appended) == 0:
185
- # nothing appended for this subset; continue
186
- continue
187
-
188
- # copy each appended layer into adata
189
- subset_mask_bool = mask.values if hasattr(mask, "values") else np.asarray(mask)
190
- for layer_name in appended:
191
- if layer_name not in subset.layers:
192
- # defensive: skip
193
- warnings.warn(f"Expected layer {layer_name} in subset but not found; skipping copy.")
194
- continue
195
- sub_layer = subset.layers[layer_name]
196
- # ensure final layer exists and assign rows
197
- try:
198
- hmm._ensure_final_layer_and_assign(adata, layer_name, subset_mask_bool, sub_layer)
199
- except Exception as e:
200
- warnings.warn(f"Failed to copy layer {layer_name} into adata: {e}", stacklevel=2)
201
- # fallback: if dense and small, try to coerce
202
- if issparse(sub_layer):
203
- arr = sub_layer.toarray()
204
- else:
205
- arr = np.asarray(sub_layer)
206
- adata.layers[layer_name] = adata.layers.get(layer_name, np.zeros((adata.shape[0], arr.shape[1]), dtype=arr.dtype))
207
- final_idx = np.nonzero(subset_mask_bool)[0]
208
- adata.layers[layer_name][final_idx, :] = arr
209
-
210
- # merge appended layer names into adata.uns
211
- existing = list(adata.uns.get(uns_key, []))
212
- for ln in appended:
213
- if ln not in existing:
214
- existing.append(ln)
215
- adata.uns[uns_key] = existing
216
-
217
- else:
218
- pass
219
684
 
220
- from ..hmm import call_hmm_peaks
685
+ # =========================
686
+ # 2) Multi-channel accessibility (Combined)
687
+ # =========================
688
+ if run_multi and len(methbases) >= 2:
689
+ logger.info("HMM multi-channel for methbases=%s", ",".join(methbases))
690
+ try:
691
+ X3, coords_u, used_mbs = build_multi_channel_union(
692
+ subset,
693
+ ref=str(ref),
694
+ methbases=methbases,
695
+ smf_modality=smf_modality,
696
+ cfg=cfg,
697
+ )
698
+ except Exception:
699
+ X3, coords_u, used_mbs = None, None, []
700
+ logger.warning(
701
+ "Skipping HMM multi-channel due to data error or insufficient methbases"
702
+ )
703
+
704
+ if X3 is not None and len(used_mbs) >= 2:
705
+ union_mask = None
706
+ for mb in used_mbs:
707
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
708
+ union_mask = pm if union_mask is None else (union_mask | pm)
709
+
710
+ arch = trainer.choose_arch(multichannel=True)
711
+
712
+ logger.info("HMM fitting/loading for multi-channel")
713
+ hmmc = trainer.fit_or_load(
714
+ sample=str(sample),
715
+ ref=str(ref),
716
+ label="Combined",
717
+ arch=arch,
718
+ X=X3,
719
+ coords=coords_u,
720
+ device=device,
721
+ )
722
+
723
+ if not bypass_apply:
724
+ logger.info("HMM applying for multi-channel")
725
+ hmmc.annotate_adata(
726
+ subset,
727
+ prefix="Combined",
728
+ X=X3,
729
+ coords=coords_u,
730
+ var_mask=union_mask,
731
+ span_fill=True,
732
+ config=cfg,
733
+ decode=decode,
734
+ write_posterior=write_post,
735
+ posterior_state=post_state,
736
+ feature_sets=feature_sets_access, # <--- accessibility only
737
+ prob_threshold=prob_thr,
738
+ uns_key=uns_key,
739
+ uns_flag="hmm_annotated_combined",
740
+ force_redo=force_apply,
741
+ )
742
+
743
+ for core_layer, dist in (
744
+ getattr(cfg, "hmm_merge_layer_features", []) or []
745
+ ):
746
+ base_layer = f"Combined_{core_layer}"
747
+ if base_layer in subset.layers:
748
+ merged_base = hmmc.merge_intervals_to_new_layer(
749
+ subset,
750
+ base_layer,
751
+ distance_threshold=int(dist),
752
+ suffix=merged_suffix,
753
+ overwrite=True,
754
+ )
755
+ for group, fs in feature_sets_access.items():
756
+ fmap = fs.get("features", {}) or {}
757
+ if fmap:
758
+ hmmc.write_size_class_layers_from_binary(
759
+ subset,
760
+ merged_base,
761
+ out_prefix="Combined",
762
+ feature_ranges=fmap,
763
+ suffix=merged_suffix,
764
+ overwrite=True,
765
+ )
766
+
767
+ # =========================
768
+ # 3) CpG-only single-channel task
769
+ # =========================
770
+ if run_cpg:
771
+ logger.info("HMM single-channel for CpG")
772
+ try:
773
+ Xcpg, coordscpg = build_single_channel(
774
+ subset,
775
+ ref=str(ref),
776
+ methbase="CpG",
777
+ smf_modality=smf_modality,
778
+ cfg=cfg,
779
+ )
780
+ except Exception:
781
+ Xcpg, coordscpg = None, None
782
+ logger.warning("Skipping HMM single-channel for CpG due to data error")
783
+
784
+ if Xcpg is not None and Xcpg.size and feature_sets_cpg:
785
+ arch = trainer.choose_arch(multichannel=False)
786
+
787
+ logger.info("HMM fitting/loading for CpG")
788
+ hmmg = trainer.fit_or_load(
789
+ sample=str(sample),
790
+ ref=str(ref),
791
+ label="CpG",
792
+ arch=arch,
793
+ X=Xcpg,
794
+ coords=coordscpg,
795
+ device=device,
796
+ )
797
+
798
+ if not bypass_apply:
799
+ logger.info("HMM applying for CpG")
800
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), "CpG")
801
+ hmmg.annotate_adata(
802
+ subset,
803
+ prefix="CpG",
804
+ X=Xcpg,
805
+ coords=coordscpg,
806
+ var_mask=pm,
807
+ span_fill=True,
808
+ config=cfg,
809
+ decode=decode,
810
+ write_posterior=write_post,
811
+ posterior_state=post_state,
812
+ feature_sets=feature_sets_cpg, # <--- ONLY cpg group (cpg_patch)
813
+ prob_threshold=prob_thr,
814
+ uns_key=uns_key,
815
+ uns_flag="hmm_annotated_CpG",
816
+ force_redo=force_apply,
817
+ )
818
+
819
+ # ------------------------------------------------------------
820
+ # Copy newly created subset layers back into the full adata
821
+ # ------------------------------------------------------------
822
+ appended = (
823
+ list(subset.uns.get(uns_key, []))
824
+ if subset.uns.get(uns_key) is not None
825
+ else []
826
+ )
827
+ if appended:
828
+ row_mask = np.asarray(
829
+ mask.values if hasattr(mask, "values") else mask, dtype=bool
830
+ )
831
+
832
+ for ln in appended:
833
+ if ln not in subset.layers:
834
+ continue
835
+ _ensure_layer_and_assign_rows(adata, ln, row_mask, subset.layers[ln])
836
+ if ln not in global_appended:
837
+ global_appended.append(ln)
838
+
839
+ adata.uns[uns_key] = global_appended
840
+
841
+ adata.uns["hmm_annotated"] = True
842
+
843
+ hmm_layers = list(adata.uns.get("hmm_appended_layers", []) or [])
844
+ # keep only real feature layers; drop lengths/states/posterior
845
+ hmm_layers = [
846
+ layer
847
+ for layer in hmm_layers
848
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
849
+ ]
850
+ logger.info(f"HMM appended layers: {hmm_layers}")
851
+
852
+ # ---------------------------- HMM peak calling stage ----------------------------
221
853
  hmm_dir = pp_dir / "11_hmm_peak_calling"
222
854
  if hmm_dir.is_dir():
223
855
  pass
@@ -225,29 +857,32 @@ def hmm_adata(config_path):
225
857
  make_dirs([pp_dir, hmm_dir])
226
858
 
227
859
  call_hmm_peaks(
228
- adata,
229
- feature_configs=cfg.hmm_peak_feature_configs,
230
- ref_column=cfg.reference_column,
231
- site_types=cfg.mod_target_bases,
232
- save_plot=True,
233
- output_dir=hmm_dir,
234
- index_col_suffix=cfg.reindexed_var_suffix)
235
-
236
- ## Save HMM annotated adata
237
- if not hmm_adata_path.exists():
238
- print('Saving hmm analyzed adata post preprocessing and duplicate removal')
239
- if ".gz" == hmm_adata_path.suffix:
240
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
241
- else:
242
- hmm_adata_path = hmm_adata_path.with_name(hmm_adata_path.name + '.gz')
243
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
860
+ adata,
861
+ feature_configs=cfg.hmm_peak_feature_configs,
862
+ ref_column=cfg.reference_column,
863
+ site_types=cfg.mod_target_bases,
864
+ save_plot=True,
865
+ output_dir=hmm_dir,
866
+ index_col_suffix=cfg.reindexed_var_suffix,
867
+ )
244
868
 
245
- add_or_update_column_in_csv(cfg.summary_file, "hmm_adata", hmm_adata_path)
869
+ ## Save HMM annotated adata
870
+ if not paths.hmm.exists():
871
+ logger.info("Saving hmm analyzed AnnData (post preprocessing and duplicate removal).")
872
+ record_smftools_metadata(
873
+ adata,
874
+ step_name="hmm",
875
+ cfg=cfg,
876
+ config_path=config_path,
877
+ input_paths=[source_adata_path] if source_adata_path else None,
878
+ output_path=paths.hmm,
879
+ )
880
+ write_gz_h5ad(adata, paths.hmm)
246
881
 
247
882
  ########################################################################################################################
248
883
 
249
- ############################################### HMM based feature plotting ###############################################
250
- from ..plotting import combined_hmm_raw_clustermap
884
+ ############################################### HMM based feature plotting ###############################################
885
+
251
886
  hmm_dir = pp_dir / "12_hmm_clustermaps"
252
887
  make_dirs([pp_dir, hmm_dir])
253
888
 
@@ -256,6 +891,9 @@ def hmm_adata(config_path):
256
891
  for base in cfg.hmm_methbases:
257
892
  layers.extend([f"{base}_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
258
893
 
894
+ if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
895
+ layers.extend([f"Combined_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
896
+
259
897
  if cfg.cpg:
260
898
  layers.extend(["CpG_cpg_patch"])
261
899
 
@@ -273,40 +911,48 @@ def hmm_adata(config_path):
273
911
  make_dirs([hmm_cluster_save_dir])
274
912
 
275
913
  combined_hmm_raw_clustermap(
276
- adata,
277
- sample_col=cfg.sample_name_col_for_plotting,
278
- reference_col=cfg.reference_column,
279
- hmm_feature_layer=layer,
280
- layer_gpc=cfg.layer_for_clustermap_plotting,
281
- layer_cpg=cfg.layer_for_clustermap_plotting,
282
- layer_c=cfg.layer_for_clustermap_plotting,
283
- layer_a=cfg.layer_for_clustermap_plotting,
284
- cmap_hmm=cfg.clustermap_cmap_hmm,
285
- cmap_gpc=cfg.clustermap_cmap_gpc,
286
- cmap_cpg=cfg.clustermap_cmap_cpg,
287
- cmap_c=cfg.clustermap_cmap_c,
288
- cmap_a=cfg.clustermap_cmap_a,
289
- min_quality=cfg.read_quality_filter_thresholds[0],
290
- min_length=cfg.read_len_filter_thresholds[0],
291
- min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[0],
292
- min_position_valid_fraction=1-cfg.position_max_nan_threshold,
293
- save_path=hmm_cluster_save_dir,
294
- normalize_hmm=False,
295
- sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
296
- bins=None,
297
- deaminase=deaminase,
298
- min_signal=0,
299
- index_col_suffix=cfg.reindexed_var_suffix
914
+ adata,
915
+ sample_col=cfg.sample_name_col_for_plotting,
916
+ reference_col=cfg.reference_column,
917
+ hmm_feature_layer=layer,
918
+ layer_gpc=cfg.layer_for_clustermap_plotting,
919
+ layer_cpg=cfg.layer_for_clustermap_plotting,
920
+ layer_c=cfg.layer_for_clustermap_plotting,
921
+ layer_a=cfg.layer_for_clustermap_plotting,
922
+ cmap_hmm=cfg.clustermap_cmap_hmm,
923
+ cmap_gpc=cfg.clustermap_cmap_gpc,
924
+ cmap_cpg=cfg.clustermap_cmap_cpg,
925
+ cmap_c=cfg.clustermap_cmap_c,
926
+ cmap_a=cfg.clustermap_cmap_a,
927
+ min_quality=cfg.read_quality_filter_thresholds[0],
928
+ min_length=cfg.read_len_filter_thresholds[0],
929
+ min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
930
+ 0
931
+ ],
932
+ min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
933
+ demux_types=("double", "already"),
934
+ save_path=hmm_cluster_save_dir,
935
+ normalize_hmm=False,
936
+ sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
937
+ bins=None,
938
+ deaminase=deaminase,
939
+ min_signal=0,
940
+ index_col_suffix=cfg.reindexed_var_suffix,
300
941
  )
301
942
 
302
943
  hmm_dir = pp_dir / "13_hmm_bulk_traces"
303
944
 
304
945
  if hmm_dir.is_dir():
305
- print(f'{hmm_dir} already exists.')
946
+ logger.debug(f"{hmm_dir} already exists.")
306
947
  else:
307
948
  make_dirs([pp_dir, hmm_dir])
308
949
  from ..plotting import plot_hmm_layers_rolling_by_sample_ref
309
- bulk_hmm_layers = [layer for layer in adata.uns['hmm_appended_layers'] if "_lengths" not in layer]
950
+
951
+ bulk_hmm_layers = [
952
+ layer
953
+ for layer in hmm_layers
954
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
955
+ ]
310
956
  saved = plot_hmm_layers_rolling_by_sample_ref(
311
957
  adata,
312
958
  layers=bulk_hmm_layers,
@@ -314,26 +960,38 @@ def hmm_adata(config_path):
314
960
  ref_col=cfg.reference_column,
315
961
  window=101,
316
962
  rows_per_page=4,
317
- figsize_per_cell=(4,2.5),
963
+ figsize_per_cell=(4, 2.5),
318
964
  output_dir=hmm_dir,
319
965
  save=True,
320
- show_raw=False
966
+ show_raw=False,
321
967
  )
322
968
 
323
969
  hmm_dir = pp_dir / "14_hmm_fragment_distributions"
324
970
 
325
971
  if hmm_dir.is_dir():
326
- print(f'{hmm_dir} already exists.')
972
+ logger.debug(f"{hmm_dir} already exists.")
327
973
  else:
328
974
  make_dirs([pp_dir, hmm_dir])
329
975
  from ..plotting import plot_hmm_size_contours
330
976
 
331
- if smf_modality == 'deaminase':
332
- fragments = [('C_all_accessible_features_lengths', 400), ('C_all_footprint_features_lengths', 250), ('C_all_accessible_features_merged_lengths', 800)]
333
- elif smf_modality == 'conversion':
334
- fragments = [('GpC_all_accessible_features_lengths', 400), ('GpC_all_footprint_features_lengths', 250), ('GpC_all_accessible_features_merged_lengths', 800)]
977
+ if smf_modality == "deaminase":
978
+ fragments = [
979
+ ("C_all_accessible_features_lengths", 400),
980
+ ("C_all_footprint_features_lengths", 250),
981
+ ("C_all_accessible_features_merged_lengths", 800),
982
+ ]
983
+ elif smf_modality == "conversion":
984
+ fragments = [
985
+ ("GpC_all_accessible_features_lengths", 400),
986
+ ("GpC_all_footprint_features_lengths", 250),
987
+ ("GpC_all_accessible_features_merged_lengths", 800),
988
+ ]
335
989
  elif smf_modality == "direct":
336
- fragments = [('A_all_accessible_features_lengths', 400), ('A_all_footprint_features_lengths', 200), ('A_all_accessible_features_merged_lengths', 800)]
990
+ fragments = [
991
+ ("A_all_accessible_features_lengths", 400),
992
+ ("A_all_footprint_features_lengths", 200),
993
+ ("A_all_accessible_features_merged_lengths", 800),
994
+ ]
337
995
 
338
996
  for layer, max in fragments:
339
997
  save_path = hmm_dir / layer
@@ -353,9 +1011,9 @@ def hmm_adata(config_path):
353
1011
  dpi=200,
354
1012
  smoothing_sigma=(10, 10),
355
1013
  normalize_after_smoothing=True,
356
- cmap='Greens',
357
- log_scale_z=True
1014
+ cmap="Greens",
1015
+ log_scale_z=True,
358
1016
  )
359
1017
  ########################################################################################################################
360
1018
 
361
- return (adata, hmm_adata_path)
1019
+ return (adata, paths.hmm)