smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/cli/hmm_adata.py CHANGED
@@ -1,318 +1,999 @@
1
- def hmm_adata(config_path):
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, List, Optional, Sequence, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from smftools.logging_utils import get_logger
12
+
13
+ # FIX: import _to_dense_np to avoid NameError
14
+ from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ # =============================================================================
19
+ # Helpers: extracting training arrays
20
+ # =============================================================================
21
+
22
+
23
+ def _get_training_matrix(
24
+ subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
25
+ ) -> Tuple[np.ndarray, Optional[str]]:
26
+ """
27
+ Matches your existing behavior:
28
+ - direct -> uses cfg.output_binary_layer_name in .layers
29
+ - else -> uses .X
30
+ Returns (X, layer_name_or_None) where X is dense float array.
2
31
  """
3
- High-level function to call for hmm analysis of an adata object.
4
- Command line accesses this through smftools hmm <config_path>
32
+ sub = subset[:, cols_mask]
5
33
 
6
- Parameters:
7
- config_path (str): A string representing the file path to the experiment configuration csv file.
34
+ if smf_modality == "direct":
35
+ hmm_layer = getattr(cfg, "output_binary_layer_name", None)
36
+ if hmm_layer is None or hmm_layer not in sub.layers:
37
+ raise KeyError(f"Missing HMM training layer '{hmm_layer}' in subset.")
8
38
 
39
+ logger.debug("Using direct modality HMM training layer: %s", hmm_layer)
40
+ mat = sub.layers[hmm_layer]
41
+ else:
42
+ logger.debug("Using .X for HMM training matrix")
43
+ hmm_layer = None
44
+ mat = sub.X
45
+
46
+ X = _to_dense_np(mat).astype(float)
47
+ if X.ndim != 2:
48
+ raise ValueError(f"Expected 2D training matrix; got {X.shape}")
49
+ return X, hmm_layer
50
+
51
+
52
+ def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
53
+ """
54
+ Reproduces your mask resolution, with compatibility for both *_any_C_site and *_C_site.
55
+ """
56
+ key = str(methbase).strip().lower()
57
+
58
+ logger.debug("Resolving position mask for methbase=%s on ref=%s", key, ref)
59
+
60
+ if key in ("a",):
61
+ col = f"{ref}_A_site"
62
+ if col not in subset.var:
63
+ return None
64
+ logger.debug("Using positions with A calls from column: %s", col)
65
+ return np.asarray(subset.var[col])
66
+
67
+ if key in ("c", "any_c", "anyc", "any-c"):
68
+ for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
69
+ if col in subset.var:
70
+ logger.debug("Using positions with C calls from column: %s", col)
71
+ return np.asarray(subset.var[col])
72
+ return None
73
+
74
+ if key in ("gpc", "gpc_site", "gpc-site"):
75
+ col = f"{ref}_GpC_site"
76
+ if col not in subset.var:
77
+ return None
78
+ logger.debug("Using positions with GpC calls from column: %s", col)
79
+ return np.asarray(subset.var[col])
80
+
81
+ if key in ("cpg", "cpg_site", "cpg-site"):
82
+ col = f"{ref}_CpG_site"
83
+ if col not in subset.var:
84
+ return None
85
+ logger.debug("Using positions with CpG calls from column: %s", col)
86
+ return np.asarray(subset.var[col])
87
+
88
+ alt = f"{ref}_{methbase}_site"
89
+ if alt not in subset.var:
90
+ return None
91
+
92
+ logger.debug("Using positions from column: %s", alt)
93
+ return np.asarray(subset.var[alt])
94
+
95
+
96
+ def build_single_channel(
97
+ subset, ref: str, methbase: str, smf_modality: Optional[str], cfg
98
+ ) -> Tuple[np.ndarray, np.ndarray]:
99
+ """
9
100
  Returns:
10
- (pp_dedup_spatial_hmm_adata, pp_dedup_spatial_hmm_adata_path)
101
+ X (N, Lmb) float with NaNs allowed
102
+ coords (Lmb,) int coords from var_names
103
+ """
104
+ pm = _resolve_pos_mask_for_methbase(subset, ref, methbase)
105
+ logger.debug(
106
+ "Position mask for methbase=%s on ref=%s has %d sites",
107
+ methbase,
108
+ ref,
109
+ int(np.sum(pm)) if pm is not None else 0,
110
+ )
111
+
112
+ if pm is None or int(np.sum(pm)) == 0:
113
+ raise ValueError(f"No columns for methbase={methbase} on ref={ref}")
114
+
115
+ X, _ = _get_training_matrix(subset, pm, smf_modality, cfg)
116
+ logger.debug("Training matrix for methbase=%s on ref=%s has shape %s", methbase, ref, X.shape)
117
+
118
+ coords, _ = _safe_int_coords(subset[:, pm].var_names)
119
+ logger.debug(
120
+ "Coordinates for methbase=%s on ref=%s have length %d", methbase, ref, coords.shape[0]
121
+ )
122
+
123
+ return X, coords
124
+
125
+
126
+ def build_multi_channel_union(
127
+ subset, ref: str, methbases: Sequence[str], smf_modality: Optional[str], cfg
128
+ ) -> Tuple[np.ndarray, np.ndarray, List[str]]:
129
+ """
130
+ Build (N, Lunion, C) on union coordinate grid across methbases.
131
+
132
+ Returns:
133
+ X3d: (N, Lunion, C) float with NaN where methbase has no site
134
+ coords: (Lunion,) int union coords
135
+ used_methbases: list of methbases actually included (>=2)
136
+ """
137
+ per: List[Tuple[str, np.ndarray, np.ndarray, np.ndarray]] = [] # (mb, X, coords, pm)
138
+
139
+ for mb in methbases:
140
+ pm = _resolve_pos_mask_for_methbase(subset, ref, mb)
141
+ if pm is None or int(np.sum(pm)) == 0:
142
+ continue
143
+ Xmb, _ = _get_training_matrix(subset, pm, smf_modality, cfg) # (N,Lmb)
144
+ cmb, _ = _safe_int_coords(subset[:, pm].var_names)
145
+ per.append((mb, Xmb.astype(float), cmb.astype(int), pm))
146
+
147
+ if len(per) < 2:
148
+ raise ValueError(f"Need >=2 methbases with columns for union multi-channel on ref={ref}")
149
+
150
+ # union coordinates
151
+ coords = np.unique(np.concatenate([c for _, _, c, _ in per], axis=0)).astype(int)
152
+ idx = {int(v): i for i, v in enumerate(coords.tolist())}
153
+
154
+ N = per[0][1].shape[0]
155
+ L = coords.shape[0]
156
+ C = len(per)
157
+ X3 = np.full((N, L, C), np.nan, dtype=float)
158
+
159
+ for ci, (mb, Xmb, cmb, _) in enumerate(per):
160
+ cols = np.array([idx[int(v)] for v in cmb.tolist()], dtype=int)
161
+ X3[:, cols, ci] = Xmb
162
+
163
+ used = [mb for (mb, _, _, _) in per]
164
+ return X3, coords, used
165
+
166
+
167
+ @dataclass
168
+ class HMMTask:
169
+ name: str
170
+ signals: List[str] # e.g. ["GpC"] or ["GpC","CpG"] or ["CpG"]
171
+ feature_groups: List[str] # e.g. ["footprint","accessible"] or ["cpg"]
172
+ output_prefix: Optional[str] = None # force prefix (CpG task uses "CpG")
173
+
174
+
175
+ def build_hmm_tasks(cfg: Union[dict, Any]) -> List[HMMTask]:
176
+ """
177
+ Accessibility signals come from cfg['hmm_methbases'].
178
+ CpG task is enabled by cfg['cpg']==True, independent of hmm_methbases.
179
+ """
180
+ if not isinstance(cfg, dict):
181
+ # best effort conversion
182
+ cfg = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith("_")}
183
+
184
+ tasks: List[HMMTask] = []
185
+
186
+ # accessibility task
187
+ methbases = list(cfg.get("hmm_methbases", []) or [])
188
+ if len(methbases) > 0:
189
+ tasks.append(
190
+ HMMTask(
191
+ name="accessibility",
192
+ signals=methbases,
193
+ feature_groups=["footprint", "accessible"],
194
+ output_prefix=None,
195
+ )
196
+ )
197
+
198
+ # CpG task (special case)
199
+ if bool(cfg.get("cpg", False)):
200
+ tasks.append(
201
+ HMMTask(
202
+ name="cpg",
203
+ signals=["CpG"],
204
+ feature_groups=["cpg"],
205
+ output_prefix="CpG",
206
+ )
207
+ )
208
+
209
+ return tasks
210
+
211
+
212
+ def select_hmm_arch(cfg: dict, signals: Sequence[str]) -> str:
213
+ """
214
+ Simple, explicit model selection:
215
+ - distance-aware => 'single_distance_binned' (only meaningful for single-channel)
216
+ - multi-signal => 'multi'
217
+ - else => 'single'
218
+ """
219
+ if bool(cfg.get("hmm_distance_aware", False)) and len(signals) == 1:
220
+ return "single_distance_binned"
221
+ if len(signals) > 1:
222
+ return "multi"
223
+ return "single"
224
+
225
+
226
+ def resolve_input_layer(adata, cfg: dict, layer_override: Optional[str]) -> Optional[str]:
227
+ """
228
+ If direct modality, prefer cfg.output_binary_layer_name.
229
+ Else use layer_override or None (meaning use .X).
230
+ """
231
+ smf_modality = cfg.get("smf_modality", None)
232
+ if smf_modality == "direct":
233
+ nm = cfg.get("output_binary_layer_name", None)
234
+ if nm is None:
235
+ raise KeyError("cfg.output_binary_layer_name missing for smf_modality='direct'")
236
+ if nm not in adata.layers:
237
+ raise KeyError(f"Direct modality expects layer '{nm}' in adata.layers")
238
+ return nm
239
+ return layer_override
240
+
241
+
242
+ def _ensure_layer_and_assign_rows(adata, layer_name: str, row_mask: np.ndarray, subset_layer):
11
243
  """
12
- from ..readwrite import safe_read_h5ad, safe_write_h5ad, make_dirs, add_or_update_column_in_csv
244
+ Writes subset_layer (n_subset_obs, n_vars) into adata.layers[layer_name] for rows where row_mask==True.
245
+ """
246
+ row_mask = np.asarray(row_mask, dtype=bool)
247
+ if row_mask.ndim != 1 or row_mask.size != adata.n_obs:
248
+ raise ValueError("row_mask must be length adata.n_obs")
249
+
250
+ arr = _to_dense_np(subset_layer)
251
+ if arr.shape != (int(row_mask.sum()), adata.n_vars):
252
+ raise ValueError(
253
+ f"subset layer '{layer_name}' shape {arr.shape} != ({int(row_mask.sum())}, {adata.n_vars})"
254
+ )
255
+
256
+ if layer_name not in adata.layers:
257
+ adata.layers[layer_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=arr.dtype)
258
+
259
+ adata.layers[layer_name][row_mask, :] = arr
260
+
261
+
262
+ def resolve_torch_device(device_str: str | None) -> torch.device:
263
+ d = (device_str or "auto").lower()
264
+ if d == "auto":
265
+ if torch.cuda.is_available():
266
+ return torch.device("cuda")
267
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
268
+ return torch.device("mps")
269
+ return torch.device("cpu")
270
+ return torch.device(d)
271
+
272
+
273
+ # =============================================================================
274
+ # Model selection + fit strategy manager
275
+ # =============================================================================
276
+ @dataclass
277
+ class HMMTrainer:
278
+ cfg: Any
279
+ models_dir: Path
280
+
281
+ def __post_init__(self):
282
+ self.models_dir = Path(self.models_dir)
283
+ self.models_dir.mkdir(parents=True, exist_ok=True)
284
+
285
+ def choose_arch(self, *, multichannel: bool) -> str:
286
+ use_dist = bool(getattr(self.cfg, "hmm_distance_aware", False))
287
+ if multichannel:
288
+ return "multi"
289
+ return "single_distance_binned" if use_dist else "single"
290
+
291
+ def _fit_scope(self) -> str:
292
+ return str(getattr(self.cfg, "hmm_fit_scope", "per_sample")).lower()
293
+ # "per_sample" | "global" | "global_then_adapt"
294
+
295
+ def _path(self, kind: str, sample: str, ref: str, label: str) -> Path:
296
+ # kind: "GLOBAL" | "PER" | "ADAPT"
297
+ def safe(s):
298
+ str(s).replace("/", "_")
299
+
300
+ return self.models_dir / f"{kind}_{safe(sample)}_{safe(ref)}_{safe(label)}.pt"
301
+
302
+ def _save(self, model, path: Path):
303
+ override = {}
304
+ if getattr(model, "hmm_name", None) == "multi":
305
+ override["hmm_n_channels"] = int(getattr(model, "n_channels", 2))
306
+ if getattr(model, "hmm_name", None) == "single_distance_binned":
307
+ override["hmm_distance_bins"] = list(
308
+ getattr(model, "distance_bins", [1, 5, 10, 25, 50, 100])
309
+ )
310
+
311
+ payload = {
312
+ "state_dict": model.state_dict(),
313
+ "hmm_arch": getattr(model, "hmm_name", None) or getattr(self.cfg, "hmm_arch", None),
314
+ "override": override,
315
+ }
316
+ torch.save(payload, path)
317
+
318
+ def _load(self, path: Path, arch: str, device):
319
+ payload = torch.load(path, map_location="cpu")
320
+ override = payload.get("override", None)
321
+ m = create_hmm(self.cfg, arch=arch, override=override, device=device)
322
+ sd = payload["state_dict"]
323
+
324
+ target_dtype = next(m.parameters()).dtype
325
+ for k, v in list(sd.items()):
326
+ if isinstance(v, torch.Tensor) and v.dtype != target_dtype:
327
+ sd[k] = v.to(dtype=target_dtype)
328
+
329
+ m.load_state_dict(sd)
330
+ m.to(device)
331
+ m.eval()
332
+ return m
333
+
334
+ def fit_or_load(
335
+ self,
336
+ *,
337
+ sample: str,
338
+ ref: str,
339
+ label: str,
340
+ arch: str,
341
+ X,
342
+ coords: Optional[np.ndarray],
343
+ device,
344
+ ):
345
+ force_fit = bool(getattr(self.cfg, "force_redo_hmm_fit", False))
346
+ scope = self._fit_scope()
347
+
348
+ max_iter = int(getattr(self.cfg, "hmm_max_iter", 50))
349
+ tol = float(getattr(self.cfg, "hmm_tol", 1e-4))
350
+ verbose = bool(getattr(self.cfg, "hmm_verbose", False))
351
+
352
+ # ---- global then adapt ----
353
+ if scope == "global_then_adapt":
354
+ p_global = self._path("GLOBAL", "ALL", ref, label)
355
+ if p_global.exists() and not force_fit:
356
+ base = self._load(p_global, arch=arch, device=device)
357
+ else:
358
+ base = create_hmm(self.cfg, arch=arch).to(device)
359
+ if arch == "single_distance_binned":
360
+ base.fit(
361
+ X, device=device, coords=coords, max_iter=max_iter, tol=tol, verbose=verbose
362
+ )
363
+ else:
364
+ base.fit(X, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
365
+ self._save(base, p_global)
366
+
367
+ p_adapt = self._path("ADAPT", sample, ref, label)
368
+ if p_adapt.exists() and not force_fit:
369
+ return self._load(p_adapt, arch=arch, device=device)
370
+
371
+ # IMPORTANT: this assumes you added model.adapt_emissions(...)
372
+ adapted = copy.deepcopy(base).to(device)
373
+ if arch == "single_distance_binned":
374
+ adapted.adapt_emissions(
375
+ X,
376
+ coords,
377
+ device=device,
378
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
379
+ verbose=verbose,
380
+ )
381
+
382
+ else:
383
+ adapted.adapt_emissions(
384
+ X,
385
+ coords,
386
+ device=device,
387
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
388
+ verbose=verbose,
389
+ )
390
+
391
+ self._save(adapted, p_adapt)
392
+ return adapted
393
+
394
+ # ---- global only ----
395
+ if scope == "global":
396
+ p = self._path("GLOBAL", "ALL", ref, label)
397
+ if p.exists() and not force_fit:
398
+ return self._load(p, arch=arch, device=device)
399
+
400
+ # ---- per sample ----
401
+ else:
402
+ p = self._path("PER", sample, ref, label)
403
+ if p.exists() and not force_fit:
404
+ return self._load(p, arch=arch, device=device)
405
+
406
+ m = create_hmm(self.cfg, arch=arch, device=device)
407
+ if arch == "single_distance_binned":
408
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
409
+ else:
410
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
411
+ self._save(m, p)
412
+ return m
413
+
414
+
415
+ def _fully_qualified_merge_layers(cfg, prefix: str) -> List[Tuple[str, int]]:
416
+ """
417
+ cfg.hmm_merge_layer_features is assumed to be a list of (core_layer_name, merge_distance),
418
+ where core_layer_name is like "all_accessible_features" (NOT prefixed with methbase).
419
+ We expand to f"{prefix}_{core_layer_name}".
420
+ """
421
+ out = []
422
+ for core_layer, dist in getattr(cfg, "hmm_merge_layer_features", []) or []:
423
+ if not core_layer:
424
+ continue
425
+ out.append((f"{prefix}_{core_layer}", int(dist)))
426
+ return out
427
+
428
+
429
+ def hmm_adata(config_path: str):
430
+ """
431
+ CLI-facing wrapper for HMM analysis.
432
+
433
+ Command line entrypoint:
434
+ smftools hmm <config_path>
435
+
436
+ Responsibilities:
437
+ - Build cfg via load_adata()
438
+ - Ensure preprocess + spatial stages are run
439
+ - Decide which AnnData to start from (hmm > spatial > pp_dedup > pp > raw)
440
+ - Call hmm_adata_core(cfg, adata, paths)
441
+ """
442
+ from ..readwrite import safe_read_h5ad
443
+ from .helpers import get_adata_paths
13
444
  from .load_adata import load_adata
14
445
  from .preprocess_adata import preprocess_adata
15
446
  from .spatial_adata import spatial_adata
16
447
 
17
- import numpy as np
18
- import pandas as pd
19
- import anndata as ad
20
- import scanpy as sc
448
+ # 1) load cfg / stage paths
449
+ _, _, cfg = load_adata(config_path)
450
+ paths = get_adata_paths(cfg)
21
451
 
22
- import os
23
- from importlib import resources
24
- from pathlib import Path
452
+ # 2) make sure upstream stages are run (they have their own skipping logic)
453
+ preprocess_adata(config_path)
454
+ spatial_ad, spatial_path = spatial_adata(config_path)
25
455
 
26
- from datetime import datetime
27
- date_str = datetime.today().strftime("%y%m%d")
456
+ # 3) choose starting AnnData
457
+ # Prefer:
458
+ # - existing HMM h5ad if not forcing redo
459
+ # - in-memory spatial_ad from wrapper call
460
+ # - saved spatial / pp_dedup / pp / raw on disk
461
+ if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
462
+ adata, _ = safe_read_h5ad(paths.hmm)
463
+ return adata, paths.hmm
28
464
 
29
- ############################################### smftools load start ###############################################
30
- adata, adata_path, cfg = load_adata(config_path)
31
- # General config variable init - Necessary user passed inputs
32
- smf_modality = cfg.smf_modality # needed for specifying if the data is conversion SMF or direct methylation detection SMF. Or deaminase smf Necessary.
33
- output_directory = Path(cfg.output_directory) # Path to the output directory to make for the analysis. Necessary.
34
-
35
- # Make initial output directory
36
- make_dirs([output_directory])
37
- ############################################### smftools load end ###############################################
38
-
39
- ############################################### smftools preprocess start ###############################################
40
- pp_adata, pp_adata_path, pp_dedup_adata, pp_dup_rem_adata_path = preprocess_adata(config_path)
41
- ############################################### smftools preprocess end ###############################################
42
-
43
- ############################################### smftools spatial start ###############################################
44
- spatial_ad, spatial_adata_path = spatial_adata(config_path)
45
- ############################################### smftools spatial end ###############################################
46
-
47
- ############################################### smftools hmm start ###############################################
48
- input_manager_df = pd.read_csv(cfg.summary_file)
49
- initial_adata_path = Path(input_manager_df['load_adata'][0])
50
- pp_adata_path = Path(input_manager_df['pp_adata'][0])
51
- pp_dup_rem_adata_path = Path(input_manager_df['pp_dedup_adata'][0])
52
- spatial_adata_path = Path(input_manager_df['spatial_adata'][0])
53
- hmm_adata_path = Path(input_manager_df['hmm_adata'][0])
54
-
55
- if spatial_ad:
56
- # This happens on first run of the pipeline
465
+ if spatial_ad is not None:
57
466
  adata = spatial_ad
467
+ source_path = spatial_path
468
+ elif paths.spatial.exists():
469
+ adata, _ = safe_read_h5ad(paths.spatial)
470
+ source_path = paths.spatial
471
+ elif paths.pp_dedup.exists():
472
+ adata, _ = safe_read_h5ad(paths.pp_dedup)
473
+ source_path = paths.pp_dedup
474
+ elif paths.pp.exists():
475
+ adata, _ = safe_read_h5ad(paths.pp)
476
+ source_path = paths.pp
477
+ elif paths.raw.exists():
478
+ adata, _ = safe_read_h5ad(paths.raw)
479
+ source_path = paths.raw
58
480
  else:
59
- # If an anndata is saved, check which stages of the anndata are available
60
- initial_version_available = initial_adata_path.exists()
61
- preprocessed_version_available = pp_adata_path.exists()
62
- preprocessed_dup_removed_version_available = pp_dup_rem_adata_path.exists()
63
- preprocessed_dedup_spatial_version_available = spatial_adata_path.exists()
64
- preprocessed_dedup_spatial_hmm_version_available = hmm_adata_path.exists()
65
-
66
- if cfg.force_redo_hmm_fit:
67
- print(f"Forcing redo of basic analysis workflow, starting from the preprocessed adata if available. Otherwise, will use the raw adata.")
68
- if preprocessed_dedup_spatial_version_available:
69
- adata, load_report = safe_read_h5ad(spatial_adata_path)
70
- elif preprocessed_dup_removed_version_available:
71
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
72
- elif initial_version_available:
73
- adata, load_report = safe_read_h5ad(initial_adata_path)
74
- else:
75
- print(f"Can not redo duplicate detection when there is no compatible adata available: either raw or preprocessed are required")
76
- elif preprocessed_dedup_spatial_hmm_version_available:
77
- return (None, hmm_adata_path)
78
- else:
79
- if preprocessed_dedup_spatial_version_available:
80
- adata, load_report = safe_read_h5ad(spatial_adata_path)
81
- elif preprocessed_dup_removed_version_available:
82
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
83
- elif initial_version_available:
84
- adata, load_report = safe_read_h5ad(initial_adata_path)
85
- else:
86
- print(f"No adata available.")
87
- return
88
- references = adata.obs[cfg.reference_column].cat.categories
89
- deaminase = smf_modality == 'deaminase'
90
- ############################################### HMM based feature annotations ###############################################
481
+ raise FileNotFoundError(
482
+ "No AnnData available for HMM: expected at least raw or preprocessed h5ad."
483
+ )
484
+
485
+ # 4) delegate to core
486
+ adata, hmm_adata_path = hmm_adata_core(
487
+ cfg,
488
+ adata,
489
+ paths,
490
+ source_adata_path=source_path,
491
+ config_path=config_path,
492
+ )
493
+ return adata, hmm_adata_path
494
+
495
+
496
+ def hmm_adata_core(
497
+ cfg,
498
+ adata,
499
+ paths,
500
+ source_adata_path: Path | None = None,
501
+ config_path: str | None = None,
502
+ ) -> Tuple["anndata.AnnData", Path]:
503
+ """
504
+ Core HMM analysis pipeline.
505
+
506
+ Assumes:
507
+ - cfg is an ExperimentConfig
508
+ - adata is the starting AnnData (typically spatial + dedup)
509
+ - paths is an AdataPaths object (with .raw/.pp/.pp_dedup/.spatial/.hmm)
510
+
511
+ Does NOT decide which h5ad to start from – that is the wrapper's job.
512
+ """
513
+
514
+ import numpy as np
515
+
516
+ from ..hmm import call_hmm_peaks
517
+ from ..metadata import record_smftools_metadata
518
+ from ..plotting import (
519
+ combined_hmm_raw_clustermap,
520
+ plot_hmm_layers_rolling_by_sample_ref,
521
+ plot_hmm_size_contours,
522
+ )
523
+ from ..readwrite import make_dirs
524
+ from .helpers import write_gz_h5ad
525
+
526
+ smf_modality = cfg.smf_modality
527
+ deaminase = smf_modality == "deaminase"
528
+
529
+ output_directory = Path(cfg.output_directory)
530
+ make_dirs([output_directory])
531
+
532
+ pp_dir = output_directory / "preprocessed" / "deduplicated"
533
+
534
+ # ---------------------------- HMM annotate stage ----------------------------
91
535
  if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
92
- from ..hmm.HMM import HMM
93
- from scipy.sparse import issparse, csr_matrix
94
- import warnings
536
+ hmm_models_dir = pp_dir / "10_hmm_models"
537
+ make_dirs([pp_dir, hmm_models_dir])
95
538
 
96
- pp_dir = output_directory / "preprocessed"
97
- pp_dir = pp_dir / "deduplicated"
98
- hmm_dir = pp_dir / "10_hmm_models"
539
+ # Standard bookkeeping
540
+ uns_key = "hmm_appended_layers"
541
+ if adata.uns.get(uns_key) is None:
542
+ adata.uns[uns_key] = []
543
+ global_appended = list(adata.uns.get(uns_key, []))
99
544
 
100
- if hmm_dir.is_dir():
101
- print(f'{hmm_dir} already exists.')
102
- else:
103
- make_dirs([pp_dir, hmm_dir])
545
+ # Prepare trainer + feature config
546
+ trainer = HMMTrainer(cfg=cfg, models_dir=hmm_models_dir)
547
+
548
+ feature_sets = normalize_hmm_feature_sets(getattr(cfg, "hmm_feature_sets", None))
549
+ prob_thr = float(getattr(cfg, "hmm_feature_prob_threshold", 0.5))
550
+ decode = str(getattr(cfg, "hmm_decode", "marginal"))
551
+ write_post = bool(getattr(cfg, "hmm_write_posterior", True))
552
+ post_state = getattr(cfg, "hmm_posterior_state", "Modified")
553
+ merged_suffix = str(getattr(cfg, "hmm_merged_suffix", "_merged"))
554
+ force_apply = bool(getattr(cfg, "force_redo_hmm_apply", False))
555
+ bypass_apply = bool(getattr(cfg, "bypass_hmm_apply", False))
556
+ bypass_fit = bool(getattr(cfg, "bypass_hmm_fit", False))
104
557
 
105
558
  samples = adata.obs[cfg.sample_name_col_for_plotting].cat.categories
106
559
  references = adata.obs[cfg.reference_column].cat.categories
107
- uns_key = "hmm_appended_layers"
560
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
108
561
 
109
- # ensure uns key exists (avoid KeyError later)
110
- if adata.uns.get(uns_key) is None:
111
- adata.uns[uns_key] = []
562
+ if not methbases:
563
+ raise ValueError("cfg.hmm_methbases is empty.")
112
564
 
113
- for sample in samples:
114
- for ref in references:
115
- mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (adata.obs[cfg.reference_column] == ref)
116
- subset = adata[mask].copy()
117
- if subset.shape[0] < 1:
118
- continue
119
-
120
- for mod_site in cfg.hmm_methbases:
121
- mod_label = {'C': 'C'}.get(mod_site, mod_site)
122
- hmm_path = hmm_dir / f"{sample}_{ref}_{mod_label}_hmm_model.pth"
123
-
124
- # ensure the input obsm exists
125
- obsm_key = f'{ref}_{mod_label}_site'
126
- if obsm_key not in subset.obsm:
127
- print(f"Skipping {sample} {ref} {mod_label}: missing obsm '{obsm_key}'")
565
+ # Top-level skip
566
+ already = bool(adata.uns.get("hmm_annotated", False))
567
+ if already and not (bool(getattr(cfg, "force_redo_hmm_fit", False)) or force_apply):
568
+ pass
569
+
570
+ else:
571
+ logger.info("Starting HMM annotation over samples and references")
572
+ for sample in samples:
573
+ for ref in references:
574
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
575
+ adata.obs[cfg.reference_column] == ref
576
+ )
577
+ if int(np.sum(mask)) == 0:
128
578
  continue
129
579
 
130
- # Fit or load model
131
- if os.path.exists(hmm_path) and not cfg.force_redo_hmm_fit:
132
- hmm = HMM.load(hmm_path)
133
- hmm.print_params()
134
- else:
135
- print(f"Fitting HMM for {sample} {ref} {mod_label}")
136
- hmm = HMM.from_config(cfg)
137
- # fit expects a list-of-seqs or 2D ndarray in the obsm
138
- seqs = subset.obsm[obsm_key]
139
- hmm.fit(seqs)
140
- hmm.print_params()
141
- hmm.save(hmm_path)
142
-
143
- # Apply / annotate on the subset, then copy layers back to final_adata
144
- if (not cfg.bypass_hmm_apply) or cfg.force_redo_hmm_apply:
145
- print(f"Applying HMM on subset for {sample} {ref} {mod_label}")
146
- # Use the new uns_key argument so subset will record appended layer names
147
- # (annotate_adata modifies subset.obs/layers in-place and should write subset.uns[uns_key])
148
- hmm.annotate_adata(subset,
149
- obs_column=cfg.reference_column,
150
- layer=cfg.layer_for_umap_plotting,
151
- config=cfg)
152
-
153
- #to_merge = [("C_all_accessible_features", 80)]
154
- to_merge = cfg.hmm_merge_layer_features
155
- for layer_to_merge, merge_distance in to_merge:
156
- if layer_to_merge:
157
- hmm.merge_intervals_in_layer(subset,
158
- layer=layer_to_merge,
159
- distance_threshold=merge_distance,
160
- overwrite=True
161
- )
162
- else:
163
- pass
164
-
165
- # collect appended layers from subset.uns
166
- appended = list(subset.uns.get(uns_key, []))
167
- print(appended)
168
- if len(appended) == 0:
169
- # nothing appended for this subset; continue
580
+ subset = adata[mask].copy()
581
+ subset.uns[uns_key] = [] # isolate appended tracking per subset
582
+
583
+ # ---- Decide which tasks to run ----
584
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
585
+ run_multi = bool(getattr(cfg, "hmm_run_multichannel", True))
586
+ run_cpg = bool(getattr(cfg, "cpg", False))
587
+ device = resolve_torch_device(cfg.device)
588
+
589
+ logger.info("HMM processing sample=%s ref=%s", sample, ref)
590
+
591
+ # ---- split feature sets ----
592
+ feature_sets_all = normalize_hmm_feature_sets(
593
+ getattr(cfg, "hmm_feature_sets", None)
594
+ )
595
+ feature_sets_access = {
596
+ k: v
597
+ for k, v in feature_sets_all.items()
598
+ if k in ("footprint", "accessible")
599
+ }
600
+ feature_sets_cpg = (
601
+ {"cpg": feature_sets_all["cpg"]} if "cpg" in feature_sets_all else {}
602
+ )
603
+
604
+ # =========================
605
+ # 1) Single-channel accessibility (per methbase)
606
+ # =========================
607
+ for mb in methbases:
608
+ logger.info("HMM single-channel for methbase=%s", mb)
609
+
610
+ try:
611
+ X, coords = build_single_channel(
612
+ subset,
613
+ ref=str(ref),
614
+ methbase=str(mb),
615
+ smf_modality=smf_modality,
616
+ cfg=cfg,
617
+ )
618
+ except Exception:
619
+ logger.warning(
620
+ "Skipping HMM single-channel for methbase=%s due to data error", mb
621
+ )
170
622
  continue
171
623
 
172
- # copy each appended layer into adata
173
- subset_mask_bool = mask.values if hasattr(mask, "values") else np.asarray(mask)
174
- for layer_name in appended:
175
- if layer_name not in subset.layers:
176
- # defensive: skip
177
- warnings.warn(f"Expected layer {layer_name} in subset but not found; skipping copy.")
178
- continue
179
- sub_layer = subset.layers[layer_name]
180
- # ensure final layer exists and assign rows
181
- try:
182
- hmm._ensure_final_layer_and_assign(adata, layer_name, subset_mask_bool, sub_layer)
183
- except Exception as e:
184
- warnings.warn(f"Failed to copy layer {layer_name} into adata: {e}", stacklevel=2)
185
- # fallback: if dense and small, try to coerce
186
- if issparse(sub_layer):
187
- arr = sub_layer.toarray()
188
- else:
189
- arr = np.asarray(sub_layer)
190
- adata.layers[layer_name] = adata.layers.get(layer_name, np.zeros((adata.shape[0], arr.shape[1]), dtype=arr.dtype))
191
- final_idx = np.nonzero(subset_mask_bool)[0]
192
- adata.layers[layer_name][final_idx, :] = arr
193
-
194
- # merge appended layer names into adata.uns
195
- existing = list(adata.uns.get(uns_key, []))
196
- for ln in appended:
197
- if ln not in existing:
198
- existing.append(ln)
199
- adata.uns[uns_key] = existing
624
+ arch = trainer.choose_arch(multichannel=False)
625
+
626
+ logger.info("HMM fitting/loading for methbase=%s", mb)
627
+ hmm = trainer.fit_or_load(
628
+ sample=str(sample),
629
+ ref=str(ref),
630
+ label=str(mb),
631
+ arch=arch,
632
+ X=X,
633
+ coords=coords,
634
+ device=device,
635
+ )
636
+
637
+ if not bypass_apply:
638
+ logger.info("HMM applying for methbase=%s", mb)
639
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
640
+ hmm.annotate_adata(
641
+ subset,
642
+ prefix=str(mb),
643
+ X=X,
644
+ coords=coords,
645
+ var_mask=pm,
646
+ span_fill=True,
647
+ config=cfg,
648
+ decode=decode,
649
+ write_posterior=write_post,
650
+ posterior_state=post_state,
651
+ feature_sets=feature_sets_access, # <--- ONLY accessibility feature sets
652
+ prob_threshold=prob_thr,
653
+ uns_key=uns_key,
654
+ uns_flag=f"hmm_annotated_{mb}",
655
+ force_redo=force_apply,
656
+ )
657
+
658
+ # merges for this mb
659
+ for core_layer, dist in (
660
+ getattr(cfg, "hmm_merge_layer_features", []) or []
661
+ ):
662
+ base_layer = f"{mb}_{core_layer}"
663
+ logger.info("Merging intervals for layer=%s", base_layer)
664
+ if base_layer in subset.layers:
665
+ merged_base = hmm.merge_intervals_to_new_layer(
666
+ subset,
667
+ base_layer,
668
+ distance_threshold=int(dist),
669
+ suffix=merged_suffix,
670
+ overwrite=True,
671
+ )
672
+ # write merged size classes based on whichever group core_layer corresponds to
673
+ for group, fs in feature_sets_access.items():
674
+ fmap = fs.get("features", {}) or {}
675
+ if fmap:
676
+ hmm.write_size_class_layers_from_binary(
677
+ subset,
678
+ merged_base,
679
+ out_prefix=str(mb),
680
+ feature_ranges=fmap,
681
+ suffix=merged_suffix,
682
+ overwrite=True,
683
+ )
684
+
685
+ # =========================
686
+ # 2) Multi-channel accessibility (Combined)
687
+ # =========================
688
+ if run_multi and len(methbases) >= 2:
689
+ logger.info("HMM multi-channel for methbases=%s", ",".join(methbases))
690
+ try:
691
+ X3, coords_u, used_mbs = build_multi_channel_union(
692
+ subset,
693
+ ref=str(ref),
694
+ methbases=methbases,
695
+ smf_modality=smf_modality,
696
+ cfg=cfg,
697
+ )
698
+ except Exception:
699
+ X3, coords_u, used_mbs = None, None, []
700
+ logger.warning(
701
+ "Skipping HMM multi-channel due to data error or insufficient methbases"
702
+ )
703
+
704
+ if X3 is not None and len(used_mbs) >= 2:
705
+ union_mask = None
706
+ for mb in used_mbs:
707
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
708
+ union_mask = pm if union_mask is None else (union_mask | pm)
709
+
710
+ arch = trainer.choose_arch(multichannel=True)
711
+
712
+ logger.info("HMM fitting/loading for multi-channel")
713
+ hmmc = trainer.fit_or_load(
714
+ sample=str(sample),
715
+ ref=str(ref),
716
+ label="Combined",
717
+ arch=arch,
718
+ X=X3,
719
+ coords=coords_u,
720
+ device=device,
721
+ )
722
+
723
+ if not bypass_apply:
724
+ logger.info("HMM applying for multi-channel")
725
+ hmmc.annotate_adata(
726
+ subset,
727
+ prefix="Combined",
728
+ X=X3,
729
+ coords=coords_u,
730
+ var_mask=union_mask,
731
+ span_fill=True,
732
+ config=cfg,
733
+ decode=decode,
734
+ write_posterior=write_post,
735
+ posterior_state=post_state,
736
+ feature_sets=feature_sets_access, # <--- accessibility only
737
+ prob_threshold=prob_thr,
738
+ uns_key=uns_key,
739
+ uns_flag="hmm_annotated_combined",
740
+ force_redo=force_apply,
741
+ )
742
+
743
+ for core_layer, dist in (
744
+ getattr(cfg, "hmm_merge_layer_features", []) or []
745
+ ):
746
+ base_layer = f"Combined_{core_layer}"
747
+ if base_layer in subset.layers:
748
+ merged_base = hmmc.merge_intervals_to_new_layer(
749
+ subset,
750
+ base_layer,
751
+ distance_threshold=int(dist),
752
+ suffix=merged_suffix,
753
+ overwrite=True,
754
+ )
755
+ for group, fs in feature_sets_access.items():
756
+ fmap = fs.get("features", {}) or {}
757
+ if fmap:
758
+ hmmc.write_size_class_layers_from_binary(
759
+ subset,
760
+ merged_base,
761
+ out_prefix="Combined",
762
+ feature_ranges=fmap,
763
+ suffix=merged_suffix,
764
+ overwrite=True,
765
+ )
766
+
767
+ # =========================
768
+ # 3) CpG-only single-channel task
769
+ # =========================
770
+ if run_cpg:
771
+ logger.info("HMM single-channel for CpG")
772
+ try:
773
+ Xcpg, coordscpg = build_single_channel(
774
+ subset,
775
+ ref=str(ref),
776
+ methbase="CpG",
777
+ smf_modality=smf_modality,
778
+ cfg=cfg,
779
+ )
780
+ except Exception:
781
+ Xcpg, coordscpg = None, None
782
+ logger.warning("Skipping HMM single-channel for CpG due to data error")
783
+
784
+ if Xcpg is not None and Xcpg.size and feature_sets_cpg:
785
+ arch = trainer.choose_arch(multichannel=False)
786
+
787
+ logger.info("HMM fitting/loading for CpG")
788
+ hmmg = trainer.fit_or_load(
789
+ sample=str(sample),
790
+ ref=str(ref),
791
+ label="CpG",
792
+ arch=arch,
793
+ X=Xcpg,
794
+ coords=coordscpg,
795
+ device=device,
796
+ )
797
+
798
+ if not bypass_apply:
799
+ logger.info("HMM applying for CpG")
800
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), "CpG")
801
+ hmmg.annotate_adata(
802
+ subset,
803
+ prefix="CpG",
804
+ X=Xcpg,
805
+ coords=coordscpg,
806
+ var_mask=pm,
807
+ span_fill=True,
808
+ config=cfg,
809
+ decode=decode,
810
+ write_posterior=write_post,
811
+ posterior_state=post_state,
812
+ feature_sets=feature_sets_cpg, # <--- ONLY cpg group (cpg_patch)
813
+ prob_threshold=prob_thr,
814
+ uns_key=uns_key,
815
+ uns_flag="hmm_annotated_CpG",
816
+ force_redo=force_apply,
817
+ )
818
+
819
+ # ------------------------------------------------------------
820
+ # Copy newly created subset layers back into the full adata
821
+ # ------------------------------------------------------------
822
+ appended = (
823
+ list(subset.uns.get(uns_key, []))
824
+ if subset.uns.get(uns_key) is not None
825
+ else []
826
+ )
827
+ if appended:
828
+ row_mask = np.asarray(
829
+ mask.values if hasattr(mask, "values") else mask, dtype=bool
830
+ )
200
831
 
201
- else:
202
- pass
203
-
204
- ## Save HMM annotated adata
205
- if not hmm_adata_path.exists():
206
- print('Saving hmm analyzed adata post preprocessing and duplicate removal')
207
- if ".gz" == hmm_adata_path.suffix:
208
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
209
- else:
210
- hmm_adata_path = hmm_adata_path.with_name(hmm_adata_path.name + '.gz')
211
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
832
+ for ln in appended:
833
+ if ln not in subset.layers:
834
+ continue
835
+ _ensure_layer_and_assign_rows(adata, ln, row_mask, subset.layers[ln])
836
+ if ln not in global_appended:
837
+ global_appended.append(ln)
212
838
 
213
- add_or_update_column_in_csv(cfg.summary_file, "hmm_adata", hmm_adata_path)
839
+ adata.uns[uns_key] = global_appended
214
840
 
215
- ########################################################################################################################
841
+ adata.uns["hmm_annotated"] = True
216
842
 
217
- ############################################### HMM based feature plotting ###############################################
218
-
219
- hmm_dir = pp_dir / "11_hmm_clustermaps"
843
+ hmm_layers = list(adata.uns.get("hmm_appended_layers", []) or [])
844
+ # keep only real feature layers; drop lengths/states/posterior
845
+ hmm_layers = [
846
+ layer
847
+ for layer in hmm_layers
848
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
849
+ ]
850
+ logger.info(f"HMM appended layers: {hmm_layers}")
220
851
 
852
+ # ---------------------------- HMM peak calling stage ----------------------------
853
+ hmm_dir = pp_dir / "11_hmm_peak_calling"
221
854
  if hmm_dir.is_dir():
222
- print(f'{hmm_dir} already exists.')
855
+ pass
223
856
  else:
224
857
  make_dirs([pp_dir, hmm_dir])
225
- from ..plotting import combined_hmm_raw_clustermap
226
- feature_layers = [
227
- "all_accessible_features",
228
- "large_accessible_patch",
229
- "small_bound_stretch",
230
- "medium_bound_stretch",
231
- "putative_nucleosome",
232
- "all_accessible_features_merged",
233
- ]
234
858
 
235
- layers: list[str] = []
859
+ call_hmm_peaks(
860
+ adata,
861
+ feature_configs=cfg.hmm_peak_feature_configs,
862
+ ref_column=cfg.reference_column,
863
+ site_types=cfg.mod_target_bases,
864
+ save_plot=True,
865
+ output_dir=hmm_dir,
866
+ index_col_suffix=cfg.reindexed_var_suffix,
867
+ )
236
868
 
237
- if any(base in ["C", "CpG", "GpC"] for base in cfg.mod_target_bases):
238
- if smf_modality == 'deaminase':
239
- layers.extend([f"C_{layer}" for layer in feature_layers])
240
- elif smf_modality == 'conversion':
241
- layers.extend([f"GpC_{layer}" for layer in feature_layers])
869
+ ## Save HMM annotated adata
870
+ if not paths.hmm.exists():
871
+ logger.info("Saving hmm analyzed AnnData (post preprocessing and duplicate removal).")
872
+ record_smftools_metadata(
873
+ adata,
874
+ step_name="hmm",
875
+ cfg=cfg,
876
+ config_path=config_path,
877
+ input_paths=[source_adata_path] if source_adata_path else None,
878
+ output_path=paths.hmm,
879
+ )
880
+ write_gz_h5ad(adata, paths.hmm)
242
881
 
243
- if 'A' in cfg.mod_target_bases:
244
- layers.extend([f"A_{layer}" for layer in feature_layers])
882
+ ########################################################################################################################
245
883
 
246
- if not layers:
247
- raise ValueError(
248
- f"No HMM feature layers matched mod_target_bases={cfg.mod_target_bases} "
249
- f"and smf_modality={smf_modality}"
250
- )
251
-
252
- if smf_modality == 'direct':
253
- sort_by = "any_a"
254
- else:
255
- sort_by = 'gpc'
884
+ ############################################### HMM based feature plotting ###############################################
256
885
 
257
- for layer in layers:
258
- save_path = hmm_dir / layer
259
- make_dirs([save_path])
886
+ hmm_dir = pp_dir / "12_hmm_clustermaps"
887
+ make_dirs([pp_dir, hmm_dir])
888
+
889
+ layers: list[str] = []
890
+
891
+ for base in cfg.hmm_methbases:
892
+ layers.extend([f"{base}_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
893
+
894
+ if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
895
+ layers.extend([f"Combined_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
896
+
897
+ if cfg.cpg:
898
+ layers.extend(["CpG_cpg_patch"])
899
+
900
+ if not layers:
901
+ raise ValueError(
902
+ f"No HMM feature layers matched mod_target_bases={cfg.mod_target_bases} "
903
+ f"and smf_modality={smf_modality}"
904
+ )
905
+
906
+ for layer in layers:
907
+ hmm_cluster_save_dir = hmm_dir / layer
908
+ if hmm_cluster_save_dir.is_dir():
909
+ pass
910
+ else:
911
+ make_dirs([hmm_cluster_save_dir])
260
912
 
261
913
  combined_hmm_raw_clustermap(
262
- adata,
263
- sample_col=cfg.sample_name_col_for_plotting,
264
- reference_col=cfg.reference_column,
265
- hmm_feature_layer=layer,
266
- layer_gpc="nan0_0minus1",
267
- layer_cpg="nan0_0minus1",
268
- layer_any_c="nan0_0minus1",
269
- layer_a= "nan0_0minus1",
270
- cmap_hmm="coolwarm",
271
- cmap_gpc="coolwarm",
272
- cmap_cpg="viridis",
273
- cmap_any_c='coolwarm',
274
- cmap_a= "coolwarm",
275
- min_quality=cfg.read_quality_filter_thresholds[0],
276
- min_length=cfg.read_len_filter_thresholds[0],
277
- min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[0],
278
- min_position_valid_fraction=1-cfg.position_max_nan_threshold,
279
- save_path=save_path,
280
- normalize_hmm=False,
281
- sort_by=sort_by, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
282
- bins=None,
283
- deaminase=deaminase,
284
- min_signal=0
914
+ adata,
915
+ sample_col=cfg.sample_name_col_for_plotting,
916
+ reference_col=cfg.reference_column,
917
+ hmm_feature_layer=layer,
918
+ layer_gpc=cfg.layer_for_clustermap_plotting,
919
+ layer_cpg=cfg.layer_for_clustermap_plotting,
920
+ layer_c=cfg.layer_for_clustermap_plotting,
921
+ layer_a=cfg.layer_for_clustermap_plotting,
922
+ cmap_hmm=cfg.clustermap_cmap_hmm,
923
+ cmap_gpc=cfg.clustermap_cmap_gpc,
924
+ cmap_cpg=cfg.clustermap_cmap_cpg,
925
+ cmap_c=cfg.clustermap_cmap_c,
926
+ cmap_a=cfg.clustermap_cmap_a,
927
+ min_quality=cfg.read_quality_filter_thresholds[0],
928
+ min_length=cfg.read_len_filter_thresholds[0],
929
+ min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
930
+ 0
931
+ ],
932
+ min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
933
+ demux_types=("double", "already"),
934
+ save_path=hmm_cluster_save_dir,
935
+ normalize_hmm=False,
936
+ sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
937
+ bins=None,
938
+ deaminase=deaminase,
939
+ min_signal=0,
940
+ index_col_suffix=cfg.reindexed_var_suffix,
285
941
  )
286
942
 
287
- hmm_dir = pp_dir / "12_hmm_bulk_traces"
943
+ hmm_dir = pp_dir / "13_hmm_bulk_traces"
288
944
 
289
945
  if hmm_dir.is_dir():
290
- print(f'{hmm_dir} already exists.')
946
+ logger.debug(f"{hmm_dir} already exists.")
291
947
  else:
292
948
  make_dirs([pp_dir, hmm_dir])
293
949
  from ..plotting import plot_hmm_layers_rolling_by_sample_ref
950
+
951
+ bulk_hmm_layers = [
952
+ layer
953
+ for layer in hmm_layers
954
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
955
+ ]
294
956
  saved = plot_hmm_layers_rolling_by_sample_ref(
295
957
  adata,
296
- layers=adata.uns['hmm_appended_layers'],
958
+ layers=bulk_hmm_layers,
297
959
  sample_col=cfg.sample_name_col_for_plotting,
298
960
  ref_col=cfg.reference_column,
299
961
  window=101,
300
962
  rows_per_page=4,
301
- figsize_per_cell=(4,2.5),
963
+ figsize_per_cell=(4, 2.5),
302
964
  output_dir=hmm_dir,
303
965
  save=True,
304
- show_raw=False
966
+ show_raw=False,
305
967
  )
306
968
 
307
- hmm_dir = pp_dir / "13_hmm_fragment_distributions"
969
+ hmm_dir = pp_dir / "14_hmm_fragment_distributions"
308
970
 
309
971
  if hmm_dir.is_dir():
310
- print(f'{hmm_dir} already exists.')
972
+ logger.debug(f"{hmm_dir} already exists.")
311
973
  else:
312
974
  make_dirs([pp_dir, hmm_dir])
313
975
  from ..plotting import plot_hmm_size_contours
314
976
 
315
- for layer, max in [('C_all_accessible_features_lengths', 400), ('C_all_footprint_features_lengths', 160), ('C_all_accessible_features_merged_lengths', 800)]:
977
+ if smf_modality == "deaminase":
978
+ fragments = [
979
+ ("C_all_accessible_features_lengths", 400),
980
+ ("C_all_footprint_features_lengths", 250),
981
+ ("C_all_accessible_features_merged_lengths", 800),
982
+ ]
983
+ elif smf_modality == "conversion":
984
+ fragments = [
985
+ ("GpC_all_accessible_features_lengths", 400),
986
+ ("GpC_all_footprint_features_lengths", 250),
987
+ ("GpC_all_accessible_features_merged_lengths", 800),
988
+ ]
989
+ elif smf_modality == "direct":
990
+ fragments = [
991
+ ("A_all_accessible_features_lengths", 400),
992
+ ("A_all_footprint_features_lengths", 200),
993
+ ("A_all_accessible_features_merged_lengths", 800),
994
+ ]
995
+
996
+ for layer, max in fragments:
316
997
  save_path = hmm_dir / layer
317
998
  make_dirs([save_path])
318
999
 
@@ -328,11 +1009,11 @@ def hmm_adata(config_path):
328
1009
  save_pdf=False,
329
1010
  save_each_page=True,
330
1011
  dpi=200,
331
- smoothing_sigma=None,
332
- normalize_after_smoothing=False,
333
- cmap='viridis',
334
- log_scale_z=True
1012
+ smoothing_sigma=(10, 10),
1013
+ normalize_after_smoothing=True,
1014
+ cmap="Greens",
1015
+ log_scale_z=True,
335
1016
  )
336
1017
  ########################################################################################################################
337
1018
 
338
- return (adata, hmm_adata_path)
1019
+ return (adata, paths.hmm)