smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/cli/hmm_adata.py
CHANGED
|
@@ -1,223 +1,860 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from smftools.logging_utils import get_logger
|
|
11
|
+
from smftools.optional_imports import require
|
|
12
|
+
|
|
13
|
+
# FIX: import _to_dense_np to avoid NameError
|
|
14
|
+
from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import torch as torch_types
|
|
20
|
+
|
|
21
|
+
torch = require("torch", extra="torch", purpose="HMM CLI")
|
|
22
|
+
|
|
23
|
+
# =============================================================================
|
|
24
|
+
# Helpers: extracting training arrays
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_training_matrix(
|
|
29
|
+
subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
|
|
30
|
+
) -> Tuple[np.ndarray, Optional[str]]:
|
|
31
|
+
"""
|
|
32
|
+
Matches your existing behavior:
|
|
33
|
+
- direct -> uses cfg.output_binary_layer_name in .layers
|
|
34
|
+
- else -> uses .X
|
|
35
|
+
Returns (X, layer_name_or_None) where X is dense float array.
|
|
36
|
+
"""
|
|
37
|
+
sub = subset[:, cols_mask]
|
|
38
|
+
|
|
39
|
+
if smf_modality == "direct":
|
|
40
|
+
hmm_layer = getattr(cfg, "output_binary_layer_name", None)
|
|
41
|
+
if hmm_layer is None or hmm_layer not in sub.layers:
|
|
42
|
+
raise KeyError(f"Missing HMM training layer '{hmm_layer}' in subset.")
|
|
43
|
+
|
|
44
|
+
logger.debug("Using direct modality HMM training layer: %s", hmm_layer)
|
|
45
|
+
mat = sub.layers[hmm_layer]
|
|
46
|
+
else:
|
|
47
|
+
logger.debug("Using .X for HMM training matrix")
|
|
48
|
+
hmm_layer = None
|
|
49
|
+
mat = sub.X
|
|
50
|
+
|
|
51
|
+
X = _to_dense_np(mat).astype(float)
|
|
52
|
+
if X.ndim != 2:
|
|
53
|
+
raise ValueError(f"Expected 2D training matrix; got {X.shape}")
|
|
54
|
+
return X, hmm_layer
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
|
|
2
58
|
"""
|
|
3
|
-
|
|
4
|
-
|
|
59
|
+
Reproduces your mask resolution, with compatibility for both *_any_C_site and *_C_site.
|
|
60
|
+
"""
|
|
61
|
+
key = str(methbase).strip().lower()
|
|
62
|
+
|
|
63
|
+
logger.debug("Resolving position mask for methbase=%s on ref=%s", key, ref)
|
|
64
|
+
|
|
65
|
+
if key in ("a",):
|
|
66
|
+
col = f"{ref}_A_site"
|
|
67
|
+
if col not in subset.var:
|
|
68
|
+
return None
|
|
69
|
+
logger.debug("Using positions with A calls from column: %s", col)
|
|
70
|
+
return np.asarray(subset.var[col])
|
|
71
|
+
|
|
72
|
+
if key in ("c", "any_c", "anyc", "any-c"):
|
|
73
|
+
for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
|
|
74
|
+
if col in subset.var:
|
|
75
|
+
logger.debug("Using positions with C calls from column: %s", col)
|
|
76
|
+
return np.asarray(subset.var[col])
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
if key in ("gpc", "gpc_site", "gpc-site"):
|
|
80
|
+
col = f"{ref}_GpC_site"
|
|
81
|
+
if col not in subset.var:
|
|
82
|
+
return None
|
|
83
|
+
logger.debug("Using positions with GpC calls from column: %s", col)
|
|
84
|
+
return np.asarray(subset.var[col])
|
|
85
|
+
|
|
86
|
+
if key in ("cpg", "cpg_site", "cpg-site"):
|
|
87
|
+
col = f"{ref}_CpG_site"
|
|
88
|
+
if col not in subset.var:
|
|
89
|
+
return None
|
|
90
|
+
logger.debug("Using positions with CpG calls from column: %s", col)
|
|
91
|
+
return np.asarray(subset.var[col])
|
|
5
92
|
|
|
6
|
-
|
|
7
|
-
|
|
93
|
+
alt = f"{ref}_{methbase}_site"
|
|
94
|
+
if alt not in subset.var:
|
|
95
|
+
return None
|
|
8
96
|
|
|
97
|
+
logger.debug("Using positions from column: %s", alt)
|
|
98
|
+
return np.asarray(subset.var[alt])
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def build_single_channel(
|
|
102
|
+
subset, ref: str, methbase: str, smf_modality: Optional[str], cfg
|
|
103
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
104
|
+
"""
|
|
9
105
|
Returns:
|
|
10
|
-
|
|
106
|
+
X (N, Lmb) float with NaNs allowed
|
|
107
|
+
coords (Lmb,) int coords from var_names
|
|
108
|
+
"""
|
|
109
|
+
pm = _resolve_pos_mask_for_methbase(subset, ref, methbase)
|
|
110
|
+
logger.debug(
|
|
111
|
+
"Position mask for methbase=%s on ref=%s has %d sites",
|
|
112
|
+
methbase,
|
|
113
|
+
ref,
|
|
114
|
+
int(np.sum(pm)) if pm is not None else 0,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if pm is None or int(np.sum(pm)) == 0:
|
|
118
|
+
raise ValueError(f"No columns for methbase={methbase} on ref={ref}")
|
|
119
|
+
|
|
120
|
+
X, _ = _get_training_matrix(subset, pm, smf_modality, cfg)
|
|
121
|
+
logger.debug("Training matrix for methbase=%s on ref=%s has shape %s", methbase, ref, X.shape)
|
|
122
|
+
|
|
123
|
+
coords, _ = _safe_int_coords(subset[:, pm].var_names)
|
|
124
|
+
logger.debug(
|
|
125
|
+
"Coordinates for methbase=%s on ref=%s have length %d", methbase, ref, coords.shape[0]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return X, coords
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def build_multi_channel_union(
|
|
132
|
+
subset, ref: str, methbases: Sequence[str], smf_modality: Optional[str], cfg
|
|
133
|
+
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
|
|
134
|
+
"""
|
|
135
|
+
Build (N, Lunion, C) on union coordinate grid across methbases.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
X3d: (N, Lunion, C) float with NaN where methbase has no site
|
|
139
|
+
coords: (Lunion,) int union coords
|
|
140
|
+
used_methbases: list of methbases actually included (>=2)
|
|
141
|
+
"""
|
|
142
|
+
per: List[Tuple[str, np.ndarray, np.ndarray, np.ndarray]] = [] # (mb, X, coords, pm)
|
|
143
|
+
|
|
144
|
+
for mb in methbases:
|
|
145
|
+
pm = _resolve_pos_mask_for_methbase(subset, ref, mb)
|
|
146
|
+
if pm is None or int(np.sum(pm)) == 0:
|
|
147
|
+
continue
|
|
148
|
+
Xmb, _ = _get_training_matrix(subset, pm, smf_modality, cfg) # (N,Lmb)
|
|
149
|
+
cmb, _ = _safe_int_coords(subset[:, pm].var_names)
|
|
150
|
+
per.append((mb, Xmb.astype(float), cmb.astype(int), pm))
|
|
151
|
+
|
|
152
|
+
if len(per) < 2:
|
|
153
|
+
raise ValueError(f"Need >=2 methbases with columns for union multi-channel on ref={ref}")
|
|
154
|
+
|
|
155
|
+
# union coordinates
|
|
156
|
+
coords = np.unique(np.concatenate([c for _, _, c, _ in per], axis=0)).astype(int)
|
|
157
|
+
idx = {int(v): i for i, v in enumerate(coords.tolist())}
|
|
158
|
+
|
|
159
|
+
N = per[0][1].shape[0]
|
|
160
|
+
L = coords.shape[0]
|
|
161
|
+
C = len(per)
|
|
162
|
+
X3 = np.full((N, L, C), np.nan, dtype=float)
|
|
163
|
+
|
|
164
|
+
for ci, (mb, Xmb, cmb, _) in enumerate(per):
|
|
165
|
+
cols = np.array([idx[int(v)] for v in cmb.tolist()], dtype=int)
|
|
166
|
+
X3[:, cols, ci] = Xmb
|
|
167
|
+
|
|
168
|
+
used = [mb for (mb, _, _, _) in per]
|
|
169
|
+
return X3, coords, used
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass
|
|
173
|
+
class HMMTask:
|
|
174
|
+
name: str
|
|
175
|
+
signals: List[str] # e.g. ["GpC"] or ["GpC","CpG"] or ["CpG"]
|
|
176
|
+
feature_groups: List[str] # e.g. ["footprint","accessible"] or ["cpg"]
|
|
177
|
+
output_prefix: Optional[str] = None # force prefix (CpG task uses "CpG")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def build_hmm_tasks(cfg: Union[dict, Any]) -> List[HMMTask]:
|
|
181
|
+
"""
|
|
182
|
+
Accessibility signals come from cfg['hmm_methbases'].
|
|
183
|
+
CpG task is enabled by cfg['cpg']==True, independent of hmm_methbases.
|
|
184
|
+
"""
|
|
185
|
+
if not isinstance(cfg, dict):
|
|
186
|
+
# best effort conversion
|
|
187
|
+
cfg = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith("_")}
|
|
188
|
+
|
|
189
|
+
tasks: List[HMMTask] = []
|
|
190
|
+
|
|
191
|
+
# accessibility task
|
|
192
|
+
methbases = list(cfg.get("hmm_methbases", []) or [])
|
|
193
|
+
if len(methbases) > 0:
|
|
194
|
+
tasks.append(
|
|
195
|
+
HMMTask(
|
|
196
|
+
name="accessibility",
|
|
197
|
+
signals=methbases,
|
|
198
|
+
feature_groups=["footprint", "accessible"],
|
|
199
|
+
output_prefix=None,
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# CpG task (special case)
|
|
204
|
+
if bool(cfg.get("cpg", False)):
|
|
205
|
+
tasks.append(
|
|
206
|
+
HMMTask(
|
|
207
|
+
name="cpg",
|
|
208
|
+
signals=["CpG"],
|
|
209
|
+
feature_groups=["cpg"],
|
|
210
|
+
output_prefix="CpG",
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return tasks
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def select_hmm_arch(cfg: dict, signals: Sequence[str]) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Simple, explicit model selection:
|
|
220
|
+
- distance-aware => 'single_distance_binned' (only meaningful for single-channel)
|
|
221
|
+
- multi-signal => 'multi'
|
|
222
|
+
- else => 'single'
|
|
223
|
+
"""
|
|
224
|
+
if bool(cfg.get("hmm_distance_aware", False)) and len(signals) == 1:
|
|
225
|
+
return "single_distance_binned"
|
|
226
|
+
if len(signals) > 1:
|
|
227
|
+
return "multi"
|
|
228
|
+
return "single"
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def resolve_input_layer(adata, cfg: dict, layer_override: Optional[str]) -> Optional[str]:
|
|
232
|
+
"""
|
|
233
|
+
If direct modality, prefer cfg.output_binary_layer_name.
|
|
234
|
+
Else use layer_override or None (meaning use .X).
|
|
235
|
+
"""
|
|
236
|
+
smf_modality = cfg.get("smf_modality", None)
|
|
237
|
+
if smf_modality == "direct":
|
|
238
|
+
nm = cfg.get("output_binary_layer_name", None)
|
|
239
|
+
if nm is None:
|
|
240
|
+
raise KeyError("cfg.output_binary_layer_name missing for smf_modality='direct'")
|
|
241
|
+
if nm not in adata.layers:
|
|
242
|
+
raise KeyError(f"Direct modality expects layer '{nm}' in adata.layers")
|
|
243
|
+
return nm
|
|
244
|
+
return layer_override
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _ensure_layer_and_assign_rows(adata, layer_name: str, row_mask: np.ndarray, subset_layer):
|
|
248
|
+
"""
|
|
249
|
+
Writes subset_layer (n_subset_obs, n_vars) into adata.layers[layer_name] for rows where row_mask==True.
|
|
250
|
+
"""
|
|
251
|
+
row_mask = np.asarray(row_mask, dtype=bool)
|
|
252
|
+
if row_mask.ndim != 1 or row_mask.size != adata.n_obs:
|
|
253
|
+
raise ValueError("row_mask must be length adata.n_obs")
|
|
254
|
+
|
|
255
|
+
arr = _to_dense_np(subset_layer)
|
|
256
|
+
if arr.shape != (int(row_mask.sum()), adata.n_vars):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"subset layer '{layer_name}' shape {arr.shape} != ({int(row_mask.sum())}, {adata.n_vars})"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if layer_name not in adata.layers:
|
|
262
|
+
adata.layers[layer_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=arr.dtype)
|
|
263
|
+
|
|
264
|
+
adata.layers[layer_name][row_mask, :] = arr
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def resolve_torch_device(device_str: str | None) -> torch.device:
|
|
268
|
+
d = (device_str or "auto").lower()
|
|
269
|
+
if d == "auto":
|
|
270
|
+
if torch.cuda.is_available():
|
|
271
|
+
return torch.device("cuda")
|
|
272
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
273
|
+
return torch.device("mps")
|
|
274
|
+
return torch.device("cpu")
|
|
275
|
+
return torch.device(d)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# =============================================================================
|
|
279
|
+
# Model selection + fit strategy manager
|
|
280
|
+
# =============================================================================
|
|
281
|
+
@dataclass
|
|
282
|
+
class HMMTrainer:
|
|
283
|
+
cfg: Any
|
|
284
|
+
models_dir: Path
|
|
285
|
+
|
|
286
|
+
def __post_init__(self):
|
|
287
|
+
self.models_dir = Path(self.models_dir)
|
|
288
|
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
|
289
|
+
|
|
290
|
+
def choose_arch(self, *, multichannel: bool) -> str:
|
|
291
|
+
use_dist = bool(getattr(self.cfg, "hmm_distance_aware", False))
|
|
292
|
+
if multichannel:
|
|
293
|
+
return "multi"
|
|
294
|
+
return "single_distance_binned" if use_dist else "single"
|
|
295
|
+
|
|
296
|
+
def _fit_scope(self) -> str:
|
|
297
|
+
return str(getattr(self.cfg, "hmm_fit_scope", "per_sample")).lower()
|
|
298
|
+
# "per_sample" | "global" | "global_then_adapt"
|
|
299
|
+
|
|
300
|
+
def _path(self, kind: str, sample: str, ref: str, label: str) -> Path:
|
|
301
|
+
# kind: "GLOBAL" | "PER" | "ADAPT"
|
|
302
|
+
def safe(s):
|
|
303
|
+
str(s).replace("/", "_")
|
|
304
|
+
|
|
305
|
+
return self.models_dir / f"{kind}_{safe(sample)}_{safe(ref)}_{safe(label)}.pt"
|
|
306
|
+
|
|
307
|
+
def _save(self, model, path: Path):
|
|
308
|
+
override = {}
|
|
309
|
+
if getattr(model, "hmm_name", None) == "multi":
|
|
310
|
+
override["hmm_n_channels"] = int(getattr(model, "n_channels", 2))
|
|
311
|
+
if getattr(model, "hmm_name", None) == "single_distance_binned":
|
|
312
|
+
override["hmm_distance_bins"] = list(
|
|
313
|
+
getattr(model, "distance_bins", [1, 5, 10, 25, 50, 100])
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
payload = {
|
|
317
|
+
"state_dict": model.state_dict(),
|
|
318
|
+
"hmm_arch": getattr(model, "hmm_name", None) or getattr(self.cfg, "hmm_arch", None),
|
|
319
|
+
"override": override,
|
|
320
|
+
}
|
|
321
|
+
torch.save(payload, path)
|
|
322
|
+
|
|
323
|
+
def _load(self, path: Path, arch: str, device):
|
|
324
|
+
payload = torch.load(path, map_location="cpu")
|
|
325
|
+
override = payload.get("override", None)
|
|
326
|
+
m = create_hmm(self.cfg, arch=arch, override=override, device=device)
|
|
327
|
+
sd = payload["state_dict"]
|
|
328
|
+
|
|
329
|
+
target_dtype = next(m.parameters()).dtype
|
|
330
|
+
for k, v in list(sd.items()):
|
|
331
|
+
if isinstance(v, torch.Tensor) and v.dtype != target_dtype:
|
|
332
|
+
sd[k] = v.to(dtype=target_dtype)
|
|
333
|
+
|
|
334
|
+
m.load_state_dict(sd)
|
|
335
|
+
m.to(device)
|
|
336
|
+
m.eval()
|
|
337
|
+
return m
|
|
338
|
+
|
|
339
|
+
def fit_or_load(
|
|
340
|
+
self,
|
|
341
|
+
*,
|
|
342
|
+
sample: str,
|
|
343
|
+
ref: str,
|
|
344
|
+
label: str,
|
|
345
|
+
arch: str,
|
|
346
|
+
X,
|
|
347
|
+
coords: Optional[np.ndarray],
|
|
348
|
+
device,
|
|
349
|
+
):
|
|
350
|
+
force_fit = bool(getattr(self.cfg, "force_redo_hmm_fit", False))
|
|
351
|
+
scope = self._fit_scope()
|
|
352
|
+
|
|
353
|
+
max_iter = int(getattr(self.cfg, "hmm_max_iter", 50))
|
|
354
|
+
tol = float(getattr(self.cfg, "hmm_tol", 1e-4))
|
|
355
|
+
verbose = bool(getattr(self.cfg, "hmm_verbose", False))
|
|
356
|
+
|
|
357
|
+
# ---- global then adapt ----
|
|
358
|
+
if scope == "global_then_adapt":
|
|
359
|
+
p_global = self._path("GLOBAL", "ALL", ref, label)
|
|
360
|
+
if p_global.exists() and not force_fit:
|
|
361
|
+
base = self._load(p_global, arch=arch, device=device)
|
|
362
|
+
else:
|
|
363
|
+
base = create_hmm(self.cfg, arch=arch).to(device)
|
|
364
|
+
if arch == "single_distance_binned":
|
|
365
|
+
base.fit(
|
|
366
|
+
X, device=device, coords=coords, max_iter=max_iter, tol=tol, verbose=verbose
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
base.fit(X, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
370
|
+
self._save(base, p_global)
|
|
371
|
+
|
|
372
|
+
p_adapt = self._path("ADAPT", sample, ref, label)
|
|
373
|
+
if p_adapt.exists() and not force_fit:
|
|
374
|
+
return self._load(p_adapt, arch=arch, device=device)
|
|
375
|
+
|
|
376
|
+
# IMPORTANT: this assumes you added model.adapt_emissions(...)
|
|
377
|
+
adapted = copy.deepcopy(base).to(device)
|
|
378
|
+
if arch == "single_distance_binned":
|
|
379
|
+
adapted.adapt_emissions(
|
|
380
|
+
X,
|
|
381
|
+
coords,
|
|
382
|
+
device=device,
|
|
383
|
+
max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
|
|
384
|
+
verbose=verbose,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
adapted.adapt_emissions(
|
|
389
|
+
X,
|
|
390
|
+
coords,
|
|
391
|
+
device=device,
|
|
392
|
+
max_iter=int(getattr(self.cfg, "hmm_adapt_iters", 10)),
|
|
393
|
+
verbose=verbose,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
self._save(adapted, p_adapt)
|
|
397
|
+
return adapted
|
|
398
|
+
|
|
399
|
+
# ---- global only ----
|
|
400
|
+
if scope == "global":
|
|
401
|
+
p = self._path("GLOBAL", "ALL", ref, label)
|
|
402
|
+
if p.exists() and not force_fit:
|
|
403
|
+
return self._load(p, arch=arch, device=device)
|
|
404
|
+
|
|
405
|
+
# ---- per sample ----
|
|
406
|
+
else:
|
|
407
|
+
p = self._path("PER", sample, ref, label)
|
|
408
|
+
if p.exists() and not force_fit:
|
|
409
|
+
return self._load(p, arch=arch, device=device)
|
|
410
|
+
|
|
411
|
+
m = create_hmm(self.cfg, arch=arch, device=device)
|
|
412
|
+
if arch == "single_distance_binned":
|
|
413
|
+
m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
414
|
+
else:
|
|
415
|
+
m.fit(X, coords, device=device, max_iter=max_iter, tol=tol, verbose=verbose)
|
|
416
|
+
self._save(m, p)
|
|
417
|
+
return m
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _fully_qualified_merge_layers(cfg, prefix: str) -> List[Tuple[str, int]]:
|
|
421
|
+
"""
|
|
422
|
+
cfg.hmm_merge_layer_features is assumed to be a list of (core_layer_name, merge_distance),
|
|
423
|
+
where core_layer_name is like "all_accessible_features" (NOT prefixed with methbase).
|
|
424
|
+
We expand to f"{prefix}_{core_layer_name}".
|
|
425
|
+
"""
|
|
426
|
+
out = []
|
|
427
|
+
for core_layer, dist in getattr(cfg, "hmm_merge_layer_features", []) or []:
|
|
428
|
+
if not core_layer:
|
|
429
|
+
continue
|
|
430
|
+
out.append((f"{prefix}_{core_layer}", int(dist)))
|
|
431
|
+
return out
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def hmm_adata(config_path: str):
|
|
435
|
+
"""
|
|
436
|
+
CLI-facing wrapper for HMM analysis.
|
|
437
|
+
|
|
438
|
+
Command line entrypoint:
|
|
439
|
+
smftools hmm <config_path>
|
|
440
|
+
|
|
441
|
+
Responsibilities:
|
|
442
|
+
- Build cfg via load_adata()
|
|
443
|
+
- Ensure preprocess + spatial stages are run
|
|
444
|
+
- Decide which AnnData to start from (hmm > spatial > pp_dedup > pp > raw)
|
|
445
|
+
- Call hmm_adata_core(cfg, adata, paths)
|
|
11
446
|
"""
|
|
12
|
-
from ..readwrite import safe_read_h5ad
|
|
447
|
+
from ..readwrite import safe_read_h5ad
|
|
448
|
+
from .helpers import get_adata_paths
|
|
13
449
|
from .load_adata import load_adata
|
|
14
450
|
from .preprocess_adata import preprocess_adata
|
|
15
451
|
from .spatial_adata import spatial_adata
|
|
16
452
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
import scanpy as sc
|
|
21
|
-
|
|
22
|
-
import os
|
|
23
|
-
from importlib import resources
|
|
24
|
-
from pathlib import Path
|
|
453
|
+
# 1) load cfg / stage paths
|
|
454
|
+
_, _, cfg = load_adata(config_path)
|
|
455
|
+
paths = get_adata_paths(cfg)
|
|
25
456
|
|
|
26
|
-
|
|
27
|
-
|
|
457
|
+
# 2) make sure upstream stages are run (they have their own skipping logic)
|
|
458
|
+
preprocess_adata(config_path)
|
|
459
|
+
spatial_ad, spatial_path = spatial_adata(config_path)
|
|
28
460
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
#
|
|
32
|
-
|
|
33
|
-
|
|
461
|
+
# 3) choose starting AnnData
|
|
462
|
+
# Prefer:
|
|
463
|
+
# - existing HMM h5ad if not forcing redo
|
|
464
|
+
# - in-memory spatial_ad from wrapper call
|
|
465
|
+
# - saved spatial / pp_dedup / pp / raw on disk
|
|
466
|
+
if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
|
|
467
|
+
adata, _ = safe_read_h5ad(paths.hmm)
|
|
468
|
+
return adata, paths.hmm
|
|
34
469
|
|
|
35
|
-
|
|
36
|
-
make_dirs([output_directory])
|
|
37
|
-
############################################### smftools load end ###############################################
|
|
38
|
-
|
|
39
|
-
############################################### smftools preprocess start ###############################################
|
|
40
|
-
pp_adata, pp_adata_path, pp_dedup_adata, pp_dup_rem_adata_path = preprocess_adata(config_path)
|
|
41
|
-
############################################### smftools preprocess end ###############################################
|
|
42
|
-
|
|
43
|
-
############################################### smftools spatial start ###############################################
|
|
44
|
-
spatial_ad, spatial_adata_path = spatial_adata(config_path)
|
|
45
|
-
############################################### smftools spatial end ###############################################
|
|
46
|
-
|
|
47
|
-
############################################### smftools hmm start ###############################################
|
|
48
|
-
input_manager_df = pd.read_csv(cfg.summary_file)
|
|
49
|
-
initial_adata_path = Path(input_manager_df['load_adata'][0])
|
|
50
|
-
pp_adata_path = Path(input_manager_df['pp_adata'][0])
|
|
51
|
-
pp_dup_rem_adata_path = Path(input_manager_df['pp_dedup_adata'][0])
|
|
52
|
-
spatial_adata_path = Path(input_manager_df['spatial_adata'][0])
|
|
53
|
-
hmm_adata_path = Path(input_manager_df['hmm_adata'][0])
|
|
54
|
-
|
|
55
|
-
if spatial_ad:
|
|
56
|
-
# This happens on first run of the pipeline
|
|
470
|
+
if spatial_ad is not None:
|
|
57
471
|
adata = spatial_ad
|
|
472
|
+
source_path = spatial_path
|
|
473
|
+
elif paths.spatial.exists():
|
|
474
|
+
adata, _ = safe_read_h5ad(paths.spatial)
|
|
475
|
+
source_path = paths.spatial
|
|
476
|
+
elif paths.pp_dedup.exists():
|
|
477
|
+
adata, _ = safe_read_h5ad(paths.pp_dedup)
|
|
478
|
+
source_path = paths.pp_dedup
|
|
479
|
+
elif paths.pp.exists():
|
|
480
|
+
adata, _ = safe_read_h5ad(paths.pp)
|
|
481
|
+
source_path = paths.pp
|
|
482
|
+
elif paths.raw.exists():
|
|
483
|
+
adata, _ = safe_read_h5ad(paths.raw)
|
|
484
|
+
source_path = paths.raw
|
|
58
485
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
486
|
+
raise FileNotFoundError(
|
|
487
|
+
"No AnnData available for HMM: expected at least raw or preprocessed h5ad."
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# 4) delegate to core
|
|
491
|
+
adata, hmm_adata_path = hmm_adata_core(
|
|
492
|
+
cfg,
|
|
493
|
+
adata,
|
|
494
|
+
paths,
|
|
495
|
+
source_adata_path=source_path,
|
|
496
|
+
config_path=config_path,
|
|
497
|
+
)
|
|
498
|
+
return adata, hmm_adata_path
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def hmm_adata_core(
|
|
502
|
+
cfg,
|
|
503
|
+
adata,
|
|
504
|
+
paths,
|
|
505
|
+
source_adata_path: Path | None = None,
|
|
506
|
+
config_path: str | None = None,
|
|
507
|
+
) -> Tuple["anndata.AnnData", Path]:
|
|
508
|
+
"""
|
|
509
|
+
Core HMM analysis pipeline.
|
|
510
|
+
|
|
511
|
+
Assumes:
|
|
512
|
+
- cfg is an ExperimentConfig
|
|
513
|
+
- adata is the starting AnnData (typically spatial + dedup)
|
|
514
|
+
- paths is an AdataPaths object (with .raw/.pp/.pp_dedup/.spatial/.hmm)
|
|
515
|
+
|
|
516
|
+
Does NOT decide which h5ad to start from – that is the wrapper's job.
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
import numpy as np
|
|
520
|
+
|
|
521
|
+
from ..hmm import call_hmm_peaks
|
|
522
|
+
from ..metadata import record_smftools_metadata
|
|
523
|
+
from ..plotting import (
|
|
524
|
+
combined_hmm_raw_clustermap,
|
|
525
|
+
plot_hmm_layers_rolling_by_sample_ref,
|
|
526
|
+
plot_hmm_size_contours,
|
|
527
|
+
)
|
|
528
|
+
from ..readwrite import make_dirs
|
|
529
|
+
from .helpers import write_gz_h5ad
|
|
530
|
+
|
|
531
|
+
smf_modality = cfg.smf_modality
|
|
532
|
+
deaminase = smf_modality == "deaminase"
|
|
533
|
+
|
|
534
|
+
output_directory = Path(cfg.output_directory)
|
|
535
|
+
make_dirs([output_directory])
|
|
536
|
+
|
|
537
|
+
pp_dir = output_directory / "preprocessed" / "deduplicated"
|
|
538
|
+
|
|
539
|
+
# ---------------------------- HMM annotate stage ----------------------------
|
|
93
540
|
if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
import warnings
|
|
541
|
+
hmm_models_dir = pp_dir / "10_hmm_models"
|
|
542
|
+
make_dirs([pp_dir, hmm_models_dir])
|
|
97
543
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
544
|
+
# Standard bookkeeping
|
|
545
|
+
uns_key = "hmm_appended_layers"
|
|
546
|
+
if adata.uns.get(uns_key) is None:
|
|
547
|
+
adata.uns[uns_key] = []
|
|
548
|
+
global_appended = list(adata.uns.get(uns_key, []))
|
|
101
549
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
550
|
+
# Prepare trainer + feature config
|
|
551
|
+
trainer = HMMTrainer(cfg=cfg, models_dir=hmm_models_dir)
|
|
552
|
+
|
|
553
|
+
feature_sets = normalize_hmm_feature_sets(getattr(cfg, "hmm_feature_sets", None))
|
|
554
|
+
prob_thr = float(getattr(cfg, "hmm_feature_prob_threshold", 0.5))
|
|
555
|
+
decode = str(getattr(cfg, "hmm_decode", "marginal"))
|
|
556
|
+
write_post = bool(getattr(cfg, "hmm_write_posterior", True))
|
|
557
|
+
post_state = getattr(cfg, "hmm_posterior_state", "Modified")
|
|
558
|
+
merged_suffix = str(getattr(cfg, "hmm_merged_suffix", "_merged"))
|
|
559
|
+
force_apply = bool(getattr(cfg, "force_redo_hmm_apply", False))
|
|
560
|
+
bypass_apply = bool(getattr(cfg, "bypass_hmm_apply", False))
|
|
561
|
+
bypass_fit = bool(getattr(cfg, "bypass_hmm_fit", False))
|
|
106
562
|
|
|
107
563
|
samples = adata.obs[cfg.sample_name_col_for_plotting].cat.categories
|
|
108
564
|
references = adata.obs[cfg.reference_column].cat.categories
|
|
109
|
-
|
|
565
|
+
methbases = list(getattr(cfg, "hmm_methbases", [])) or []
|
|
110
566
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
adata.uns[uns_key] = []
|
|
567
|
+
if not methbases:
|
|
568
|
+
raise ValueError("cfg.hmm_methbases is empty.")
|
|
114
569
|
|
|
115
|
-
|
|
570
|
+
# Top-level skip
|
|
571
|
+
already = bool(adata.uns.get("hmm_annotated", False))
|
|
572
|
+
if already and not (bool(getattr(cfg, "force_redo_hmm_fit", False)) or force_apply):
|
|
116
573
|
pass
|
|
574
|
+
|
|
117
575
|
else:
|
|
576
|
+
logger.info("Starting HMM annotation over samples and references")
|
|
118
577
|
for sample in samples:
|
|
119
578
|
for ref in references:
|
|
120
|
-
mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
|
|
121
|
-
|
|
122
|
-
|
|
579
|
+
mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
|
|
580
|
+
adata.obs[cfg.reference_column] == ref
|
|
581
|
+
)
|
|
582
|
+
if int(np.sum(mask)) == 0:
|
|
123
583
|
continue
|
|
124
584
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
585
|
+
subset = adata[mask].copy()
|
|
586
|
+
subset.uns[uns_key] = [] # isolate appended tracking per subset
|
|
587
|
+
|
|
588
|
+
# ---- Decide which tasks to run ----
|
|
589
|
+
methbases = list(getattr(cfg, "hmm_methbases", [])) or []
|
|
590
|
+
run_multi = bool(getattr(cfg, "hmm_run_multichannel", True))
|
|
591
|
+
run_cpg = bool(getattr(cfg, "cpg", False))
|
|
592
|
+
device = resolve_torch_device(cfg.device)
|
|
593
|
+
|
|
594
|
+
logger.info("HMM processing sample=%s ref=%s", sample, ref)
|
|
595
|
+
|
|
596
|
+
# ---- split feature sets ----
|
|
597
|
+
feature_sets_all = normalize_hmm_feature_sets(
|
|
598
|
+
getattr(cfg, "hmm_feature_sets", None)
|
|
599
|
+
)
|
|
600
|
+
feature_sets_access = {
|
|
601
|
+
k: v
|
|
602
|
+
for k, v in feature_sets_all.items()
|
|
603
|
+
if k in ("footprint", "accessible")
|
|
604
|
+
}
|
|
605
|
+
feature_sets_cpg = (
|
|
606
|
+
{"cpg": feature_sets_all["cpg"]} if "cpg" in feature_sets_all else {}
|
|
607
|
+
)
|
|
128
608
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
609
|
+
# =========================
|
|
610
|
+
# 1) Single-channel accessibility (per methbase)
|
|
611
|
+
# =========================
|
|
612
|
+
for mb in methbases:
|
|
613
|
+
logger.info("HMM single-channel for methbase=%s", mb)
|
|
614
|
+
|
|
615
|
+
try:
|
|
616
|
+
X, coords = build_single_channel(
|
|
617
|
+
subset,
|
|
618
|
+
ref=str(ref),
|
|
619
|
+
methbase=str(mb),
|
|
620
|
+
smf_modality=smf_modality,
|
|
621
|
+
cfg=cfg,
|
|
622
|
+
)
|
|
623
|
+
except Exception:
|
|
624
|
+
logger.warning(
|
|
625
|
+
"Skipping HMM single-channel for methbase=%s due to data error", mb
|
|
626
|
+
)
|
|
133
627
|
continue
|
|
134
628
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
629
|
+
arch = trainer.choose_arch(multichannel=False)
|
|
630
|
+
|
|
631
|
+
logger.info("HMM fitting/loading for methbase=%s", mb)
|
|
632
|
+
hmm = trainer.fit_or_load(
|
|
633
|
+
sample=str(sample),
|
|
634
|
+
ref=str(ref),
|
|
635
|
+
label=str(mb),
|
|
636
|
+
arch=arch,
|
|
637
|
+
X=X,
|
|
638
|
+
coords=coords,
|
|
639
|
+
device=device,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
if not bypass_apply:
|
|
643
|
+
logger.info("HMM applying for methbase=%s", mb)
|
|
644
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
|
|
645
|
+
hmm.annotate_adata(
|
|
646
|
+
subset,
|
|
647
|
+
prefix=str(mb),
|
|
648
|
+
X=X,
|
|
649
|
+
coords=coords,
|
|
650
|
+
var_mask=pm,
|
|
651
|
+
span_fill=True,
|
|
652
|
+
config=cfg,
|
|
653
|
+
decode=decode,
|
|
654
|
+
write_posterior=write_post,
|
|
655
|
+
posterior_state=post_state,
|
|
656
|
+
feature_sets=feature_sets_access, # <--- ONLY accessibility feature sets
|
|
657
|
+
prob_threshold=prob_thr,
|
|
658
|
+
uns_key=uns_key,
|
|
659
|
+
uns_flag=f"hmm_annotated_{mb}",
|
|
660
|
+
force_redo=force_apply,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# merges for this mb
|
|
664
|
+
for core_layer, dist in (
|
|
665
|
+
getattr(cfg, "hmm_merge_layer_features", []) or []
|
|
666
|
+
):
|
|
667
|
+
base_layer = f"{mb}_{core_layer}"
|
|
668
|
+
logger.info("Merging intervals for layer=%s", base_layer)
|
|
669
|
+
if base_layer in subset.layers:
|
|
670
|
+
merged_base = hmm.merge_intervals_to_new_layer(
|
|
671
|
+
subset,
|
|
672
|
+
base_layer,
|
|
673
|
+
distance_threshold=int(dist),
|
|
674
|
+
suffix=merged_suffix,
|
|
675
|
+
overwrite=True,
|
|
676
|
+
)
|
|
677
|
+
# write merged size classes based on whichever group core_layer corresponds to
|
|
678
|
+
for group, fs in feature_sets_access.items():
|
|
679
|
+
fmap = fs.get("features", {}) or {}
|
|
680
|
+
if fmap:
|
|
681
|
+
hmm.write_size_class_layers_from_binary(
|
|
682
|
+
subset,
|
|
683
|
+
merged_base,
|
|
684
|
+
out_prefix=str(mb),
|
|
685
|
+
feature_ranges=fmap,
|
|
686
|
+
suffix=merged_suffix,
|
|
687
|
+
overwrite=True,
|
|
165
688
|
)
|
|
166
|
-
|
|
167
|
-
if adata.uns.get('hmm_annotated', False) and not cfg.force_redo_hmm_apply:
|
|
168
|
-
pass
|
|
169
|
-
else:
|
|
170
|
-
to_merge = cfg.hmm_merge_layer_features
|
|
171
|
-
for layer_to_merge, merge_distance in to_merge:
|
|
172
|
-
if layer_to_merge:
|
|
173
|
-
hmm.merge_intervals_in_layer(subset,
|
|
174
|
-
layer=layer_to_merge,
|
|
175
|
-
distance_threshold=merge_distance,
|
|
176
|
-
overwrite=True
|
|
177
|
-
)
|
|
178
|
-
else:
|
|
179
|
-
pass
|
|
180
|
-
|
|
181
|
-
# collect appended layers from subset.uns
|
|
182
|
-
appended = list(subset.uns.get(uns_key, []))
|
|
183
|
-
print(appended)
|
|
184
|
-
if len(appended) == 0:
|
|
185
|
-
# nothing appended for this subset; continue
|
|
186
|
-
continue
|
|
187
|
-
|
|
188
|
-
# copy each appended layer into adata
|
|
189
|
-
subset_mask_bool = mask.values if hasattr(mask, "values") else np.asarray(mask)
|
|
190
|
-
for layer_name in appended:
|
|
191
|
-
if layer_name not in subset.layers:
|
|
192
|
-
# defensive: skip
|
|
193
|
-
warnings.warn(f"Expected layer {layer_name} in subset but not found; skipping copy.")
|
|
194
|
-
continue
|
|
195
|
-
sub_layer = subset.layers[layer_name]
|
|
196
|
-
# ensure final layer exists and assign rows
|
|
197
|
-
try:
|
|
198
|
-
hmm._ensure_final_layer_and_assign(adata, layer_name, subset_mask_bool, sub_layer)
|
|
199
|
-
except Exception as e:
|
|
200
|
-
warnings.warn(f"Failed to copy layer {layer_name} into adata: {e}", stacklevel=2)
|
|
201
|
-
# fallback: if dense and small, try to coerce
|
|
202
|
-
if issparse(sub_layer):
|
|
203
|
-
arr = sub_layer.toarray()
|
|
204
|
-
else:
|
|
205
|
-
arr = np.asarray(sub_layer)
|
|
206
|
-
adata.layers[layer_name] = adata.layers.get(layer_name, np.zeros((adata.shape[0], arr.shape[1]), dtype=arr.dtype))
|
|
207
|
-
final_idx = np.nonzero(subset_mask_bool)[0]
|
|
208
|
-
adata.layers[layer_name][final_idx, :] = arr
|
|
209
|
-
|
|
210
|
-
# merge appended layer names into adata.uns
|
|
211
|
-
existing = list(adata.uns.get(uns_key, []))
|
|
212
|
-
for ln in appended:
|
|
213
|
-
if ln not in existing:
|
|
214
|
-
existing.append(ln)
|
|
215
|
-
adata.uns[uns_key] = existing
|
|
216
689
|
|
|
217
|
-
|
|
218
|
-
|
|
690
|
+
# =========================
|
|
691
|
+
# 2) Multi-channel accessibility (Combined)
|
|
692
|
+
# =========================
|
|
693
|
+
if run_multi and len(methbases) >= 2:
|
|
694
|
+
logger.info("HMM multi-channel for methbases=%s", ",".join(methbases))
|
|
695
|
+
try:
|
|
696
|
+
X3, coords_u, used_mbs = build_multi_channel_union(
|
|
697
|
+
subset,
|
|
698
|
+
ref=str(ref),
|
|
699
|
+
methbases=methbases,
|
|
700
|
+
smf_modality=smf_modality,
|
|
701
|
+
cfg=cfg,
|
|
702
|
+
)
|
|
703
|
+
except Exception:
|
|
704
|
+
X3, coords_u, used_mbs = None, None, []
|
|
705
|
+
logger.warning(
|
|
706
|
+
"Skipping HMM multi-channel due to data error or insufficient methbases"
|
|
707
|
+
)
|
|
219
708
|
|
|
220
|
-
|
|
709
|
+
if X3 is not None and len(used_mbs) >= 2:
|
|
710
|
+
union_mask = None
|
|
711
|
+
for mb in used_mbs:
|
|
712
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), str(mb))
|
|
713
|
+
union_mask = pm if union_mask is None else (union_mask | pm)
|
|
714
|
+
|
|
715
|
+
arch = trainer.choose_arch(multichannel=True)
|
|
716
|
+
|
|
717
|
+
logger.info("HMM fitting/loading for multi-channel")
|
|
718
|
+
hmmc = trainer.fit_or_load(
|
|
719
|
+
sample=str(sample),
|
|
720
|
+
ref=str(ref),
|
|
721
|
+
label="Combined",
|
|
722
|
+
arch=arch,
|
|
723
|
+
X=X3,
|
|
724
|
+
coords=coords_u,
|
|
725
|
+
device=device,
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
if not bypass_apply:
|
|
729
|
+
logger.info("HMM applying for multi-channel")
|
|
730
|
+
hmmc.annotate_adata(
|
|
731
|
+
subset,
|
|
732
|
+
prefix="Combined",
|
|
733
|
+
X=X3,
|
|
734
|
+
coords=coords_u,
|
|
735
|
+
var_mask=union_mask,
|
|
736
|
+
span_fill=True,
|
|
737
|
+
config=cfg,
|
|
738
|
+
decode=decode,
|
|
739
|
+
write_posterior=write_post,
|
|
740
|
+
posterior_state=post_state,
|
|
741
|
+
feature_sets=feature_sets_access, # <--- accessibility only
|
|
742
|
+
prob_threshold=prob_thr,
|
|
743
|
+
uns_key=uns_key,
|
|
744
|
+
uns_flag="hmm_annotated_combined",
|
|
745
|
+
force_redo=force_apply,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
for core_layer, dist in (
|
|
749
|
+
getattr(cfg, "hmm_merge_layer_features", []) or []
|
|
750
|
+
):
|
|
751
|
+
base_layer = f"Combined_{core_layer}"
|
|
752
|
+
if base_layer in subset.layers:
|
|
753
|
+
merged_base = hmmc.merge_intervals_to_new_layer(
|
|
754
|
+
subset,
|
|
755
|
+
base_layer,
|
|
756
|
+
distance_threshold=int(dist),
|
|
757
|
+
suffix=merged_suffix,
|
|
758
|
+
overwrite=True,
|
|
759
|
+
)
|
|
760
|
+
for group, fs in feature_sets_access.items():
|
|
761
|
+
fmap = fs.get("features", {}) or {}
|
|
762
|
+
if fmap:
|
|
763
|
+
hmmc.write_size_class_layers_from_binary(
|
|
764
|
+
subset,
|
|
765
|
+
merged_base,
|
|
766
|
+
out_prefix="Combined",
|
|
767
|
+
feature_ranges=fmap,
|
|
768
|
+
suffix=merged_suffix,
|
|
769
|
+
overwrite=True,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# =========================
|
|
773
|
+
# 3) CpG-only single-channel task
|
|
774
|
+
# =========================
|
|
775
|
+
if run_cpg:
|
|
776
|
+
logger.info("HMM single-channel for CpG")
|
|
777
|
+
try:
|
|
778
|
+
Xcpg, coordscpg = build_single_channel(
|
|
779
|
+
subset,
|
|
780
|
+
ref=str(ref),
|
|
781
|
+
methbase="CpG",
|
|
782
|
+
smf_modality=smf_modality,
|
|
783
|
+
cfg=cfg,
|
|
784
|
+
)
|
|
785
|
+
except Exception:
|
|
786
|
+
Xcpg, coordscpg = None, None
|
|
787
|
+
logger.warning("Skipping HMM single-channel for CpG due to data error")
|
|
788
|
+
|
|
789
|
+
if Xcpg is not None and Xcpg.size and feature_sets_cpg:
|
|
790
|
+
arch = trainer.choose_arch(multichannel=False)
|
|
791
|
+
|
|
792
|
+
logger.info("HMM fitting/loading for CpG")
|
|
793
|
+
hmmg = trainer.fit_or_load(
|
|
794
|
+
sample=str(sample),
|
|
795
|
+
ref=str(ref),
|
|
796
|
+
label="CpG",
|
|
797
|
+
arch=arch,
|
|
798
|
+
X=Xcpg,
|
|
799
|
+
coords=coordscpg,
|
|
800
|
+
device=device,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
if not bypass_apply:
|
|
804
|
+
logger.info("HMM applying for CpG")
|
|
805
|
+
pm = _resolve_pos_mask_for_methbase(subset, str(ref), "CpG")
|
|
806
|
+
hmmg.annotate_adata(
|
|
807
|
+
subset,
|
|
808
|
+
prefix="CpG",
|
|
809
|
+
X=Xcpg,
|
|
810
|
+
coords=coordscpg,
|
|
811
|
+
var_mask=pm,
|
|
812
|
+
span_fill=True,
|
|
813
|
+
config=cfg,
|
|
814
|
+
decode=decode,
|
|
815
|
+
write_posterior=write_post,
|
|
816
|
+
posterior_state=post_state,
|
|
817
|
+
feature_sets=feature_sets_cpg, # <--- ONLY cpg group (cpg_patch)
|
|
818
|
+
prob_threshold=prob_thr,
|
|
819
|
+
uns_key=uns_key,
|
|
820
|
+
uns_flag="hmm_annotated_CpG",
|
|
821
|
+
force_redo=force_apply,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
# ------------------------------------------------------------
|
|
825
|
+
# Copy newly created subset layers back into the full adata
|
|
826
|
+
# ------------------------------------------------------------
|
|
827
|
+
appended = (
|
|
828
|
+
list(subset.uns.get(uns_key, []))
|
|
829
|
+
if subset.uns.get(uns_key) is not None
|
|
830
|
+
else []
|
|
831
|
+
)
|
|
832
|
+
if appended:
|
|
833
|
+
row_mask = np.asarray(
|
|
834
|
+
mask.values if hasattr(mask, "values") else mask, dtype=bool
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
for ln in appended:
|
|
838
|
+
if ln not in subset.layers:
|
|
839
|
+
continue
|
|
840
|
+
_ensure_layer_and_assign_rows(adata, ln, row_mask, subset.layers[ln])
|
|
841
|
+
if ln not in global_appended:
|
|
842
|
+
global_appended.append(ln)
|
|
843
|
+
|
|
844
|
+
adata.uns[uns_key] = global_appended
|
|
845
|
+
|
|
846
|
+
adata.uns["hmm_annotated"] = True
|
|
847
|
+
|
|
848
|
+
hmm_layers = list(adata.uns.get("hmm_appended_layers", []) or [])
|
|
849
|
+
# keep only real feature layers; drop lengths/states/posterior
|
|
850
|
+
hmm_layers = [
|
|
851
|
+
layer
|
|
852
|
+
for layer in hmm_layers
|
|
853
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
854
|
+
]
|
|
855
|
+
logger.info(f"HMM appended layers: {hmm_layers}")
|
|
856
|
+
|
|
857
|
+
# ---------------------------- HMM peak calling stage ----------------------------
|
|
221
858
|
hmm_dir = pp_dir / "11_hmm_peak_calling"
|
|
222
859
|
if hmm_dir.is_dir():
|
|
223
860
|
pass
|
|
@@ -225,29 +862,32 @@ def hmm_adata(config_path):
|
|
|
225
862
|
make_dirs([pp_dir, hmm_dir])
|
|
226
863
|
|
|
227
864
|
call_hmm_peaks(
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
## Save HMM annotated adata
|
|
237
|
-
if not hmm_adata_path.exists():
|
|
238
|
-
print('Saving hmm analyzed adata post preprocessing and duplicate removal')
|
|
239
|
-
if ".gz" == hmm_adata_path.suffix:
|
|
240
|
-
safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
|
|
241
|
-
else:
|
|
242
|
-
hmm_adata_path = hmm_adata_path.with_name(hmm_adata_path.name + '.gz')
|
|
243
|
-
safe_write_h5ad(adata, hmm_adata_path, compression='gzip', backup=True)
|
|
865
|
+
adata,
|
|
866
|
+
feature_configs=cfg.hmm_peak_feature_configs,
|
|
867
|
+
ref_column=cfg.reference_column,
|
|
868
|
+
site_types=cfg.mod_target_bases,
|
|
869
|
+
save_plot=True,
|
|
870
|
+
output_dir=hmm_dir,
|
|
871
|
+
index_col_suffix=cfg.reindexed_var_suffix,
|
|
872
|
+
)
|
|
244
873
|
|
|
245
|
-
|
|
874
|
+
## Save HMM annotated adata
|
|
875
|
+
if not paths.hmm.exists():
|
|
876
|
+
logger.info("Saving hmm analyzed AnnData (post preprocessing and duplicate removal).")
|
|
877
|
+
record_smftools_metadata(
|
|
878
|
+
adata,
|
|
879
|
+
step_name="hmm",
|
|
880
|
+
cfg=cfg,
|
|
881
|
+
config_path=config_path,
|
|
882
|
+
input_paths=[source_adata_path] if source_adata_path else None,
|
|
883
|
+
output_path=paths.hmm,
|
|
884
|
+
)
|
|
885
|
+
write_gz_h5ad(adata, paths.hmm)
|
|
246
886
|
|
|
247
887
|
########################################################################################################################
|
|
248
888
|
|
|
249
|
-
############################################### HMM based feature plotting ###############################################
|
|
250
|
-
|
|
889
|
+
############################################### HMM based feature plotting ###############################################
|
|
890
|
+
|
|
251
891
|
hmm_dir = pp_dir / "12_hmm_clustermaps"
|
|
252
892
|
make_dirs([pp_dir, hmm_dir])
|
|
253
893
|
|
|
@@ -256,6 +896,9 @@ def hmm_adata(config_path):
|
|
|
256
896
|
for base in cfg.hmm_methbases:
|
|
257
897
|
layers.extend([f"{base}_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
|
|
258
898
|
|
|
899
|
+
if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
|
|
900
|
+
layers.extend([f"Combined_{layer}" for layer in cfg.hmm_clustermap_feature_layers])
|
|
901
|
+
|
|
259
902
|
if cfg.cpg:
|
|
260
903
|
layers.extend(["CpG_cpg_patch"])
|
|
261
904
|
|
|
@@ -273,40 +916,48 @@ def hmm_adata(config_path):
|
|
|
273
916
|
make_dirs([hmm_cluster_save_dir])
|
|
274
917
|
|
|
275
918
|
combined_hmm_raw_clustermap(
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
919
|
+
adata,
|
|
920
|
+
sample_col=cfg.sample_name_col_for_plotting,
|
|
921
|
+
reference_col=cfg.reference_column,
|
|
922
|
+
hmm_feature_layer=layer,
|
|
923
|
+
layer_gpc=cfg.layer_for_clustermap_plotting,
|
|
924
|
+
layer_cpg=cfg.layer_for_clustermap_plotting,
|
|
925
|
+
layer_c=cfg.layer_for_clustermap_plotting,
|
|
926
|
+
layer_a=cfg.layer_for_clustermap_plotting,
|
|
927
|
+
cmap_hmm=cfg.clustermap_cmap_hmm,
|
|
928
|
+
cmap_gpc=cfg.clustermap_cmap_gpc,
|
|
929
|
+
cmap_cpg=cfg.clustermap_cmap_cpg,
|
|
930
|
+
cmap_c=cfg.clustermap_cmap_c,
|
|
931
|
+
cmap_a=cfg.clustermap_cmap_a,
|
|
932
|
+
min_quality=cfg.read_quality_filter_thresholds[0],
|
|
933
|
+
min_length=cfg.read_len_filter_thresholds[0],
|
|
934
|
+
min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
|
|
935
|
+
0
|
|
936
|
+
],
|
|
937
|
+
min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
|
|
938
|
+
demux_types=("double", "already"),
|
|
939
|
+
save_path=hmm_cluster_save_dir,
|
|
940
|
+
normalize_hmm=False,
|
|
941
|
+
sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
942
|
+
bins=None,
|
|
943
|
+
deaminase=deaminase,
|
|
944
|
+
min_signal=0,
|
|
945
|
+
index_col_suffix=cfg.reindexed_var_suffix,
|
|
300
946
|
)
|
|
301
947
|
|
|
302
948
|
hmm_dir = pp_dir / "13_hmm_bulk_traces"
|
|
303
949
|
|
|
304
950
|
if hmm_dir.is_dir():
|
|
305
|
-
|
|
951
|
+
logger.debug(f"{hmm_dir} already exists.")
|
|
306
952
|
else:
|
|
307
953
|
make_dirs([pp_dir, hmm_dir])
|
|
308
954
|
from ..plotting import plot_hmm_layers_rolling_by_sample_ref
|
|
309
|
-
|
|
955
|
+
|
|
956
|
+
bulk_hmm_layers = [
|
|
957
|
+
layer
|
|
958
|
+
for layer in hmm_layers
|
|
959
|
+
if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
|
|
960
|
+
]
|
|
310
961
|
saved = plot_hmm_layers_rolling_by_sample_ref(
|
|
311
962
|
adata,
|
|
312
963
|
layers=bulk_hmm_layers,
|
|
@@ -314,26 +965,38 @@ def hmm_adata(config_path):
|
|
|
314
965
|
ref_col=cfg.reference_column,
|
|
315
966
|
window=101,
|
|
316
967
|
rows_per_page=4,
|
|
317
|
-
figsize_per_cell=(4,2.5),
|
|
968
|
+
figsize_per_cell=(4, 2.5),
|
|
318
969
|
output_dir=hmm_dir,
|
|
319
970
|
save=True,
|
|
320
|
-
show_raw=False
|
|
971
|
+
show_raw=False,
|
|
321
972
|
)
|
|
322
973
|
|
|
323
974
|
hmm_dir = pp_dir / "14_hmm_fragment_distributions"
|
|
324
975
|
|
|
325
976
|
if hmm_dir.is_dir():
|
|
326
|
-
|
|
977
|
+
logger.debug(f"{hmm_dir} already exists.")
|
|
327
978
|
else:
|
|
328
979
|
make_dirs([pp_dir, hmm_dir])
|
|
329
980
|
from ..plotting import plot_hmm_size_contours
|
|
330
981
|
|
|
331
|
-
if smf_modality ==
|
|
332
|
-
fragments = [
|
|
333
|
-
|
|
334
|
-
|
|
982
|
+
if smf_modality == "deaminase":
|
|
983
|
+
fragments = [
|
|
984
|
+
("C_all_accessible_features_lengths", 400),
|
|
985
|
+
("C_all_footprint_features_lengths", 250),
|
|
986
|
+
("C_all_accessible_features_merged_lengths", 800),
|
|
987
|
+
]
|
|
988
|
+
elif smf_modality == "conversion":
|
|
989
|
+
fragments = [
|
|
990
|
+
("GpC_all_accessible_features_lengths", 400),
|
|
991
|
+
("GpC_all_footprint_features_lengths", 250),
|
|
992
|
+
("GpC_all_accessible_features_merged_lengths", 800),
|
|
993
|
+
]
|
|
335
994
|
elif smf_modality == "direct":
|
|
336
|
-
fragments = [
|
|
995
|
+
fragments = [
|
|
996
|
+
("A_all_accessible_features_lengths", 400),
|
|
997
|
+
("A_all_footprint_features_lengths", 200),
|
|
998
|
+
("A_all_accessible_features_merged_lengths", 800),
|
|
999
|
+
]
|
|
337
1000
|
|
|
338
1001
|
for layer, max in fragments:
|
|
339
1002
|
save_path = hmm_dir / layer
|
|
@@ -353,9 +1016,9 @@ def hmm_adata(config_path):
|
|
|
353
1016
|
dpi=200,
|
|
354
1017
|
smoothing_sigma=(10, 10),
|
|
355
1018
|
normalize_after_smoothing=True,
|
|
356
|
-
cmap=
|
|
357
|
-
log_scale_z=True
|
|
1019
|
+
cmap="Greens",
|
|
1020
|
+
log_scale_z=True,
|
|
358
1021
|
)
|
|
359
1022
|
########################################################################################################################
|
|
360
1023
|
|
|
361
|
-
return (adata,
|
|
1024
|
+
return (adata, paths.hmm)
|