smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py
CHANGED
|
@@ -1,1587 +1,2292 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
3
|
import ast
|
|
4
4
|
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn as nn
|
|
9
|
+
from scipy.sparse import issparse
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
from smftools.logging_utils import get_logger
|
|
12
|
+
from smftools.optional_imports import require
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
Methods:
|
|
19
|
-
- fit(data, ...) -> trains params in-place
|
|
20
|
-
- predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
|
|
21
|
-
- viterbi(seq, ...) -> (path_list, score)
|
|
22
|
-
- batch_viterbi(data, ...) -> list of (path_list, score)
|
|
23
|
-
- score(seq_or_list, ...) -> float or list of floats
|
|
24
|
-
|
|
25
|
-
Notes:
|
|
26
|
-
- data: list of sequences (each sequence is iterable of {0,1,np.nan}).
|
|
27
|
-
- impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
|
|
28
|
-
"""
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import torch as torch_types
|
|
16
|
+
import torch.nn as nn_types
|
|
29
17
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
n_states: int = 2,
|
|
33
|
-
init_start: Optional[List[float]] = None,
|
|
34
|
-
init_trans: Optional[List[List[float]]] = None,
|
|
35
|
-
init_emission: Optional[List[float]] = None,
|
|
36
|
-
dtype: torch.dtype = torch.float64,
|
|
37
|
-
eps: float = 1e-8,
|
|
38
|
-
smf_modality: Optional[str] = None,
|
|
39
|
-
):
|
|
40
|
-
super().__init__()
|
|
41
|
-
if n_states < 2:
|
|
42
|
-
raise ValueError("n_states must be >= 2")
|
|
43
|
-
self.n_states = n_states
|
|
44
|
-
self.eps = float(eps)
|
|
45
|
-
self.dtype = dtype
|
|
46
|
-
self.smf_modality = smf_modality
|
|
18
|
+
torch = require("torch", extra="torch", purpose="HMM modeling")
|
|
19
|
+
nn = torch.nn
|
|
47
20
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
start = np.asarray(init_start, dtype=float)
|
|
53
|
-
if init_trans is None:
|
|
54
|
-
trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
|
|
55
|
-
else:
|
|
56
|
-
trans = np.asarray(init_trans, dtype=float)
|
|
57
|
-
# --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
|
|
58
|
-
if init_emission is None:
|
|
59
|
-
emission = np.full((n_states,), 0.5, dtype=float)
|
|
60
|
-
else:
|
|
61
|
-
em_arr = np.asarray(init_emission, dtype=float)
|
|
62
|
-
# case: (K,2) -> pick P(obs==1) from second column
|
|
63
|
-
if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
|
|
64
|
-
emission = em_arr[:, 1].astype(float)
|
|
65
|
-
# case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
|
|
66
|
-
elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
|
|
67
|
-
emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
|
|
68
|
-
else:
|
|
69
|
-
emission = em_arr.reshape(-1)[:n_states].astype(float)
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
# =============================================================================
|
|
23
|
+
# Registry / Factory
|
|
24
|
+
# =============================================================================
|
|
70
25
|
|
|
71
|
-
|
|
72
|
-
self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
|
|
73
|
-
self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
|
|
74
|
-
self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
|
|
26
|
+
_HMM_REGISTRY: Dict[str, type] = {}
|
|
75
27
|
|
|
76
|
-
self._normalize_params()
|
|
77
28
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
# coerce shapes
|
|
81
|
-
K = self.n_states
|
|
82
|
-
self.start.data = self.start.data.squeeze()
|
|
83
|
-
if self.start.data.numel() != K:
|
|
84
|
-
self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
|
|
29
|
+
def register_hmm(name: str):
|
|
30
|
+
"""Decorator to register an HMM backend under a string key."""
|
|
85
31
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
|
|
32
|
+
def deco(cls):
|
|
33
|
+
"""Register the provided class in the HMM registry."""
|
|
34
|
+
_HMM_REGISTRY[name] = cls
|
|
35
|
+
cls.hmm_name = name
|
|
36
|
+
return cls
|
|
92
37
|
|
|
93
|
-
|
|
94
|
-
if self.emission.data.numel() != K:
|
|
95
|
-
self.emission.data = torch.full((K,), 0.5, dtype=self.dtype)
|
|
38
|
+
return deco
|
|
96
39
|
|
|
97
|
-
# now perform smoothing/normalization
|
|
98
|
-
self.start.data = (self.start.data + self.eps)
|
|
99
|
-
self.start.data = self.start.data / self.start.data.sum()
|
|
100
40
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
41
|
+
def create_hmm(cfg: Union[dict, Any, None], arch: Optional[str] = None, **kwargs):
|
|
42
|
+
"""
|
|
43
|
+
Factory: creates an HMM from cfg + arch (override).
|
|
44
|
+
"""
|
|
45
|
+
key = (
|
|
46
|
+
arch
|
|
47
|
+
or getattr(cfg, "hmm_arch", None)
|
|
48
|
+
or (cfg.get("hmm_arch") if isinstance(cfg, dict) else None)
|
|
49
|
+
or "single"
|
|
50
|
+
)
|
|
51
|
+
if key not in _HMM_REGISTRY:
|
|
52
|
+
raise KeyError(f"Unknown hmm_arch={key!r}. Known: {sorted(_HMM_REGISTRY.keys())}")
|
|
53
|
+
return _HMM_REGISTRY[key].from_config(cfg, **kwargs)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# =============================================================================
|
|
57
|
+
# Small utilities
|
|
58
|
+
# =============================================================================
|
|
59
|
+
def _coerce_dtype_for_device(
|
|
60
|
+
dtype: torch.dtype, device: Optional[Union[str, torch.device]]
|
|
61
|
+
) -> torch.dtype:
|
|
62
|
+
"""MPS does not support float64. When targeting MPS, coerce to float32."""
|
|
63
|
+
dev = torch.device(device) if isinstance(device, str) else device
|
|
64
|
+
if dev is not None and getattr(dev, "type", None) == "mps" and dtype == torch.float64:
|
|
65
|
+
return torch.float32
|
|
66
|
+
return dtype
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _try_json_or_literal(x: Any) -> Any:
|
|
70
|
+
"""Parse a string value as JSON or a Python literal when possible.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
x: Value to parse.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The parsed value if possible, otherwise the original value.
|
|
77
|
+
"""
|
|
78
|
+
if x is None:
|
|
79
|
+
return None
|
|
80
|
+
if not isinstance(x, str):
|
|
81
|
+
return x
|
|
82
|
+
s = x.strip()
|
|
83
|
+
if not s:
|
|
84
|
+
return None
|
|
85
|
+
try:
|
|
86
|
+
return json.loads(s)
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
try:
|
|
90
|
+
return ast.literal_eval(s)
|
|
91
|
+
except Exception:
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _coerce_bool(x: Any) -> bool:
|
|
96
|
+
"""Coerce a value into a boolean using common truthy strings.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
x: Value to coerce.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Boolean interpretation of the input.
|
|
103
|
+
"""
|
|
104
|
+
if x is None:
|
|
105
|
+
return False
|
|
106
|
+
if isinstance(x, bool):
|
|
107
|
+
return x
|
|
108
|
+
if isinstance(x, (int, float)):
|
|
109
|
+
return bool(x)
|
|
110
|
+
s = str(x).strip().lower()
|
|
111
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
112
|
+
|
|
105
113
|
|
|
106
|
-
|
|
114
|
+
def _resolve_dtype(dtype_entry: Any) -> torch.dtype:
|
|
115
|
+
"""Resolve a torch dtype from a config entry.
|
|
107
116
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
s = str(dtype_entry).lower()
|
|
116
|
-
if "32" in s:
|
|
117
|
-
return torch.float32
|
|
118
|
-
if "16" in s:
|
|
119
|
-
return torch.float16
|
|
117
|
+
Args:
|
|
118
|
+
dtype_entry: Config value (string or torch.dtype).
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Resolved torch dtype.
|
|
122
|
+
"""
|
|
123
|
+
if dtype_entry is None:
|
|
120
124
|
return torch.float64
|
|
125
|
+
if isinstance(dtype_entry, torch.dtype):
|
|
126
|
+
return dtype_entry
|
|
127
|
+
s = str(dtype_entry).lower()
|
|
128
|
+
if "16" in s:
|
|
129
|
+
return torch.float16
|
|
130
|
+
if "32" in s:
|
|
131
|
+
return torch.float32
|
|
132
|
+
return torch.float64
|
|
121
133
|
|
|
122
|
-
@classmethod
|
|
123
|
-
def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
|
|
124
|
-
override: Optional[dict] = None,
|
|
125
|
-
device: Optional[Union[str, torch.device]] = None) -> "HMM":
|
|
126
|
-
"""
|
|
127
|
-
Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
|
|
128
134
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
135
|
+
def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
|
|
136
|
+
"""
|
|
137
|
+
Try to cast var_names to int coordinates. If not possible,
|
|
138
|
+
fall back to 0..L-1 index coordinates.
|
|
139
|
+
"""
|
|
140
|
+
try:
|
|
141
|
+
coords = np.asarray(var_names, dtype=int)
|
|
142
|
+
return coords, True
|
|
143
|
+
except Exception:
|
|
144
|
+
return np.arange(len(var_names), dtype=int), False
|
|
133
145
|
|
|
134
|
-
override: optional dict to override resolved keys (handy for tests).
|
|
135
|
-
device: optional device string or torch.device to move model to.
|
|
136
|
-
"""
|
|
137
|
-
# Accept ExperimentConfig dataclass
|
|
138
|
-
if cfg is None:
|
|
139
|
-
merged = {}
|
|
140
|
-
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
141
|
-
merged = dict(cfg.to_dict())
|
|
142
|
-
elif isinstance(cfg, dict):
|
|
143
|
-
merged = dict(cfg)
|
|
144
|
-
else:
|
|
145
|
-
# try attr access as fallback
|
|
146
|
-
try:
|
|
147
|
-
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
148
|
-
except Exception:
|
|
149
|
-
merged = {}
|
|
150
146
|
|
|
151
|
-
|
|
152
|
-
|
|
147
|
+
def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
148
|
+
"""Compute log-sum-exp in a numerically stable way.
|
|
153
149
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
|
|
158
|
-
init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
|
|
159
|
-
eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
|
|
160
|
-
dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
|
|
150
|
+
Args:
|
|
151
|
+
x: Input tensor.
|
|
152
|
+
dim: Dimension to reduce.
|
|
161
153
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
return np.asarray(x, dtype=float)
|
|
154
|
+
Returns:
|
|
155
|
+
Reduced tensor.
|
|
156
|
+
"""
|
|
157
|
+
return torch.logsumexp(x, dim=dim)
|
|
167
158
|
|
|
168
|
-
init_start = _coerce_np(init_start)
|
|
169
|
-
init_trans = _coerce_np(init_trans)
|
|
170
|
-
init_emission = _coerce_np(init_emission)
|
|
171
159
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
)
|
|
160
|
+
def _ensure_layer_full_shape(adata, name: str, dtype, fill_value=0):
|
|
161
|
+
"""
|
|
162
|
+
Ensure adata.layers[name] exists with shape (n_obs, n_vars).
|
|
163
|
+
"""
|
|
164
|
+
if name not in adata.layers:
|
|
165
|
+
arr = np.full((adata.n_obs, adata.n_vars), fill_value=fill_value, dtype=dtype)
|
|
166
|
+
adata.layers[name] = arr
|
|
167
|
+
else:
|
|
168
|
+
arr = _to_dense_np(adata.layers[name])
|
|
169
|
+
if arr.shape != (adata.n_obs, adata.n_vars):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Layer '{name}' exists but has shape {arr.shape}; expected {(adata.n_obs, adata.n_vars)}"
|
|
172
|
+
)
|
|
173
|
+
return adata.layers[name]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _assign_back_obs(final_adata, sub_adata, cols: List[str]):
|
|
177
|
+
"""
|
|
178
|
+
Assign obs columns from sub_adata back into final_adata for the matching obs_names.
|
|
179
|
+
Works for list/object columns too.
|
|
180
|
+
"""
|
|
181
|
+
idx = final_adata.obs_names.get_indexer(sub_adata.obs_names)
|
|
182
|
+
if (idx < 0).any():
|
|
183
|
+
raise ValueError("Some sub_adata.obs_names not found in final_adata.obs_names")
|
|
181
184
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
if isinstance(device, str):
|
|
185
|
-
device = torch.device(device)
|
|
186
|
-
model.to(device)
|
|
185
|
+
for c in cols:
|
|
186
|
+
final_adata.obs.iloc[idx, final_adata.obs.columns.get_loc(c)] = sub_adata.obs[c].values
|
|
187
187
|
|
|
188
|
-
# persist the config to the hmm class
|
|
189
|
-
cls.config = cfg
|
|
190
188
|
|
|
191
|
-
|
|
189
|
+
def _to_dense_np(x):
|
|
190
|
+
"""Convert sparse or array-like input to a dense NumPy array.
|
|
192
191
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
"""
|
|
196
|
-
Update existing model parameters from a config or dict (in-place).
|
|
197
|
-
This will normalize / reinitialize start/trans/emission using same logic as constructor.
|
|
198
|
-
"""
|
|
199
|
-
if cfg is None:
|
|
200
|
-
merged = {}
|
|
201
|
-
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
202
|
-
merged = dict(cfg.to_dict())
|
|
203
|
-
elif isinstance(cfg, dict):
|
|
204
|
-
merged = dict(cfg)
|
|
205
|
-
else:
|
|
206
|
-
try:
|
|
207
|
-
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
208
|
-
except Exception:
|
|
209
|
-
merged = {}
|
|
192
|
+
Args:
|
|
193
|
+
x: Input array or sparse matrix.
|
|
210
194
|
|
|
211
|
-
|
|
212
|
-
|
|
195
|
+
Returns:
|
|
196
|
+
Dense NumPy array or None.
|
|
197
|
+
"""
|
|
198
|
+
if x is None:
|
|
199
|
+
return None
|
|
200
|
+
if issparse(x):
|
|
201
|
+
return x.toarray()
|
|
202
|
+
return np.asarray(x)
|
|
213
203
|
|
|
214
|
-
# extract same keys as from_config
|
|
215
|
-
n_states = int(merged.get("hmm_n_states", self.n_states))
|
|
216
|
-
init_start = merged.get("hmm_init_start_probs", None)
|
|
217
|
-
init_trans = merged.get("hmm_init_transition_probs", None)
|
|
218
|
-
init_emission = merged.get("hmm_init_emission_probs", None)
|
|
219
|
-
eps = merged.get("hmm_eps", None)
|
|
220
|
-
dtype = merged.get("hmm_dtype", None)
|
|
221
|
-
|
|
222
|
-
# apply dtype/eps if present
|
|
223
|
-
if eps is not None:
|
|
224
|
-
self.eps = float(eps)
|
|
225
|
-
if dtype is not None:
|
|
226
|
-
self.dtype = self._resolve_dtype(dtype)
|
|
227
|
-
|
|
228
|
-
# if n_states changes we need a fresh re-init (easy approach: reconstruct)
|
|
229
|
-
if int(n_states) != int(self.n_states):
|
|
230
|
-
# rebuild self in-place: create a new model and copy tensors
|
|
231
|
-
new_model = HMM.from_config(merged)
|
|
232
|
-
# copy content
|
|
233
|
-
with torch.no_grad():
|
|
234
|
-
self.n_states = new_model.n_states
|
|
235
|
-
self.eps = new_model.eps
|
|
236
|
-
self.dtype = new_model.dtype
|
|
237
|
-
self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
|
|
238
|
-
self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
|
|
239
|
-
self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
|
|
240
|
-
return
|
|
241
204
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
if obj is None:
|
|
245
|
-
return None
|
|
246
|
-
arr = np.asarray(obj, dtype=float)
|
|
247
|
-
if shape_expected is not None:
|
|
248
|
-
try:
|
|
249
|
-
arr = arr.reshape(shape_expected)
|
|
250
|
-
except Exception:
|
|
251
|
-
# try to free-form slice/reshape (keep best-effort)
|
|
252
|
-
arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
|
|
253
|
-
return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
|
|
205
|
+
def _ensure_2d_np(x):
|
|
206
|
+
"""Ensure an array is 2D, reshaping 1D inputs.
|
|
254
207
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
t = _to_tensor(init_start, (self.n_states,))
|
|
258
|
-
if t.numel() == self.n_states:
|
|
259
|
-
self.start.data = t.clone()
|
|
260
|
-
if init_trans is not None:
|
|
261
|
-
t = _to_tensor(init_trans, (self.n_states, self.n_states))
|
|
262
|
-
if t.shape == (self.n_states, self.n_states):
|
|
263
|
-
self.trans.data = t.clone()
|
|
264
|
-
if init_emission is not None:
|
|
265
|
-
# attempt to extract P(obs==1) if shaped (K,2)
|
|
266
|
-
arr = np.asarray(init_emission, dtype=float)
|
|
267
|
-
if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
|
|
268
|
-
em = arr[: self.n_states, 1]
|
|
269
|
-
else:
|
|
270
|
-
em = arr.reshape(-1)[: self.n_states]
|
|
271
|
-
t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
|
|
272
|
-
if t.numel() == self.n_states:
|
|
273
|
-
self.emission.data = t.clone()
|
|
208
|
+
Args:
|
|
209
|
+
x: Input array-like.
|
|
274
210
|
|
|
275
|
-
|
|
276
|
-
|
|
211
|
+
Returns:
|
|
212
|
+
2D NumPy array.
|
|
213
|
+
"""
|
|
214
|
+
x = _to_dense_np(x)
|
|
215
|
+
if x.ndim == 1:
|
|
216
|
+
x = x.reshape(1, -1)
|
|
217
|
+
if x.ndim != 2:
|
|
218
|
+
raise ValueError(f"Expected 2D array; got shape {x.shape}")
|
|
219
|
+
return x
|
|
277
220
|
|
|
278
|
-
def _ensure_device_dtype(self, device: Optional[torch.device]):
|
|
279
|
-
if device is None:
|
|
280
|
-
device = next(self.parameters()).device
|
|
281
|
-
self.start.data = self.start.data.to(device=device, dtype=self.dtype)
|
|
282
|
-
self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
|
|
283
|
-
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
284
|
-
return device
|
|
285
221
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
222
|
+
# =============================================================================
|
|
223
|
+
# Feature-set normalization
|
|
224
|
+
# =============================================================================
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def normalize_hmm_feature_sets(raw: Any) -> Dict[str, Dict[str, Any]]:
|
|
228
|
+
"""
|
|
229
|
+
Canonical format:
|
|
230
|
+
{
|
|
231
|
+
"footprints": {"state": "Non-Modified", "features": {"small_bound_stretch": [0,50], ...}},
|
|
232
|
+
"accessible": {"state": "Modified", "features": {"all_accessible_features": [0, inf], ...}},
|
|
233
|
+
...
|
|
234
|
+
}
|
|
235
|
+
Each feature range is [lo, hi) in genomic bp (or index units if coords aren't ints).
|
|
236
|
+
"""
|
|
237
|
+
parsed = _try_json_or_literal(raw)
|
|
238
|
+
if not isinstance(parsed, dict):
|
|
239
|
+
return {}
|
|
240
|
+
|
|
241
|
+
def _coerce_bound(v):
|
|
242
|
+
"""Coerce a bound value into a float or sentinel.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
v: Bound value.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Float, np.inf, or None.
|
|
293
249
|
"""
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
250
|
+
if v is None:
|
|
251
|
+
return None
|
|
252
|
+
if isinstance(v, (int, float)):
|
|
253
|
+
return float(v)
|
|
254
|
+
s = str(v).strip().lower()
|
|
255
|
+
if s in ("inf", "infty", "np.inf", "infinite"):
|
|
256
|
+
return np.inf
|
|
257
|
+
if s in ("none", ""):
|
|
258
|
+
return None
|
|
259
|
+
try:
|
|
260
|
+
return float(v)
|
|
261
|
+
except Exception:
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
def _coerce_map(feats):
|
|
265
|
+
"""Coerce feature ranges into (lo, hi) tuples.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
feats: Mapping of feature names to ranges.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Mapping of feature names to numeric bounds.
|
|
298
272
|
"""
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
warned_collapse = False
|
|
313
|
-
|
|
314
|
-
for i, seq in enumerate(data):
|
|
315
|
-
# seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
|
|
316
|
-
arr = np.asarray(seq, dtype=float)
|
|
317
|
-
|
|
318
|
-
# If arr is shape (L,1,1,...) squeeze trailing singletons
|
|
319
|
-
while arr.ndim > 1 and arr.shape[-1] == 1:
|
|
320
|
-
arr = np.squeeze(arr, axis=-1)
|
|
321
|
-
|
|
322
|
-
# If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
|
|
323
|
-
if arr.ndim > 1:
|
|
324
|
-
if not warned_collapse:
|
|
325
|
-
warnings.warn(
|
|
326
|
-
"HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
|
|
327
|
-
"(arr had shape {}). If you prefer a different reduction, "
|
|
328
|
-
"preprocess your data.".format(arr.shape),
|
|
329
|
-
stacklevel=2,
|
|
330
|
-
)
|
|
331
|
-
warned_collapse = True
|
|
332
|
-
# collapse features -> scalar per timestep
|
|
333
|
-
arr = np.asarray(arr, dtype=float).mean(axis=-1)
|
|
334
|
-
|
|
335
|
-
# now arr should be 1D (T,)
|
|
336
|
-
if arr.ndim == 0:
|
|
337
|
-
# single scalar: treat as length-1 sequence
|
|
338
|
-
arr = np.atleast_1d(arr)
|
|
339
|
-
|
|
340
|
-
nan_mask = np.isnan(arr)
|
|
341
|
-
if impute_strategy == "random" and nan_mask.any():
|
|
342
|
-
arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
|
|
343
|
-
local_mask = np.ones_like(arr, dtype=bool)
|
|
273
|
+
out = {}
|
|
274
|
+
if not isinstance(feats, dict):
|
|
275
|
+
return out
|
|
276
|
+
for name, rng in feats.items():
|
|
277
|
+
if rng is None:
|
|
278
|
+
out[name] = (0.0, np.inf)
|
|
279
|
+
continue
|
|
280
|
+
if isinstance(rng, (list, tuple)) and len(rng) >= 2:
|
|
281
|
+
lo = _coerce_bound(rng[0])
|
|
282
|
+
hi = _coerce_bound(rng[1])
|
|
283
|
+
lo = 0.0 if lo is None else float(lo)
|
|
284
|
+
hi = np.inf if hi is None else float(hi)
|
|
285
|
+
out[name] = (lo, hi)
|
|
344
286
|
else:
|
|
345
|
-
|
|
346
|
-
|
|
287
|
+
hi = _coerce_bound(rng)
|
|
288
|
+
hi = np.inf if hi is None else float(hi)
|
|
289
|
+
out[name] = (0.0, hi)
|
|
290
|
+
return out
|
|
291
|
+
|
|
292
|
+
out: Dict[str, Dict[str, Any]] = {}
|
|
293
|
+
for group, info in parsed.items():
|
|
294
|
+
if isinstance(info, dict):
|
|
295
|
+
feats = _coerce_map(info.get("features", info.get("ranges", {})))
|
|
296
|
+
state = info.get("state", info.get("label", "Modified"))
|
|
297
|
+
else:
|
|
298
|
+
feats = _coerce_map(info)
|
|
299
|
+
state = "Modified"
|
|
300
|
+
out[group] = {"features": feats, "state": state}
|
|
301
|
+
return out
|
|
347
302
|
|
|
348
|
-
L_i = arr.shape[0]
|
|
349
|
-
obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
|
|
350
|
-
mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
|
|
351
303
|
|
|
352
|
-
|
|
304
|
+
# =============================================================================
|
|
305
|
+
# BaseHMM: shared decoding + annotation pipeline
|
|
306
|
+
# =============================================================================
|
|
353
307
|
|
|
354
|
-
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
355
|
-
"""
|
|
356
|
-
obs: (B, L)
|
|
357
|
-
mask: (B, L) bool
|
|
358
|
-
returns logB (B, L, K)
|
|
359
|
-
"""
|
|
360
|
-
B, L = obs.shape
|
|
361
|
-
p = self.emission # (K,)
|
|
362
|
-
logp = torch.log(p + self.eps)
|
|
363
|
-
log1mp = torch.log1p(-p + self.eps)
|
|
364
|
-
obs_expand = obs.unsqueeze(-1) # (B, L, 1)
|
|
365
|
-
logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
|
|
366
|
-
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
367
|
-
return logB
|
|
368
308
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
"""
|
|
382
|
-
if device is None:
|
|
383
|
-
device = next(self.parameters()).device
|
|
384
|
-
elif isinstance(device, str):
|
|
385
|
-
device = torch.device(device)
|
|
386
|
-
device = self._ensure_device_dtype(device)
|
|
309
|
+
class BaseHMM(nn.Module):
|
|
310
|
+
"""
|
|
311
|
+
BaseHMM responsibilities:
|
|
312
|
+
- config resolution (from_config)
|
|
313
|
+
- EM fit wrapper (fit / fit_em)
|
|
314
|
+
- decoding (gamma / viterbi)
|
|
315
|
+
- AnnData annotation from provided arrays (X + coords)
|
|
316
|
+
- save/load registry aware
|
|
317
|
+
Subclasses implement:
|
|
318
|
+
- _log_emission(...) -> logB
|
|
319
|
+
- optional distance-aware transition handling
|
|
320
|
+
"""
|
|
387
321
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# rows are sequences: convert to list of 1D arrays
|
|
391
|
-
data = data.tolist()
|
|
392
|
-
elif data.ndim == 1:
|
|
393
|
-
# single sequence
|
|
394
|
-
data = [data.tolist()]
|
|
395
|
-
else:
|
|
396
|
-
raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
|
|
322
|
+
def __init__(self, n_states: int = 2, eps: float = 1e-8, dtype: torch.dtype = torch.float64):
|
|
323
|
+
"""Initialize the base HMM with shared parameters.
|
|
397
324
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
325
|
+
Args:
|
|
326
|
+
n_states: Number of hidden states.
|
|
327
|
+
eps: Smoothing epsilon for probabilities.
|
|
328
|
+
dtype: Torch dtype for parameters.
|
|
329
|
+
"""
|
|
330
|
+
super().__init__()
|
|
331
|
+
if n_states < 2:
|
|
332
|
+
raise ValueError("n_states must be >= 2")
|
|
333
|
+
self.n_states = int(n_states)
|
|
334
|
+
self.eps = float(eps)
|
|
335
|
+
self.dtype = dtype
|
|
402
336
|
|
|
403
|
-
|
|
404
|
-
|
|
337
|
+
# start probs + transitions (shared across backends)
|
|
338
|
+
start = np.full((self.n_states,), 1.0 / self.n_states, dtype=float)
|
|
339
|
+
trans = np.full((self.n_states, self.n_states), 1.0 / self.n_states, dtype=float)
|
|
405
340
|
|
|
406
|
-
|
|
341
|
+
self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
|
|
342
|
+
self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
|
|
343
|
+
self._normalize_params()
|
|
407
344
|
|
|
408
|
-
|
|
409
|
-
if verbose:
|
|
410
|
-
print(f"[HMM.fit] EM iter {it}")
|
|
345
|
+
# ------------------------- config -------------------------
|
|
411
346
|
|
|
412
|
-
|
|
413
|
-
|
|
347
|
+
@classmethod
|
|
348
|
+
def from_config(
|
|
349
|
+
cls, cfg: Union[dict, Any, None], *, override: Optional[dict] = None, device=None
|
|
350
|
+
):
|
|
351
|
+
"""Create a model from config with optional overrides.
|
|
414
352
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
353
|
+
Args:
|
|
354
|
+
cfg: Configuration mapping or object.
|
|
355
|
+
override: Override values to apply.
|
|
356
|
+
device: Device specifier.
|
|
418
357
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
358
|
+
Returns:
|
|
359
|
+
Initialized HMM instance.
|
|
360
|
+
"""
|
|
361
|
+
merged = cls._cfg_to_dict(cfg)
|
|
362
|
+
if override:
|
|
363
|
+
merged.update(override)
|
|
426
364
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
# temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
|
|
432
|
-
temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
|
|
433
|
-
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
365
|
+
n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
|
|
366
|
+
eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
|
|
367
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
|
|
368
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
434
369
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
total_loglike = float(seq_loglikes.sum().item())
|
|
370
|
+
model = cls(n_states=n_states, eps=eps, dtype=dtype)
|
|
371
|
+
if device is not None:
|
|
372
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
373
|
+
model._persisted_cfg = merged
|
|
374
|
+
return model
|
|
441
375
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
gamma = (log_gamma - logZ_time).exp() # (B, L, K)
|
|
376
|
+
@staticmethod
|
|
377
|
+
def _cfg_to_dict(cfg: Union[dict, Any, None]) -> dict:
|
|
378
|
+
"""Normalize a config object into a dictionary.
|
|
446
379
|
|
|
447
|
-
|
|
448
|
-
|
|
380
|
+
Args:
|
|
381
|
+
cfg: Config mapping or object.
|
|
449
382
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
383
|
+
Returns:
|
|
384
|
+
Dictionary of HMM-related config values.
|
|
385
|
+
"""
|
|
386
|
+
if cfg is None:
|
|
387
|
+
return {}
|
|
388
|
+
if isinstance(cfg, dict):
|
|
389
|
+
return dict(cfg)
|
|
390
|
+
if hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
391
|
+
return dict(cfg.to_dict())
|
|
392
|
+
out = {}
|
|
393
|
+
for k in dir(cfg):
|
|
394
|
+
if k.startswith("hmm_") or k in ("smf_modality", "cpg"):
|
|
395
|
+
try:
|
|
396
|
+
out[k] = getattr(cfg, k)
|
|
397
|
+
except Exception:
|
|
398
|
+
pass
|
|
399
|
+
return out
|
|
454
400
|
|
|
455
|
-
|
|
456
|
-
trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
457
|
-
if L >= 2:
|
|
458
|
-
time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
|
|
459
|
-
valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
|
|
460
|
-
for t in range(L - 1):
|
|
461
|
-
a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
|
|
462
|
-
b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
|
|
463
|
-
log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
|
|
464
|
-
log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
|
|
465
|
-
log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
466
|
-
xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
|
|
467
|
-
valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
468
|
-
xi_masked = xi * valid_t
|
|
469
|
-
trans_accum += xi_masked.sum(dim=0) # (i,j)
|
|
470
|
-
|
|
471
|
-
# M-step: update parameters with smoothing
|
|
472
|
-
with torch.no_grad():
|
|
473
|
-
new_start = gamma_start_accum + eps
|
|
474
|
-
new_start = new_start / new_start.sum()
|
|
401
|
+
# ------------------------- params -------------------------
|
|
475
402
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
403
|
+
def _normalize_params(self):
|
|
404
|
+
"""Normalize start and transition probabilities in-place."""
|
|
405
|
+
with torch.no_grad():
|
|
406
|
+
K = self.n_states
|
|
480
407
|
|
|
481
|
-
|
|
482
|
-
|
|
408
|
+
# start
|
|
409
|
+
self.start.data = self.start.data.reshape(-1)
|
|
410
|
+
if self.start.data.numel() != K:
|
|
411
|
+
self.start.data = torch.full(
|
|
412
|
+
(K,), 1.0 / K, dtype=self.dtype, device=self.start.device
|
|
413
|
+
)
|
|
414
|
+
self.start.data = self.start.data + self.eps
|
|
415
|
+
self.start.data = self.start.data / self.start.data.sum()
|
|
483
416
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
417
|
+
# trans
|
|
418
|
+
self.trans.data = self.trans.data.reshape(K, K)
|
|
419
|
+
self.trans.data = self.trans.data + self.eps
|
|
420
|
+
rs = self.trans.data.sum(dim=1, keepdim=True)
|
|
421
|
+
rs[rs == 0.0] = 1.0
|
|
422
|
+
self.trans.data = self.trans.data / rs
|
|
487
423
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
424
|
+
def _ensure_device_dtype(
|
|
425
|
+
self, device: Optional[Union[str, torch.device]] = None
|
|
426
|
+
) -> torch.device:
|
|
427
|
+
"""Move parameters to the requested device/dtype.
|
|
491
428
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
|
|
495
|
-
break
|
|
429
|
+
Args:
|
|
430
|
+
device: Device specifier or None to use current device.
|
|
496
431
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
def get_params(self) -> dict:
|
|
432
|
+
Returns:
|
|
433
|
+
Resolved torch device.
|
|
500
434
|
"""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
"trans": self.trans.detach().cpu().numpy().astype(float),
|
|
508
|
-
"emission": self.emission.detach().cpu().numpy().astype(float).reshape(-1),
|
|
509
|
-
}
|
|
435
|
+
if device is None:
|
|
436
|
+
device = next(self.parameters()).device
|
|
437
|
+
device = torch.device(device) if isinstance(device, str) else device
|
|
438
|
+
self.start.data = self.start.data.to(device=device, dtype=self.dtype)
|
|
439
|
+
self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
|
|
440
|
+
return device
|
|
510
441
|
|
|
511
|
-
|
|
442
|
+
# ------------------------- state labeling -------------------------
|
|
443
|
+
|
|
444
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
445
|
+
"""Subclasses return (K,) score; higher => more “Modified/Accessible”."""
|
|
446
|
+
raise NotImplementedError
|
|
447
|
+
|
|
448
|
+
def modified_state_index(self) -> int:
|
|
449
|
+
"""Return the index of the most modified/accessible state."""
|
|
450
|
+
scores = self._state_modified_score()
|
|
451
|
+
return int(torch.argmax(scores).item())
|
|
452
|
+
|
|
453
|
+
def resolve_target_state_index(self, state_target: Any) -> int:
|
|
512
454
|
"""
|
|
513
|
-
|
|
455
|
+
Accept:
|
|
456
|
+
- int -> explicit state index
|
|
457
|
+
- "Modified" / "Non-Modified" and aliases
|
|
514
458
|
"""
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
459
|
+
if isinstance(state_target, (int, np.integer)):
|
|
460
|
+
idx = int(state_target)
|
|
461
|
+
return max(0, min(idx, self.n_states - 1))
|
|
462
|
+
|
|
463
|
+
s = str(state_target).strip().lower()
|
|
464
|
+
if s in ("modified", "open", "accessible", "1", "pos", "positive"):
|
|
465
|
+
return self.modified_state_index()
|
|
466
|
+
if s in ("non-modified", "closed", "inaccessible", "0", "neg", "negative"):
|
|
467
|
+
scores = self._state_modified_score()
|
|
468
|
+
return int(torch.argmin(scores).item())
|
|
469
|
+
return self.modified_state_index()
|
|
470
|
+
|
|
471
|
+
# ------------------------- emissions -------------------------
|
|
472
|
+
|
|
473
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
529
474
|
"""
|
|
530
|
-
Return
|
|
475
|
+
Return logB:
|
|
476
|
+
- single: obs (N,L), mask (N,L) -> logB (N,L,K)
|
|
477
|
+
- multi : obs (N,L,C), mask (N,L,C) -> logB (N,L,K)
|
|
531
478
|
"""
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
479
|
+
raise NotImplementedError
|
|
480
|
+
|
|
481
|
+
# ------------------------- decoding core -------------------------
|
|
482
|
+
|
|
483
|
+
def _forward_backward(
|
|
484
|
+
self,
|
|
485
|
+
obs: torch.Tensor,
|
|
486
|
+
mask: torch.Tensor,
|
|
487
|
+
*,
|
|
488
|
+
coords: Optional[np.ndarray] = None,
|
|
489
|
+
) -> torch.Tensor:
|
|
541
490
|
"""
|
|
542
|
-
|
|
491
|
+
Returns gamma (N,L,K) in probability space.
|
|
492
|
+
Subclasses can override for distance-aware transitions.
|
|
543
493
|
"""
|
|
544
|
-
|
|
545
|
-
device = next(self.parameters()).device
|
|
546
|
-
elif isinstance(device, str):
|
|
547
|
-
device = torch.device(device)
|
|
548
|
-
device = self._ensure_device_dtype(device)
|
|
549
|
-
|
|
550
|
-
obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
551
|
-
B, L = obs.shape
|
|
552
|
-
K = self.n_states
|
|
494
|
+
device = obs.device
|
|
553
495
|
eps = float(self.eps)
|
|
496
|
+
K = self.n_states
|
|
554
497
|
|
|
555
|
-
logB = self._log_emission(obs, mask) # (
|
|
556
|
-
logA = torch.log(self.trans + eps)
|
|
557
|
-
logstart = torch.log(self.start + eps)
|
|
498
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
499
|
+
logA = torch.log(self.trans + eps) # (K,K)
|
|
500
|
+
logstart = torch.log(self.start + eps) # (K,)
|
|
501
|
+
|
|
502
|
+
N, L, _ = logB.shape
|
|
558
503
|
|
|
559
|
-
|
|
560
|
-
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
504
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
561
505
|
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
506
|
+
|
|
562
507
|
for t in range(1, L):
|
|
563
|
-
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
508
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
|
|
564
509
|
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
565
510
|
|
|
566
|
-
|
|
567
|
-
beta
|
|
568
|
-
beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
|
|
511
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
512
|
+
beta[:, L - 1, :] = 0.0
|
|
569
513
|
for t in range(L - 2, -1, -1):
|
|
570
|
-
temp = logA.unsqueeze(0) + (logB[:, t + 1, :]
|
|
514
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
571
515
|
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
572
516
|
|
|
573
|
-
# gamma
|
|
574
517
|
log_gamma = alpha + beta
|
|
575
|
-
|
|
576
|
-
gamma = (log_gamma -
|
|
518
|
+
logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
|
|
519
|
+
gamma = (log_gamma - logZ).exp()
|
|
520
|
+
return gamma
|
|
577
521
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
522
|
+
def _viterbi(
|
|
523
|
+
self,
|
|
524
|
+
obs: torch.Tensor,
|
|
525
|
+
mask: torch.Tensor,
|
|
526
|
+
*,
|
|
527
|
+
coords: Optional[np.ndarray] = None,
|
|
528
|
+
) -> torch.Tensor:
|
|
585
529
|
"""
|
|
586
|
-
|
|
587
|
-
|
|
530
|
+
Returns states (N,L) int64. Missing positions (mask False for all channels)
|
|
531
|
+
are still decoded, but you’ll overwrite them to -1 during writing.
|
|
532
|
+
Subclasses can override for distance-aware transitions.
|
|
588
533
|
"""
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
single = True
|
|
593
|
-
else:
|
|
594
|
-
seqs = seq_or_list
|
|
534
|
+
device = obs.device
|
|
535
|
+
eps = float(self.eps)
|
|
536
|
+
K = self.n_states
|
|
595
537
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
device = torch.device(device)
|
|
600
|
-
device = self._ensure_device_dtype(device)
|
|
538
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
539
|
+
logA = torch.log(self.trans + eps) # (K,K)
|
|
540
|
+
logstart = torch.log(self.start + eps) # (K,)
|
|
601
541
|
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
K =
|
|
605
|
-
eps = float(self.eps)
|
|
542
|
+
N, L, _ = logB.shape
|
|
543
|
+
delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
544
|
+
psi = torch.empty((N, L, K), dtype=torch.long, device=device)
|
|
606
545
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
logstart = torch.log(self.start + eps)
|
|
546
|
+
delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
547
|
+
psi[:, 0, :] = -1
|
|
610
548
|
|
|
611
|
-
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
612
|
-
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
613
549
|
for t in range(1, L):
|
|
614
|
-
|
|
615
|
-
|
|
550
|
+
cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
|
|
551
|
+
best_val, best_idx = cand.max(dim=1)
|
|
552
|
+
delta[:, t, :] = best_val + logB[:, t, :]
|
|
553
|
+
psi[:, t, :] = best_idx
|
|
616
554
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
final_alpha = alpha[idx_range, last_idx, :] # (B, K)
|
|
620
|
-
seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
|
|
621
|
-
seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
|
|
622
|
-
return seq_loglikes[0] if single else seq_loglikes
|
|
555
|
+
last_state = torch.argmax(delta[:, L - 1, :], dim=1) # (N,)
|
|
556
|
+
states = torch.empty((N, L), dtype=torch.long, device=device)
|
|
623
557
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
"""
|
|
628
|
-
paths, scores = self.batch_viterbi([seq], impute_strategy=impute_strategy, device=device)
|
|
629
|
-
return paths[0], scores[0]
|
|
558
|
+
states[:, L - 1] = last_state
|
|
559
|
+
for t in range(L - 2, -1, -1):
|
|
560
|
+
states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
|
|
630
561
|
|
|
631
|
-
|
|
562
|
+
return states
|
|
563
|
+
|
|
564
|
+
def decode(
|
|
565
|
+
self,
|
|
566
|
+
X: np.ndarray,
|
|
567
|
+
coords: Optional[np.ndarray] = None,
|
|
568
|
+
*,
|
|
569
|
+
decode: str = "marginal",
|
|
570
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
571
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
572
|
+
"""Decode observations into state calls and posterior probabilities.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
X: Observations array (N, L) or (N, L, C).
|
|
576
|
+
coords: Optional coordinates aligned to L.
|
|
577
|
+
decode: Decoding strategy ("marginal" or "viterbi").
|
|
578
|
+
device: Device specifier.
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
Tuple of (states, posterior probabilities).
|
|
632
582
|
"""
|
|
633
|
-
|
|
634
|
-
|
|
583
|
+
device = self._ensure_device_dtype(device)
|
|
584
|
+
|
|
585
|
+
X = np.asarray(X, dtype=float)
|
|
586
|
+
if X.ndim == 2:
|
|
587
|
+
L = X.shape[1]
|
|
588
|
+
elif X.ndim == 3:
|
|
589
|
+
L = X.shape[1]
|
|
590
|
+
else:
|
|
591
|
+
raise ValueError(f"X must be 2D or 3D; got shape {X.shape}")
|
|
592
|
+
|
|
593
|
+
if coords is None:
|
|
594
|
+
coords = np.arange(L, dtype=int)
|
|
595
|
+
coords = np.asarray(coords, dtype=int)
|
|
596
|
+
|
|
597
|
+
if X.ndim == 2:
|
|
598
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
599
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
600
|
+
else:
|
|
601
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
602
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
603
|
+
|
|
604
|
+
gamma = self._forward_backward(obs, mask, coords=coords)
|
|
605
|
+
|
|
606
|
+
if str(decode).lower() == "viterbi":
|
|
607
|
+
st = self._viterbi(obs, mask, coords=coords)
|
|
608
|
+
else:
|
|
609
|
+
st = torch.argmax(gamma, dim=2)
|
|
610
|
+
|
|
611
|
+
return st.detach().cpu().numpy(), gamma.detach().cpu().numpy()
|
|
612
|
+
|
|
613
|
+
# ------------------------- EM fit -------------------------
|
|
614
|
+
|
|
615
|
+
def fit(
|
|
616
|
+
self,
|
|
617
|
+
X: np.ndarray,
|
|
618
|
+
coords: Optional[np.ndarray] = None,
|
|
619
|
+
*,
|
|
620
|
+
max_iter: int = 50,
|
|
621
|
+
tol: float = 1e-4,
|
|
622
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
623
|
+
update_start: bool = True,
|
|
624
|
+
update_trans: bool = True,
|
|
625
|
+
update_emission: bool = True,
|
|
626
|
+
verbose: bool = False,
|
|
627
|
+
**kwargs,
|
|
628
|
+
) -> List[float]:
|
|
629
|
+
"""Fit HMM parameters using EM.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
X: Observations array.
|
|
633
|
+
coords: Optional coordinate array.
|
|
634
|
+
max_iter: Maximum EM iterations.
|
|
635
|
+
tol: Convergence tolerance.
|
|
636
|
+
device: Device specifier.
|
|
637
|
+
update_start: Whether to update start probabilities.
|
|
638
|
+
update_trans: Whether to update transition probabilities.
|
|
639
|
+
update_emission: Whether to update emission parameters.
|
|
640
|
+
verbose: Whether to log progress.
|
|
641
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
List of log-likelihood values across iterations.
|
|
635
645
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
646
|
+
X = np.asarray(X, dtype=float)
|
|
647
|
+
if X.ndim not in (2, 3):
|
|
648
|
+
raise ValueError(f"X must be 2D or 3D; got {X.shape}")
|
|
649
|
+
L = X.shape[1]
|
|
650
|
+
|
|
651
|
+
if coords is None:
|
|
652
|
+
coords = np.arange(L, dtype=int)
|
|
653
|
+
coords = np.asarray(coords, dtype=int)
|
|
654
|
+
|
|
640
655
|
device = self._ensure_device_dtype(device)
|
|
656
|
+
return self.fit_em(
|
|
657
|
+
X,
|
|
658
|
+
coords,
|
|
659
|
+
device=device,
|
|
660
|
+
max_iter=max_iter,
|
|
661
|
+
tol=tol,
|
|
662
|
+
update_start=update_start,
|
|
663
|
+
update_trans=update_trans,
|
|
664
|
+
update_emission=update_emission,
|
|
665
|
+
verbose=verbose,
|
|
666
|
+
**kwargs,
|
|
667
|
+
)
|
|
641
668
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
669
|
+
def adapt_emissions(
|
|
670
|
+
self,
|
|
671
|
+
X: np.ndarray,
|
|
672
|
+
coords: np.ndarray,
|
|
673
|
+
*,
|
|
674
|
+
iters: int = 5,
|
|
675
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
676
|
+
freeze_start: bool = True,
|
|
677
|
+
freeze_trans: bool = True,
|
|
678
|
+
verbose: bool = False,
|
|
679
|
+
**kwargs,
|
|
680
|
+
) -> List[float]:
|
|
681
|
+
"""Adapt emission parameters while keeping shared structure fixed.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
X: Observations array.
|
|
685
|
+
coords: Coordinate array aligned to X.
|
|
686
|
+
iters: Number of EM iterations.
|
|
687
|
+
device: Device specifier.
|
|
688
|
+
freeze_start: Whether to freeze start probabilities.
|
|
689
|
+
freeze_trans: Whether to freeze transitions.
|
|
690
|
+
verbose: Whether to log progress.
|
|
691
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
List of log-likelihood values across iterations.
|
|
695
|
+
"""
|
|
696
|
+
return self.fit(
|
|
697
|
+
X,
|
|
698
|
+
coords,
|
|
699
|
+
max_iter=int(iters),
|
|
700
|
+
tol=0.0,
|
|
701
|
+
device=device,
|
|
702
|
+
update_start=not freeze_start,
|
|
703
|
+
update_trans=not freeze_trans,
|
|
704
|
+
update_emission=True,
|
|
705
|
+
verbose=verbose,
|
|
706
|
+
**kwargs,
|
|
707
|
+
)
|
|
646
708
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
709
|
+
def fit_em(
|
|
710
|
+
self,
|
|
711
|
+
X: np.ndarray,
|
|
712
|
+
coords: np.ndarray,
|
|
713
|
+
*,
|
|
714
|
+
device: torch.device,
|
|
715
|
+
max_iter: int,
|
|
716
|
+
tol: float,
|
|
717
|
+
update_start: bool,
|
|
718
|
+
update_trans: bool,
|
|
719
|
+
update_emission: bool,
|
|
720
|
+
verbose: bool,
|
|
721
|
+
**kwargs,
|
|
722
|
+
) -> List[float]:
|
|
723
|
+
"""Run the core EM update loop (subclasses implement).
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
X: Observations array.
|
|
727
|
+
coords: Coordinate array aligned to X.
|
|
728
|
+
device: Torch device.
|
|
729
|
+
max_iter: Maximum iterations.
|
|
730
|
+
tol: Convergence tolerance.
|
|
731
|
+
update_start: Whether to update start probabilities.
|
|
732
|
+
update_trans: Whether to update transitions.
|
|
733
|
+
update_emission: Whether to update emission parameters.
|
|
734
|
+
verbose: Whether to log progress.
|
|
735
|
+
**kwargs: Additional subclass-specific kwargs.
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
List of log-likelihood values across iterations.
|
|
739
|
+
"""
|
|
740
|
+
raise NotImplementedError
|
|
652
741
|
|
|
653
|
-
|
|
654
|
-
logA = torch.log(self.trans + eps)
|
|
742
|
+
# ------------------------- save/load -------------------------
|
|
655
743
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
744
|
+
def _extra_save_payload(self) -> dict:
|
|
745
|
+
"""Return extra model state to include when saving."""
|
|
746
|
+
return {}
|
|
659
747
|
|
|
660
|
-
|
|
661
|
-
|
|
748
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
749
|
+
"""Load extra model state saved by subclasses.
|
|
662
750
|
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
|
|
667
|
-
delta[:, t, :] = best_val + logB[:, t, :]
|
|
668
|
-
psi[:, t, :] = best_idx # best previous state index for each (B, j)
|
|
669
|
-
|
|
670
|
-
# backtrack
|
|
671
|
-
last_idx = (lengths - 1).clamp(min=0)
|
|
672
|
-
idx_range = torch.arange(B, device=device)
|
|
673
|
-
final_delta = delta[idx_range, last_idx, :] # (B, K)
|
|
674
|
-
best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
|
|
675
|
-
paths = []
|
|
676
|
-
scores = []
|
|
677
|
-
for b in range(B):
|
|
678
|
-
Lb = int(lengths[b].item())
|
|
679
|
-
if Lb == 0:
|
|
680
|
-
paths.append([])
|
|
681
|
-
scores.append(float("-inf"))
|
|
682
|
-
continue
|
|
683
|
-
s = int(best_last_state[b].item())
|
|
684
|
-
path = [s]
|
|
685
|
-
for t in range(Lb - 1, 0, -1):
|
|
686
|
-
s = int(psi[b, t, s].item())
|
|
687
|
-
path.append(s)
|
|
688
|
-
path.reverse()
|
|
689
|
-
paths.append(path)
|
|
690
|
-
scores.append(float(best_last_val[b].item()))
|
|
691
|
-
return paths, scores
|
|
692
|
-
|
|
693
|
-
def save(self, path: str) -> None:
|
|
751
|
+
Args:
|
|
752
|
+
payload: Serialized model payload.
|
|
753
|
+
device: Torch device for tensors.
|
|
694
754
|
"""
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
755
|
+
return
|
|
756
|
+
|
|
757
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
758
|
+
"""Serialize the model to disk.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
path: Output path for the serialized model.
|
|
698
762
|
"""
|
|
763
|
+
path = str(path)
|
|
699
764
|
payload = {
|
|
765
|
+
"hmm_name": getattr(self, "hmm_name", self.__class__.__name__),
|
|
766
|
+
"class": self.__class__.__name__,
|
|
700
767
|
"n_states": int(self.n_states),
|
|
701
768
|
"eps": float(self.eps),
|
|
702
|
-
# store dtype as a string like "torch.float64" (portable)
|
|
703
769
|
"dtype": str(self.dtype),
|
|
704
770
|
"start": self.start.detach().cpu(),
|
|
705
771
|
"trans": self.trans.detach().cpu(),
|
|
706
|
-
"emission": self.emission.detach().cpu(),
|
|
707
772
|
}
|
|
773
|
+
payload.update(self._extra_save_payload())
|
|
708
774
|
torch.save(payload, path)
|
|
709
775
|
|
|
710
776
|
@classmethod
|
|
711
|
-
def load(cls, path: str, device: Optional[Union[torch.device
|
|
712
|
-
"""
|
|
713
|
-
Load model from `path`. If `device` is provided (str or torch.device),
|
|
714
|
-
parameters will be moved to that device; otherwise they remain on CPU.
|
|
715
|
-
Example: model = HMM.load('hmm.pt', device='cuda')
|
|
716
|
-
"""
|
|
717
|
-
payload = torch.load(path, map_location="cpu")
|
|
777
|
+
def load(cls, path: Union[str, Path], device: Optional[Union[str, torch.device]] = None):
|
|
778
|
+
"""Load a serialized model from disk.
|
|
718
779
|
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
780
|
+
Args:
|
|
781
|
+
path: Path to the serialized model.
|
|
782
|
+
device: Optional device specifier.
|
|
722
783
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
model
|
|
741
|
-
|
|
742
|
-
# Determine target device
|
|
743
|
-
if device is None:
|
|
744
|
-
device = torch.device("cpu")
|
|
745
|
-
elif isinstance(device, str):
|
|
746
|
-
device = torch.device(device)
|
|
784
|
+
Returns:
|
|
785
|
+
Loaded HMM instance.
|
|
786
|
+
"""
|
|
787
|
+
payload = torch.load(str(path), map_location="cpu")
|
|
788
|
+
hmm_name = payload.get("hmm_name", None)
|
|
789
|
+
klass = _HMM_REGISTRY.get(hmm_name, cls)
|
|
790
|
+
|
|
791
|
+
dtype_str = str(payload.get("dtype", "torch.float64"))
|
|
792
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1], torch.float64)
|
|
793
|
+
torch_dtype = _coerce_dtype_for_device(torch_dtype, device) # <<< NEW
|
|
794
|
+
|
|
795
|
+
model = klass(
|
|
796
|
+
n_states=int(payload["n_states"]),
|
|
797
|
+
eps=float(payload.get("eps", 1e-8)),
|
|
798
|
+
dtype=torch_dtype,
|
|
799
|
+
)
|
|
800
|
+
dev = torch.device(device) if isinstance(device, str) else (device or torch.device("cpu"))
|
|
801
|
+
model.to(dev)
|
|
747
802
|
|
|
748
|
-
# Load params (they were saved on CPU) and cast to model dtype/device
|
|
749
803
|
with torch.no_grad():
|
|
750
|
-
model.start.data = payload["start"].to(device=
|
|
751
|
-
model.trans.data = payload["trans"].to(device=
|
|
752
|
-
model.emission.data = payload["emission"].to(device=device, dtype=model.dtype)
|
|
804
|
+
model.start.data = payload["start"].to(device=dev, dtype=model.dtype)
|
|
805
|
+
model.trans.data = payload["trans"].to(device=dev, dtype=model.dtype)
|
|
753
806
|
|
|
754
|
-
|
|
807
|
+
model._load_extra_payload(payload, device=dev)
|
|
755
808
|
model._normalize_params()
|
|
756
809
|
return model
|
|
757
|
-
|
|
810
|
+
|
|
811
|
+
# ------------------------- interval helpers -------------------------
|
|
812
|
+
|
|
813
|
+
@staticmethod
|
|
814
|
+
def _runs_from_bool(mask_1d: np.ndarray) -> List[Tuple[int, int]]:
|
|
815
|
+
"""
|
|
816
|
+
Return runs as (start_idx, end_idx_exclusive) for True segments.
|
|
817
|
+
"""
|
|
818
|
+
idx = np.nonzero(mask_1d)[0]
|
|
819
|
+
if idx.size == 0:
|
|
820
|
+
return []
|
|
821
|
+
breaks = np.where(np.diff(idx) > 1)[0]
|
|
822
|
+
starts = np.r_[idx[0], idx[breaks + 1]]
|
|
823
|
+
ends = np.r_[idx[breaks] + 1, idx[-1] + 1]
|
|
824
|
+
return list(zip(starts, ends))
|
|
825
|
+
|
|
826
|
+
@staticmethod
|
|
827
|
+
def _interval_length(coords: np.ndarray, s: int, e: int) -> int:
|
|
828
|
+
"""Genomic length for [s,e) on coords."""
|
|
829
|
+
if e <= s:
|
|
830
|
+
return 0
|
|
831
|
+
return int(coords[e - 1]) - int(coords[s]) + 1
|
|
832
|
+
|
|
833
|
+
@staticmethod
|
|
834
|
+
def _write_lengths_for_binary_layer(bin_mat: np.ndarray) -> np.ndarray:
|
|
835
|
+
"""
|
|
836
|
+
For each row, each True-run gets its run-length assigned across that run.
|
|
837
|
+
Output same shape as bin_mat, int32.
|
|
838
|
+
"""
|
|
839
|
+
n, L = bin_mat.shape
|
|
840
|
+
out = np.zeros((n, L), dtype=np.int32)
|
|
841
|
+
for i in range(n):
|
|
842
|
+
runs = BaseHMM._runs_from_bool(bin_mat[i].astype(bool))
|
|
843
|
+
for s, e in runs:
|
|
844
|
+
out[i, s:e] = e - s
|
|
845
|
+
return out
|
|
846
|
+
|
|
847
|
+
@staticmethod
|
|
848
|
+
def _write_lengths_for_state_layer(states: np.ndarray) -> np.ndarray:
|
|
849
|
+
"""
|
|
850
|
+
For each row, each constant-state run gets run-length assigned across run.
|
|
851
|
+
Missing values should be -1 and will get 0 length.
|
|
852
|
+
"""
|
|
853
|
+
n, L = states.shape
|
|
854
|
+
out = np.zeros((n, L), dtype=np.int32)
|
|
855
|
+
for i in range(n):
|
|
856
|
+
row = states[i]
|
|
857
|
+
valid = row >= 0
|
|
858
|
+
if not np.any(valid):
|
|
859
|
+
continue
|
|
860
|
+
# scan runs
|
|
861
|
+
s = 0
|
|
862
|
+
while s < L:
|
|
863
|
+
if row[s] < 0:
|
|
864
|
+
s += 1
|
|
865
|
+
continue
|
|
866
|
+
v = row[s]
|
|
867
|
+
e = s + 1
|
|
868
|
+
while e < L and row[e] == v:
|
|
869
|
+
e += 1
|
|
870
|
+
out[i, s:e] = e - s
|
|
871
|
+
s = e
|
|
872
|
+
return out
|
|
873
|
+
|
|
874
|
+
# ------------------------- merging -------------------------
|
|
875
|
+
|
|
876
|
+
def merge_intervals_to_new_layer(
|
|
877
|
+
self,
|
|
878
|
+
adata,
|
|
879
|
+
base_layer: str,
|
|
880
|
+
*,
|
|
881
|
+
distance_threshold: int,
|
|
882
|
+
suffix: str = "_merged",
|
|
883
|
+
overwrite: bool = True,
|
|
884
|
+
) -> str:
|
|
885
|
+
"""
|
|
886
|
+
Merge adjacent 1-intervals in a binary layer if gaps <= distance_threshold (in coords space),
|
|
887
|
+
writing:
|
|
888
|
+
- {base_layer}{suffix}
|
|
889
|
+
- {base_layer}{suffix}_lengths (run-length in index units)
|
|
890
|
+
"""
|
|
891
|
+
if base_layer not in adata.layers:
|
|
892
|
+
raise KeyError(f"Layer '{base_layer}' not found.")
|
|
893
|
+
|
|
894
|
+
coords, coords_are_ints = _safe_int_coords(adata.var_names)
|
|
895
|
+
arr = np.asarray(adata.layers[base_layer])
|
|
896
|
+
arr = (arr > 0).astype(np.uint8)
|
|
897
|
+
|
|
898
|
+
merged_name = f"{base_layer}{suffix}"
|
|
899
|
+
merged_len_name = f"{merged_name}_lengths"
|
|
900
|
+
|
|
901
|
+
if (merged_name in adata.layers or merged_len_name in adata.layers) and not overwrite:
|
|
902
|
+
raise KeyError(f"Merged outputs exist (use overwrite=True): {merged_name}")
|
|
903
|
+
|
|
904
|
+
n, L = arr.shape
|
|
905
|
+
out = np.zeros_like(arr, dtype=np.uint8)
|
|
906
|
+
|
|
907
|
+
dt = int(distance_threshold)
|
|
908
|
+
|
|
909
|
+
for i in range(n):
|
|
910
|
+
ones = np.nonzero(arr[i] != 0)[0]
|
|
911
|
+
runs = self._runs_from_bool(arr[i] != 0)
|
|
912
|
+
if not runs:
|
|
913
|
+
continue
|
|
914
|
+
ms, me = runs[0]
|
|
915
|
+
merged_runs = []
|
|
916
|
+
for s, e in runs[1:]:
|
|
917
|
+
if coords_are_ints:
|
|
918
|
+
gap = int(coords[s]) - int(coords[me - 1]) - 1
|
|
919
|
+
else:
|
|
920
|
+
gap = s - me
|
|
921
|
+
if gap <= dt:
|
|
922
|
+
me = e
|
|
923
|
+
else:
|
|
924
|
+
merged_runs.append((ms, me))
|
|
925
|
+
ms, me = s, e
|
|
926
|
+
merged_runs.append((ms, me))
|
|
927
|
+
|
|
928
|
+
for s, e in merged_runs:
|
|
929
|
+
out[i, s:e] = 1
|
|
930
|
+
|
|
931
|
+
adata.layers[merged_name] = out
|
|
932
|
+
adata.layers[merged_len_name] = self._write_lengths_for_binary_layer(out)
|
|
933
|
+
|
|
934
|
+
# bookkeeping
|
|
935
|
+
key = "hmm_appended_layers"
|
|
936
|
+
if adata.uns.get(key) is None:
|
|
937
|
+
adata.uns[key] = []
|
|
938
|
+
for nm in (merged_name, merged_len_name):
|
|
939
|
+
if nm not in adata.uns[key]:
|
|
940
|
+
adata.uns[key].append(nm)
|
|
941
|
+
|
|
942
|
+
return merged_name
|
|
943
|
+
|
|
944
|
+
def write_size_class_layers_from_binary(
|
|
945
|
+
self,
|
|
946
|
+
adata,
|
|
947
|
+
base_layer: str,
|
|
948
|
+
*,
|
|
949
|
+
out_prefix: str,
|
|
950
|
+
feature_ranges: Dict[str, Tuple[float, float]],
|
|
951
|
+
suffix: str = "",
|
|
952
|
+
overwrite: bool = True,
|
|
953
|
+
) -> List[str]:
|
|
954
|
+
"""
|
|
955
|
+
Take an existing binary layer (runs represent features) and write size-class layers:
|
|
956
|
+
- {out_prefix}_{feature}{suffix}
|
|
957
|
+
- plus lengths layers
|
|
958
|
+
|
|
959
|
+
feature_ranges: name -> (lo, hi) in genomic bp.
|
|
960
|
+
"""
|
|
961
|
+
if base_layer not in adata.layers:
|
|
962
|
+
raise KeyError(f"Layer '{base_layer}' not found.")
|
|
963
|
+
|
|
964
|
+
coords, coords_are_ints = _safe_int_coords(adata.var_names)
|
|
965
|
+
bin_arr = (np.asarray(adata.layers[base_layer]) > 0).astype(np.uint8)
|
|
966
|
+
n, L = bin_arr.shape
|
|
967
|
+
|
|
968
|
+
created: List[str] = []
|
|
969
|
+
for feat_name in feature_ranges.keys():
|
|
970
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
971
|
+
ln = f"{nm}_lengths"
|
|
972
|
+
if (nm in adata.layers or ln in adata.layers) and not overwrite:
|
|
973
|
+
continue
|
|
974
|
+
adata.layers[nm] = np.zeros((n, L), dtype=np.uint8)
|
|
975
|
+
adata.layers[ln] = np.zeros((n, L), dtype=np.int32)
|
|
976
|
+
created.extend([nm, ln])
|
|
977
|
+
|
|
978
|
+
for i in range(n):
|
|
979
|
+
runs = self._runs_from_bool(bin_arr[i] != 0)
|
|
980
|
+
for s, e in runs:
|
|
981
|
+
length_bp = self._interval_length(coords, s, e) if coords_are_ints else (e - s)
|
|
982
|
+
for feat_name, (lo, hi) in feature_ranges.items():
|
|
983
|
+
if float(lo) <= float(length_bp) < float(hi):
|
|
984
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
985
|
+
adata.layers[nm][i, s:e] = 1
|
|
986
|
+
adata.layers[f"{nm}_lengths"][i, s:e] = e - s
|
|
987
|
+
break
|
|
988
|
+
|
|
989
|
+
# fill lengths for each size layer (consistent, even if overlaps)
|
|
990
|
+
for feat_name in feature_ranges.keys():
|
|
991
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
992
|
+
adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
|
|
993
|
+
np.asarray(adata.layers[nm])
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
key = "hmm_appended_layers"
|
|
997
|
+
if adata.uns.get(key) is None:
|
|
998
|
+
adata.uns[key] = []
|
|
999
|
+
for nm in created:
|
|
1000
|
+
if nm not in adata.uns[key]:
|
|
1001
|
+
adata.uns[key].append(nm)
|
|
1002
|
+
|
|
1003
|
+
return created
|
|
1004
|
+
|
|
1005
|
+
# ------------------------- AnnData annotation -------------------------
|
|
1006
|
+
|
|
1007
|
+
@staticmethod
|
|
1008
|
+
def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
|
|
1009
|
+
"""
|
|
1010
|
+
Local helper to resolve per-base masks from subset.var.* columns.
|
|
1011
|
+
Returns a boolean np.ndarray of length subset.n_vars or None.
|
|
1012
|
+
"""
|
|
1013
|
+
key = str(methbase).strip().lower()
|
|
1014
|
+
var = subset.var
|
|
1015
|
+
|
|
1016
|
+
def _has(col: str) -> bool:
|
|
1017
|
+
"""Return True when a column exists on subset.var."""
|
|
1018
|
+
return col in var.columns
|
|
1019
|
+
|
|
1020
|
+
if key in ("a",):
|
|
1021
|
+
col = f"{ref}_strand_FASTA_base"
|
|
1022
|
+
if not _has(col):
|
|
1023
|
+
return None
|
|
1024
|
+
return np.asarray(var[col] == "A")
|
|
1025
|
+
|
|
1026
|
+
if key in ("c", "any_c", "anyc", "any-c"):
|
|
1027
|
+
for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
|
|
1028
|
+
if _has(col):
|
|
1029
|
+
return np.asarray(var[col])
|
|
1030
|
+
return None
|
|
1031
|
+
|
|
1032
|
+
if key in ("gpc", "gpc_site", "gpc-site"):
|
|
1033
|
+
col = f"{ref}_GpC_site"
|
|
1034
|
+
if not _has(col):
|
|
1035
|
+
return None
|
|
1036
|
+
return np.asarray(var[col])
|
|
1037
|
+
|
|
1038
|
+
if key in ("cpg", "cpg_site", "cpg-site"):
|
|
1039
|
+
col = f"{ref}_CpG_site"
|
|
1040
|
+
if not _has(col):
|
|
1041
|
+
return None
|
|
1042
|
+
return np.asarray(var[col])
|
|
1043
|
+
|
|
1044
|
+
alt = f"{ref}_{methbase}_site"
|
|
1045
|
+
if not _has(alt):
|
|
1046
|
+
return None
|
|
1047
|
+
return np.asarray(var[alt])
|
|
1048
|
+
|
|
758
1049
|
def annotate_adata(
|
|
759
1050
|
self,
|
|
760
1051
|
adata,
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
1052
|
+
*,
|
|
1053
|
+
prefix: str,
|
|
1054
|
+
X: np.ndarray,
|
|
1055
|
+
coords: np.ndarray,
|
|
1056
|
+
var_mask: np.ndarray,
|
|
1057
|
+
span_fill: bool = True,
|
|
1058
|
+
config=None,
|
|
1059
|
+
decode: str = "marginal",
|
|
1060
|
+
write_posterior: bool = True,
|
|
1061
|
+
posterior_state: str = "Modified",
|
|
1062
|
+
feature_sets: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
1063
|
+
prob_threshold: float = 0.5,
|
|
773
1064
|
uns_key: str = "hmm_appended_layers",
|
|
774
|
-
config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
|
|
775
1065
|
uns_flag: str = "hmm_annotated",
|
|
776
|
-
force_redo: bool = False
|
|
1066
|
+
force_redo: bool = False,
|
|
1067
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1068
|
+
**kwargs,
|
|
777
1069
|
):
|
|
1070
|
+
"""Decode and annotate an AnnData object with HMM-derived layers.
|
|
1071
|
+
|
|
1072
|
+
Args:
|
|
1073
|
+
adata: AnnData to annotate.
|
|
1074
|
+
prefix: Prefix for newly written layers.
|
|
1075
|
+
X: Observations array for decoding.
|
|
1076
|
+
coords: Coordinate array aligned to X.
|
|
1077
|
+
var_mask: Boolean mask for positions in adata.var.
|
|
1078
|
+
span_fill: Whether to fill missing spans.
|
|
1079
|
+
config: Optional config for naming and state selection.
|
|
1080
|
+
decode: Decode method ("marginal" or "viterbi").
|
|
1081
|
+
write_posterior: Whether to write posterior probabilities.
|
|
1082
|
+
posterior_state: State label to write posterior for.
|
|
1083
|
+
feature_sets: Optional feature set definition for size classes.
|
|
1084
|
+
prob_threshold: Posterior probability threshold for binary calls.
|
|
1085
|
+
uns_key: .uns key to track appended layers.
|
|
1086
|
+
uns_flag: .uns flag to mark annotations.
|
|
1087
|
+
force_redo: Whether to overwrite existing layers.
|
|
1088
|
+
device: Device specifier.
|
|
1089
|
+
**kwargs: Additional parameters for specialized workflows.
|
|
1090
|
+
|
|
1091
|
+
Returns:
|
|
1092
|
+
List of created layer names or None if skipped.
|
|
778
1093
|
"""
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
#
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
if
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
1094
|
+
# skip logic
|
|
1095
|
+
if bool(adata.uns.get(uns_flag, False)) and not force_redo:
|
|
1096
|
+
return None
|
|
1097
|
+
|
|
1098
|
+
if adata.uns.get(uns_key) is None:
|
|
1099
|
+
adata.uns[uns_key] = []
|
|
1100
|
+
appended = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
|
|
1101
|
+
|
|
1102
|
+
X = np.asarray(X, dtype=float)
|
|
1103
|
+
coords = np.asarray(coords, dtype=int)
|
|
1104
|
+
var_mask = np.asarray(var_mask, dtype=bool)
|
|
1105
|
+
if var_mask.shape[0] != adata.n_vars:
|
|
1106
|
+
raise ValueError(f"var_mask length {var_mask.shape[0]} != adata.n_vars {adata.n_vars}")
|
|
1107
|
+
|
|
1108
|
+
# decode
|
|
1109
|
+
states, gamma = self.decode(
|
|
1110
|
+
X, coords, decode=decode, device=device
|
|
1111
|
+
) # states (N,L), gamma (N,L,K)
|
|
1112
|
+
N, L = states.shape
|
|
1113
|
+
if N != adata.n_obs:
|
|
1114
|
+
raise ValueError(f"X has N={N} rows but adata.n_obs={adata.n_obs}")
|
|
1115
|
+
|
|
1116
|
+
# map coords -> full-var indices for span_fill
|
|
1117
|
+
full_coords, full_int = _safe_int_coords(adata.var_names)
|
|
1118
|
+
|
|
1119
|
+
# ---- write posterior + states on masked columns only ----
|
|
1120
|
+
masked_idx = np.nonzero(var_mask)[0]
|
|
1121
|
+
masked_coords, _ = _safe_int_coords(adata.var_names[var_mask])
|
|
1122
|
+
|
|
1123
|
+
# build mapping from coords order -> masked column order
|
|
1124
|
+
coord_to_pos_in_decoded = {int(c): i for i, c in enumerate(coords.tolist())}
|
|
1125
|
+
take = np.array(
|
|
1126
|
+
[coord_to_pos_in_decoded.get(int(c), -1) for c in masked_coords.tolist()], dtype=int
|
|
1127
|
+
)
|
|
1128
|
+
good = take >= 0
|
|
1129
|
+
masked_idx = masked_idx[good]
|
|
1130
|
+
take = take[good]
|
|
1131
|
+
|
|
1132
|
+
# states layer
|
|
1133
|
+
states_name = f"{prefix}_states"
|
|
1134
|
+
if states_name not in adata.layers:
|
|
1135
|
+
adata.layers[states_name] = np.full((adata.n_obs, adata.n_vars), -1, dtype=np.int8)
|
|
1136
|
+
adata.layers[states_name][:, masked_idx] = states[:, take].astype(np.int8)
|
|
1137
|
+
if states_name not in appended:
|
|
1138
|
+
appended.append(states_name)
|
|
1139
|
+
|
|
1140
|
+
# posterior layer (requested state)
|
|
1141
|
+
if write_posterior:
|
|
1142
|
+
t_idx = self.resolve_target_state_index(posterior_state)
|
|
1143
|
+
post = gamma[:, :, t_idx].astype(np.float32)
|
|
1144
|
+
post_name = f"{prefix}_posterior_{str(posterior_state).strip().lower().replace(' ', '_').replace('-', '_')}"
|
|
1145
|
+
if post_name not in adata.layers:
|
|
1146
|
+
adata.layers[post_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
|
|
1147
|
+
adata.layers[post_name][:, masked_idx] = post[:, take]
|
|
1148
|
+
if post_name not in appended:
|
|
1149
|
+
appended.append(post_name)
|
|
1150
|
+
|
|
1151
|
+
# ---- feature layers ----
|
|
1152
|
+
if feature_sets is None:
|
|
1153
|
+
cfgd = self._cfg_to_dict(config)
|
|
1154
|
+
feature_sets = normalize_hmm_feature_sets(cfgd.get("hmm_feature_sets", None))
|
|
1155
|
+
|
|
1156
|
+
if not feature_sets:
|
|
1157
|
+
adata.uns[uns_key] = appended
|
|
1158
|
+
adata.uns[uns_flag] = True
|
|
1159
|
+
return None
|
|
1160
|
+
|
|
1161
|
+
# allocate outputs
|
|
1162
|
+
for group, fs in feature_sets.items():
|
|
1163
|
+
fmap = fs.get("features", {}) or {}
|
|
1164
|
+
if not fmap:
|
|
1165
|
+
continue
|
|
1166
|
+
|
|
1167
|
+
all_layer = f"{prefix}_all_{group}_features"
|
|
1168
|
+
if all_layer not in adata.layers:
|
|
1169
|
+
adata.layers[all_layer] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint8)
|
|
1170
|
+
if f"{all_layer}_lengths" not in adata.layers:
|
|
1171
|
+
adata.layers[f"{all_layer}_lengths"] = np.zeros(
|
|
1172
|
+
(adata.n_obs, adata.n_vars), dtype=np.int32
|
|
1173
|
+
)
|
|
1174
|
+
for nm in (all_layer, f"{all_layer}_lengths"):
|
|
1175
|
+
if nm not in appended:
|
|
1176
|
+
appended.append(nm)
|
|
1177
|
+
|
|
1178
|
+
for feat in fmap.keys():
|
|
1179
|
+
nm = f"{prefix}_{feat}"
|
|
1180
|
+
if nm not in adata.layers:
|
|
1181
|
+
adata.layers[nm] = np.zeros(
|
|
1182
|
+
(adata.n_obs, adata.n_vars),
|
|
1183
|
+
dtype=np.int32 if nm.endswith("_lengths") else np.uint8,
|
|
1184
|
+
)
|
|
1185
|
+
if f"{nm}_lengths" not in adata.layers:
|
|
1186
|
+
adata.layers[f"{nm}_lengths"] = np.zeros(
|
|
1187
|
+
(adata.n_obs, adata.n_vars), dtype=np.int32
|
|
1188
|
+
)
|
|
1189
|
+
for outnm in (nm, f"{nm}_lengths"):
|
|
1190
|
+
if outnm not in appended:
|
|
1191
|
+
appended.append(outnm)
|
|
1192
|
+
|
|
1193
|
+
# classify runs per row
|
|
1194
|
+
target_idx = self.resolve_target_state_index(fs.get("state", "Modified"))
|
|
1195
|
+
membership = (
|
|
1196
|
+
(states == target_idx)
|
|
1197
|
+
if str(decode).lower() == "viterbi"
|
|
1198
|
+
else (gamma[:, :, target_idx] >= float(prob_threshold))
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
for i in range(N):
|
|
1202
|
+
runs = self._runs_from_bool(membership[i].astype(bool))
|
|
1203
|
+
for s, e in runs:
|
|
1204
|
+
# genomic length in coords space
|
|
1205
|
+
glen = int(coords[e - 1]) - int(coords[s]) + 1 if e > s else 0
|
|
1206
|
+
if glen <= 0:
|
|
1207
|
+
continue
|
|
1208
|
+
|
|
1209
|
+
# pick feature bin
|
|
1210
|
+
chosen = None
|
|
1211
|
+
for feat_name, (lo, hi) in fmap.items():
|
|
1212
|
+
if float(lo) <= float(glen) < float(hi):
|
|
1213
|
+
chosen = feat_name
|
|
1214
|
+
break
|
|
1215
|
+
if chosen is None:
|
|
863
1216
|
continue
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
1217
|
+
|
|
1218
|
+
# convert span to indices in full var grid
|
|
1219
|
+
if span_fill and full_int:
|
|
1220
|
+
left = int(np.searchsorted(full_coords, int(coords[s]), side="left"))
|
|
1221
|
+
right = int(np.searchsorted(full_coords, int(coords[e - 1]), side="right"))
|
|
1222
|
+
if left >= right:
|
|
1223
|
+
continue
|
|
1224
|
+
adata.layers[f"{prefix}_{chosen}"][i, left:right] = 1
|
|
1225
|
+
adata.layers[f"{prefix}_all_{group}_features"][i, left:right] = 1
|
|
870
1226
|
else:
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
1227
|
+
# only fill at masked indices
|
|
1228
|
+
cols = masked_idx[
|
|
1229
|
+
(masked_coords >= coords[s]) & (masked_coords <= coords[e - 1])
|
|
1230
|
+
]
|
|
1231
|
+
if cols.size == 0:
|
|
1232
|
+
continue
|
|
1233
|
+
adata.layers[f"{prefix}_{chosen}"][i, cols] = 1
|
|
1234
|
+
adata.layers[f"{prefix}_all_{group}_features"][i, cols] = 1
|
|
1235
|
+
|
|
1236
|
+
# lengths derived from binary
|
|
1237
|
+
adata.layers[f"{prefix}_all_{group}_features_lengths"] = (
|
|
1238
|
+
self._write_lengths_for_binary_layer(
|
|
1239
|
+
np.asarray(adata.layers[f"{prefix}_all_{group}_features"])
|
|
1240
|
+
)
|
|
1241
|
+
)
|
|
1242
|
+
for feat in fmap.keys():
|
|
1243
|
+
nm = f"{prefix}_{feat}"
|
|
1244
|
+
adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
|
|
1245
|
+
np.asarray(adata.layers[nm])
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
adata.uns[uns_key] = appended
|
|
1249
|
+
adata.uns[uns_flag] = True
|
|
1250
|
+
return None
|
|
1251
|
+
|
|
1252
|
+
# ------------------------- row-copy helper (workflow uses it) -------------------------
|
|
1253
|
+
|
|
1254
|
+
def _ensure_final_layer_and_assign(
|
|
1255
|
+
self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data
|
|
1256
|
+
):
|
|
1257
|
+
"""
|
|
1258
|
+
Assign rows from sub_data into final_adata.layers[layer_name] for rows where subset_idx_mask is True.
|
|
1259
|
+
Handles dense arrays. If you want sparse support, add it here.
|
|
1260
|
+
"""
|
|
1261
|
+
n_final_obs, n_vars = final_adata.shape
|
|
1262
|
+
final_rows = np.nonzero(np.asarray(subset_idx_mask).astype(bool))[0]
|
|
1263
|
+
sub_arr = np.asarray(sub_data)
|
|
1264
|
+
|
|
1265
|
+
if layer_name not in final_adata.layers:
|
|
1266
|
+
final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
|
|
1267
|
+
|
|
1268
|
+
final_arr = np.asarray(final_adata.layers[layer_name])
|
|
1269
|
+
if sub_arr.shape[0] != final_rows.size:
|
|
1270
|
+
raise ValueError(f"Sub rows {sub_arr.shape[0]} != mask sum {final_rows.size}")
|
|
1271
|
+
final_arr[final_rows, :] = sub_arr
|
|
1272
|
+
final_adata.layers[layer_name] = final_arr
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
# =============================================================================
|
|
1276
|
+
# Single-channel Bernoulli HMM
|
|
1277
|
+
# =============================================================================
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
@register_hmm("single")
|
|
1281
|
+
class SingleBernoulliHMM(BaseHMM):
|
|
1282
|
+
"""
|
|
1283
|
+
Bernoulli emission per state:
|
|
1284
|
+
emission[k] = P(obs==1 | state=k)
|
|
1285
|
+
"""
|
|
1286
|
+
|
|
1287
|
+
def __init__(
|
|
1288
|
+
self,
|
|
1289
|
+
n_states: int = 2,
|
|
1290
|
+
init_emission: Optional[Sequence[float]] = None,
|
|
1291
|
+
eps: float = 1e-8,
|
|
1292
|
+
dtype: torch.dtype = torch.float64,
|
|
1293
|
+
):
|
|
1294
|
+
"""Initialize a single-channel Bernoulli HMM.
|
|
1295
|
+
|
|
1296
|
+
Args:
|
|
1297
|
+
n_states: Number of hidden states.
|
|
1298
|
+
init_emission: Initial emission probabilities per state.
|
|
1299
|
+
eps: Smoothing epsilon for probabilities.
|
|
1300
|
+
dtype: Torch dtype for parameters.
|
|
1301
|
+
"""
|
|
1302
|
+
super().__init__(n_states=n_states, eps=eps, dtype=dtype)
|
|
1303
|
+
if init_emission is None:
|
|
1304
|
+
em = np.full((self.n_states,), 0.5, dtype=float)
|
|
1305
|
+
else:
|
|
1306
|
+
em = np.asarray(init_emission, dtype=float).reshape(-1)[: self.n_states]
|
|
1307
|
+
if em.size != self.n_states:
|
|
1308
|
+
em = np.full((self.n_states,), 0.5, dtype=float)
|
|
1309
|
+
|
|
1310
|
+
self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
|
|
1311
|
+
self._normalize_emission()
|
|
1312
|
+
|
|
1313
|
+
@classmethod
|
|
1314
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1315
|
+
"""Create a single-channel Bernoulli HMM from config.
|
|
1316
|
+
|
|
1317
|
+
Args:
|
|
1318
|
+
cfg: Configuration mapping or object.
|
|
1319
|
+
override: Override values to apply.
|
|
1320
|
+
device: Optional device specifier.
|
|
1321
|
+
|
|
1322
|
+
Returns:
|
|
1323
|
+
Initialized SingleBernoulliHMM instance.
|
|
1324
|
+
"""
|
|
1325
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1326
|
+
if override:
|
|
1327
|
+
merged.update(override)
|
|
1328
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1329
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1330
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1331
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1332
|
+
init_em = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
|
|
1333
|
+
model = cls(n_states=n_states, init_emission=init_em, eps=eps, dtype=dtype)
|
|
1334
|
+
if device is not None:
|
|
1335
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1336
|
+
model._persisted_cfg = merged
|
|
1337
|
+
return model
|
|
1338
|
+
|
|
1339
|
+
def _normalize_emission(self):
|
|
1340
|
+
"""Normalize and clamp emission probabilities in-place."""
|
|
1341
|
+
with torch.no_grad():
|
|
1342
|
+
self.emission.data = self.emission.data.reshape(-1)
|
|
1343
|
+
if self.emission.data.numel() != self.n_states:
|
|
1344
|
+
self.emission.data = torch.full(
|
|
1345
|
+
(self.n_states,), 0.5, dtype=self.dtype, device=self.emission.device
|
|
1346
|
+
)
|
|
1347
|
+
self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
|
|
1348
|
+
|
|
1349
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1350
|
+
"""Move emission parameters to the requested device/dtype."""
|
|
1351
|
+
device = super()._ensure_device_dtype(device)
|
|
1352
|
+
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
1353
|
+
return device
|
|
1354
|
+
|
|
1355
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
1356
|
+
"""Return per-state modified scores for ranking."""
|
|
1357
|
+
return self.emission.detach()
|
|
1358
|
+
|
|
1359
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
1360
|
+
"""
|
|
1361
|
+
obs: (N,L), mask: (N,L) -> logB: (N,L,K)
|
|
1362
|
+
"""
|
|
1363
|
+
p = self.emission # (K,)
|
|
1364
|
+
logp = torch.log(p + self.eps)
|
|
1365
|
+
log1mp = torch.log1p(-p + self.eps)
|
|
1366
|
+
|
|
1367
|
+
o = obs.unsqueeze(-1) # (N,L,1)
|
|
1368
|
+
logB = o * logp.view(1, 1, -1) + (1.0 - o) * log1mp.view(1, 1, -1)
|
|
1369
|
+
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
1370
|
+
return logB
|
|
1371
|
+
|
|
1372
|
+
def _extra_save_payload(self) -> dict:
|
|
1373
|
+
"""Return extra payload data for serialization."""
|
|
1374
|
+
return {"emission": self.emission.detach().cpu()}
|
|
1375
|
+
|
|
1376
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1377
|
+
"""Load serialized emission parameters.
|
|
1378
|
+
|
|
1379
|
+
Args:
|
|
1380
|
+
payload: Serialized payload dictionary.
|
|
1381
|
+
device: Target torch device.
|
|
1382
|
+
"""
|
|
1383
|
+
with torch.no_grad():
|
|
1384
|
+
self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
|
|
1385
|
+
self._normalize_emission()
|
|
1386
|
+
|
|
1387
|
+
def fit_em(
|
|
1388
|
+
self,
|
|
1389
|
+
X: np.ndarray,
|
|
1390
|
+
coords: np.ndarray,
|
|
1391
|
+
*,
|
|
1392
|
+
device: torch.device,
|
|
1393
|
+
max_iter: int,
|
|
1394
|
+
tol: float,
|
|
1395
|
+
update_start: bool,
|
|
1396
|
+
update_trans: bool,
|
|
1397
|
+
update_emission: bool,
|
|
1398
|
+
verbose: bool,
|
|
1399
|
+
**kwargs,
|
|
1400
|
+
) -> List[float]:
|
|
1401
|
+
"""Run EM updates for a single-channel Bernoulli HMM.
|
|
1402
|
+
|
|
1403
|
+
Args:
|
|
1404
|
+
X: Observations array (N, L).
|
|
1405
|
+
coords: Coordinate array aligned to X.
|
|
1406
|
+
device: Torch device.
|
|
1407
|
+
max_iter: Maximum iterations.
|
|
1408
|
+
tol: Convergence tolerance.
|
|
1409
|
+
update_start: Whether to update start probabilities.
|
|
1410
|
+
update_trans: Whether to update transitions.
|
|
1411
|
+
update_emission: Whether to update emission parameters.
|
|
1412
|
+
verbose: Whether to log progress.
|
|
1413
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
1414
|
+
|
|
1415
|
+
Returns:
|
|
1416
|
+
List of log-likelihood proxy values.
|
|
1417
|
+
"""
|
|
1418
|
+
X = np.asarray(X, dtype=float)
|
|
1419
|
+
if X.ndim != 2:
|
|
1420
|
+
raise ValueError("SingleBernoulliHMM expects X shape (N,L).")
|
|
1421
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
1422
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
1423
|
+
|
|
1424
|
+
eps = float(self.eps)
|
|
1425
|
+
K = self.n_states
|
|
1426
|
+
N, L = obs.shape
|
|
1427
|
+
|
|
1428
|
+
hist: List[float] = []
|
|
1429
|
+
for it in range(1, int(max_iter) + 1):
|
|
1430
|
+
gamma = self._forward_backward(obs, mask) # (N,L,K)
|
|
1431
|
+
|
|
1432
|
+
# log-likelihood proxy
|
|
1433
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
1434
|
+
hist.append(ll_proxy)
|
|
1435
|
+
|
|
1436
|
+
# expected start
|
|
1437
|
+
start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
|
|
1438
|
+
|
|
1439
|
+
# expected transitions xi
|
|
1440
|
+
logB = self._log_emission(obs, mask)
|
|
1441
|
+
logA = torch.log(self.trans + eps)
|
|
1442
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1443
|
+
alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
|
|
1444
|
+
for t in range(1, L):
|
|
1445
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
1446
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
1447
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1448
|
+
beta[:, L - 1, :] = 0.0
|
|
1449
|
+
for t in range(L - 2, -1, -1):
|
|
1450
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1451
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
1452
|
+
|
|
1453
|
+
trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
1454
|
+
for t in range(L - 1):
|
|
1455
|
+
valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
|
|
1456
|
+
log_xi = (
|
|
1457
|
+
alpha[:, t, :].unsqueeze(2)
|
|
1458
|
+
+ logA.unsqueeze(0)
|
|
1459
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1460
|
+
)
|
|
1461
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
1462
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
1463
|
+
trans_acc += xi.sum(dim=0)
|
|
1464
|
+
|
|
1465
|
+
# emission update
|
|
1466
|
+
mask_f = mask.float().unsqueeze(-1) # (N,L,1)
|
|
1467
|
+
emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
|
|
1468
|
+
emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
|
|
1469
|
+
|
|
1470
|
+
with torch.no_grad():
|
|
1471
|
+
if update_start:
|
|
1472
|
+
new_start = start_acc + eps
|
|
1473
|
+
self.start.data = new_start / new_start.sum()
|
|
1474
|
+
|
|
1475
|
+
if update_trans:
|
|
1476
|
+
new_trans = trans_acc + eps
|
|
1477
|
+
rs = new_trans.sum(dim=1, keepdim=True)
|
|
1478
|
+
rs[rs == 0.0] = 1.0
|
|
1479
|
+
self.trans.data = new_trans / rs
|
|
1480
|
+
|
|
1481
|
+
if update_emission:
|
|
1482
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
1483
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
1484
|
+
|
|
1485
|
+
self._normalize_params()
|
|
1486
|
+
self._normalize_emission()
|
|
931
1487
|
|
|
932
|
-
if not feature_sets:
|
|
933
1488
|
if verbose:
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
device
|
|
977
|
-
|
|
1489
|
+
logger.info(
|
|
1490
|
+
"[SingleBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
|
|
1491
|
+
it,
|
|
1492
|
+
hist[-1],
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
1496
|
+
break
|
|
1497
|
+
|
|
1498
|
+
return hist
|
|
1499
|
+
|
|
1500
|
+
def adapt_emissions(
|
|
1501
|
+
self,
|
|
1502
|
+
X: np.ndarray,
|
|
1503
|
+
coords: Optional[np.ndarray] = None,
|
|
1504
|
+
*,
|
|
1505
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1506
|
+
iters: Optional[int] = None,
|
|
1507
|
+
max_iter: Optional[int] = None, # alias for your trainer
|
|
1508
|
+
verbose: bool = False,
|
|
1509
|
+
**kwargs,
|
|
1510
|
+
):
|
|
1511
|
+
"""Adapt emissions with legacy parameter names.
|
|
1512
|
+
|
|
1513
|
+
Args:
|
|
1514
|
+
X: Observations array.
|
|
1515
|
+
coords: Optional coordinate array.
|
|
1516
|
+
device: Device specifier.
|
|
1517
|
+
iters: Number of iterations.
|
|
1518
|
+
max_iter: Alias for iters.
|
|
1519
|
+
verbose: Whether to log progress.
|
|
1520
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
1521
|
+
|
|
1522
|
+
Returns:
|
|
1523
|
+
List of log-likelihood values.
|
|
1524
|
+
"""
|
|
1525
|
+
if iters is None:
|
|
1526
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
1527
|
+
return super().adapt_emissions(
|
|
1528
|
+
np.asarray(X, dtype=float),
|
|
1529
|
+
coords if coords is not None else None,
|
|
1530
|
+
iters=int(iters),
|
|
1531
|
+
device=device,
|
|
1532
|
+
verbose=verbose,
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1535
|
+
|
|
1536
|
+
# =============================================================================
|
|
1537
|
+
# Multi-channel Bernoulli HMM (union coordinate grid)
|
|
1538
|
+
# =============================================================================
|
|
1539
|
+
|
|
1540
|
+
|
|
1541
|
+
@register_hmm("multi")
|
|
1542
|
+
class MultiBernoulliHMM(BaseHMM):
|
|
1543
|
+
"""
|
|
1544
|
+
Multi-channel independent Bernoulli:
|
|
1545
|
+
emission[k,c] = P(obs_c==1 | state=k)
|
|
1546
|
+
X must be (N,L,C) on a union coordinate grid; NaN per-channel allowed.
|
|
1547
|
+
"""
|
|
1548
|
+
|
|
1549
|
+
def __init__(
|
|
1550
|
+
self,
|
|
1551
|
+
n_states: int = 2,
|
|
1552
|
+
n_channels: int = 2,
|
|
1553
|
+
init_emission: Optional[Any] = None,
|
|
1554
|
+
eps: float = 1e-8,
|
|
1555
|
+
dtype: torch.dtype = torch.float64,
|
|
1556
|
+
):
|
|
1557
|
+
"""Initialize a multi-channel Bernoulli HMM.
|
|
1558
|
+
|
|
1559
|
+
Args:
|
|
1560
|
+
n_states: Number of hidden states.
|
|
1561
|
+
n_channels: Number of observed channels.
|
|
1562
|
+
init_emission: Initial emission probabilities.
|
|
1563
|
+
eps: Smoothing epsilon for probabilities.
|
|
1564
|
+
dtype: Torch dtype for parameters.
|
|
1565
|
+
"""
|
|
1566
|
+
super().__init__(n_states=n_states, eps=eps, dtype=dtype)
|
|
1567
|
+
self.n_channels = int(n_channels)
|
|
1568
|
+
if self.n_channels < 1:
|
|
1569
|
+
raise ValueError("n_channels must be >=1")
|
|
978
1570
|
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
1571
|
+
if init_emission is None:
|
|
1572
|
+
em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
|
|
1573
|
+
else:
|
|
1574
|
+
arr = np.asarray(init_emission, dtype=float)
|
|
982
1575
|
if arr.ndim == 1:
|
|
983
|
-
arr = arr
|
|
984
|
-
|
|
985
|
-
# squeeze trailing singletons
|
|
986
|
-
while arr.ndim > 2 and arr.shape[-1] == 1:
|
|
987
|
-
arr = _np.squeeze(arr, axis=-1)
|
|
988
|
-
if arr.ndim != 2:
|
|
989
|
-
raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
|
|
990
|
-
return arr
|
|
991
|
-
|
|
992
|
-
def calculate_batch_distances(intervals_list, threshold_local=0.9):
|
|
993
|
-
results_local = []
|
|
994
|
-
for intervals in intervals_list:
|
|
995
|
-
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
996
|
-
results_local.append([])
|
|
997
|
-
continue
|
|
998
|
-
valid = [iv for iv in intervals if iv[2] > threshold_local]
|
|
999
|
-
if len(valid) <= 1:
|
|
1000
|
-
results_local.append([])
|
|
1001
|
-
continue
|
|
1002
|
-
valid = sorted(valid, key=lambda x: x[0])
|
|
1003
|
-
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
1004
|
-
results_local.append(dists)
|
|
1005
|
-
return results_local
|
|
1006
|
-
|
|
1007
|
-
def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
|
|
1008
|
-
# Accept numpy arrays or torch tensors
|
|
1009
|
-
if isinstance(predicted_states_batch, _torch.Tensor):
|
|
1010
|
-
pred_np = predicted_states_batch.detach().cpu().numpy()
|
|
1576
|
+
arr = arr.reshape(-1, 1)
|
|
1577
|
+
em = np.repeat(arr[: self.n_states, :], self.n_channels, axis=1)
|
|
1011
1578
|
else:
|
|
1012
|
-
|
|
1013
|
-
if
|
|
1014
|
-
|
|
1015
|
-
else:
|
|
1016
|
-
probs_np = _np.asarray(probabilities_batch)
|
|
1017
|
-
|
|
1018
|
-
batch_size, L = pred_np.shape
|
|
1019
|
-
all_classifications_local = []
|
|
1020
|
-
# allow caller to pass arbitrary state labels mapping; default two-state mapping:
|
|
1021
|
-
state_labels = ["Non-Modified", "Modified"]
|
|
1022
|
-
try:
|
|
1023
|
-
target_idx = state_labels.index(target_state)
|
|
1024
|
-
except ValueError:
|
|
1025
|
-
target_idx = 1 # fallback
|
|
1026
|
-
|
|
1027
|
-
for b in range(batch_size):
|
|
1028
|
-
predicted_states = pred_np[b]
|
|
1029
|
-
probabilities = probs_np[b]
|
|
1030
|
-
regions = []
|
|
1031
|
-
current_start, current_length, current_probs = None, 0, []
|
|
1032
|
-
for i, state_index in enumerate(predicted_states):
|
|
1033
|
-
state_prob = float(probabilities[i][state_index])
|
|
1034
|
-
if state_index == target_idx:
|
|
1035
|
-
if current_start is None:
|
|
1036
|
-
current_start = i
|
|
1037
|
-
current_length += 1
|
|
1038
|
-
current_probs.append(state_prob)
|
|
1039
|
-
elif current_start is not None:
|
|
1040
|
-
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1041
|
-
current_start, current_length, current_probs = None, 0, []
|
|
1042
|
-
if current_start is not None:
|
|
1043
|
-
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1044
|
-
|
|
1045
|
-
final = []
|
|
1046
|
-
for start, length, prob in regions:
|
|
1047
|
-
# compute genomic length try/catch
|
|
1048
|
-
try:
|
|
1049
|
-
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
1050
|
-
except Exception:
|
|
1051
|
-
feature_length = int(length)
|
|
1052
|
-
|
|
1053
|
-
# classification_mapping values are (lo, hi) tuples or lists
|
|
1054
|
-
label = None
|
|
1055
|
-
for ftype, rng in classification_mapping.items():
|
|
1056
|
-
lo, hi = rng[0], rng[1]
|
|
1057
|
-
try:
|
|
1058
|
-
if lo <= feature_length < hi:
|
|
1059
|
-
label = ftype
|
|
1060
|
-
break
|
|
1061
|
-
except Exception:
|
|
1062
|
-
continue
|
|
1063
|
-
if label is None:
|
|
1064
|
-
# fallback to first mapping key or 'unknown'
|
|
1065
|
-
label = next(iter(classification_mapping.keys()), "feature")
|
|
1066
|
-
|
|
1067
|
-
# Store reported start coordinate in same coordinate system as `coordinates`.
|
|
1068
|
-
try:
|
|
1069
|
-
genomic_start = int(coordinates[start])
|
|
1070
|
-
except Exception:
|
|
1071
|
-
genomic_start = int(start)
|
|
1072
|
-
final.append((genomic_start, feature_length, label, prob))
|
|
1073
|
-
all_classifications_local.append(final)
|
|
1074
|
-
return all_classifications_local
|
|
1075
|
-
|
|
1076
|
-
# -----------------------------------------------------------------------
|
|
1077
|
-
|
|
1078
|
-
# Ensure obs_column is categorical-like for iteration
|
|
1079
|
-
sseries = adata.obs[obs_column]
|
|
1080
|
-
if not pd.api.types.is_categorical_dtype(sseries):
|
|
1081
|
-
sseries = sseries.astype("category")
|
|
1082
|
-
references = list(sseries.cat.categories)
|
|
1083
|
-
|
|
1084
|
-
ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
|
|
1085
|
-
for ref in ref_iter:
|
|
1086
|
-
# subset reads with this obs_column value
|
|
1087
|
-
ref_mask = adata.obs[obs_column] == ref
|
|
1088
|
-
ref_subset = adata[ref_mask].copy()
|
|
1089
|
-
combined_mask = None
|
|
1090
|
-
|
|
1091
|
-
# per-methbase processing
|
|
1092
|
-
for methbase in methbases:
|
|
1093
|
-
key_lower = methbase.strip().lower()
|
|
1094
|
-
|
|
1095
|
-
# map several common synonyms -> canonical lookup
|
|
1096
|
-
if key_lower in ("a",):
|
|
1097
|
-
pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
|
|
1098
|
-
elif key_lower in ("c", "any_c", "anyc", "any-c"):
|
|
1099
|
-
# unify 'C' or 'any_C' names to the any_C var column
|
|
1100
|
-
pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
|
|
1101
|
-
elif key_lower in ("gpc", "gpc_site", "gpc-site"):
|
|
1102
|
-
pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
|
|
1103
|
-
elif key_lower in ("cpg", "cpg_site", "cpg-site"):
|
|
1104
|
-
pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
|
|
1105
|
-
else:
|
|
1106
|
-
# try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
|
|
1107
|
-
alt_col = f"{ref}_{methbase}_site"
|
|
1108
|
-
pos_mask = ref_subset.var.get(alt_col, None)
|
|
1579
|
+
em = arr[: self.n_states, : self.n_channels]
|
|
1580
|
+
if em.shape != (self.n_states, self.n_channels):
|
|
1581
|
+
em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
|
|
1109
1582
|
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
|
|
1583
|
+
self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
|
|
1584
|
+
self._normalize_emission()
|
|
1113
1585
|
|
|
1114
|
-
|
|
1115
|
-
|
|
1586
|
+
@classmethod
|
|
1587
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1588
|
+
"""Create a multi-channel Bernoulli HMM from config.
|
|
1116
1589
|
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
n_reads = matrix.shape[0]
|
|
1590
|
+
Args:
|
|
1591
|
+
cfg: Configuration mapping or object.
|
|
1592
|
+
override: Override values to apply.
|
|
1593
|
+
device: Optional device specifier.
|
|
1122
1594
|
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
if use_viterbi:
|
|
1143
|
-
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1144
|
-
pred_states = _np.asarray(paths)
|
|
1145
|
-
else:
|
|
1146
|
-
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1595
|
+
Returns:
|
|
1596
|
+
Initialized MultiBernoulliHMM instance.
|
|
1597
|
+
"""
|
|
1598
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1599
|
+
if override:
|
|
1600
|
+
merged.update(override)
|
|
1601
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1602
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1603
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1604
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1605
|
+
n_channels = int(merged.get("hmm_n_channels", merged.get("n_channels", 2)))
|
|
1606
|
+
init_em = merged.get("hmm_init_emission_probs", None)
|
|
1607
|
+
model = cls(
|
|
1608
|
+
n_states=n_states, n_channels=n_channels, init_emission=init_em, eps=eps, dtype=dtype
|
|
1609
|
+
)
|
|
1610
|
+
if device is not None:
|
|
1611
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1612
|
+
model._persisted_cfg = merged
|
|
1613
|
+
return model
|
|
1147
1614
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
if verbose:
|
|
1179
|
-
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:Combined chunks")
|
|
1180
|
-
for start_idx in chunk_iter:
|
|
1181
|
-
stop_idx = min(n_reads_comb, start_idx + batch_size)
|
|
1182
|
-
chunk = matrix[start_idx:stop_idx]
|
|
1183
|
-
seqs = chunk.tolist()
|
|
1184
|
-
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1185
|
-
if len(gammas) == 0:
|
|
1186
|
-
continue
|
|
1187
|
-
probs_batch = _np.stack(gammas, axis=0)
|
|
1188
|
-
if use_viterbi:
|
|
1189
|
-
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1190
|
-
pred_states = _np.asarray(paths)
|
|
1191
|
-
else:
|
|
1192
|
-
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1193
|
-
|
|
1194
|
-
for key, fs in feature_sets.items():
|
|
1195
|
-
if key == "cpg":
|
|
1196
|
-
continue
|
|
1197
|
-
state_target = fs.get("state", "Modified")
|
|
1198
|
-
feature_map = fs.get("features", {})
|
|
1199
|
-
classifications = classify_batch_local(pred_states, probs_batch, coords_comb, feature_map, target_state=state_target)
|
|
1200
|
-
row_indices = list(comb.obs.index[start_idx:stop_idx])
|
|
1201
|
-
for i_local, idx in enumerate(row_indices):
|
|
1202
|
-
for start, length, label, prob in classifications[i_local]:
|
|
1203
|
-
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
1204
|
-
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
1205
|
-
|
|
1206
|
-
# CpG special handling
|
|
1207
|
-
if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
|
|
1208
|
-
cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
|
|
1209
|
-
for ref in cpg_iter:
|
|
1210
|
-
ref_mask = adata.obs[obs_column] == ref
|
|
1211
|
-
ref_subset = adata[ref_mask].copy()
|
|
1212
|
-
pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
|
|
1213
|
-
if pos_mask.sum() == 0:
|
|
1214
|
-
continue
|
|
1215
|
-
cpg_sub = ref_subset[:, pos_mask]
|
|
1216
|
-
matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
|
|
1217
|
-
matrix = _ensure_2d_array_like(matrix)
|
|
1218
|
-
n_reads = matrix.shape[0]
|
|
1219
|
-
try:
|
|
1220
|
-
coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
|
|
1221
|
-
except Exception:
|
|
1222
|
-
coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
|
|
1223
|
-
|
|
1224
|
-
chunk_iter = range(0, n_reads, batch_size)
|
|
1225
|
-
if verbose:
|
|
1226
|
-
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
|
|
1227
|
-
for start_idx in chunk_iter:
|
|
1228
|
-
stop_idx = min(n_reads, start_idx + batch_size)
|
|
1229
|
-
chunk = matrix[start_idx:stop_idx]
|
|
1230
|
-
seqs = chunk.tolist()
|
|
1231
|
-
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1232
|
-
if len(gammas) == 0:
|
|
1233
|
-
continue
|
|
1234
|
-
probs_batch = _np.stack(gammas, axis=0)
|
|
1235
|
-
if use_viterbi:
|
|
1236
|
-
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1237
|
-
pred_states = _np.asarray(paths)
|
|
1238
|
-
else:
|
|
1239
|
-
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1240
|
-
|
|
1241
|
-
fs = feature_sets["cpg"]
|
|
1242
|
-
state_target = fs.get("state", "Modified")
|
|
1243
|
-
feature_map = fs.get("features", {})
|
|
1244
|
-
classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
|
|
1245
|
-
row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
|
|
1246
|
-
for i_local, idx in enumerate(row_indices):
|
|
1247
|
-
for start, length, label, prob in classifications[i_local]:
|
|
1248
|
-
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
1249
|
-
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
1250
|
-
|
|
1251
|
-
# finalize: convert intervals into binary layers and distances
|
|
1252
|
-
try:
|
|
1253
|
-
coordinates = _np.asarray(adata.var_names, dtype=int)
|
|
1254
|
-
coords_are_ints = True
|
|
1255
|
-
except Exception:
|
|
1256
|
-
coordinates = _np.arange(adata.shape[1], dtype=int)
|
|
1257
|
-
coords_are_ints = False
|
|
1258
|
-
|
|
1259
|
-
features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
|
|
1260
|
-
for feature in features_iter:
|
|
1261
|
-
bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1262
|
-
counts = _np.zeros(adata.shape[0], dtype=int)
|
|
1263
|
-
|
|
1264
|
-
# new: integer-length layer (0 where not inside a feature)
|
|
1265
|
-
len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1266
|
-
|
|
1267
|
-
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
1268
|
-
if not isinstance(intervals, list):
|
|
1269
|
-
intervals = []
|
|
1270
|
-
for start, length, prob in intervals:
|
|
1271
|
-
if prob > threshold:
|
|
1272
|
-
if coords_are_ints:
|
|
1273
|
-
# map genomic start/length into index interval [start_idx, end_idx)
|
|
1274
|
-
start_idx = _np.searchsorted(coordinates, int(start), side="left")
|
|
1275
|
-
end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
|
|
1276
|
-
else:
|
|
1277
|
-
start_idx = int(start)
|
|
1278
|
-
end_idx = start_idx + int(length)
|
|
1279
|
-
|
|
1280
|
-
start_idx = max(0, min(start_idx, adata.shape[1]))
|
|
1281
|
-
end_idx = max(0, min(end_idx, adata.shape[1]))
|
|
1282
|
-
|
|
1283
|
-
if start_idx < end_idx:
|
|
1284
|
-
span = end_idx - start_idx # number of positions covered
|
|
1285
|
-
# set binary mask
|
|
1286
|
-
bin_matrix[row_idx, start_idx:end_idx] = 1
|
|
1287
|
-
# set length mask: use maximum in case of overlaps
|
|
1288
|
-
existing = len_matrix[row_idx, start_idx:end_idx]
|
|
1289
|
-
len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
|
|
1290
|
-
counts[row_idx] += 1
|
|
1291
|
-
|
|
1292
|
-
# write binary layer and length layer, track appended names
|
|
1293
|
-
adata.layers[feature] = bin_matrix
|
|
1294
|
-
appended_layers.append(feature)
|
|
1295
|
-
|
|
1296
|
-
# name the integer-length layer (choose suffix you like)
|
|
1297
|
-
length_layer_name = f"{feature}_lengths"
|
|
1298
|
-
adata.layers[length_layer_name] = len_matrix
|
|
1299
|
-
appended_layers.append(length_layer_name)
|
|
1300
|
-
|
|
1301
|
-
adata.obs[f"n_{feature}"] = counts
|
|
1302
|
-
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
1303
|
-
|
|
1304
|
-
# Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
|
|
1305
|
-
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
|
|
1306
|
-
new_list = existing + [l for l in appended_layers if l not in existing]
|
|
1307
|
-
adata.uns[uns_key] = new_list
|
|
1308
|
-
|
|
1309
|
-
# Mark that the annotation has been completed
|
|
1310
|
-
adata.uns[uns_flag] = True
|
|
1615
|
+
def _normalize_emission(self):
|
|
1616
|
+
"""Normalize and clamp emission probabilities in-place."""
|
|
1617
|
+
with torch.no_grad():
|
|
1618
|
+
self.emission.data = self.emission.data.reshape(self.n_states, self.n_channels)
|
|
1619
|
+
self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
|
|
1620
|
+
|
|
1621
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1622
|
+
"""Move emission parameters to the requested device/dtype."""
|
|
1623
|
+
device = super()._ensure_device_dtype(device)
|
|
1624
|
+
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
1625
|
+
return device
|
|
1626
|
+
|
|
1627
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
1628
|
+
"""Return per-state modified scores for ranking."""
|
|
1629
|
+
# more “modified” = higher mean P(1) across channels
|
|
1630
|
+
return self.emission.detach().mean(dim=1)
|
|
1631
|
+
|
|
1632
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
1633
|
+
"""
|
|
1634
|
+
obs: (N,L,C), mask: (N,L,C) -> logB: (N,L,K)
|
|
1635
|
+
"""
|
|
1636
|
+
N, L, C = obs.shape
|
|
1637
|
+
K = self.n_states
|
|
1638
|
+
|
|
1639
|
+
p = self.emission # (K,C)
|
|
1640
|
+
logp = torch.log(p + self.eps).view(1, 1, K, C)
|
|
1641
|
+
log1mp = torch.log1p(-p + self.eps).view(1, 1, K, C)
|
|
1642
|
+
|
|
1643
|
+
o = obs.unsqueeze(2) # (N,L,1,C)
|
|
1644
|
+
m = mask.unsqueeze(2) # (N,L,1,C)
|
|
1311
1645
|
|
|
1312
|
-
|
|
1646
|
+
logBC = o * logp + (1.0 - o) * log1mp
|
|
1647
|
+
logBC = torch.where(m, logBC, torch.zeros_like(logBC))
|
|
1648
|
+
return logBC.sum(dim=3) # sum channels -> (N,L,K)
|
|
1313
1649
|
|
|
1314
|
-
def
|
|
1650
|
+
def _extra_save_payload(self) -> dict:
|
|
1651
|
+
"""Return extra payload data for serialization."""
|
|
1652
|
+
return {"n_channels": int(self.n_channels), "emission": self.emission.detach().cpu()}
|
|
1653
|
+
|
|
1654
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1655
|
+
"""Load serialized emission parameters.
|
|
1656
|
+
|
|
1657
|
+
Args:
|
|
1658
|
+
payload: Serialized payload dictionary.
|
|
1659
|
+
device: Target torch device.
|
|
1660
|
+
"""
|
|
1661
|
+
self.n_channels = int(payload.get("n_channels", self.n_channels))
|
|
1662
|
+
with torch.no_grad():
|
|
1663
|
+
self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
|
|
1664
|
+
self._normalize_emission()
|
|
1665
|
+
|
|
1666
|
+
def fit_em(
|
|
1315
1667
|
self,
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1668
|
+
X: np.ndarray,
|
|
1669
|
+
coords: np.ndarray,
|
|
1670
|
+
*,
|
|
1671
|
+
device: torch.device,
|
|
1672
|
+
max_iter: int,
|
|
1673
|
+
tol: float,
|
|
1674
|
+
update_start: bool,
|
|
1675
|
+
update_trans: bool,
|
|
1676
|
+
update_emission: bool,
|
|
1677
|
+
verbose: bool,
|
|
1678
|
+
**kwargs,
|
|
1679
|
+
) -> List[float]:
|
|
1680
|
+
"""Run EM updates for a multi-channel Bernoulli HMM.
|
|
1681
|
+
|
|
1682
|
+
Args:
|
|
1683
|
+
X: Observations array (N, L, C).
|
|
1684
|
+
coords: Coordinate array aligned to X.
|
|
1685
|
+
device: Torch device.
|
|
1686
|
+
max_iter: Maximum iterations.
|
|
1687
|
+
tol: Convergence tolerance.
|
|
1688
|
+
update_start: Whether to update start probabilities.
|
|
1689
|
+
update_trans: Whether to update transitions.
|
|
1690
|
+
update_emission: Whether to update emission parameters.
|
|
1691
|
+
verbose: Whether to log progress.
|
|
1692
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
1693
|
+
|
|
1694
|
+
Returns:
|
|
1695
|
+
List of log-likelihood proxy values.
|
|
1696
|
+
"""
|
|
1697
|
+
X = np.asarray(X, dtype=float)
|
|
1698
|
+
if X.ndim != 3:
|
|
1699
|
+
raise ValueError("MultiBernoulliHMM expects X shape (N,L,C).")
|
|
1700
|
+
|
|
1701
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
1702
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
1703
|
+
|
|
1704
|
+
eps = float(self.eps)
|
|
1705
|
+
K = self.n_states
|
|
1706
|
+
N, L, C = obs.shape
|
|
1707
|
+
|
|
1708
|
+
self._ensure_n_channels(C, device)
|
|
1709
|
+
|
|
1710
|
+
hist: List[float] = []
|
|
1711
|
+
for it in range(1, int(max_iter) + 1):
|
|
1712
|
+
gamma = self._forward_backward(obs, mask) # (N,L,K)
|
|
1713
|
+
|
|
1714
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
1715
|
+
hist.append(ll_proxy)
|
|
1716
|
+
|
|
1717
|
+
# expected start
|
|
1718
|
+
start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
|
|
1719
|
+
|
|
1720
|
+
# transitions xi
|
|
1721
|
+
logB = self._log_emission(obs, mask)
|
|
1722
|
+
logA = torch.log(self.trans + eps)
|
|
1723
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1724
|
+
alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
|
|
1725
|
+
for t in range(1, L):
|
|
1726
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
1727
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
1728
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1729
|
+
beta[:, L - 1, :] = 0.0
|
|
1730
|
+
for t in range(L - 2, -1, -1):
|
|
1731
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1732
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
1733
|
+
|
|
1734
|
+
trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
1735
|
+
# valid timestep if at least one channel observed at both positions
|
|
1736
|
+
valid_pos = mask.any(dim=2) # (N,L)
|
|
1737
|
+
for t in range(L - 1):
|
|
1738
|
+
valid_t = (valid_pos[:, t] & valid_pos[:, t + 1]).float().view(N, 1, 1)
|
|
1739
|
+
log_xi = (
|
|
1740
|
+
alpha[:, t, :].unsqueeze(2)
|
|
1741
|
+
+ logA.unsqueeze(0)
|
|
1742
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1743
|
+
)
|
|
1744
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
1745
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
1746
|
+
trans_acc += xi.sum(dim=0)
|
|
1747
|
+
|
|
1748
|
+
# emission update per channel
|
|
1749
|
+
gamma_k = gamma.unsqueeze(-1) # (N,L,K,1)
|
|
1750
|
+
obs_c = obs.unsqueeze(2) # (N,L,1,C)
|
|
1751
|
+
mask_c = mask.unsqueeze(2).float() # (N,L,1,C)
|
|
1752
|
+
|
|
1753
|
+
emit_num = (gamma_k * obs_c * mask_c).sum(dim=(0, 1)) # (K,C)
|
|
1754
|
+
emit_den = (gamma_k * mask_c).sum(dim=(0, 1)) # (K,C)
|
|
1755
|
+
|
|
1756
|
+
with torch.no_grad():
|
|
1757
|
+
if update_start:
|
|
1758
|
+
new_start = start_acc + eps
|
|
1759
|
+
self.start.data = new_start / new_start.sum()
|
|
1760
|
+
|
|
1761
|
+
if update_trans:
|
|
1762
|
+
new_trans = trans_acc + eps
|
|
1763
|
+
rs = new_trans.sum(dim=1, keepdim=True)
|
|
1764
|
+
rs[rs == 0.0] = 1.0
|
|
1765
|
+
self.trans.data = new_trans / rs
|
|
1766
|
+
|
|
1767
|
+
if update_emission:
|
|
1768
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
1769
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
1770
|
+
|
|
1771
|
+
self._normalize_params()
|
|
1772
|
+
self._normalize_emission()
|
|
1773
|
+
|
|
1774
|
+
if verbose:
|
|
1775
|
+
logger.info(
|
|
1776
|
+
"[MultiBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
|
|
1777
|
+
it,
|
|
1778
|
+
hist[-1],
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
1782
|
+
break
|
|
1783
|
+
|
|
1784
|
+
return hist
|
|
1785
|
+
|
|
1786
|
+
def adapt_emissions(
|
|
1787
|
+
self,
|
|
1788
|
+
X: np.ndarray,
|
|
1789
|
+
coords: Optional[np.ndarray] = None,
|
|
1790
|
+
*,
|
|
1791
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1792
|
+
iters: Optional[int] = None,
|
|
1793
|
+
max_iter: Optional[int] = None, # alias for your trainer
|
|
1325
1794
|
verbose: bool = False,
|
|
1795
|
+
**kwargs,
|
|
1326
1796
|
):
|
|
1797
|
+
"""Adapt emissions with legacy parameter names.
|
|
1798
|
+
|
|
1799
|
+
Args:
|
|
1800
|
+
X: Observations array.
|
|
1801
|
+
coords: Optional coordinate array.
|
|
1802
|
+
device: Device specifier.
|
|
1803
|
+
iters: Number of iterations.
|
|
1804
|
+
max_iter: Alias for iters.
|
|
1805
|
+
verbose: Whether to log progress.
|
|
1806
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
1807
|
+
|
|
1808
|
+
Returns:
|
|
1809
|
+
List of log-likelihood values.
|
|
1327
1810
|
"""
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
|
|
1338
|
-
merged_suffix : str
|
|
1339
|
-
Suffix appended to original layer for the merged binary layer (default "_merged").
|
|
1340
|
-
length_layer_suffix : str
|
|
1341
|
-
Suffix appended after merged suffix for the lengths layer (default "_lengths").
|
|
1342
|
-
update_obs : bool
|
|
1343
|
-
If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
|
|
1344
|
-
prob_strategy : str
|
|
1345
|
-
How to combine probs when merging ('mean', 'max', 'orig_first').
|
|
1346
|
-
inplace : bool
|
|
1347
|
-
If False, returns a new AnnData with changes (original untouched).
|
|
1348
|
-
overwrite : bool
|
|
1349
|
-
If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
|
|
1350
|
-
"""
|
|
1351
|
-
import numpy as _np
|
|
1352
|
-
from scipy.sparse import issparse
|
|
1811
|
+
if iters is None:
|
|
1812
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
1813
|
+
return super().adapt_emissions(
|
|
1814
|
+
np.asarray(X, dtype=float),
|
|
1815
|
+
coords if coords is not None else None,
|
|
1816
|
+
iters=int(iters),
|
|
1817
|
+
device=device,
|
|
1818
|
+
verbose=verbose,
|
|
1819
|
+
)
|
|
1353
1820
|
|
|
1354
|
-
|
|
1355
|
-
|
|
1821
|
+
def _ensure_n_channels(self, C: int, device: torch.device):
|
|
1822
|
+
"""Expand emission parameters when channel count changes.
|
|
1356
1823
|
|
|
1357
|
-
|
|
1358
|
-
|
|
1824
|
+
Args:
|
|
1825
|
+
C: Target channel count.
|
|
1826
|
+
device: Torch device for the new parameters.
|
|
1827
|
+
"""
|
|
1828
|
+
C = int(C)
|
|
1829
|
+
if C == self.n_channels:
|
|
1830
|
+
return
|
|
1831
|
+
with torch.no_grad():
|
|
1832
|
+
old = self.emission.detach().cpu().numpy() # (K, Cold)
|
|
1833
|
+
K = old.shape[0]
|
|
1834
|
+
new = np.full((K, C), 0.5, dtype=float)
|
|
1835
|
+
m = min(old.shape[1], C)
|
|
1836
|
+
new[:, :m] = old[:, :m]
|
|
1837
|
+
if C > old.shape[1]:
|
|
1838
|
+
fill = old.mean(axis=1, keepdims=True)
|
|
1839
|
+
new[:, m:] = fill
|
|
1840
|
+
self.n_channels = C
|
|
1841
|
+
self.emission = nn.Parameter(
|
|
1842
|
+
torch.tensor(new, dtype=self.dtype, device=device), requires_grad=False
|
|
1843
|
+
)
|
|
1844
|
+
self._normalize_emission()
|
|
1845
|
+
|
|
1846
|
+
|
|
1847
|
+
# =============================================================================
|
|
1848
|
+
# Distance-binned transitions (single-channel only)
|
|
1849
|
+
# =============================================================================
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
@register_hmm("single_distance_binned")
|
|
1853
|
+
class DistanceBinnedSingleBernoulliHMM(SingleBernoulliHMM):
|
|
1854
|
+
"""
|
|
1855
|
+
Transition matrix depends on binned distances between consecutive coords.
|
|
1359
1856
|
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1857
|
+
Config keys:
|
|
1858
|
+
hmm_distance_bins: list[int] edges (bp)
|
|
1859
|
+
hmm_init_transitions_by_bin: optional (n_bins,K,K)
|
|
1860
|
+
"""
|
|
1363
1861
|
|
|
1364
|
-
|
|
1365
|
-
|
|
1862
|
+
def __init__(
|
|
1863
|
+
self,
|
|
1864
|
+
n_states: int = 2,
|
|
1865
|
+
init_emission: Optional[Sequence[float]] = None,
|
|
1866
|
+
distance_bins: Optional[Sequence[int]] = None,
|
|
1867
|
+
init_trans_by_bin: Optional[Any] = None,
|
|
1868
|
+
eps: float = 1e-8,
|
|
1869
|
+
dtype: torch.dtype = torch.float64,
|
|
1870
|
+
):
|
|
1871
|
+
"""Initialize a distance-binned transition HMM.
|
|
1872
|
+
|
|
1873
|
+
Args:
|
|
1874
|
+
n_states: Number of hidden states.
|
|
1875
|
+
init_emission: Initial emission probabilities per state.
|
|
1876
|
+
distance_bins: Distance bin edges in base pairs.
|
|
1877
|
+
init_trans_by_bin: Initial transition matrices per bin.
|
|
1878
|
+
eps: Smoothing epsilon for probabilities.
|
|
1879
|
+
dtype: Torch dtype for parameters.
|
|
1880
|
+
"""
|
|
1881
|
+
super().__init__(n_states=n_states, init_emission=init_emission, eps=eps, dtype=dtype)
|
|
1882
|
+
|
|
1883
|
+
self.distance_bins = np.asarray(
|
|
1884
|
+
distance_bins if distance_bins is not None else [1, 5, 10, 25, 50, 100], dtype=int
|
|
1885
|
+
)
|
|
1886
|
+
self.n_bins = int(len(self.distance_bins) + 1)
|
|
1366
1887
|
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1888
|
+
if init_trans_by_bin is None:
|
|
1889
|
+
base = self.trans.detach().cpu().numpy()
|
|
1890
|
+
tb = np.stack([base for _ in range(self.n_bins)], axis=0)
|
|
1370
1891
|
else:
|
|
1371
|
-
|
|
1892
|
+
tb = np.asarray(init_trans_by_bin, dtype=float)
|
|
1893
|
+
if tb.shape != (self.n_bins, self.n_states, self.n_states):
|
|
1894
|
+
base = self.trans.detach().cpu().numpy()
|
|
1895
|
+
tb = np.stack([base for _ in range(self.n_bins)], axis=0)
|
|
1372
1896
|
|
|
1373
|
-
|
|
1897
|
+
self.trans_by_bin = nn.Parameter(torch.tensor(tb, dtype=self.dtype), requires_grad=False)
|
|
1898
|
+
self._normalize_trans_by_bin()
|
|
1374
1899
|
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
coords_are_ints = True
|
|
1379
|
-
except Exception:
|
|
1380
|
-
coords = _np.arange(n_cols, dtype=int)
|
|
1381
|
-
coords_are_ints = False
|
|
1382
|
-
|
|
1383
|
-
# helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
|
|
1384
|
-
def _runs_from_mask(mask_1d):
|
|
1385
|
-
idx = _np.nonzero(mask_1d)[0]
|
|
1386
|
-
if idx.size == 0:
|
|
1387
|
-
return []
|
|
1388
|
-
runs = []
|
|
1389
|
-
start = idx[0]
|
|
1390
|
-
prev = idx[0]
|
|
1391
|
-
for i in idx[1:]:
|
|
1392
|
-
if i == prev + 1:
|
|
1393
|
-
prev = i
|
|
1394
|
-
continue
|
|
1395
|
-
runs.append((start, prev + 1))
|
|
1396
|
-
start = i
|
|
1397
|
-
prev = i
|
|
1398
|
-
runs.append((start, prev + 1))
|
|
1399
|
-
return runs
|
|
1400
|
-
|
|
1401
|
-
# read original obs intervals/probs if available (for combining probs)
|
|
1402
|
-
orig_obs = None
|
|
1403
|
-
if update_obs and (layer in adata.obs.columns):
|
|
1404
|
-
orig_obs = list(adata.obs[layer]) # might be non-list entries
|
|
1405
|
-
|
|
1406
|
-
# prepare outputs
|
|
1407
|
-
merged_bin = _np.zeros_like(bin_arr, dtype=int)
|
|
1408
|
-
merged_len = _np.zeros_like(bin_arr, dtype=int)
|
|
1409
|
-
merged_obs_col = [[] for _ in range(n_rows)]
|
|
1410
|
-
merged_counts = _np.zeros(n_rows, dtype=int)
|
|
1411
|
-
|
|
1412
|
-
for r in range(n_rows):
|
|
1413
|
-
mask = bin_arr[r, :] != 0
|
|
1414
|
-
runs = _runs_from_mask(mask)
|
|
1415
|
-
if not runs:
|
|
1416
|
-
merged_obs_col[r] = []
|
|
1417
|
-
continue
|
|
1900
|
+
@classmethod
|
|
1901
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1902
|
+
"""Create a distance-binned HMM from config.
|
|
1418
1903
|
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
|
|
1487
|
-
row_entries.append((start_coord, int(length_val), float(prob_val)))
|
|
1488
|
-
|
|
1489
|
-
merged_obs_col[r] = row_entries
|
|
1490
|
-
merged_counts[r] = len(row_entries)
|
|
1491
|
-
|
|
1492
|
-
# write merged layers (do not overwrite originals unless overwrite=True was set above)
|
|
1493
|
-
adata.layers[merged_bin_name] = merged_bin
|
|
1494
|
-
adata.layers[merged_len_name] = merged_len
|
|
1495
|
-
|
|
1496
|
-
if update_obs:
|
|
1497
|
-
adata.obs[merged_bin_name] = merged_obs_col
|
|
1498
|
-
adata.obs[f"n_{merged_bin_name}"] = merged_counts
|
|
1499
|
-
|
|
1500
|
-
# recompute distances list per-row (gaps between adjacent merged intervals)
|
|
1501
|
-
def _calc_distances(obs_list):
|
|
1502
|
-
out = []
|
|
1503
|
-
for intervals in obs_list:
|
|
1504
|
-
if not intervals:
|
|
1505
|
-
out.append([])
|
|
1506
|
-
continue
|
|
1507
|
-
iv = sorted(intervals, key=lambda x: int(x[0]))
|
|
1508
|
-
if len(iv) <= 1:
|
|
1509
|
-
out.append([])
|
|
1510
|
-
continue
|
|
1511
|
-
dlist = []
|
|
1512
|
-
for i in range(len(iv) - 1):
|
|
1513
|
-
endi = int(iv[i][0]) + int(iv[i][1]) - 1
|
|
1514
|
-
startn = int(iv[i + 1][0])
|
|
1515
|
-
dlist.append(startn - endi - 1)
|
|
1516
|
-
out.append(dlist)
|
|
1517
|
-
return out
|
|
1518
|
-
|
|
1519
|
-
adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
|
|
1520
|
-
|
|
1521
|
-
# update uns appended list
|
|
1522
|
-
uns_key = "hmm_appended_layers"
|
|
1523
|
-
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
|
|
1524
|
-
for nm in (merged_bin_name, merged_len_name):
|
|
1525
|
-
if nm not in existing:
|
|
1526
|
-
existing.append(nm)
|
|
1527
|
-
adata.uns[uns_key] = existing
|
|
1528
|
-
|
|
1529
|
-
if verbose:
|
|
1530
|
-
print(f"Created merged binary layer: {merged_bin_name}")
|
|
1531
|
-
print(f"Created merged length layer: {merged_len_name}")
|
|
1532
|
-
if update_obs:
|
|
1533
|
-
print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
|
|
1534
|
-
|
|
1535
|
-
return None if inplace else adata
|
|
1536
|
-
|
|
1537
|
-
def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
|
|
1904
|
+
Args:
|
|
1905
|
+
cfg: Configuration mapping or object.
|
|
1906
|
+
override: Override values to apply.
|
|
1907
|
+
device: Optional device specifier.
|
|
1908
|
+
|
|
1909
|
+
Returns:
|
|
1910
|
+
Initialized DistanceBinnedSingleBernoulliHMM instance.
|
|
1911
|
+
"""
|
|
1912
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1913
|
+
if override:
|
|
1914
|
+
merged.update(override)
|
|
1915
|
+
|
|
1916
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1917
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1918
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1919
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1920
|
+
init_em = merged.get("hmm_init_emission_probs", None)
|
|
1921
|
+
|
|
1922
|
+
bins = merged.get("hmm_distance_bins", [1, 5, 10, 25, 50, 100])
|
|
1923
|
+
init_tb = merged.get("hmm_init_transitions_by_bin", None)
|
|
1924
|
+
|
|
1925
|
+
model = cls(
|
|
1926
|
+
n_states=n_states,
|
|
1927
|
+
init_emission=init_em,
|
|
1928
|
+
distance_bins=bins,
|
|
1929
|
+
init_trans_by_bin=init_tb,
|
|
1930
|
+
eps=eps,
|
|
1931
|
+
dtype=dtype,
|
|
1932
|
+
)
|
|
1933
|
+
if device is not None:
|
|
1934
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1935
|
+
model._persisted_cfg = merged
|
|
1936
|
+
return model
|
|
1937
|
+
|
|
1938
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1939
|
+
"""Move transition-by-bin parameters to the requested device/dtype."""
|
|
1940
|
+
device = super()._ensure_device_dtype(device)
|
|
1941
|
+
self.trans_by_bin.data = self.trans_by_bin.data.to(device=device, dtype=self.dtype)
|
|
1942
|
+
return device
|
|
1943
|
+
|
|
1944
|
+
def _normalize_trans_by_bin(self):
|
|
1945
|
+
"""Normalize transition matrices per distance bin in-place."""
|
|
1946
|
+
with torch.no_grad():
|
|
1947
|
+
tb = self.trans_by_bin.data.reshape(self.n_bins, self.n_states, self.n_states)
|
|
1948
|
+
tb = tb + self.eps
|
|
1949
|
+
rs = tb.sum(dim=2, keepdim=True)
|
|
1950
|
+
rs[rs == 0.0] = 1.0
|
|
1951
|
+
self.trans_by_bin.data = tb / rs
|
|
1952
|
+
|
|
1953
|
+
def _extra_save_payload(self) -> dict:
|
|
1954
|
+
"""Return extra payload data for serialization."""
|
|
1955
|
+
p = super()._extra_save_payload()
|
|
1956
|
+
p.update(
|
|
1957
|
+
{
|
|
1958
|
+
"distance_bins": torch.tensor(self.distance_bins, dtype=torch.long),
|
|
1959
|
+
"trans_by_bin": self.trans_by_bin.detach().cpu(),
|
|
1960
|
+
}
|
|
1961
|
+
)
|
|
1962
|
+
return p
|
|
1963
|
+
|
|
1964
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1965
|
+
"""Load serialized distance-bin parameters.
|
|
1966
|
+
|
|
1967
|
+
Args:
|
|
1968
|
+
payload: Serialized payload dictionary.
|
|
1969
|
+
device: Target torch device.
|
|
1538
1970
|
"""
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1971
|
+
super()._load_extra_payload(payload, device=device)
|
|
1972
|
+
self.distance_bins = (
|
|
1973
|
+
payload.get("distance_bins", torch.tensor([1, 5, 10, 25, 50, 100]))
|
|
1974
|
+
.cpu()
|
|
1975
|
+
.numpy()
|
|
1976
|
+
.astype(int)
|
|
1977
|
+
)
|
|
1978
|
+
self.n_bins = int(len(self.distance_bins) + 1)
|
|
1979
|
+
with torch.no_grad():
|
|
1980
|
+
self.trans_by_bin.data = payload["trans_by_bin"].to(device=device, dtype=self.dtype)
|
|
1981
|
+
self._normalize_trans_by_bin()
|
|
1982
|
+
|
|
1983
|
+
def _bin_index(self, coords: np.ndarray) -> np.ndarray:
|
|
1984
|
+
"""Return per-step distance bin indices for coordinates.
|
|
1985
|
+
|
|
1986
|
+
Args:
|
|
1987
|
+
coords: Coordinate array.
|
|
1988
|
+
|
|
1989
|
+
Returns:
|
|
1990
|
+
Array of bin indices (length L-1).
|
|
1542
1991
|
"""
|
|
1543
|
-
|
|
1544
|
-
|
|
1992
|
+
d = np.diff(np.asarray(coords, dtype=int))
|
|
1993
|
+
return np.digitize(d, self.distance_bins, right=True) # length L-1
|
|
1545
1994
|
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1995
|
+
def _forward_backward(
|
|
1996
|
+
self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
|
|
1997
|
+
) -> torch.Tensor:
|
|
1998
|
+
"""Run forward-backward using distance-binned transitions.
|
|
1999
|
+
|
|
2000
|
+
Args:
|
|
2001
|
+
obs: Observation tensor.
|
|
2002
|
+
mask: Observation mask.
|
|
2003
|
+
coords: Coordinate array.
|
|
2004
|
+
|
|
2005
|
+
Returns:
|
|
2006
|
+
Posterior probabilities (gamma).
|
|
2007
|
+
"""
|
|
2008
|
+
if coords is None:
|
|
2009
|
+
raise ValueError("Distance-binned HMM requires coords.")
|
|
2010
|
+
device = obs.device
|
|
2011
|
+
eps = float(self.eps)
|
|
2012
|
+
K = self.n_states
|
|
2013
|
+
|
|
2014
|
+
coords = np.asarray(coords, dtype=int)
|
|
2015
|
+
bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
|
|
2016
|
+
|
|
2017
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
2018
|
+
logstart = torch.log(self.start + eps)
|
|
2019
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps) # (nb,K,K)
|
|
2020
|
+
|
|
2021
|
+
N, L, _ = logB.shape
|
|
2022
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2023
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2024
|
+
|
|
2025
|
+
for t in range(1, L):
|
|
2026
|
+
b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
|
|
2027
|
+
logA = logA_by_bin[b]
|
|
2028
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2029
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
2030
|
+
|
|
2031
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2032
|
+
beta[:, L - 1, :] = 0.0
|
|
2033
|
+
for t in range(L - 2, -1, -1):
|
|
2034
|
+
b = int(bins[t].item()) if t < bins.numel() else 0
|
|
2035
|
+
logA = logA_by_bin[b]
|
|
2036
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2037
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
2038
|
+
|
|
2039
|
+
log_gamma = alpha + beta
|
|
2040
|
+
logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
|
|
2041
|
+
return (log_gamma - logZ).exp()
|
|
2042
|
+
|
|
2043
|
+
def _viterbi(
|
|
2044
|
+
self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
|
|
2045
|
+
) -> torch.Tensor:
|
|
2046
|
+
"""Run Viterbi decoding using distance-binned transitions.
|
|
2047
|
+
|
|
2048
|
+
Args:
|
|
2049
|
+
obs: Observation tensor.
|
|
2050
|
+
mask: Observation mask.
|
|
2051
|
+
coords: Coordinate array.
|
|
2052
|
+
|
|
2053
|
+
Returns:
|
|
2054
|
+
Decoded state sequence tensor.
|
|
2055
|
+
"""
|
|
2056
|
+
if coords is None:
|
|
2057
|
+
raise ValueError("Distance-binned HMM requires coords.")
|
|
2058
|
+
device = obs.device
|
|
2059
|
+
eps = float(self.eps)
|
|
2060
|
+
K = self.n_states
|
|
2061
|
+
|
|
2062
|
+
coords = np.asarray(coords, dtype=int)
|
|
2063
|
+
bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
|
|
2064
|
+
|
|
2065
|
+
logB = self._log_emission(obs, mask)
|
|
2066
|
+
logstart = torch.log(self.start + eps)
|
|
2067
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps)
|
|
2068
|
+
|
|
2069
|
+
N, L, _ = logB.shape
|
|
2070
|
+
delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2071
|
+
psi = torch.empty((N, L, K), dtype=torch.long, device=device)
|
|
2072
|
+
|
|
2073
|
+
delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2074
|
+
psi[:, 0, :] = -1
|
|
2075
|
+
|
|
2076
|
+
for t in range(1, L):
|
|
2077
|
+
b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
|
|
2078
|
+
logA = logA_by_bin[b]
|
|
2079
|
+
cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2080
|
+
best_val, best_idx = cand.max(dim=1)
|
|
2081
|
+
delta[:, t, :] = best_val + logB[:, t, :]
|
|
2082
|
+
psi[:, t, :] = best_idx
|
|
2083
|
+
|
|
2084
|
+
last_state = torch.argmax(delta[:, L - 1, :], dim=1)
|
|
2085
|
+
states = torch.empty((N, L), dtype=torch.long, device=device)
|
|
2086
|
+
states[:, L - 1] = last_state
|
|
2087
|
+
for t in range(L - 2, -1, -1):
|
|
2088
|
+
states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
|
|
2089
|
+
return states
|
|
2090
|
+
|
|
2091
|
+
def fit_em(
|
|
2092
|
+
self,
|
|
2093
|
+
X: np.ndarray,
|
|
2094
|
+
coords: np.ndarray,
|
|
2095
|
+
*,
|
|
2096
|
+
device: torch.device,
|
|
2097
|
+
max_iter: int,
|
|
2098
|
+
tol: float,
|
|
2099
|
+
update_start: bool,
|
|
2100
|
+
update_trans: bool,
|
|
2101
|
+
update_emission: bool,
|
|
2102
|
+
verbose: bool,
|
|
2103
|
+
**kwargs,
|
|
2104
|
+
) -> List[float]:
|
|
2105
|
+
"""Run EM updates for distance-binned transitions.
|
|
2106
|
+
|
|
2107
|
+
Args:
|
|
2108
|
+
X: Observations array (N, L).
|
|
2109
|
+
coords: Coordinate array aligned to X.
|
|
2110
|
+
device: Torch device.
|
|
2111
|
+
max_iter: Maximum iterations.
|
|
2112
|
+
tol: Convergence tolerance.
|
|
2113
|
+
update_start: Whether to update start probabilities.
|
|
2114
|
+
update_trans: Whether to update transitions.
|
|
2115
|
+
update_emission: Whether to update emission parameters.
|
|
2116
|
+
verbose: Whether to log progress.
|
|
2117
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
2118
|
+
|
|
2119
|
+
Returns:
|
|
2120
|
+
List of log-likelihood proxy values.
|
|
2121
|
+
"""
|
|
2122
|
+
# Keep this simple: use gamma for emissions; transitions-by-bin updated via xi (same pattern).
|
|
2123
|
+
X = np.asarray(X, dtype=float)
|
|
2124
|
+
if X.ndim != 2:
|
|
2125
|
+
raise ValueError("DistanceBinnedSingleBernoulliHMM expects X shape (N,L).")
|
|
2126
|
+
|
|
2127
|
+
coords = np.asarray(coords, dtype=int)
|
|
2128
|
+
bins_np = self._bin_index(coords) # (L-1,)
|
|
2129
|
+
|
|
2130
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
2131
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
2132
|
+
|
|
2133
|
+
eps = float(self.eps)
|
|
2134
|
+
K = self.n_states
|
|
2135
|
+
N, L = obs.shape
|
|
2136
|
+
|
|
2137
|
+
hist: List[float] = []
|
|
2138
|
+
for it in range(1, int(max_iter) + 1):
|
|
2139
|
+
gamma = self._forward_backward(obs, mask, coords=coords) # (N,L,K)
|
|
2140
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
2141
|
+
hist.append(ll_proxy)
|
|
2142
|
+
|
|
2143
|
+
# expected start
|
|
2144
|
+
start_acc = gamma[:, 0, :].sum(dim=0)
|
|
2145
|
+
|
|
2146
|
+
# compute alpha/beta for xi
|
|
2147
|
+
logB = self._log_emission(obs, mask)
|
|
2148
|
+
logstart = torch.log(self.start + eps)
|
|
2149
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps)
|
|
2150
|
+
|
|
2151
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2152
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2153
|
+
|
|
2154
|
+
for t in range(1, L):
|
|
2155
|
+
b = int(bins_np[t - 1])
|
|
2156
|
+
logA = logA_by_bin[b]
|
|
2157
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2158
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
2159
|
+
|
|
2160
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2161
|
+
beta[:, L - 1, :] = 0.0
|
|
2162
|
+
for t in range(L - 2, -1, -1):
|
|
2163
|
+
b = int(bins_np[t])
|
|
2164
|
+
logA = logA_by_bin[b]
|
|
2165
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2166
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
2167
|
+
|
|
2168
|
+
trans_acc_by_bin = torch.zeros((self.n_bins, K, K), dtype=self.dtype, device=device)
|
|
2169
|
+
for t in range(L - 1):
|
|
2170
|
+
b = int(bins_np[t])
|
|
2171
|
+
logA = logA_by_bin[b]
|
|
2172
|
+
valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
|
|
2173
|
+
log_xi = (
|
|
2174
|
+
alpha[:, t, :].unsqueeze(2)
|
|
2175
|
+
+ logA.unsqueeze(0)
|
|
2176
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2177
|
+
)
|
|
2178
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
2179
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
2180
|
+
trans_acc_by_bin[b] += xi.sum(dim=0)
|
|
2181
|
+
|
|
2182
|
+
mask_f = mask.float().unsqueeze(-1)
|
|
2183
|
+
emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1))
|
|
2184
|
+
emit_den = (gamma * mask_f).sum(dim=(0, 1))
|
|
2185
|
+
|
|
2186
|
+
with torch.no_grad():
|
|
2187
|
+
if update_start:
|
|
2188
|
+
new_start = start_acc + eps
|
|
2189
|
+
self.start.data = new_start / new_start.sum()
|
|
2190
|
+
|
|
2191
|
+
if update_trans:
|
|
2192
|
+
tb = trans_acc_by_bin + eps
|
|
2193
|
+
rs = tb.sum(dim=2, keepdim=True)
|
|
2194
|
+
rs[rs == 0.0] = 1.0
|
|
2195
|
+
self.trans_by_bin.data = tb / rs
|
|
2196
|
+
|
|
2197
|
+
if update_emission:
|
|
2198
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
2199
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
2200
|
+
|
|
2201
|
+
self._normalize_params()
|
|
2202
|
+
self._normalize_emission()
|
|
2203
|
+
self._normalize_trans_by_bin()
|
|
2204
|
+
|
|
2205
|
+
if verbose:
|
|
2206
|
+
logger.info(
|
|
2207
|
+
"[DistanceBinnedSingle.fit] iter=%s ll_proxy=%.6f",
|
|
2208
|
+
it,
|
|
2209
|
+
hist[-1],
|
|
2210
|
+
)
|
|
2211
|
+
|
|
2212
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
2213
|
+
break
|
|
2214
|
+
|
|
2215
|
+
return hist
|
|
2216
|
+
|
|
2217
|
+
def adapt_emissions(
|
|
2218
|
+
self,
|
|
2219
|
+
X: np.ndarray,
|
|
2220
|
+
coords: Optional[np.ndarray] = None,
|
|
2221
|
+
*,
|
|
2222
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
2223
|
+
iters: Optional[int] = None,
|
|
2224
|
+
max_iter: Optional[int] = None,
|
|
2225
|
+
verbose: bool = False,
|
|
2226
|
+
**kwargs,
|
|
2227
|
+
):
|
|
2228
|
+
"""Adapt emissions with legacy parameter names.
|
|
2229
|
+
|
|
2230
|
+
Args:
|
|
2231
|
+
X: Observations array.
|
|
2232
|
+
coords: Optional coordinate array.
|
|
2233
|
+
device: Device specifier.
|
|
2234
|
+
iters: Number of iterations.
|
|
2235
|
+
max_iter: Alias for iters.
|
|
2236
|
+
verbose: Whether to log progress.
|
|
2237
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
2238
|
+
|
|
2239
|
+
Returns:
|
|
2240
|
+
List of log-likelihood values.
|
|
2241
|
+
"""
|
|
2242
|
+
if iters is None:
|
|
2243
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
2244
|
+
return super().adapt_emissions(
|
|
2245
|
+
np.asarray(X, dtype=float),
|
|
2246
|
+
coords if coords is not None else None,
|
|
2247
|
+
iters=int(iters),
|
|
2248
|
+
device=device,
|
|
2249
|
+
verbose=verbose,
|
|
2250
|
+
)
|
|
2251
|
+
|
|
2252
|
+
|
|
2253
|
+
# =============================================================================
|
|
2254
|
+
# Facade class to match workflow import style
|
|
2255
|
+
# =============================================================================
|
|
2256
|
+
|
|
2257
|
+
|
|
2258
|
+
class HMM:
|
|
2259
|
+
"""
|
|
2260
|
+
Facade so workflow can do:
|
|
2261
|
+
from ..hmm.HMM import HMM
|
|
2262
|
+
hmm = HMM.from_config(cfg, arch="single")
|
|
2263
|
+
hmm.save(...)
|
|
2264
|
+
hmm = HMM.load(...)
|
|
2265
|
+
"""
|
|
2266
|
+
|
|
2267
|
+
@staticmethod
|
|
2268
|
+
def from_config(cfg, arch: Optional[str] = None, **kwargs) -> BaseHMM:
|
|
2269
|
+
"""Create an HMM instance from configuration.
|
|
2270
|
+
|
|
2271
|
+
Args:
|
|
2272
|
+
cfg: Configuration mapping or object.
|
|
2273
|
+
arch: Optional HMM architecture name.
|
|
2274
|
+
**kwargs: Additional parameters passed to the factory.
|
|
2275
|
+
|
|
2276
|
+
Returns:
|
|
2277
|
+
Initialized HMM instance.
|
|
2278
|
+
"""
|
|
2279
|
+
return create_hmm(cfg, arch=arch, **kwargs)
|
|
2280
|
+
|
|
2281
|
+
@staticmethod
|
|
2282
|
+
def load(path: Union[str, Path], device: Optional[Union[str, torch.device]] = None) -> BaseHMM:
|
|
2283
|
+
"""Load an HMM instance from disk.
|
|
2284
|
+
|
|
2285
|
+
Args:
|
|
2286
|
+
path: Path to the serialized model.
|
|
2287
|
+
device: Optional device specifier.
|
|
2288
|
+
|
|
2289
|
+
Returns:
|
|
2290
|
+
Loaded HMM instance.
|
|
2291
|
+
"""
|
|
2292
|
+
return BaseHMM.load(path, device=device)
|