smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/cli/hmm_adata.py
CHANGED
|
@@ -1,318 +1,999 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from smftools.logging_utils import get_logger
|
|
12
|
+
|
|
13
|
+
# FIX: import _to_dense_np to avoid NameError
|
|
14
|
+
from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# Helpers: extracting training arrays
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_training_matrix(
|
|
24
|
+
subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
|
|
25
|
+
) -> Tuple[np.ndarray, Optional[str]]:
|
|
26
|
+
"""
|
|
27
|
+
Matches your existing behavior:
|
|
28
|
+
- direct -> uses cfg.output_binary_layer_name in .layers
|
|
29
|
+
- else -> uses .X
|
|
30
|
+
Returns (X, layer_name_or_None) where X is dense float array.
|
|
2
31
|
"""
|
|
3
|
-
|
|
4
|
-
Command line accesses this through smftools hmm <config_path>
|
|
32
|
+
sub = subset[:, cols_mask]
|
|
5
33
|
|
|
6
|
-
|
|
7
|
-
|
|
34
|
+
if smf_modality == "direct":
|
|
35
|
+
hmm_layer = getattr(cfg, "output_binary_layer_name", None)
|
|
36
|
+
if hmm_layer is None or hmm_layer not in sub.layers:
|
|
37
|
+
raise KeyError(f"Missing HMM training layer '{hmm_layer}' in subset.")
|
|
8
38
|
|
|
39
|
+
logger.debug("Using direct modality HMM training layer: %s", hmm_layer)
|
|
40
|
+
mat = sub.layers[hmm_layer]
|
|
41
|
+
else:
|
|
42
|
+
logger.debug("Using .X for HMM training matrix")
|
|
43
|
+
hmm_layer = None
|
|
44
|
+
mat = sub.X
|
|
45
|
+
|
|
46
|
+
X = _to_dense_np(mat).astype(float)
|
|
47
|
+
if X.ndim != 2:
|
|
48
|
+
raise ValueError(f"Expected 2D training matrix; got {X.shape}")
|
|
49
|
+
return X, hmm_layer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
|
|
53
|
+
"""
|
|
54
|
+
Reproduces your mask resolution, with compatibility for both *_any_C_site and *_C_site.
|
|
55
|
+
"""
|
|
56
|
+
key = str(methbase).strip().lower()
|
|
57
|
+
|
|
58
|
+
logger.debug("Resolving position mask for methbase=%s on ref=%s", key, ref)
|
|
59
|
+
|
|
60
|
+
if key in ("a",):
|
|
61
|
+
col = f"{ref}_A_site"
|
|
62
|
+
if col not in subset.var:
|
|
63
|
+
return None
|
|
64
|
+
logger.debug("Using positions with A calls from column: %s", col)
|
|
65
|
+
return np.asarray(subset.var[col])
|
|
66
|
+
|
|
67
|
+
if key in ("c", "any_c", "anyc", "any-c"):
|
|
68
|
+
for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
|
|
69
|
+
if col in subset.var:
|
|
70
|
+
logger.debug("Using positions with C calls from column: %s", col)
|
|
71
|
+
return np.asarray(subset.var[col])
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
if key in ("gpc", "gpc_site", "gpc-site"):
|
|
75
|
+
col = f"{ref}_GpC_site"
|
|
76
|
+
if col not in subset.var:
|
|
77
|
+
return None
|
|
78
|
+
logger.debug("Using positions with GpC calls from column: %s", col)
|
|
79
|
+
return np.asarray(subset.var[col])
|
|
80
|
+
|
|
81
|
+
if key in ("cpg", "cpg_site", "cpg-site"):
|
|
82
|
+
col = f"{ref}_CpG_site"
|
|
83
|
+
if col not in subset.var:
|
|
84
|
+
return None
|
|
85
|
+
logger.debug("Using positions with CpG calls from column: %s", col)
|
|
86
|
+
return np.asarray(subset.var[col])
|
|
87
|
+
|
|
88
|
+
alt = f"{ref}_{methbase}_site"
|
|
89
|
+
if alt not in subset.var:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
logger.debug("Using positions from column: %s", alt)
|
|
93
|
+
return np.asarray(subset.var[alt])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def build_single_channel(
|
|
97
|
+
subset, ref: str, methbase: str, smf_modality: Optional[str], cfg
|
|
98
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
99
|
+
"""
|
|
9
100
|
Returns:
|
|
10
|
-
|
|
101
|
+
X (N, Lmb) float with NaNs allowed
|
|
102
|
+
coords (Lmb,) int coords from var_names
|
|
103
|
+
"""
|
|
104
|
+
pm = _resolve_pos_mask_for_methbase(subset, ref, methbase)
|
|
105
|
+
logger.debug(
|
|
106
|
+
"Position mask for methbase=%s on ref=%s has %d sites",
|
|
107
|
+
methbase,
|
|
108
|
+
ref,
|
|
109
|
+
int(np.sum(pm)) if pm is not None else 0,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if pm is None or int(np.sum(pm)) == 0:
|
|
113
|
+
raise ValueError(f"No columns for methbase={methbase} on ref={ref}")
|
|
114
|
+
|
|
115
|
+
X, _ = _get_training_matrix(subset, pm, smf_modality, cfg)
|
|
116
|
+
logger.debug("Training matrix for methbase=%s on ref=%s has shape %s", methbase, ref, X.shape)
|
|
117
|
+
|
|
118
|
+
coords, _ = _safe_int_coords(subset[:, pm].var_names)
|
|
119
|
+
logger.debug(
|
|
120
|
+
"Coordinates for methbase=%s on ref=%s have length %d", methbase, ref, coords.shape[0]
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return X, coords
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def build_multi_channel_union(
|
|
127
|
+
subset, ref: str, methbases: Sequence[str], smf_modality: Optional[str], cfg
|
|
128
|
+
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
|
|
129
|
+
"""
|
|
130
|
+
Build (N, Lunion, C) on union coordinate grid across methbases.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
X3d: (N, Lunion, C) float with NaN where methbase has no site
|
|
134
|
+
coords: (Lunion,) int union coords
|
|
135
|
+
used_methbases: list of methbases actually included (>=2)
|
|
136
|
+
"""
|
|
137
|
+
per: List[Tuple[str, np.ndarray, np.ndarray, np.ndarray]] = [] # (mb, X, coords, pm)
|
|
138
|
+
|
|
139
|
+
for mb in methbases:
|
|
140
|
+
pm = _resolve_pos_mask_for_methbase(subset, ref, mb)
|
|
141
|
+
if pm is None or int(np.sum(pm)) == 0:
|
|
142
|
+
continue
|
|
143
|
+
Xmb, _ = _get_training_matrix(subset, pm, smf_modality, cfg) # (N,Lmb)
|
|
144
|
+
cmb, _ = _safe_int_coords(subset[:, pm].var_names)
|
|
145
|
+
per.append((mb, Xmb.astype(float), cmb.astype(int), pm))
|
|
146
|
+
|
|
147
|
+
if len(per) < 2:
|
|
148
|
+
raise ValueError(f"Need >=2 methbases with columns for union multi-channel on ref={ref}")
|
|
149
|
+
|
|
150
|
+
# union coordinates
|
|
151
|
+
coords = np.unique(np.concatenate([c for _, _, c, _ in per], axis=0)).astype(int)
|
|
152
|
+
idx = {int(v): i for i, v in enumerate(coords.tolist())}
|
|
153
|
+
|
|
154
|
+
N = per[0][1].shape[0]
|
|
155
|
+
L = coords.shape[0]
|
|
156
|
+
C = len(per)
|
|
157
|
+
X3 = np.full((N, L, C), np.nan, dtype=float)
|
|
158
|
+
|
|
159
|
+
for ci, (mb, Xmb, cmb, _) in enumerate(per):
|
|
160
|
+
cols = np.array([idx[int(v)] for v in cmb.tolist()], dtype=int)
|
|
161
|
+
X3[:, cols, ci] = Xmb
|
|
162
|
+
|
|
163
|
+
used = [mb for (mb, _, _, _) in per]
|
|
164
|
+
return X3, coords, used
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class HMMTask:
|
|
169
|
+
name: str
|
|
170
|
+
signals: List[str] # e.g. ["GpC"] or ["GpC","CpG"] or ["CpG"]
|
|
171
|
+
feature_groups: List[str] # e.g. ["footprint","accessible"] or ["cpg"]
|
|
172
|
+
output_prefix: Optional[str] = None # force prefix (CpG task uses "CpG")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def build_hmm_tasks(cfg: Union[dict, Any]) -> List[HMMTask]:
|
|
176
|
+
"""
|
|
177
|
+
Accessibility signals come from cfg['hmm_methbases'].
|
|
178
|
+
CpG task is enabled by cfg['cpg']==True, independent of hmm_methbases.
|
|
179
|
+
"""
|
|
180
|
+
if not isinstance(cfg, dict):
|
|
181
|
+
# best effort conversion
|
|
182
|
+
cfg = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith("_")}
|
|
183
|
+
|
|
184
|
+
tasks: List[HMMTask] = []
|
|
185
|
+
|
|
186
|
+
# accessibility task
|
|
187
|
+
methbases = list(cfg.get("hmm_methbases", []) or [])
|
|
188
|
+
if len(methbases) > 0:
|
|
189
|
+
tasks.append(
|
|
190
|
+
HMMTask(
|
|
191
|
+
name="accessibility",
|
|
192
|
+
signals=methbases,
|
|
193
|
+
feature_groups=["footprint", "accessible"],
|
|
194
|
+
output_prefix=None,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# CpG task (special case)
|
|
199
|
+
if bool(cfg.get("cpg", False)):
|
|
200
|
+
tasks.append(
|
|
201
|
+
HMMTask(
|
|
202
|
+
name="cpg",
|
|
203
|
+
signals=["CpG"],
|
|
204
|
+
feature_groups=["cpg"],
|
|
205
|
+
output_prefix="CpG",
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return tasks
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def select_hmm_arch(cfg: dict, signals: Sequence[str]) -> str:
|
|
213
|
+
"""
|
|
214
|
+
Simple, explicit model selection:
|
|
215
|
+
- distance-aware => 'single_distance_binned' (only meaningful for single-channel)
|
|
216
|
+
- multi-signal => 'multi'
|
|
217
|
+
- else => 'single'
|
|
218
|
+
"""
|
|
219
|
+
if bool(cfg.get("hmm_distance_aware", False)) and len(signals) == 1:
|
|
220
|
+
return "single_distance_binned"
|
|
221
|
+
if len(signals) > 1:
|
|
222
|
+
return "multi"
|
|
223
|
+
return "single"
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def resolve_input_layer(adata, cfg: dict, layer_override: Optional[str]) -> Optional[str]:
|
|
227
|
+
"""
|
|
228
|
+
If direct modality, prefer cfg.output_binary_layer_name.
|
|
229
|
+
Else use layer_override or None (meaning use .X).
|
|
230
|
+
"""
|
|
231
|
+
smf_modality = cfg.get("smf_modality", None)
|
|
232
|
+
if smf_modality == "direct":
|
|
233
|
+
nm = cfg.get("output_binary_layer_name", None)
|
|
234
|
+
if nm is None:
|
|
235
|
+
raise KeyError("cfg.output_binary_layer_name missing for smf_modality='direct'")
|
|
236
|
+
if nm not in adata.layers:
|
|
237
|
+
raise KeyError(f"Direct modality expects layer '{nm}' in adata.layers")
|
|
238
|
+
return nm
|
|
239
|
+
return layer_override
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _ensure_layer_and_assign_rows(adata, layer_name: str, row_mask: np.ndarray, subset_layer):
|
|
11
243
|
"""
|
|
12
|
-
|
|
244
|
+
Writes subset_layer (n_subset_obs, n_vars) into adata.layers[layer_name] for rows where row_mask==True.
|
|
245
|
+
"""
|
|
246
|
+
row_mask = np.asarray(row_mask, dtype=bool)
|
|
247
|
+
if row_mask.ndim != 1 or row_mask.size != adata.n_obs:
|
|
248
|
+
raise ValueError("row_mask must be length adata.n_obs")
|
|
249
|
+
|
|
250
|
+
arr = _to_dense_np(subset_layer)
|
|
251
|
+
if arr.shape != (int(row_mask.sum()), adata.n_vars):
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"subset layer '{layer_name}' shape {arr.shape} != ({int(row_mask.sum())}, {adata.n_vars})"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if layer_name not in adata.layers:
|
|
257
|
+
adata.layers[layer_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=arr.dtype)
|
|
258
|
+
|
|
259
|
+
adata.layers[layer_name][row_mask, :] = arr
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def resolve_torch_device(device_str: str | None) -> torch.device:
|
|
263
|
+
d = (device_str or "auto").lower()
|
|
264
|
+
if d == "auto":
|
|
265
|
+
if torch.cuda.is_available():
|
|
266
|
+
return torch.device("cuda")
|
|
267
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
268
|
+
return torch.device("mps")
|
|
269
|
+
return torch.device("cpu")
|
|
270
|
+
return torch.device(d)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# =============================================================================
|
|
274
|
+
# Model selection + fit strategy manager
|
|
275
|
+
# =============================================================================
|
|
276
|
+
@dataclass
|
|
277
|
+
class HMMTrainer:
|
|
278
|
+
cfg: Any
|
|
279
|
+
models_dir: Path
|
|
280
|
+
|
|
281
|
+
def __post_init__(self):
|
|
282
|
+
self.models_dir = Path(self.models_dir)
|
|
283
|
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
|
284
|
+
|
|
285
|
+
def choose_arch(self, *, multichannel: bool) -> str:
|
|
286
|
+
use_dist = bool(getattr(self.cfg, "hmm_distance_aware", False))
|
|
287
|
+
if multichannel:
|
|
288
|
+
return "multi"
|
|
289
|
+
return "single_distance_binned" if use_dist else "single"
|
|
290
|
+
|
|
291
|
+
def _fit_scope(self) -> str:
|
|
292
|
+
return str(getattr(self.cfg, "hmm_fit_scope", "per_sample")).lower()
|
|
293
|
+
# "per_sample" | "global" | "global_then_adapt"
|
|
294
|
+
|
|
295
|
+
def _path(self, kind: str, sample: str, ref: str, label: str) -> Path:
|
|
296
|
+
# kind: "GLOBAL" | "PER" | "ADAPT"
|
|
297
|
+
def safe(s):
|
|
298
|
+
str(s).replace("/", "_")
|
|
299
|
+
|
|
300
|
+
return self.models_dir / f"{kind}_{safe(sample)}_{safe(ref)}_{safe(label)}.pt"
|
|
301
|
+
|
|
302
|
+
def _save(self, model, path: Path):
|
|
303
|
+
override = {}
|
|
304
|
+
if getattr(model, "hmm_name", None) == "multi":
|
|
305
|
+
override["hmm_n_channels"] = int(getattr(model, "n_channels", 2))
|
|
306
|
+
if getattr(model, "hmm_name", None) == "single_distance_binned":
|
|
307
|
+
override["hmm_distance_bins"] = list(
|
|
308
|
+
getattr(model, "distance_bins", [1, 5, 10, 25, 50, 100])
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
payload = {
|
|
312
|
+
"state_dict": model.state_dict(),
|
|
313
|
+
"hmm_arch": getattr(model, "hmm_name", None) or getattr(self.cfg, "hmm_arch", None),
|
|
314
|
+
"override": override,
|
|
315
|
+
}
|
|
316
|
+
torch.save(payload, path)
|
|
317
|
+
|
|
318
|
+
def _load(self, path: Path, arch: str, device):
|
|
319
|
+
payload = torch.load(path, map_location="cpu")
|
|
320
|
+
override = payload.get("override", None)
|
|
321
|
+
m = create_hmm(self.cfg, arch=arch, override=override, device=device)
|
|
322
|
+
sd = payload["state_dict"]
|
|
323
|
+
|
|
324
|
+
target_dtype = next(m.parameters()).dtype
|
|
325
|
+
for k, v in list(sd.items()):
|
|
326
|
+
if isinstance(v, torch.Tensor) and v.dtype != target_dtype:
|
|
327
|
+
sd[k] = v.to(dtype=target_dtype)
|
|
328
|
+
|
|
329
|
+
m.load_state_dict(sd)
|
|
330
|
+
m.to(device)
|
|
331
|
+
m.eval()
|
|
332
|
+
return m
|
|
333
|
+
|
|
334
|
+
def fit_or_load(
|
|
335
|
+
self,
|
|
336
|
+
*,
|
|
337
|
+
sample: str,
|
|
338
|
+
ref: str,
|
|
339
|
+
label: str,
|
|
340
|
+
arch: str,
|
|
341
|
+
X,
|
|
342
|
+
coords: Optional[np.ndarray],
|
|
343
|
+
device,
|
|
344
|
+
):
|
|
345
|
+
force_fit = bool(getattr(self.cfg, "force_redo_hmm_fit", False))
|
|
346
|
+
scope = self._fit_scope()
|
|
347
|
+
|
|
348
|
+
max_iter = int(getattr(self.cfg, "hmm_max_iter", 50))
|
|
349
|
+
tol = float(getattr(self.cfg, "hmm_tol", 1e-4))
|
|
350
|
+
verbose = bool(getattr(self.cfg, "hmm_verbose", False))
|
|
351
|
+
|
|
352
|
+
# ---- global then adapt ----
|
|
353
|
+
if scope == "global_then_adapt":
|
|
354
|
+
p_global = self._path("GLOBAL", "ALL", ref, label)
|
|
355
|
+
if p_global.exists() and not force_fit:
|
|
356
|
+
base = self._load(p_global, arch=arch, device=device)
|
|
357
|
+
else:
|
|
358
|
+
base = create_hmm(self.cfg, arch=arch).to(device)
|
|
359
|
+
if arch == "single_distance_binned":
|
|
360
|
+
base.fit(
|
|
361
|
+
X, device=device, coords=coords, max_iter=max_iter, tol=tol, verbose=verbose
|
|
362
|
+
)
|
|
363
|
+
else:
|
|
364
|
+
base.fit(X, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
365
|
+
self._save(base, p_global)
|
|
366
|
+
|
|
367
|
+
p_adapt = self._path("ADAPT", sample, ref, label)
|
|
368
|
+
if p_adapt.exists() and not force_fit:
|
|
369
|
+
return self._load(p_adapt, arch=arch, device=device)
|
|
370
|
+
|
|
371
|
+
# IMPORTANT: this assumes you added model.adapt_emissions(...)
|
|
372
|
+
adapted = copy.deepcopy(base).to(device)
|
|
373
|
+
if arch == "single_distance_binned":
|
|
374
|
+
adapted.adapt_emissions(
|
|
375
|
+
X,
|
|
376
|
+
coords,
|
|
377
|
+
device=device,
|
|
378
|
+
max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
|
|
379
|
+
verbose=verbose,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
else:
|
|
383
|
+
adapted.adapt_emissions(
|
|
384
|
+
X,
|
|
385
|
+
coords,
|
|
386
|
+
device=device,
|
|
387
|
+
max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
|
|
388
|
+
verbose=verbose,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
self._save(adapted, p_adapt)
|
|
392
|
+
return adapted
|
|
393
|
+
|
|
394
|
+
# ---- global only ----
|
|
395
|
+
if scope == "global":
|
|
396
|
+
p = self._path("GLOBAL", "ALL", ref, label)
|
|
397
|
+
if p.exists() and not force_fit:
|
|
398
|
+
return self._load(p, arch=arch, device=device)
|
|
399
|
+
|
|
400
|
+
# ---- per sample ----
|
|
401
|
+
else:
|
|
402
|
+
p = self._path("PER", sample, ref, label)
|
|
403
|
+
if p.exists() and not force_fit:
|
|
404
|
+
return self._load(p, arch=arch, device=device)
|
|
405
|
+
|
|
406
|
+
m = create_hmm(self.cfg, arch=arch, device=device)
|
|
407
|
+
if arch == "single_distance_binned":
|
|
408
|
+
m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
409
|
+
else:
|
|
410
|
+
m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
411
|
+
self._save(m, p)
|
|
412
|
+
return m
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _fully_qualified_merge_layers(cfg, prefix: str) -> List[Tuple[str, int]]:
|
|
416
|
+
"""
|
|
417
|
+
cfg.hmm_merge_layer_features is assumed to be a list of (core_layer_name, merge_distance),
|
|
418
|
+
where core_layer_name is like "all_accessible_features" (NOT prefixed with methbase).
|
|
419
|
+
We expand to f"{prefix}_{core_layer_name}".
|
|
420
|
+
"""
|
|
421
|
+
out = []
|
|
422
|
+
for core_layer, dist in getattr(cfg, "hmm_merge_layer_features", []) or []:
|
|
423
|
+
if not core_layer:
|
|
424
|
+
continue
|
|
425
|
+
out.append((f"{prefix}_{core_layer}", int(dist)))
|
|
426
|
+
return out
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def hmm_adata(config_path: str):
|
|
430
|
+
"""
|
|
431
|
+
CLI-facing wrapper for HMM analysis.
|
|
432
|
+
|
|
433
|
+
Command line entrypoint:
|
|
434
|
+
smftools hmm <config_path>
|
|
435
|
+
|
|
436
|
+
Responsibilities:
|
|
437
|
+
- Build cfg via load_adata()
|
|
438
|
+
- Ensure preprocess + spatial stages are run
|
|
439
|
+
- Decide which AnnData to start from (hmm > spatial > pp_dedup > pp > raw)
|
|
440
|
+
- Call hmm_adata_core(cfg, adata, paths)
|
|
441
|
+
"""
|
|
442
|
+
from ..readwrite import safe_read_h5ad
|
|
443
|
+
from .helpers import get_adata_paths
|
|
13
444
|
from .load_adata import load_adata
|
|
14
445
|
from .preprocess_adata import preprocess_adata
|
|
15
446
|
from .spatial_adata import spatial_adata
|
|
16
447
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
import scanpy as sc
|
|
448
|
+
# 1) load cfg / stage paths
|
|
449
|
+
_, _, cfg = load_adata(config_path)
|
|
450
|
+
paths = get_adata_paths(cfg)
|
|
21
451
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
452
|
+
# 2) make sure upstream stages are run (they have their own skipping logic)
|
|
453
|
+
preprocess_adata(config_path)
|
|
454
|
+
spatial_ad, spatial_path = spatial_adata(config_path)
|
|
25
455
|
|
|
26
|
-
|
|
27
|
-
|
|
456
|
+
# 3) choose starting AnnData
|
|
457
|
+
# Prefer:
|
|
458
|
+
# - existing HMM h5ad if not forcing redo
|
|
459
|
+
# - in-memory spatial_ad from wrapper call
|
|
460
|
+
# - saved spatial / pp_dedup / pp / raw on disk
|
|
461
|
+
if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
|
|
462
|
+
adata, _ = safe_read_h5ad(paths.hmm)
|
|
463
|
+
return adata, paths.hmm
|
|
28
464
|
|
|
29
|
-
|
|
30
|
-
adata, adata_path, cfg = load_adata(config_path)
|
|
31
|
-
# General config variable init - Necessary user passed inputs
|
|
32
|
-
smf_modality = cfg.smf_modality # needed for specifying if the data is conversion SMF or direct methylation detection SMF. Or deaminase smf Necessary.
|
|
33
|
-
output_directory = Path(cfg.output_directory) # Path to the output directory to make for the analysis. Necessary.
|
|
34
|
-
|
|
35
|
-
# Make initial output directory
|
|
36
|
-
make_dirs([output_directory])
|
|
37
|
-
############################################### smftools load end ###############################################
|
|
38
|
-
|
|
39
|
-
############################################### smftools preprocess start ###############################################
|
|
40
|
-
pp_adata, pp_adata_path, pp_dedup_adata, pp_dup_rem_adata_path = preprocess_adata(config_path)
|
|
41
|
-
############################################### smftools preprocess end ###############################################
|
|
42
|
-
|
|
43
|
-
############################################### smftools spatial start ###############################################
|
|
44
|
-
spatial_ad, spatial_adata_path = spatial_adata(config_path)
|
|
45
|
-
############################################### smftools spatial end ###############################################
|
|
46
|
-
|
|
47
|
-
############################################### smftools hmm start ###############################################
|
|
48
|
-
input_manager_df = pd.read_csv(cfg.summary_file)
|
|
49
|
-
initial_adata_path = Path(input_manager_df['load_adata'][0])
|
|
50
|
-
pp_adata_path = Path(input_manager_df['pp_adata'][0])
|
|
51
|
-
pp_dup_rem_adata_path = Path(input_manager_df['pp_dedup_adata'][0])
|
|
52
|
-
spatial_adata_path = Path(input_manager_df['spatial_adata'][0])
|
|
53
|
-
hmm_adata_path = Path(input_manager_df['hmm_adata'][0])
|
|
54
|
-
|
|
55
|
-
if spatial_ad:
|
|
56
|
-
# This happens on first run of the pipeline
|
|
465
|
+
if spatial_ad is not None:
|
|
57
466
|
adata = spatial_ad
|
|
467
|
+
source_path = spatial_path
|
|
468
|
+
elif paths.spatial.exists():
|
|
469
|
+
adata, _ = safe_read_h5ad(paths.spatial)
|
|
470
|
+
source_path = paths.spatial
|
|
471
|
+
elif paths.pp_dedup.exists():
|
|
472
|
+
adata, _ = safe_read_h5ad(paths.pp_dedup)
|
|
473
|
+
source_path = paths.pp_dedup
|
|
474
|
+
elif paths.pp.exists():
|
|
475
|
+
adata, _ = safe_read_h5ad(paths.pp)
|
|
476
|
+
source_path = paths.pp
|
|
477
|
+
elif paths.raw.exists():
|
|
478
|
+
adata, _ = safe_read_h5ad(paths.raw)
|
|
479
|
+
source_path = paths.raw
|
|
58
480
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
481
|
+
raise FileNotFoundError(
|
|
482
|
+
"No AnnData available for HMM: expected at least raw or preprocessed h5ad."
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# 4) delegate to core
|
|
486
|
+
adata, hmm_adata_path = hmm_adata_core(
|
|
487
|
+
cfg,
|
|
488
|
+
adata,
|
|
489
|
+
paths,
|
|
490
|
+
source_adata_path=source_path,
|
|
491
|
+
config_path=config_path,
|
|
492
|
+
)
|
|
493
|
+
return adata, hmm_adata_path
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def hmm_adata_core(
|
|
497
|
+
cfg,
|
|
498
|
+
adata,
|
|
499
|
+
paths,
|
|
500
|
+
source_adata_path: Path | None = None,
|
|
501
|
+
config_path: str | None = None,
|
|
502
|
+
) -> Tuple["anndata.AnnData", Path]:
|
|
503
|
+
"""
|
|
504
|
+
Core HMM analysis pipeline.
|
|
505
|
+
|
|
506
|
+
Assumes:
|
|
507
|
+
- cfg is an ExperimentConfig
|
|
508
|
+
- adata is the starting AnnData (typically spatial + dedup)
|
|
509
|
+
- paths is an AdataPaths object (with .raw/.pp/.pp_dedup/.spatial/.hmm)
|
|
510
|
+
|
|
511
|
+
Does NOT decide which h5ad to start from – that is the wrapper's job.
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
import numpy as np
|
|
515
|
+
|
|
516
|
+
from ..hmm import call_hmm_peaks
|
|
517
|
+
from ..metadata import record_smftools_metadata
|
|
518
|
+
from ..plotting import (
|
|
519
|
+
combined_hmm_raw_clustermap,
|
|
520
|
+
plot_hmm_layers_rolling_by_sample_ref,
|
|
521
|
+
plot_hmm_size_contours,
|
|
522
|
+
)
|
|
523
|
+
from ..readwrite import make_dirs
|
|
524
|
+
from .helpers import write_gz_h5ad
|
|
525
|
+
|
|
526
|
+
smf_modality = cfg.smf_modality
|
|
527
|
+
deaminase = smf_modality == "deaminase"
|
|
528
|
+
|
|
529
|
+
output_directory = Path(cfg.output_directory)
|
|
530
|
+
make_dirs([output_directory])
|
|
531
|
+
|
|
532
|
+
pp_dir = output_directory / "preprocessed" / "deduplicated"
|
|
533
|
+
|
|
534
|
+
# ---------------------------- HMM annotate stage ----------------------------
|
|
91
535
|
if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
import warnings
|
|
536
|
+
hmm_models_dir = pp_dir / "10_hmm_models"
|
|
537
|
+
make_dirs([pp_dir, hmm_models_dir])
|
|
95
538
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
539
|
+
# Standard bookkeeping
|
|
540
|
+
uns_key = "hmm_appended_layers"
|
|
541
|
+
if adata.uns.get(uns_key) is None:
|
|
542
|
+
adata.uns[uns_key] = []
|
|
543
|
+
global_appended = list(adata.uns.get(uns_key, []))
|
|
99
544
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
545
|
+
# Prepare trainer + feature config
|
|
546
|
+
trainer = HMMTrainer(cfg=cfg, models_dir=hmm_models_dir)
|
|
547
|
+
|
|
548
|
+
feature_sets = normalize_hmm_feature_sets(getattr(cfg, "hmm_feature_sets", None))
|
|
549
|
+
prob_thr = float(getattr(cfg, "hmm_feature_prob_threshold", 0.5))
|
|
550
|
+
decode = str(getattr(cfg, "hmm_decode", "marginal"))
|
|
551
|
+
write_post = bool(getattr(cfg, "hmm_write_posterior", True))
|
|
552
|
+
post_state = getattr(cfg, "hmm_posterior_state", "Modified")
|
|
553
|
+
merged_suffix = str(getattr(cfg, "hmm_merged_suffix", "_merged"))
|
|
554
|
+
force_apply = bool(getattr(cfg, "force_redo_hmm_apply", False))
|
|
555
|
+
bypass_apply = bool(getattr(cfg, "bypass_hmm_apply", False))
|
|
556
|
+
bypass_fit = bool(getattr(cfg, "bypass_hmm_fit", False))
|
|
104
557
|
|
|
105
558
|
samples = adata.obs[cfg.sample_name_col_for_plotting].cat.categories
|
|
106
559
|
references = adata.obs[cfg.reference_column].cat.categories
|
|
107
|
-
|
|
560
|
+
methbases = list(getattr(cfg, "hmm_methbases", [])) or []
|
|
108
561
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
adata.uns[uns_key] = []
|
|
562
|
+
if not methbases:
|
|
563
|
+
raise ValueError("cfg.hmm_methbases is empty.")
|
|
112
564
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
if obsm_key not in subset.obsm:
|
|
127
|
-
print(f"Skipping {sample} {ref} {mod_label}: missing obsm '{obsm_key}'")
|
|
565
|
+
# Top-level skip
|
|
566
|
+
already = bool(adata.uns.get("hmm_annotated", False))
|
|
567
|
+
if already and not (bool(getattr(cfg, "force_redo_hmm_fit", False)) or force_apply):
|
|
568
|
+
pass
|
|
569
|
+
|
|
570
|
+
else:
|
|
571
|
+
logger.info("Starting HMM annotation over samples and references")
|
|
572
|
+
for sample in samples:
|
|
573
|
+
for ref in references:
|
|
574
|
+
mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
|
|
575
|
+
adata.obs[cfg.reference_column] == ref
|
|
576
|
+
)
|
|
577
|
+
if int(np.sum(mask)) == 0:
|
|
128
578
|
continue
|
|
129
579
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
580
|
+
subset = adata[mask].copy()
|
|
581
|
+
subset.uns[uns_key] = [] # isolate appended tracking per subset
|
|
582
|
+
|
|
583
|
+
# ---- Decide which tasks to run ----
|
|
584
|
+
methbases = list(getattr(cfg, "hmm_methbases", [])) or []
|
|
585
|
+
run_multi = bool(getattr(cfg, "hmm_run_multichannel", True))
|
|
586
|
+
run_cpg = bool(getattr(cfg, "cpg", False))
|
|
587
|
+
device = resolve_torch_device(cfg.device)
|
|
588
|
+
|
|
589
|
+
logger.info("HMM processing sample=%s ref=%s", sample, ref)
|
|
590
|
+
|
|
591
|
+
# ---- split feature sets ----
|
|
592
|
+
feature_sets_all = normalize_hmm_feature_sets(
|
|
593
|
+
getattr(cfg, "hmm_feature_sets", None)
|
|
594
|
+
)
|
|
595
|
+
feature_sets_access = {
|
|
596
|
+
k: v
|
|
597
|
+
for k, v in feature_sets_all.items()
|
|
598
|
+
if k in ("footprint", "accessible")
|
|
599
|
+
}
|
|
600
|
+
feature_sets_cpg = (
|
|
601
|
+
{"cpg": feature_sets_all["cpg"]} if "cpg" in feature_sets_all else {}
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# =========================
|
|
605
|
+
# 1) Single-channel accessibility (per methbase)
|
|
606
|
+
# =========================
|
|
607
|
+
for mb in methbases:
|
|
608
|
+
logger.info("HMM single-channel for methbase=%s", mb)
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
X, coords = build_single_channel(
|
|
612
|
+
subset,
|
|
613
|
+
ref=str(ref),
|
|
614
|
+
methbase=str(mb),
|
|
615
|
+
smf_modality=smf_modality,
|
|
616
|
+
cfg=cfg,
|
|
617
|
+
)
|
|
618
|
+
except Exception:
|
|
619
|
+
logger.warning(
|
|
620
|
+
"Skipping HMM single-channel for methbase=%s due to data error", mb
|
|
621
|
+
)
|
|
170
622
|
continue
|
|
171
623
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
for
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
624
|
+
arch = trainer.choose_arch(multichannel=False)
|
|
625
|
+
|
|
626
|
+
logger.info("HMM fitting/loading for methbase=%s", mb)
|
|
627
|
+
hmm = trainer.fit_or_load(
|
|
628
|
+
sample=str(sample),
|
|
629
|
+
ref=str(ref),
|
|
630
|
+
label=str(mb),
|
|
631
|
+
arch=arch,
|
|
632
|
+
X=X,
|
|
633
|
+
coords=coords,
|
|
634
|
+
device=device,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if not bypass_apply:
|
|
638
|
+
logger.info("HMM applying for methbase=%s", mb)
|
|
639
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
|
|
640
|
+
hmm.annotate_adata(
|
|
641
|
+
subset,
|
|
642
|
+
prefix=str(mb),
|
|
643
|
+
X=X,
|
|
644
|
+
coords=coords,
|
|
645
|
+
var_mask=pm,
|
|
646
|
+
span_fill=True,
|
|
647
|
+
config=cfg,
|
|
648
|
+
decode=decode,
|
|
649
|
+
write_posterior=write_post,
|
|
650
|
+
posterior_state=post_state,
|
|
651
|
+
feature_sets=feature_sets_access, # <--- ONLY accessibility feature sets
|
|
652
|
+
prob_threshold=prob_thr,
|
|
653
|
+
uns_key=uns_key,
|
|
654
|
+
uns_flag=f"hmm_annotated_{mb}",
|
|
655
|
+
force_redo=force_apply,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# merges for this mb
|
|
659
|
+
for core_layer, dist in (
|
|
660
|
+
getattr(cfg, "hmm_merge_layer_features", []) or []
|
|
661
|
+
):
|
|
662
|
+
base_layer = f"{mb}_{core_layer}"
|
|
663
|
+
logger.info("Merging intervals for layer=%s", base_layer)
|
|
664
|
+
if base_layer in subset.layers:
|
|
665
|
+
merged_base = hmm.merge_intervals_to_new_layer(
|
|
666
|
+
subset,
|
|
667
|
+
base_layer,
|
|
668
|
+
distance_threshold=int(dist),
|
|
669
|
+
suffix=merged_suffix,
|
|
670
|
+
overwrite=True,
|
|
671
|
+
)
|
|
672
|
+
# write merged size classes based on whichever group core_layer corresponds to
|
|
673
|
+
for group, fs in feature_sets_access.items():
|
|
674
|
+
fmap = fs.get("features", {}) or {}
|
|
675
|
+
if fmap:
|
|
676
|
+
hmm.write_size_class_layers_from_binary(
|
|
677
|
+
subset,
|
|
678
|
+
merged_base,
|
|
679
|
+
out_prefix=str(mb),
|
|
680
|
+
feature_ranges=fmap,
|
|
681
|
+
suffix=merged_suffix,
|
|
682
|
+
overwrite=True,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# =========================
|
|
686
|
+
# 2) Multi-channel accessibility (Combined)
|
|
687
|
+
# =========================
|
|
688
|
+
if run_multi and len(methbases) >= 2:
|
|
689
|
+
logger.info("HMM multi-channel for methbases=%s", ",".join(methbases))
|
|
690
|
+
try:
|
|
691
|
+
X3, coords_u, used_mbs = build_multi_channel_union(
|
|
692
|
+
subset,
|
|
693
|
+
ref=str(ref),
|
|
694
|
+
methbases=methbases,
|
|
695
|
+
smf_modality=smf_modality,
|
|
696
|
+
cfg=cfg,
|
|
697
|
+
)
|
|
698
|
+
except Exception:
|
|
699
|
+
X3, coords_u, used_mbs = None, None, []
|
|
700
|
+
logger.warning(
|
|
701
|
+
"Skipping HMM multi-channel due to data error or insufficient methbases"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
if X3 is not None and len(used_mbs) >= 2:
|
|
705
|
+
union_mask = None
|
|
706
|
+
for mb in used_mbs:
|
|
707
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
|
|
708
|
+
union_mask = pm if union_mask is None else (union_mask | pm)
|
|
709
|
+
|
|
710
|
+
arch = trainer.choose_arch(multichannel=True)
|
|
711
|
+
|
|
712
|
+
logger.info("HMM fitting/loading for multi-channel")
|
|
713
|
+
hmmc = trainer.fit_or_load(
|
|
714
|
+
sample=str(sample),
|
|
715
|
+
ref=str(ref),
|
|
716
|
+
label="Combined",
|
|
717
|
+
arch=arch,
|
|
718
|
+
X=X3,
|
|
719
|
+
coords=coords_u,
|
|
720
|
+
device=device,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
if not bypass_apply:
|
|
724
|
+
logger.info("HMM applying for multi-channel")
|
|
725
|
+
hmmc.annotate_adata(
|
|
726
|
+
subset,
|
|
727
|
+
prefix="Combined",
|
|
728
|
+
X=X3,
|
|
729
|
+
coords=coords_u,
|
|
730
|
+
var_mask=union_mask,
|
|
731
|
+
span_fill=True,
|
|
732
|
+
config=cfg,
|
|
733
|
+
decode=decode,
|
|
734
|
+
write_posterior=write_post,
|
|
735
|
+
posterior_state=post_state,
|
|
736
|
+
feature_sets=feature_sets_access, # <--- accessibility only
|
|
737
|
+
prob_threshold=prob_thr,
|
|
738
|
+
uns_key=uns_key,
|
|
739
|
+
uns_flag="hmm_annotated_combined",
|
|
740
|
+
force_redo=force_apply,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
for core_layer, dist in (
|
|
744
|
+
getattr(cfg, "hmm_merge_layer_features", []) or []
|
|
745
|
+
):
|
|
746
|
+
base_layer = f"Combined_{core_layer}"
|
|
747
|
+
if base_layer in subset.layers:
|
|
748
|
+
merged_base = hmmc.merge_intervals_to_new_layer(
|
|
749
|
+
subset,
|
|
750
|
+
base_layer,
|
|
751
|
+
distance_threshold=int(dist),
|
|
752
|
+
suffix=merged_suffix,
|
|
753
|
+
overwrite=True,
|
|
754
|
+
)
|
|
755
|
+
for group, fs in feature_sets_access.items():
|
|
756
|
+
fmap = fs.get("features", {}) or {}
|
|
757
|
+
if fmap:
|
|
758
|
+
hmmc.write_size_class_layers_from_binary(
|
|
759
|
+
subset,
|
|
760
|
+
merged_base,
|
|
761
|
+
out_prefix="Combined",
|
|
762
|
+
feature_ranges=fmap,
|
|
763
|
+
suffix=merged_suffix,
|
|
764
|
+
overwrite=True,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# =========================
|
|
768
|
+
# 3) CpG-only single-channel task
|
|
769
|
+
# =========================
|
|
770
|
+
if run_cpg:
|
|
771
|
+
logger.info("HMM single-channel for CpG")
|
|
772
|
+
try:
|
|
773
|
+
Xcpg, coordscpg = build_single_channel(
|
|
774
|
+
subset,
|
|
775
|
+
ref=str(ref),
|
|
776
|
+
methbase="CpG",
|
|
777
|
+
smf_modality=smf_modality,
|
|
778
|
+
cfg=cfg,
|
|
779
|
+
)
|
|
780
|
+
except Exception:
|
|
781
|
+
Xcpg, coordscpg = None, None
|
|
782
|
+
logger.warning("Skipping HMM single-channel for CpG due to data error")
|
|
783
|
+
|
|
784
|
+
if Xcpg is not None and Xcpg.size and feature_sets_cpg:
|
|
785
|
+
arch = trainer.choose_arch(multichannel=False)
|
|
786
|
+
|
|
787
|
+
logger.info("HMM fitting/loading for CpG")
|
|
788
|
+
hmmg = trainer.fit_or_load(
|
|
789
|
+
sample=str(sample),
|
|
790
|
+
ref=str(ref),
|
|
791
|
+
label="CpG",
|
|
792
|
+
arch=arch,
|
|
793
|
+
X=Xcpg,
|
|
794
|
+
coords=coordscpg,
|
|
795
|
+
device=device,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
if not bypass_apply:
|
|
799
|
+
logger.info("HMM applying for CpG")
|
|
800
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), "CpG")
|
|
801
|
+
hmmg.annotate_adata(
|
|
802
|
+
subset,
|
|
803
|
+
prefix="CpG",
|
|
804
|
+
X=Xcpg,
|
|
805
|
+
coords=coordscpg,
|
|
806
|
+
var_mask=pm,
|
|
807
|
+
span_fill=True,
|
|
808
|
+
config=cfg,
|
|
809
|
+
decode=decode,
|
|
810
|
+
write_posterior=write_post,
|
|
811
|
+
posterior_state=post_state,
|
|
812
|
+
feature_sets=feature_sets_cpg, # <--- ONLY cpg group (cpg_patch)
|
|
813
|
+
prob_threshold=prob_thr,
|
|
814
|
+
uns_key=uns_key,
|
|
815
|
+
uns_flag="hmm_annotated_CpG",
|
|
816
|
+
force_redo=force_apply,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
# ------------------------------------------------------------
|
|
820
|
+
# Copy newly created subset layers back into the full adata
|
|
821
|
+
# ------------------------------------------------------------
|
|
822
|
+
appended = (
|
|
823
|
+
list(subset.uns.get(uns_key, []))
|
|
824
|
+
if subset.uns.get(uns_key) is not None
|
|
825
|
+
else []
|
|
826
|
+
)
|
|
827
|
+
if appended:
|
|
828
|
+
row_mask = np.asarray(
|
|
829
|
+
mask.values if hasattr(mask, "values") else mask, dtype=bool
|
|
830
|
+
)
|
|
200
831
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
if ".gz" == hmm_adata_path.suffix:
|
|
208
|
-
safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
|
|
209
|
-
else:
|
|
210
|
-
hmm_adata_path = hmm_adata_path.with_name(hmm_adata_path.name + '.gz')
|
|
211
|
-
safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
|
|
832
|
+
for ln in appended:
|
|
833
|
+
if ln not in subset.layers:
|
|
834
|
+
continue
|
|
835
|
+
_ensure_layer_and_assign_rows(adata, ln, row_mask, subset.layers[ln])
|
|
836
|
+
if ln not in global_appended:
|
|
837
|
+
global_appended.append(ln)
|
|
212
838
|
|
|
213
|
-
|
|
839
|
+
adata.uns[uns_key] = global_appended
|
|
214
840
|
|
|
215
|
-
|
|
841
|
+
adata.uns["hmm_annotated"] = True
|
|
216
842
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
843
|
+
hmm_layers = list(adata.uns.get("hmm_appended_layers", []) or [])
|
|
844
|
+
# keep only real feature layers; drop lengths/states/posterior
|
|
845
|
+
hmm_layers = [
|
|
846
|
+
layer
|
|
847
|
+
for layer in hmm_layers
|
|
848
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
849
|
+
]
|
|
850
|
+
logger.info(f"HMM appended layers: {hmm_layers}")
|
|
220
851
|
|
|
852
|
+
# ---------------------------- HMM peak calling stage ----------------------------
|
|
853
|
+
hmm_dir = pp_dir / "11_hmm_peak_calling"
|
|
221
854
|
if hmm_dir.is_dir():
|
|
222
|
-
|
|
855
|
+
pass
|
|
223
856
|
else:
|
|
224
857
|
make_dirs([pp_dir, hmm_dir])
|
|
225
|
-
from ..plotting import combined_hmm_raw_clustermap
|
|
226
|
-
feature_layers = [
|
|
227
|
-
"all_accessible_features",
|
|
228
|
-
"large_accessible_patch",
|
|
229
|
-
"small_bound_stretch",
|
|
230
|
-
"medium_bound_stretch",
|
|
231
|
-
"putative_nucleosome",
|
|
232
|
-
"all_accessible_features_merged",
|
|
233
|
-
]
|
|
234
858
|
|
|
235
|
-
|
|
859
|
+
call_hmm_peaks(
|
|
860
|
+
adata,
|
|
861
|
+
feature_configs=cfg.hmm_peak_feature_configs,
|
|
862
|
+
ref_column=cfg.reference_column,
|
|
863
|
+
site_types=cfg.mod_target_bases,
|
|
864
|
+
save_plot=True,
|
|
865
|
+
output_dir=hmm_dir,
|
|
866
|
+
index_col_suffix=cfg.reindexed_var_suffix,
|
|
867
|
+
)
|
|
236
868
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
869
|
+
## Save HMM annotated adata
|
|
870
|
+
if not paths.hmm.exists():
|
|
871
|
+
logger.info("Saving hmm analyzed AnnData (post preprocessing and duplicate removal).")
|
|
872
|
+
record_smftools_metadata(
|
|
873
|
+
adata,
|
|
874
|
+
step_name="hmm",
|
|
875
|
+
cfg=cfg,
|
|
876
|
+
config_path=config_path,
|
|
877
|
+
input_paths=[source_adata_path] if source_adata_path else None,
|
|
878
|
+
output_path=paths.hmm,
|
|
879
|
+
)
|
|
880
|
+
write_gz_h5ad(adata, paths.hmm)
|
|
242
881
|
|
|
243
|
-
|
|
244
|
-
layers.extend([f"A_{layer}" for layer in feature_layers])
|
|
882
|
+
########################################################################################################################
|
|
245
883
|
|
|
246
|
-
|
|
247
|
-
raise ValueError(
|
|
248
|
-
f"No HMM feature layers matched mod_target_bases={cfg.mod_target_bases} "
|
|
249
|
-
f"and smf_modality={smf_modality}"
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if smf_modality == 'direct':
|
|
253
|
-
sort_by = "any_a"
|
|
254
|
-
else:
|
|
255
|
-
sort_by = 'gpc'
|
|
884
|
+
############################################### HMM based feature plotting ###############################################
|
|
256
885
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
886
|
+
hmm_dir = pp_dir / "12_hmm_clustermaps"
|
|
887
|
+
make_dirs([pp_dir, hmm_dir])
|
|
888
|
+
|
|
889
|
+
layers: list[str] = []
|
|
890
|
+
|
|
891
|
+
for base in cfg.hmm_methbases:
|
|
892
|
+
layers.extend([f"{base}_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
|
|
893
|
+
|
|
894
|
+
if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
|
|
895
|
+
layers.extend([f"Combined_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
|
|
896
|
+
|
|
897
|
+
if cfg.cpg:
|
|
898
|
+
layers.extend(["CpG_cpg_patch"])
|
|
899
|
+
|
|
900
|
+
if not layers:
|
|
901
|
+
raise ValueError(
|
|
902
|
+
f"No HMM feature layers matched mod_target_bases={cfg.mod_target_bases} "
|
|
903
|
+
f"and smf_modality={smf_modality}"
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
for layer in layers:
|
|
907
|
+
hmm_cluster_save_dir = hmm_dir / layer
|
|
908
|
+
if hmm_cluster_save_dir.is_dir():
|
|
909
|
+
pass
|
|
910
|
+
else:
|
|
911
|
+
make_dirs([hmm_cluster_save_dir])
|
|
260
912
|
|
|
261
913
|
combined_hmm_raw_clustermap(
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
914
|
+
adata,
|
|
915
|
+
sample_col=cfg.sample_name_col_for_plotting,
|
|
916
|
+
reference_col=cfg.reference_column,
|
|
917
|
+
hmm_feature_layer=layer,
|
|
918
|
+
layer_gpc=cfg.layer_for_clustermap_plotting,
|
|
919
|
+
layer_cpg=cfg.layer_for_clustermap_plotting,
|
|
920
|
+
layer_c=cfg.layer_for_clustermap_plotting,
|
|
921
|
+
layer_a=cfg.layer_for_clustermap_plotting,
|
|
922
|
+
cmap_hmm=cfg.clustermap_cmap_hmm,
|
|
923
|
+
cmap_gpc=cfg.clustermap_cmap_gpc,
|
|
924
|
+
cmap_cpg=cfg.clustermap_cmap_cpg,
|
|
925
|
+
cmap_c=cfg.clustermap_cmap_c,
|
|
926
|
+
cmap_a=cfg.clustermap_cmap_a,
|
|
927
|
+
min_quality=cfg.read_quality_filter_thresholds[0],
|
|
928
|
+
min_length=cfg.read_len_filter_thresholds[0],
|
|
929
|
+
min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
|
|
930
|
+
0
|
|
931
|
+
],
|
|
932
|
+
min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
|
|
933
|
+
demux_types=("double", "already"),
|
|
934
|
+
save_path=hmm_cluster_save_dir,
|
|
935
|
+
normalize_hmm=False,
|
|
936
|
+
sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
937
|
+
bins=None,
|
|
938
|
+
deaminase=deaminase,
|
|
939
|
+
min_signal=0,
|
|
940
|
+
index_col_suffix=cfg.reindexed_var_suffix,
|
|
285
941
|
)
|
|
286
942
|
|
|
287
|
-
hmm_dir = pp_dir / "
|
|
943
|
+
hmm_dir = pp_dir / "13_hmm_bulk_traces"
|
|
288
944
|
|
|
289
945
|
if hmm_dir.is_dir():
|
|
290
|
-
|
|
946
|
+
logger.debug(f"{hmm_dir} already exists.")
|
|
291
947
|
else:
|
|
292
948
|
make_dirs([pp_dir, hmm_dir])
|
|
293
949
|
from ..plotting import plot_hmm_layers_rolling_by_sample_ref
|
|
950
|
+
|
|
951
|
+
bulk_hmm_layers = [
|
|
952
|
+
layer
|
|
953
|
+
for layer in hmm_layers
|
|
954
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
955
|
+
]
|
|
294
956
|
saved = plot_hmm_layers_rolling_by_sample_ref(
|
|
295
957
|
adata,
|
|
296
|
-
layers=
|
|
958
|
+
layers=bulk_hmm_layers,
|
|
297
959
|
sample_col=cfg.sample_name_col_for_plotting,
|
|
298
960
|
ref_col=cfg.reference_column,
|
|
299
961
|
window=101,
|
|
300
962
|
rows_per_page=4,
|
|
301
|
-
figsize_per_cell=(4,2.5),
|
|
963
|
+
figsize_per_cell=(4, 2.5),
|
|
302
964
|
output_dir=hmm_dir,
|
|
303
965
|
save=True,
|
|
304
|
-
show_raw=False
|
|
966
|
+
show_raw=False,
|
|
305
967
|
)
|
|
306
968
|
|
|
307
|
-
hmm_dir = pp_dir / "
|
|
969
|
+
hmm_dir = pp_dir / "14_hmm_fragment_distributions"
|
|
308
970
|
|
|
309
971
|
if hmm_dir.is_dir():
|
|
310
|
-
|
|
972
|
+
logger.debug(f"{hmm_dir} already exists.")
|
|
311
973
|
else:
|
|
312
974
|
make_dirs([pp_dir, hmm_dir])
|
|
313
975
|
from ..plotting import plot_hmm_size_contours
|
|
314
976
|
|
|
315
|
-
|
|
977
|
+
if smf_modality == "deaminase":
|
|
978
|
+
fragments = [
|
|
979
|
+
("C_all_accessible_features_lengths", 400),
|
|
980
|
+
("C_all_footprint_features_lengths", 250),
|
|
981
|
+
("C_all_accessible_features_merged_lengths", 800),
|
|
982
|
+
]
|
|
983
|
+
elif smf_modality == "conversion":
|
|
984
|
+
fragments = [
|
|
985
|
+
("GpC_all_accessible_features_lengths", 400),
|
|
986
|
+
("GpC_all_footprint_features_lengths", 250),
|
|
987
|
+
("GpC_all_accessible_features_merged_lengths", 800),
|
|
988
|
+
]
|
|
989
|
+
elif smf_modality == "direct":
|
|
990
|
+
fragments = [
|
|
991
|
+
("A_all_accessible_features_lengths", 400),
|
|
992
|
+
("A_all_footprint_features_lengths", 200),
|
|
993
|
+
("A_all_accessible_features_merged_lengths", 800),
|
|
994
|
+
]
|
|
995
|
+
|
|
996
|
+
for layer, max in fragments:
|
|
316
997
|
save_path = hmm_dir / layer
|
|
317
998
|
make_dirs([save_path])
|
|
318
999
|
|
|
@@ -328,11 +1009,11 @@ def hmm_adata(config_path):
|
|
|
328
1009
|
save_pdf=False,
|
|
329
1010
|
save_each_page=True,
|
|
330
1011
|
dpi=200,
|
|
331
|
-
smoothing_sigma=
|
|
332
|
-
normalize_after_smoothing=
|
|
333
|
-
cmap=
|
|
334
|
-
log_scale_z=True
|
|
1012
|
+
smoothing_sigma=(10, 10),
|
|
1013
|
+
normalize_after_smoothing=True,
|
|
1014
|
+
cmap="Greens",
|
|
1015
|
+
log_scale_z=True,
|
|
335
1016
|
)
|
|
336
1017
|
########################################################################################################################
|
|
337
1018
|
|
|
338
|
-
return (adata,
|
|
1019
|
+
return (adata, paths.hmm)
|