smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/cli/hmm_adata.py CHANGED
@@ -1,223 +1,860 @@
1
- def hmm_adata(config_path):
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from smftools.logging_utils import get_logger
11
+ from smftools.optional_imports import require
12
+
13
+ # FIX: import _to_dense_np to avoid NameError
14
+ from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ if TYPE_CHECKING:
19
+ import torch as torch_types
20
+
21
+ torch = require("torch", extra="torch", purpose="HMM CLI")
22
+
23
+ # =============================================================================
24
+ # Helpers: extracting training arrays
25
+ # =============================================================================
26
+
27
+
28
+ def _get_training_matrix(
29
+ subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
30
+ ) -> Tuple[np.ndarray, Optional[str]]:
31
+ """
32
+ Matches your existing behavior:
33
+ - direct -> uses cfg.output_binary_layer_name in .layers
34
+ - else -> uses .X
35
+ Returns (X, layer_name_or_None) where X is dense float array.
36
+ """
37
+ sub = subset[:, cols_mask]
38
+
39
+ if smf_modality == "direct":
40
+ hmm_layer = getattr(cfg, "output_binary_layer_name", None)
41
+ if hmm_layer is None or hmm_layer not in sub.layers:
42
+ raise KeyError(f"Missing HMM training layer '{hmm_layer}' in subset.")
43
+
44
+ logger.debug("Using direct modality HMM training layer: %s", hmm_layer)
45
+ mat = sub.layers[hmm_layer]
46
+ else:
47
+ logger.debug("Using .X for HMM training matrix")
48
+ hmm_layer = None
49
+ mat = sub.X
50
+
51
+ X = _to_dense_np(mat).astype(float)
52
+ if X.ndim != 2:
53
+ raise ValueError(f"Expected 2D training matrix; got {X.shape}")
54
+ return X, hmm_layer
55
+
56
+
57
+ def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
2
58
  """
3
- High-level function to call for hmm analysis of an adata object.
4
- Command line accesses this through smftools hmm <config_path>
59
+ Reproduces your mask resolution, with compatibility for both *_any_C_site and *_C_site.
60
+ """
61
+ key = str(methbase).strip().lower()
62
+
63
+ logger.debug("Resolving position mask for methbase=%s on ref=%s", key, ref)
64
+
65
+ if key in ("a",):
66
+ col = f"{ref}_A_site"
67
+ if col not in subset.var:
68
+ return None
69
+ logger.debug("Using positions with A calls from column: %s", col)
70
+ return np.asarray(subset.var[col])
71
+
72
+ if key in ("c", "any_c", "anyc", "any-c"):
73
+ for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
74
+ if col in subset.var:
75
+ logger.debug("Using positions with C calls from column: %s", col)
76
+ return np.asarray(subset.var[col])
77
+ return None
78
+
79
+ if key in ("gpc", "gpc_site", "gpc-site"):
80
+ col = f"{ref}_GpC_site"
81
+ if col not in subset.var:
82
+ return None
83
+ logger.debug("Using positions with GpC calls from column: %s", col)
84
+ return np.asarray(subset.var[col])
85
+
86
+ if key in ("cpg", "cpg_site", "cpg-site"):
87
+ col = f"{ref}_CpG_site"
88
+ if col not in subset.var:
89
+ return None
90
+ logger.debug("Using positions with CpG calls from column: %s", col)
91
+ return np.asarray(subset.var[col])
5
92
 
6
- Parameters:
7
- config_path (str): A string representing the file path to the experiment configuration csv file.
93
+ alt = f"{ref}_{methbase}_site"
94
+ if alt not in subset.var:
95
+ return None
8
96
 
97
+ logger.debug("Using positions from column: %s", alt)
98
+ return np.asarray(subset.var[alt])
99
+
100
+
101
+ def build_single_channel(
102
+ subset, ref: str, methbase: str, smf_modality: Optional[str], cfg
103
+ ) -> Tuple[np.ndarray, np.ndarray]:
104
+ """
9
105
  Returns:
10
- (pp_dedup_spatial_hmm_adata, pp_dedup_spatial_hmm_adata_path)
106
+ X (N, Lmb) float with NaNs allowed
107
+ coords (Lmb,) int coords from var_names
108
+ """
109
+ pm = _resolve_pos_mask_for_methbase(subset, ref, methbase)
110
+ logger.debug(
111
+ "Position mask for methbase=%s on ref=%s has %d sites",
112
+ methbase,
113
+ ref,
114
+ int(np.sum(pm)) if pm is not None else 0,
115
+ )
116
+
117
+ if pm is None or int(np.sum(pm)) == 0:
118
+ raise ValueError(f"No columns for methbase={methbase} on ref={ref}")
119
+
120
+ X, _ = _get_training_matrix(subset, pm, smf_modality, cfg)
121
+ logger.debug("Training matrix for methbase=%s on ref=%s has shape %s", methbase, ref, X.shape)
122
+
123
+ coords, _ = _safe_int_coords(subset[:, pm].var_names)
124
+ logger.debug(
125
+ "Coordinates for methbase=%s on ref=%s have length %d", methbase, ref, coords.shape[0]
126
+ )
127
+
128
+ return X, coords
129
+
130
+
131
+ def build_multi_channel_union(
132
+ subset, ref: str, methbases: Sequence[str], smf_modality: Optional[str], cfg
133
+ ) -> Tuple[np.ndarray, np.ndarray, List[str]]:
134
+ """
135
+ Build (N, Lunion, C) on union coordinate grid across methbases.
136
+
137
+ Returns:
138
+ X3d: (N, Lunion, C) float with NaN where methbase has no site
139
+ coords: (Lunion,) int union coords
140
+ used_methbases: list of methbases actually included (>=2)
141
+ """
142
+ per: List[Tuple[str, np.ndarray, np.ndarray, np.ndarray]] = [] # (mb, X, coords, pm)
143
+
144
+ for mb in methbases:
145
+ pm = _resolve_pos_mask_for_methbase(subset, ref, mb)
146
+ if pm is None or int(np.sum(pm)) == 0:
147
+ continue
148
+ Xmb, _ = _get_training_matrix(subset, pm, smf_modality, cfg) # (N,Lmb)
149
+ cmb, _ = _safe_int_coords(subset[:, pm].var_names)
150
+ per.append((mb, Xmb.astype(float), cmb.astype(int), pm))
151
+
152
+ if len(per) < 2:
153
+ raise ValueError(f"Need >=2 methbases with columns for union multi-channel on ref={ref}")
154
+
155
+ # union coordinates
156
+ coords = np.unique(np.concatenate([c for _, _, c, _ in per], axis=0)).astype(int)
157
+ idx = {int(v): i for i, v in enumerate(coords.tolist())}
158
+
159
+ N = per[0][1].shape[0]
160
+ L = coords.shape[0]
161
+ C = len(per)
162
+ X3 = np.full((N, L, C), np.nan, dtype=float)
163
+
164
+ for ci, (mb, Xmb, cmb, _) in enumerate(per):
165
+ cols = np.array([idx[int(v)] for v in cmb.tolist()], dtype=int)
166
+ X3[:, cols, ci] = Xmb
167
+
168
+ used = [mb for (mb, _, _, _) in per]
169
+ return X3, coords, used
170
+
171
+
172
+ @dataclass
173
+ class HMMTask:
174
+ name: str
175
+ signals: List[str] # e.g. ["GpC"] or ["GpC","CpG"] or ["CpG"]
176
+ feature_groups: List[str] # e.g. ["footprint","accessible"] or ["cpg"]
177
+ output_prefix: Optional[str] = None # force prefix (CpG task uses "CpG")
178
+
179
+
180
+ def build_hmm_tasks(cfg: Union[dict, Any]) -> List[HMMTask]:
181
+ """
182
+ Accessibility signals come from cfg['hmm_methbases'].
183
+ CpG task is enabled by cfg['cpg']==True, independent of hmm_methbases.
184
+ """
185
+ if not isinstance(cfg, dict):
186
+ # best effort conversion
187
+ cfg = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith("_")}
188
+
189
+ tasks: List[HMMTask] = []
190
+
191
+ # accessibility task
192
+ methbases = list(cfg.get("hmm_methbases", []) or [])
193
+ if len(methbases) > 0:
194
+ tasks.append(
195
+ HMMTask(
196
+ name="accessibility",
197
+ signals=methbases,
198
+ feature_groups=["footprint", "accessible"],
199
+ output_prefix=None,
200
+ )
201
+ )
202
+
203
+ # CpG task (special case)
204
+ if bool(cfg.get("cpg", False)):
205
+ tasks.append(
206
+ HMMTask(
207
+ name="cpg",
208
+ signals=["CpG"],
209
+ feature_groups=["cpg"],
210
+ output_prefix="CpG",
211
+ )
212
+ )
213
+
214
+ return tasks
215
+
216
+
217
+ def select_hmm_arch(cfg: dict, signals: Sequence[str]) -> str:
218
+ """
219
+ Simple, explicit model selection:
220
+ - distance-aware => 'single_distance_binned' (only meaningful for single-channel)
221
+ - multi-signal => 'multi'
222
+ - else => 'single'
223
+ """
224
+ if bool(cfg.get("hmm_distance_aware", False)) and len(signals) == 1:
225
+ return "single_distance_binned"
226
+ if len(signals) > 1:
227
+ return "multi"
228
+ return "single"
229
+
230
+
231
+ def resolve_input_layer(adata, cfg: dict, layer_override: Optional[str]) -> Optional[str]:
232
+ """
233
+ If direct modality, prefer cfg.output_binary_layer_name.
234
+ Else use layer_override or None (meaning use .X).
235
+ """
236
+ smf_modality = cfg.get("smf_modality", None)
237
+ if smf_modality == "direct":
238
+ nm = cfg.get("output_binary_layer_name", None)
239
+ if nm is None:
240
+ raise KeyError("cfg.output_binary_layer_name missing for smf_modality='direct'")
241
+ if nm not in adata.layers:
242
+ raise KeyError(f"Direct modality expects layer '{nm}' in adata.layers")
243
+ return nm
244
+ return layer_override
245
+
246
+
247
+ def _ensure_layer_and_assign_rows(adata, layer_name: str, row_mask: np.ndarray, subset_layer):
248
+ """
249
+ Writes subset_layer (n_subset_obs, n_vars) into adata.layers[layer_name] for rows where row_mask==True.
250
+ """
251
+ row_mask = np.asarray(row_mask, dtype=bool)
252
+ if row_mask.ndim != 1 or row_mask.size != adata.n_obs:
253
+ raise ValueError("row_mask must be length adata.n_obs")
254
+
255
+ arr = _to_dense_np(subset_layer)
256
+ if arr.shape != (int(row_mask.sum()), adata.n_vars):
257
+ raise ValueError(
258
+ f"subset layer '{layer_name}' shape {arr.shape} != ({int(row_mask.sum())}, {adata.n_vars})"
259
+ )
260
+
261
+ if layer_name not in adata.layers:
262
+ adata.layers[layer_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=arr.dtype)
263
+
264
+ adata.layers[layer_name][row_mask, :] = arr
265
+
266
+
267
+ def resolve_torch_device(device_str: str | None) -> torch.device:
268
+ d = (device_str or "auto").lower()
269
+ if d == "auto":
270
+ if torch.cuda.is_available():
271
+ return torch.device("cuda")
272
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
273
+ return torch.device("mps")
274
+ return torch.device("cpu")
275
+ return torch.device(d)
276
+
277
+
278
+ # =============================================================================
279
+ # Model selection + fit strategy manager
280
+ # =============================================================================
281
+ @dataclass
282
+ class HMMTrainer:
283
+ cfg: Any
284
+ models_dir: Path
285
+
286
+ def __post_init__(self):
287
+ self.models_dir = Path(self.models_dir)
288
+ self.models_dir.mkdir(parents=True, exist_ok=True)
289
+
290
+ def choose_arch(self, *, multichannel: bool) -> str:
291
+ use_dist = bool(getattr(self.cfg, "hmm_distance_aware", False))
292
+ if multichannel:
293
+ return "multi"
294
+ return "single_distance_binned" if use_dist else "single"
295
+
296
+ def _fit_scope(self) -> str:
297
+ return str(getattr(self.cfg, "hmm_fit_scope", "per_sample")).lower()
298
+ # "per_sample" | "global" | "global_then_adapt"
299
+
300
+ def _path(self, kind: str, sample: str, ref: str, label: str) -> Path:
301
+ # kind: "GLOBAL" | "PER" | "ADAPT"
302
+ def safe(s):
303
+ str(s).replace("/", "_")
304
+
305
+ return self.models_dir / f"{kind}_{safe(sample)}_{safe(ref)}_{safe(label)}.pt"
306
+
307
+ def _save(self, model, path: Path):
308
+ override = {}
309
+ if getattr(model, "hmm_name", None) == "multi":
310
+ override["hmm_n_channels"] = int(getattr(model, "n_channels", 2))
311
+ if getattr(model, "hmm_name", None) == "single_distance_binned":
312
+ override["hmm_distance_bins"] = list(
313
+ getattr(model, "distance_bins", [1, 5, 10, 25, 50, 100])
314
+ )
315
+
316
+ payload = {
317
+ "state_dict": model.state_dict(),
318
+ "hmm_arch": getattr(model, "hmm_name", None) or getattr(self.cfg, "hmm_arch", None),
319
+ "override": override,
320
+ }
321
+ torch.save(payload, path)
322
+
323
+ def _load(self, path: Path, arch: str, device):
324
+ payload = torch.load(path, map_location="cpu")
325
+ override = payload.get("override", None)
326
+ m = create_hmm(self.cfg, arch=arch, override=override, device=device)
327
+ sd = payload["state_dict"]
328
+
329
+ target_dtype = next(m.parameters()).dtype
330
+ for k, v in list(sd.items()):
331
+ if isinstance(v, torch.Tensor) and v.dtype != target_dtype:
332
+ sd[k] = v.to(dtype=target_dtype)
333
+
334
+ m.load_state_dict(sd)
335
+ m.to(device)
336
+ m.eval()
337
+ return m
338
+
339
+ def fit_or_load(
340
+ self,
341
+ *,
342
+ sample: str,
343
+ ref: str,
344
+ label: str,
345
+ arch: str,
346
+ X,
347
+ coords: Optional[np.ndarray],
348
+ device,
349
+ ):
350
+ force_fit = bool(getattr(self.cfg, "force_redo_hmm_fit", False))
351
+ scope = self._fit_scope()
352
+
353
+ max_iter = int(getattr(self.cfg, "hmm_max_iter", 50))
354
+ tol = float(getattr(self.cfg, "hmm_tol", 1e-4))
355
+ verbose = bool(getattr(self.cfg, "hmm_verbose", False))
356
+
357
+ # ---- global then adapt ----
358
+ if scope == "global_then_adapt":
359
+ p_global = self._path("GLOBAL", "ALL", ref, label)
360
+ if p_global.exists() and not force_fit:
361
+ base = self._load(p_global, arch=arch, device=device)
362
+ else:
363
+ base = create_hmm(self.cfg, arch=arch).to(device)
364
+ if arch == "single_distance_binned":
365
+ base.fit(
366
+ X, device=device, coords=coords, max_iter=max_iter, tol=tol, verbose=verbose
367
+ )
368
+ else:
369
+ base.fit(X, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
370
+ self._save(base, p_global)
371
+
372
+ p_adapt = self._path("ADAPT", sample, ref, label)
373
+ if p_adapt.exists() and not force_fit:
374
+ return self._load(p_adapt, arch=arch, device=device)
375
+
376
+ # IMPORTANT: this assumes you added model.adapt_emissions(...)
377
+ adapted = copy.deepcopy(base).to(device)
378
+ if arch == "single_distance_binned":
379
+ adapted.adapt_emissions(
380
+ X,
381
+ coords,
382
+ device=device,
383
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
384
+ verbose=verbose,
385
+ )
386
+
387
+ else:
388
+ adapted.adapt_emissions(
389
+ X,
390
+ coords,
391
+ device=device,
392
+ max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
393
+ verbose=verbose,
394
+ )
395
+
396
+ self._save(adapted, p_adapt)
397
+ return adapted
398
+
399
+ # ---- global only ----
400
+ if scope == "global":
401
+ p = self._path("GLOBAL", "ALL", ref, label)
402
+ if p.exists() and not force_fit:
403
+ return self._load(p, arch=arch, device=device)
404
+
405
+ # ---- per sample ----
406
+ else:
407
+ p = self._path("PER", sample, ref, label)
408
+ if p.exists() and not force_fit:
409
+ return self._load(p, arch=arch, device=device)
410
+
411
+ m = create_hmm(self.cfg, arch=arch, device=device)
412
+ if arch == "single_distance_binned":
413
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
414
+ else:
415
+ m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
416
+ self._save(m, p)
417
+ return m
418
+
419
+
420
+ def _fully_qualified_merge_layers(cfg, prefix: str) -> List[Tuple[str, int]]:
421
+ """
422
+ cfg.hmm_merge_layer_features is assumed to be a list of (core_layer_name, merge_distance),
423
+ where core_layer_name is like "all_accessible_features" (NOT prefixed with methbase).
424
+ We expand to f"{prefix}_{core_layer_name}".
425
+ """
426
+ out = []
427
+ for core_layer, dist in getattr(cfg, "hmm_merge_layer_features", []) or []:
428
+ if not core_layer:
429
+ continue
430
+ out.append((f"{prefix}_{core_layer}", int(dist)))
431
+ return out
432
+
433
+
434
+ def hmm_adata(config_path: str):
435
+ """
436
+ CLI-facing wrapper for HMM analysis.
437
+
438
+ Command line entrypoint:
439
+ smftools hmm <config_path>
440
+
441
+ Responsibilities:
442
+ - Build cfg via load_adata()
443
+ - Ensure preprocess + spatial stages are run
444
+ - Decide which AnnData to start from (hmm > spatial > pp_dedup > pp > raw)
445
+ - Call hmm_adata_core(cfg, adata, paths)
11
446
  """
12
- from ..readwrite import safe_read_h5ad, safe_write_h5ad, make_dirs, add_or_update_column_in_csv
447
+ from ..readwrite import safe_read_h5ad
448
+ from .helpers import get_adata_paths
13
449
  from .load_adata import load_adata
14
450
  from .preprocess_adata import preprocess_adata
15
451
  from .spatial_adata import spatial_adata
16
452
 
17
- import numpy as np
18
- import pandas as pd
19
- import anndata as ad
20
- import scanpy as sc
21
-
22
- import os
23
- from importlib import resources
24
- from pathlib import Path
453
+ # 1) load cfg / stage paths
454
+ _, _, cfg = load_adata(config_path)
455
+ paths = get_adata_paths(cfg)
25
456
 
26
- from datetime import datetime
27
- date_str = datetime.today().strftime("%y%m%d")
457
+ # 2) make sure upstream stages are run (they have their own skipping logic)
458
+ preprocess_adata(config_path)
459
+ spatial_ad, spatial_path = spatial_adata(config_path)
28
460
 
29
- ############################################### smftools load start ###############################################
30
- adata, adata_path, cfg = load_adata(config_path)
31
- # General config variable init - Necessary user passed inputs
32
- smf_modality = cfg.smf_modality # needed for specifying if the data is conversion SMF or direct methylation detection SMF. Or deaminase smf Necessary.
33
- output_directory = Path(cfg.output_directory) # Path to the output directory to make for the analysis. Necessary.
461
+ # 3) choose starting AnnData
462
+ # Prefer:
463
+ # - existing HMM h5ad if not forcing redo
464
+ # - in-memory spatial_ad from wrapper call
465
+ # - saved spatial / pp_dedup / pp / raw on disk
466
+ if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
467
+ adata, _ = safe_read_h5ad(paths.hmm)
468
+ return adata, paths.hmm
34
469
 
35
- # Make initial output directory
36
- make_dirs([output_directory])
37
- ############################################### smftools load end ###############################################
38
-
39
- ############################################### smftools preprocess start ###############################################
40
- pp_adata, pp_adata_path, pp_dedup_adata, pp_dup_rem_adata_path = preprocess_adata(config_path)
41
- ############################################### smftools preprocess end ###############################################
42
-
43
- ############################################### smftools spatial start ###############################################
44
- spatial_ad, spatial_adata_path = spatial_adata(config_path)
45
- ############################################### smftools spatial end ###############################################
46
-
47
- ############################################### smftools hmm start ###############################################
48
- input_manager_df = pd.read_csv(cfg.summary_file)
49
- initial_adata_path = Path(input_manager_df['load_adata'][0])
50
- pp_adata_path = Path(input_manager_df['pp_adata'][0])
51
- pp_dup_rem_adata_path = Path(input_manager_df['pp_dedup_adata'][0])
52
- spatial_adata_path = Path(input_manager_df['spatial_adata'][0])
53
- hmm_adata_path = Path(input_manager_df['hmm_adata'][0])
54
-
55
- if spatial_ad:
56
- # This happens on first run of the pipeline
470
+ if spatial_ad is not None:
57
471
  adata = spatial_ad
472
+ source_path = spatial_path
473
+ elif paths.spatial.exists():
474
+ adata, _ = safe_read_h5ad(paths.spatial)
475
+ source_path = paths.spatial
476
+ elif paths.pp_dedup.exists():
477
+ adata, _ = safe_read_h5ad(paths.pp_dedup)
478
+ source_path = paths.pp_dedup
479
+ elif paths.pp.exists():
480
+ adata, _ = safe_read_h5ad(paths.pp)
481
+ source_path = paths.pp
482
+ elif paths.raw.exists():
483
+ adata, _ = safe_read_h5ad(paths.raw)
484
+ source_path = paths.raw
58
485
  else:
59
- # If an anndata is saved, check which stages of the anndata are available
60
- initial_version_available = initial_adata_path.exists()
61
- preprocessed_version_available = pp_adata_path.exists()
62
- preprocessed_dup_removed_version_available = pp_dup_rem_adata_path.exists()
63
- preprocessed_dedup_spatial_version_available = spatial_adata_path.exists()
64
- preprocessed_dedup_spatial_hmm_version_available = hmm_adata_path.exists()
65
-
66
- if cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply:
67
- print(f"Forcing redo of hmm analysis workflow.")
68
- if preprocessed_dedup_spatial_hmm_version_available:
69
- adata, load_report = safe_read_h5ad(hmm_adata_path)
70
- elif preprocessed_dedup_spatial_version_available:
71
- adata, load_report = safe_read_h5ad(spatial_adata_path)
72
- elif preprocessed_dup_removed_version_available:
73
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
74
- elif initial_version_available:
75
- adata, load_report = safe_read_h5ad(initial_adata_path)
76
- else:
77
- print(f"Can not redo duplicate detection when there is no compatible adata available: either raw or preprocessed are required")
78
- elif preprocessed_dedup_spatial_hmm_version_available:
79
- adata, load_report = safe_read_h5ad(hmm_adata_path)
80
- else:
81
- if preprocessed_dedup_spatial_version_available:
82
- adata, load_report = safe_read_h5ad(spatial_adata_path)
83
- elif preprocessed_dup_removed_version_available:
84
- adata, load_report = safe_read_h5ad(pp_dup_rem_adata_path)
85
- elif initial_version_available:
86
- adata, load_report = safe_read_h5ad(initial_adata_path)
87
- else:
88
- print(f"No adata available.")
89
- return
90
- references = adata.obs[cfg.reference_column].cat.categories
91
- deaminase = smf_modality == 'deaminase'
92
- ############################################### HMM based feature annotations ###############################################
486
+ raise FileNotFoundError(
487
+ "No AnnData available for HMM: expected at least raw or preprocessed h5ad."
488
+ )
489
+
490
+ # 4) delegate to core
491
+ adata, hmm_adata_path = hmm_adata_core(
492
+ cfg,
493
+ adata,
494
+ paths,
495
+ source_adata_path=source_path,
496
+ config_path=config_path,
497
+ )
498
+ return adata, hmm_adata_path
499
+
500
+
501
+ def hmm_adata_core(
502
+ cfg,
503
+ adata,
504
+ paths,
505
+ source_adata_path: Path | None = None,
506
+ config_path: str | None = None,
507
+ ) -> Tuple["anndata.AnnData", Path]:
508
+ """
509
+ Core HMM analysis pipeline.
510
+
511
+ Assumes:
512
+ - cfg is an ExperimentConfig
513
+ - adata is the starting AnnData (typically spatial + dedup)
514
+ - paths is an AdataPaths object (with .raw/.pp/.pp_dedup/.spatial/.hmm)
515
+
516
+ Does NOT decide which h5ad to start from – that is the wrapper's job.
517
+ """
518
+
519
+ import numpy as np
520
+
521
+ from ..hmm import call_hmm_peaks
522
+ from ..metadata import record_smftools_metadata
523
+ from ..plotting import (
524
+ combined_hmm_raw_clustermap,
525
+ plot_hmm_layers_rolling_by_sample_ref,
526
+ plot_hmm_size_contours,
527
+ )
528
+ from ..readwrite import make_dirs
529
+ from .helpers import write_gz_h5ad
530
+
531
+ smf_modality = cfg.smf_modality
532
+ deaminase = smf_modality == "deaminase"
533
+
534
+ output_directory = Path(cfg.output_directory)
535
+ make_dirs([output_directory])
536
+
537
+ pp_dir = output_directory / "preprocessed" / "deduplicated"
538
+
539
+ # ---------------------------- HMM annotate stage ----------------------------
93
540
  if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
94
- from ..hmm.HMM import HMM
95
- from scipy.sparse import issparse, csr_matrix
96
- import warnings
541
+ hmm_models_dir = pp_dir / "10_hmm_models"
542
+ make_dirs([pp_dir, hmm_models_dir])
97
543
 
98
- pp_dir = output_directory / "preprocessed"
99
- pp_dir = pp_dir / "deduplicated"
100
- hmm_dir = pp_dir / "10_hmm_models"
544
+ # Standard bookkeeping
545
+ uns_key = "hmm_appended_layers"
546
+ if adata.uns.get(uns_key) is None:
547
+ adata.uns[uns_key] = []
548
+ global_appended = list(adata.uns.get(uns_key, []))
101
549
 
102
- if hmm_dir.is_dir():
103
- print(f'{hmm_dir} already exists.')
104
- else:
105
- make_dirs([pp_dir, hmm_dir])
550
+ # Prepare trainer + feature config
551
+ trainer = HMMTrainer(cfg=cfg, models_dir=hmm_models_dir)
552
+
553
+ feature_sets = normalize_hmm_feature_sets(getattr(cfg, "hmm_feature_sets", None))
554
+ prob_thr = float(getattr(cfg, "hmm_feature_prob_threshold", 0.5))
555
+ decode = str(getattr(cfg, "hmm_decode", "marginal"))
556
+ write_post = bool(getattr(cfg, "hmm_write_posterior", True))
557
+ post_state = getattr(cfg, "hmm_posterior_state", "Modified")
558
+ merged_suffix = str(getattr(cfg, "hmm_merged_suffix", "_merged"))
559
+ force_apply = bool(getattr(cfg, "force_redo_hmm_apply", False))
560
+ bypass_apply = bool(getattr(cfg, "bypass_hmm_apply", False))
561
+ bypass_fit = bool(getattr(cfg, "bypass_hmm_fit", False))
106
562
 
107
563
  samples = adata.obs[cfg.sample_name_col_for_plotting].cat.categories
108
564
  references = adata.obs[cfg.reference_column].cat.categories
109
- uns_key = "hmm_appended_layers"
565
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
110
566
 
111
- # ensure uns key exists (avoid KeyError later)
112
- if adata.uns.get(uns_key) is None:
113
- adata.uns[uns_key] = []
567
+ if not methbases:
568
+ raise ValueError("cfg.hmm_methbases is empty.")
114
569
 
115
- if adata.uns.get('hmm_annotated', False) and not cfg.force_redo_hmm_fit and not cfg.force_redo_hmm_apply:
570
+ # Top-level skip
571
+ already = bool(adata.uns.get("hmm_annotated", False))
572
+ if already and not (bool(getattr(cfg, "force_redo_hmm_fit", False)) or force_apply):
116
573
  pass
574
+
117
575
  else:
576
+ logger.info("Starting HMM annotation over samples and references")
118
577
  for sample in samples:
119
578
  for ref in references:
120
- mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (adata.obs[cfg.reference_column] == ref)
121
- subset = adata[mask].copy()
122
- if subset.shape[0] < 1:
579
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
580
+ adata.obs[cfg.reference_column] == ref
581
+ )
582
+ if int(np.sum(mask)) == 0:
123
583
  continue
124
584
 
125
- for mod_site in cfg.hmm_methbases:
126
- mod_label = {'C': 'C'}.get(mod_site, mod_site)
127
- hmm_path = hmm_dir / f"{sample}_{ref}_{mod_label}_hmm_model.pth"
585
+ subset = adata[mask].copy()
586
+ subset.uns[uns_key] = [] # isolate appended tracking per subset
587
+
588
+ # ---- Decide which tasks to run ----
589
+ methbases = list(getattr(cfg, "hmm_methbases", [])) or []
590
+ run_multi = bool(getattr(cfg, "hmm_run_multichannel", True))
591
+ run_cpg = bool(getattr(cfg, "cpg", False))
592
+ device = resolve_torch_device(cfg.device)
593
+
594
+ logger.info("HMM processing sample=%s ref=%s", sample, ref)
595
+
596
+ # ---- split feature sets ----
597
+ feature_sets_all = normalize_hmm_feature_sets(
598
+ getattr(cfg, "hmm_feature_sets", None)
599
+ )
600
+ feature_sets_access = {
601
+ k: v
602
+ for k, v in feature_sets_all.items()
603
+ if k in ("footprint", "accessible")
604
+ }
605
+ feature_sets_cpg = (
606
+ {"cpg": feature_sets_all["cpg"]} if "cpg" in feature_sets_all else {}
607
+ )
128
608
 
129
- # ensure the input obsm exists
130
- obsm_key = f'{ref}_{mod_label}_site'
131
- if obsm_key not in subset.obsm:
132
- print(f"Skipping {sample} {ref} {mod_label}: missing obsm '{obsm_key}'")
609
+ # =========================
610
+ # 1) Single-channel accessibility (per methbase)
611
+ # =========================
612
+ for mb in methbases:
613
+ logger.info("HMM single-channel for methbase=%s", mb)
614
+
615
+ try:
616
+ X, coords = build_single_channel(
617
+ subset,
618
+ ref=str(ref),
619
+ methbase=str(mb),
620
+ smf_modality=smf_modality,
621
+ cfg=cfg,
622
+ )
623
+ except Exception:
624
+ logger.warning(
625
+ "Skipping HMM single-channel for methbase=%s due to data error", mb
626
+ )
133
627
  continue
134
628
 
135
- # Fit or load model
136
- if hmm_path.exists() and not cfg.force_redo_hmm_fit:
137
- hmm = HMM.load(hmm_path)
138
- hmm.print_params()
139
- else:
140
- print(f"Fitting HMM for {sample} {ref} {mod_label}")
141
- hmm = HMM.from_config(cfg)
142
- # fit expects a list-of-seqs or 2D ndarray in the obsm
143
- seqs = subset.obsm[obsm_key]
144
- hmm.fit(seqs)
145
- hmm.print_params()
146
- hmm.save(hmm_path)
147
-
148
- # Apply / annotate on the subset, then copy layers back to final_adata
149
- if cfg.bypass_hmm_apply:
150
- pass
151
- else:
152
- print(f"Applying HMM on subset for {sample} {ref} {mod_label}")
153
- # Use the new uns_key argument so subset will record appended layer names
154
- # (annotate_adata modifies subset.obs/layers in-place and should write subset.uns[uns_key])
155
- if smf_modality == "direct":
156
- hmm_layer = cfg.output_binary_layer_name
157
- else:
158
- hmm_layer = None
159
-
160
- hmm.annotate_adata(subset,
161
- obs_column=cfg.reference_column,
162
- layer=hmm_layer,
163
- config=cfg,
164
- force_redo=cfg.force_redo_hmm_apply
629
+ arch = trainer.choose_arch(multichannel=False)
630
+
631
+ logger.info("HMM fitting/loading for methbase=%s", mb)
632
+ hmm = trainer.fit_or_load(
633
+ sample=str(sample),
634
+ ref=str(ref),
635
+ label=str(mb),
636
+ arch=arch,
637
+ X=X,
638
+ coords=coords,
639
+ device=device,
640
+ )
641
+
642
+ if not bypass_apply:
643
+ logger.info("HMM applying for methbase=%s", mb)
644
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
645
+ hmm.annotate_adata(
646
+ subset,
647
+ prefix=str(mb),
648
+ X=X,
649
+ coords=coords,
650
+ var_mask=pm,
651
+ span_fill=True,
652
+ config=cfg,
653
+ decode=decode,
654
+ write_posterior=write_post,
655
+ posterior_state=post_state,
656
+ feature_sets=feature_sets_access, # <--- ONLY accessibility feature sets
657
+ prob_threshold=prob_thr,
658
+ uns_key=uns_key,
659
+ uns_flag=f"hmm_annotated_{mb}",
660
+ force_redo=force_apply,
661
+ )
662
+
663
+ # merges for this mb
664
+ for core_layer, dist in (
665
+ getattr(cfg, "hmm_merge_layer_features", []) or []
666
+ ):
667
+ base_layer = f"{mb}_{core_layer}"
668
+ logger.info("Merging intervals for layer=%s", base_layer)
669
+ if base_layer in subset.layers:
670
+ merged_base = hmm.merge_intervals_to_new_layer(
671
+ subset,
672
+ base_layer,
673
+ distance_threshold=int(dist),
674
+ suffix=merged_suffix,
675
+ overwrite=True,
676
+ )
677
+ # write merged size classes based on whichever group core_layer corresponds to
678
+ for group, fs in feature_sets_access.items():
679
+ fmap = fs.get("features", {}) or {}
680
+ if fmap:
681
+ hmm.write_size_class_layers_from_binary(
682
+ subset,
683
+ merged_base,
684
+ out_prefix=str(mb),
685
+ feature_ranges=fmap,
686
+ suffix=merged_suffix,
687
+ overwrite=True,
165
688
  )
166
-
167
- if adata.uns.get('hmm_annotated', False) and not cfg.force_redo_hmm_apply:
168
- pass
169
- else:
170
- to_merge = cfg.hmm_merge_layer_features
171
- for layer_to_merge, merge_distance in to_merge:
172
- if layer_to_merge:
173
- hmm.merge_intervals_in_layer(subset,
174
- layer=layer_to_merge,
175
- distance_threshold=merge_distance,
176
- overwrite=True
177
- )
178
- else:
179
- pass
180
-
181
- # collect appended layers from subset.uns
182
- appended = list(subset.uns.get(uns_key, []))
183
- print(appended)
184
- if len(appended) == 0:
185
- # nothing appended for this subset; continue
186
- continue
187
-
188
- # copy each appended layer into adata
189
- subset_mask_bool = mask.values if hasattr(mask, "values") else np.asarray(mask)
190
- for layer_name in appended:
191
- if layer_name not in subset.layers:
192
- # defensive: skip
193
- warnings.warn(f"Expected layer {layer_name} in subset but not found; skipping copy.")
194
- continue
195
- sub_layer = subset.layers[layer_name]
196
- # ensure final layer exists and assign rows
197
- try:
198
- hmm._ensure_final_layer_and_assign(adata, layer_name, subset_mask_bool, sub_layer)
199
- except Exception as e:
200
- warnings.warn(f"Failed to copy layer {layer_name} into adata: {e}", stacklevel=2)
201
- # fallback: if dense and small, try to coerce
202
- if issparse(sub_layer):
203
- arr = sub_layer.toarray()
204
- else:
205
- arr = np.asarray(sub_layer)
206
- adata.layers[layer_name] = adata.layers.get(layer_name, np.zeros((adata.shape[0], arr.shape[1]), dtype=arr.dtype))
207
- final_idx = np.nonzero(subset_mask_bool)[0]
208
- adata.layers[layer_name][final_idx, :] = arr
209
-
210
- # merge appended layer names into adata.uns
211
- existing = list(adata.uns.get(uns_key, []))
212
- for ln in appended:
213
- if ln not in existing:
214
- existing.append(ln)
215
- adata.uns[uns_key] = existing
216
689
 
217
- else:
218
- pass
690
+ # =========================
691
+ # 2) Multi-channel accessibility (Combined)
692
+ # =========================
693
+ if run_multi and len(methbases) >= 2:
694
+ logger.info("HMM multi-channel for methbases=%s", ",".join(methbases))
695
+ try:
696
+ X3, coords_u, used_mbs = build_multi_channel_union(
697
+ subset,
698
+ ref=str(ref),
699
+ methbases=methbases,
700
+ smf_modality=smf_modality,
701
+ cfg=cfg,
702
+ )
703
+ except Exception:
704
+ X3, coords_u, used_mbs = None, None, []
705
+ logger.warning(
706
+ "Skipping HMM multi-channel due to data error or insufficient methbases"
707
+ )
219
708
 
220
- from ..hmm import call_hmm_peaks
709
+ if X3 is not None and len(used_mbs) >= 2:
710
+ union_mask = None
711
+ for mb in used_mbs:
712
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
713
+ union_mask = pm if union_mask is None else (union_mask | pm)
714
+
715
+ arch = trainer.choose_arch(multichannel=True)
716
+
717
+ logger.info("HMM fitting/loading for multi-channel")
718
+ hmmc = trainer.fit_or_load(
719
+ sample=str(sample),
720
+ ref=str(ref),
721
+ label="Combined",
722
+ arch=arch,
723
+ X=X3,
724
+ coords=coords_u,
725
+ device=device,
726
+ )
727
+
728
+ if not bypass_apply:
729
+ logger.info("HMM applying for multi-channel")
730
+ hmmc.annotate_adata(
731
+ subset,
732
+ prefix="Combined",
733
+ X=X3,
734
+ coords=coords_u,
735
+ var_mask=union_mask,
736
+ span_fill=True,
737
+ config=cfg,
738
+ decode=decode,
739
+ write_posterior=write_post,
740
+ posterior_state=post_state,
741
+ feature_sets=feature_sets_access, # <--- accessibility only
742
+ prob_threshold=prob_thr,
743
+ uns_key=uns_key,
744
+ uns_flag="hmm_annotated_combined",
745
+ force_redo=force_apply,
746
+ )
747
+
748
+ for core_layer, dist in (
749
+ getattr(cfg, "hmm_merge_layer_features", []) or []
750
+ ):
751
+ base_layer = f"Combined_{core_layer}"
752
+ if base_layer in subset.layers:
753
+ merged_base = hmmc.merge_intervals_to_new_layer(
754
+ subset,
755
+ base_layer,
756
+ distance_threshold=int(dist),
757
+ suffix=merged_suffix,
758
+ overwrite=True,
759
+ )
760
+ for group, fs in feature_sets_access.items():
761
+ fmap = fs.get("features", {}) or {}
762
+ if fmap:
763
+ hmmc.write_size_class_layers_from_binary(
764
+ subset,
765
+ merged_base,
766
+ out_prefix="Combined",
767
+ feature_ranges=fmap,
768
+ suffix=merged_suffix,
769
+ overwrite=True,
770
+ )
771
+
772
+ # =========================
773
+ # 3) CpG-only single-channel task
774
+ # =========================
775
+ if run_cpg:
776
+ logger.info("HMM single-channel for CpG")
777
+ try:
778
+ Xcpg, coordscpg = build_single_channel(
779
+ subset,
780
+ ref=str(ref),
781
+ methbase="CpG",
782
+ smf_modality=smf_modality,
783
+ cfg=cfg,
784
+ )
785
+ except Exception:
786
+ Xcpg, coordscpg = None, None
787
+ logger.warning("Skipping HMM single-channel for CpG due to data error")
788
+
789
+ if Xcpg is not None and Xcpg.size and feature_sets_cpg:
790
+ arch = trainer.choose_arch(multichannel=False)
791
+
792
+ logger.info("HMM fitting/loading for CpG")
793
+ hmmg = trainer.fit_or_load(
794
+ sample=str(sample),
795
+ ref=str(ref),
796
+ label="CpG",
797
+ arch=arch,
798
+ X=Xcpg,
799
+ coords=coordscpg,
800
+ device=device,
801
+ )
802
+
803
+ if not bypass_apply:
804
+ logger.info("HMM applying for CpG")
805
+ pm = _resolve_pos_mask_for_methbase(subset, str(ref), "CpG")
806
+ hmmg.annotate_adata(
807
+ subset,
808
+ prefix="CpG",
809
+ X=Xcpg,
810
+ coords=coordscpg,
811
+ var_mask=pm,
812
+ span_fill=True,
813
+ config=cfg,
814
+ decode=decode,
815
+ write_posterior=write_post,
816
+ posterior_state=post_state,
817
+ feature_sets=feature_sets_cpg, # <--- ONLY cpg group (cpg_patch)
818
+ prob_threshold=prob_thr,
819
+ uns_key=uns_key,
820
+ uns_flag="hmm_annotated_CpG",
821
+ force_redo=force_apply,
822
+ )
823
+
824
+ # ------------------------------------------------------------
825
+ # Copy newly created subset layers back into the full adata
826
+ # ------------------------------------------------------------
827
+ appended = (
828
+ list(subset.uns.get(uns_key, []))
829
+ if subset.uns.get(uns_key) is not None
830
+ else []
831
+ )
832
+ if appended:
833
+ row_mask = np.asarray(
834
+ mask.values if hasattr(mask, "values") else mask, dtype=bool
835
+ )
836
+
837
+ for ln in appended:
838
+ if ln not in subset.layers:
839
+ continue
840
+ _ensure_layer_and_assign_rows(adata, ln, row_mask, subset.layers[ln])
841
+ if ln not in global_appended:
842
+ global_appended.append(ln)
843
+
844
+ adata.uns[uns_key] = global_appended
845
+
846
+ adata.uns["hmm_annotated"] = True
847
+
848
+ hmm_layers = list(adata.uns.get("hmm_appended_layers", []) or [])
849
+ # keep only real feature layers; drop lengths/states/posterior
850
+ hmm_layers = [
851
+ layer
852
+ for layer in hmm_layers
853
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
854
+ ]
855
+ logger.info(f"HMM appended layers: {hmm_layers}")
856
+
857
+ # ---------------------------- HMM peak calling stage ----------------------------
221
858
  hmm_dir = pp_dir / "11_hmm_peak_calling"
222
859
  if hmm_dir.is_dir():
223
860
  pass
@@ -225,29 +862,32 @@ def hmm_adata(config_path):
225
862
  make_dirs([pp_dir, hmm_dir])
226
863
 
227
864
  call_hmm_peaks(
228
- adata,
229
- feature_configs=cfg.hmm_peak_feature_configs,
230
- ref_column=cfg.reference_column,
231
- site_types=cfg.mod_target_bases,
232
- save_plot=True,
233
- output_dir=hmm_dir,
234
- index_col_suffix=cfg.reindexed_var_suffix)
235
-
236
- ## Save HMM annotated adata
237
- if not hmm_adata_path.exists():
238
- print('Saving hmm analyzed adata post preprocessing and duplicate removal')
239
- if ".gz" == hmm_adata_path.suffix:
240
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
241
- else:
242
- hmm_adata_path = hmm_adata_path.with_name(hmm_adata_path.name + '.gz')
243
- safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
865
+ adata,
866
+ feature_configs=cfg.hmm_peak_feature_configs,
867
+ ref_column=cfg.reference_column,
868
+ site_types=cfg.mod_target_bases,
869
+ save_plot=True,
870
+ output_dir=hmm_dir,
871
+ index_col_suffix=cfg.reindexed_var_suffix,
872
+ )
244
873
 
245
- add_or_update_column_in_csv(cfg.summary_file, "hmm_adata", hmm_adata_path)
874
+ ## Save HMM annotated adata
875
+ if not paths.hmm.exists():
876
+ logger.info("Saving hmm analyzed AnnData (post preprocessing and duplicate removal).")
877
+ record_smftools_metadata(
878
+ adata,
879
+ step_name="hmm",
880
+ cfg=cfg,
881
+ config_path=config_path,
882
+ input_paths=[source_adata_path] if source_adata_path else None,
883
+ output_path=paths.hmm,
884
+ )
885
+ write_gz_h5ad(adata, paths.hmm)
246
886
 
247
887
  ########################################################################################################################
248
888
 
249
- ############################################### HMM based feature plotting ###############################################
250
- from ..plotting import combined_hmm_raw_clustermap
889
+ ############################################### HMM based feature plotting ###############################################
890
+
251
891
  hmm_dir = pp_dir / "12_hmm_clustermaps"
252
892
  make_dirs([pp_dir, hmm_dir])
253
893
 
@@ -256,6 +896,9 @@ def hmm_adata(config_path):
256
896
  for base in cfg.hmm_methbases:
257
897
  layers.extend([f"{base}_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
258
898
 
899
+ if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
900
+ layers.extend([f"Combined_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
901
+
259
902
  if cfg.cpg:
260
903
  layers.extend(["CpG_cpg_patch"])
261
904
 
@@ -273,40 +916,48 @@ def hmm_adata(config_path):
273
916
  make_dirs([hmm_cluster_save_dir])
274
917
 
275
918
  combined_hmm_raw_clustermap(
276
- adata,
277
- sample_col=cfg.sample_name_col_for_plotting,
278
- reference_col=cfg.reference_column,
279
- hmm_feature_layer=layer,
280
- layer_gpc=cfg.layer_for_clustermap_plotting,
281
- layer_cpg=cfg.layer_for_clustermap_plotting,
282
- layer_c=cfg.layer_for_clustermap_plotting,
283
- layer_a=cfg.layer_for_clustermap_plotting,
284
- cmap_hmm=cfg.clustermap_cmap_hmm,
285
- cmap_gpc=cfg.clustermap_cmap_gpc,
286
- cmap_cpg=cfg.clustermap_cmap_cpg,
287
- cmap_c=cfg.clustermap_cmap_c,
288
- cmap_a=cfg.clustermap_cmap_a,
289
- min_quality=cfg.read_quality_filter_thresholds[0],
290
- min_length=cfg.read_len_filter_thresholds[0],
291
- min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[0],
292
- min_position_valid_fraction=1-cfg.position_max_nan_threshold,
293
- save_path=hmm_cluster_save_dir,
294
- normalize_hmm=False,
295
- sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
296
- bins=None,
297
- deaminase=deaminase,
298
- min_signal=0,
299
- index_col_suffix=cfg.reindexed_var_suffix
919
+ adata,
920
+ sample_col=cfg.sample_name_col_for_plotting,
921
+ reference_col=cfg.reference_column,
922
+ hmm_feature_layer=layer,
923
+ layer_gpc=cfg.layer_for_clustermap_plotting,
924
+ layer_cpg=cfg.layer_for_clustermap_plotting,
925
+ layer_c=cfg.layer_for_clustermap_plotting,
926
+ layer_a=cfg.layer_for_clustermap_plotting,
927
+ cmap_hmm=cfg.clustermap_cmap_hmm,
928
+ cmap_gpc=cfg.clustermap_cmap_gpc,
929
+ cmap_cpg=cfg.clustermap_cmap_cpg,
930
+ cmap_c=cfg.clustermap_cmap_c,
931
+ cmap_a=cfg.clustermap_cmap_a,
932
+ min_quality=cfg.read_quality_filter_thresholds[0],
933
+ min_length=cfg.read_len_filter_thresholds[0],
934
+ min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
935
+ 0
936
+ ],
937
+ min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
938
+ demux_types=("double", "already"),
939
+ save_path=hmm_cluster_save_dir,
940
+ normalize_hmm=False,
941
+ sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
942
+ bins=None,
943
+ deaminase=deaminase,
944
+ min_signal=0,
945
+ index_col_suffix=cfg.reindexed_var_suffix,
300
946
  )
301
947
 
302
948
  hmm_dir = pp_dir / "13_hmm_bulk_traces"
303
949
 
304
950
  if hmm_dir.is_dir():
305
- print(f'{hmm_dir} already exists.')
951
+ logger.debug(f"{hmm_dir} already exists.")
306
952
  else:
307
953
  make_dirs([pp_dir, hmm_dir])
308
954
  from ..plotting import plot_hmm_layers_rolling_by_sample_ref
309
- bulk_hmm_layers = [layer for layer in adata.uns['hmm_appended_layers'] if "_lengths" not in layer]
955
+
956
+ bulk_hmm_layers = [
957
+ layer
958
+ for layer in hmm_layers
959
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
960
+ ]
310
961
  saved = plot_hmm_layers_rolling_by_sample_ref(
311
962
  adata,
312
963
  layers=bulk_hmm_layers,
@@ -314,26 +965,38 @@ def hmm_adata(config_path):
314
965
  ref_col=cfg.reference_column,
315
966
  window=101,
316
967
  rows_per_page=4,
317
- figsize_per_cell=(4,2.5),
968
+ figsize_per_cell=(4, 2.5),
318
969
  output_dir=hmm_dir,
319
970
  save=True,
320
- show_raw=False
971
+ show_raw=False,
321
972
  )
322
973
 
323
974
  hmm_dir = pp_dir / "14_hmm_fragment_distributions"
324
975
 
325
976
  if hmm_dir.is_dir():
326
- print(f'{hmm_dir} already exists.')
977
+ logger.debug(f"{hmm_dir} already exists.")
327
978
  else:
328
979
  make_dirs([pp_dir, hmm_dir])
329
980
  from ..plotting import plot_hmm_size_contours
330
981
 
331
- if smf_modality == 'deaminase':
332
- fragments = [('C_all_accessible_features_lengths', 400), ('C_all_footprint_features_lengths', 250), ('C_all_accessible_features_merged_lengths', 800)]
333
- elif smf_modality == 'conversion':
334
- fragments = [('GpC_all_accessible_features_lengths', 400), ('GpC_all_footprint_features_lengths', 250), ('GpC_all_accessible_features_merged_lengths', 800)]
982
+ if smf_modality == "deaminase":
983
+ fragments = [
984
+ ("C_all_accessible_features_lengths", 400),
985
+ ("C_all_footprint_features_lengths", 250),
986
+ ("C_all_accessible_features_merged_lengths", 800),
987
+ ]
988
+ elif smf_modality == "conversion":
989
+ fragments = [
990
+ ("GpC_all_accessible_features_lengths", 400),
991
+ ("GpC_all_footprint_features_lengths", 250),
992
+ ("GpC_all_accessible_features_merged_lengths", 800),
993
+ ]
335
994
  elif smf_modality == "direct":
336
- fragments = [('A_all_accessible_features_lengths', 400), ('A_all_footprint_features_lengths', 200), ('A_all_accessible_features_merged_lengths', 800)]
995
+ fragments = [
996
+ ("A_all_accessible_features_lengths", 400),
997
+ ("A_all_footprint_features_lengths", 200),
998
+ ("A_all_accessible_features_merged_lengths", 800),
999
+ ]
337
1000
 
338
1001
  for layer, max in fragments:
339
1002
  save_path = hmm_dir / layer
@@ -353,9 +1016,9 @@ def hmm_adata(config_path):
353
1016
  dpi=200,
354
1017
  smoothing_sigma=(10, 10),
355
1018
  normalize_after_smoothing=True,
356
- cmap='Greens',
357
- log_scale_z=True
1019
+ cmap="Greens",
1020
+ log_scale_z=True,
358
1021
  )
359
1022
  ########################################################################################################################
360
1023
 
361
- return (adata, hmm_adata_path)
1024
+ return (adata, paths.hmm)