smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py
CHANGED
|
@@ -1,1576 +1,2286 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
3
|
import ast
|
|
4
4
|
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
|
-
import pandas as pd
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
11
|
+
from scipy.sparse import issparse
|
|
10
12
|
|
|
11
|
-
|
|
12
|
-
return torch.logsumexp(vec, dim=dim, keepdim=keepdim)
|
|
13
|
+
from smftools.logging_utils import get_logger
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
Methods:
|
|
19
|
-
- fit(data, ...) -> trains params in-place
|
|
20
|
-
- predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
|
|
21
|
-
- viterbi(seq, ...) -> (path_list, score)
|
|
22
|
-
- batch_viterbi(data, ...) -> list of (path_list, score)
|
|
23
|
-
- score(seq_or_list, ...) -> float or list of floats
|
|
24
|
-
|
|
25
|
-
Notes:
|
|
26
|
-
- data: list of sequences (each sequence is iterable of {0,1,np.nan}).
|
|
27
|
-
- impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
|
|
28
|
-
"""
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
# =============================================================================
|
|
17
|
+
# Registry / Factory
|
|
18
|
+
# =============================================================================
|
|
29
19
|
|
|
30
|
-
|
|
31
|
-
self,
|
|
32
|
-
n_states: int = 2,
|
|
33
|
-
init_start: Optional[List[float]] = None,
|
|
34
|
-
init_trans: Optional[List[List[float]]] = None,
|
|
35
|
-
init_emission: Optional[List[float]] = None,
|
|
36
|
-
dtype: torch.dtype = torch.float64,
|
|
37
|
-
eps: float = 1e-8,
|
|
38
|
-
smf_modality: Optional[str] = None,
|
|
39
|
-
):
|
|
40
|
-
super().__init__()
|
|
41
|
-
if n_states < 2:
|
|
42
|
-
raise ValueError("n_states must be >= 2")
|
|
43
|
-
self.n_states = n_states
|
|
44
|
-
self.eps = float(eps)
|
|
45
|
-
self.dtype = dtype
|
|
46
|
-
self.smf_modality = smf_modality
|
|
20
|
+
_HMM_REGISTRY: Dict[str, type] = {}
|
|
47
21
|
|
|
48
|
-
# initialize params (probabilities)
|
|
49
|
-
if init_start is None:
|
|
50
|
-
start = np.full((n_states,), 1.0 / n_states, dtype=float)
|
|
51
|
-
else:
|
|
52
|
-
start = np.asarray(init_start, dtype=float)
|
|
53
|
-
if init_trans is None:
|
|
54
|
-
trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
|
|
55
|
-
else:
|
|
56
|
-
trans = np.asarray(init_trans, dtype=float)
|
|
57
|
-
# --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
|
|
58
|
-
if init_emission is None:
|
|
59
|
-
emission = np.full((n_states,), 0.5, dtype=float)
|
|
60
|
-
else:
|
|
61
|
-
em_arr = np.asarray(init_emission, dtype=float)
|
|
62
|
-
# case: (K,2) -> pick P(obs==1) from second column
|
|
63
|
-
if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
|
|
64
|
-
emission = em_arr[:, 1].astype(float)
|
|
65
|
-
# case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
|
|
66
|
-
elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
|
|
67
|
-
emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
|
|
68
|
-
else:
|
|
69
|
-
emission = em_arr.reshape(-1)[:n_states].astype(float)
|
|
70
22
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
|
|
74
|
-
self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
|
|
23
|
+
def register_hmm(name: str):
|
|
24
|
+
"""Decorator to register an HMM backend under a string key."""
|
|
75
25
|
|
|
76
|
-
|
|
26
|
+
def deco(cls):
|
|
27
|
+
"""Register the provided class in the HMM registry."""
|
|
28
|
+
_HMM_REGISTRY[name] = cls
|
|
29
|
+
cls.hmm_name = name
|
|
30
|
+
return cls
|
|
77
31
|
|
|
78
|
-
|
|
79
|
-
with torch.no_grad():
|
|
80
|
-
# coerce shapes
|
|
81
|
-
K = self.n_states
|
|
82
|
-
self.start.data = self.start.data.squeeze()
|
|
83
|
-
if self.start.data.numel() != K:
|
|
84
|
-
self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
|
|
32
|
+
return deco
|
|
85
33
|
|
|
86
|
-
self.trans.data = self.trans.data.squeeze()
|
|
87
|
-
if not (self.trans.data.ndim == 2 and self.trans.data.shape == (K, K)):
|
|
88
|
-
if K == 2:
|
|
89
|
-
self.trans.data = torch.tensor([[0.9,0.1],[0.1,0.9]], dtype=self.dtype)
|
|
90
|
-
else:
|
|
91
|
-
self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
|
|
92
34
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
35
|
+
def create_hmm(cfg: Union[dict, Any, None], arch: Optional[str] = None, **kwargs):
|
|
36
|
+
"""
|
|
37
|
+
Factory: creates an HMM from cfg + arch (override).
|
|
38
|
+
"""
|
|
39
|
+
key = (
|
|
40
|
+
arch
|
|
41
|
+
or getattr(cfg, "hmm_arch", None)
|
|
42
|
+
or (cfg.get("hmm_arch") if isinstance(cfg, dict) else None)
|
|
43
|
+
or "single"
|
|
44
|
+
)
|
|
45
|
+
if key not in _HMM_REGISTRY:
|
|
46
|
+
raise KeyError(f"Unknown hmm_arch={key!r}. Known: {sorted(_HMM_REGISTRY.keys())}")
|
|
47
|
+
return _HMM_REGISTRY[key].from_config(cfg, **kwargs)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# =============================================================================
|
|
51
|
+
# Small utilities
|
|
52
|
+
# =============================================================================
|
|
53
|
+
def _coerce_dtype_for_device(
|
|
54
|
+
dtype: torch.dtype, device: Optional[Union[str, torch.device]]
|
|
55
|
+
) -> torch.dtype:
|
|
56
|
+
"""MPS does not support float64. When targeting MPS, coerce to float32."""
|
|
57
|
+
dev = torch.device(device) if isinstance(device, str) else device
|
|
58
|
+
if dev is not None and getattr(dev, "type", None) == "mps" and dtype == torch.float64:
|
|
59
|
+
return torch.float32
|
|
60
|
+
return dtype
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _try_json_or_literal(x: Any) -> Any:
|
|
64
|
+
"""Parse a string value as JSON or a Python literal when possible.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
x: Value to parse.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The parsed value if possible, otherwise the original value.
|
|
71
|
+
"""
|
|
72
|
+
if x is None:
|
|
73
|
+
return None
|
|
74
|
+
if not isinstance(x, str):
|
|
75
|
+
return x
|
|
76
|
+
s = x.strip()
|
|
77
|
+
if not s:
|
|
78
|
+
return None
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(s)
|
|
81
|
+
except Exception:
|
|
82
|
+
pass
|
|
83
|
+
try:
|
|
84
|
+
return ast.literal_eval(s)
|
|
85
|
+
except Exception:
|
|
86
|
+
return x
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _coerce_bool(x: Any) -> bool:
|
|
90
|
+
"""Coerce a value into a boolean using common truthy strings.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
x: Value to coerce.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Boolean interpretation of the input.
|
|
97
|
+
"""
|
|
98
|
+
if x is None:
|
|
99
|
+
return False
|
|
100
|
+
if isinstance(x, bool):
|
|
101
|
+
return x
|
|
102
|
+
if isinstance(x, (int, float)):
|
|
103
|
+
return bool(x)
|
|
104
|
+
s = str(x).strip().lower()
|
|
105
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
96
106
|
|
|
97
|
-
# now perform smoothing/normalization
|
|
98
|
-
self.start.data = (self.start.data + self.eps)
|
|
99
|
-
self.start.data = self.start.data / self.start.data.sum()
|
|
100
107
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
row_sums[row_sums == 0.0] = 1.0
|
|
104
|
-
self.trans.data = self.trans.data / row_sums
|
|
108
|
+
def _resolve_dtype(dtype_entry: Any) -> torch.dtype:
|
|
109
|
+
"""Resolve a torch dtype from a config entry.
|
|
105
110
|
|
|
106
|
-
|
|
111
|
+
Args:
|
|
112
|
+
dtype_entry: Config value (string or torch.dtype).
|
|
107
113
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
return torch.float64
|
|
113
|
-
if isinstance(dtype_entry, torch.dtype):
|
|
114
|
-
return dtype_entry
|
|
115
|
-
s = str(dtype_entry).lower()
|
|
116
|
-
if "32" in s:
|
|
117
|
-
return torch.float32
|
|
118
|
-
if "16" in s:
|
|
119
|
-
return torch.float16
|
|
114
|
+
Returns:
|
|
115
|
+
Resolved torch dtype.
|
|
116
|
+
"""
|
|
117
|
+
if dtype_entry is None:
|
|
120
118
|
return torch.float64
|
|
119
|
+
if isinstance(dtype_entry, torch.dtype):
|
|
120
|
+
return dtype_entry
|
|
121
|
+
s = str(dtype_entry).lower()
|
|
122
|
+
if "16" in s:
|
|
123
|
+
return torch.float16
|
|
124
|
+
if "32" in s:
|
|
125
|
+
return torch.float32
|
|
126
|
+
return torch.float64
|
|
121
127
|
|
|
122
|
-
@classmethod
|
|
123
|
-
def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
|
|
124
|
-
override: Optional[dict] = None,
|
|
125
|
-
device: Optional[Union[str, torch.device]] = None) -> "HMM":
|
|
126
|
-
"""
|
|
127
|
-
Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
|
|
130
|
+
"""
|
|
131
|
+
Try to cast var_names to int coordinates. If not possible,
|
|
132
|
+
fall back to 0..L-1 index coordinates.
|
|
133
|
+
"""
|
|
134
|
+
try:
|
|
135
|
+
coords = np.asarray(var_names, dtype=int)
|
|
136
|
+
return coords, True
|
|
137
|
+
except Exception:
|
|
138
|
+
return np.arange(len(var_names), dtype=int), False
|
|
133
139
|
|
|
134
|
-
override: optional dict to override resolved keys (handy for tests).
|
|
135
|
-
device: optional device string or torch.device to move model to.
|
|
136
|
-
"""
|
|
137
|
-
# Accept ExperimentConfig dataclass
|
|
138
|
-
if cfg is None:
|
|
139
|
-
merged = {}
|
|
140
|
-
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
141
|
-
merged = dict(cfg.to_dict())
|
|
142
|
-
elif isinstance(cfg, dict):
|
|
143
|
-
merged = dict(cfg)
|
|
144
|
-
else:
|
|
145
|
-
# try attr access as fallback
|
|
146
|
-
try:
|
|
147
|
-
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
148
|
-
except Exception:
|
|
149
|
-
merged = {}
|
|
150
140
|
|
|
151
|
-
|
|
152
|
-
|
|
141
|
+
def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
142
|
+
"""Compute log-sum-exp in a numerically stable way.
|
|
153
143
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
|
|
158
|
-
init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
|
|
159
|
-
eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
|
|
160
|
-
dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
|
|
144
|
+
Args:
|
|
145
|
+
x: Input tensor.
|
|
146
|
+
dim: Dimension to reduce.
|
|
161
147
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
return np.asarray(x, dtype=float)
|
|
148
|
+
Returns:
|
|
149
|
+
Reduced tensor.
|
|
150
|
+
"""
|
|
151
|
+
return torch.logsumexp(x, dim=dim)
|
|
167
152
|
|
|
168
|
-
init_start = _coerce_np(init_start)
|
|
169
|
-
init_trans = _coerce_np(init_trans)
|
|
170
|
-
init_emission = _coerce_np(init_emission)
|
|
171
153
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
)
|
|
154
|
+
def _ensure_layer_full_shape(adata, name: str, dtype, fill_value=0):
|
|
155
|
+
"""
|
|
156
|
+
Ensure adata.layers[name] exists with shape (n_obs, n_vars).
|
|
157
|
+
"""
|
|
158
|
+
if name not in adata.layers:
|
|
159
|
+
arr = np.full((adata.n_obs, adata.n_vars), fill_value=fill_value, dtype=dtype)
|
|
160
|
+
adata.layers[name] = arr
|
|
161
|
+
else:
|
|
162
|
+
arr = _to_dense_np(adata.layers[name])
|
|
163
|
+
if arr.shape != (adata.n_obs, adata.n_vars):
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"Layer '{name}' exists but has shape {arr.shape}; expected {(adata.n_obs, adata.n_vars)}"
|
|
166
|
+
)
|
|
167
|
+
return adata.layers[name]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _assign_back_obs(final_adata, sub_adata, cols: List[str]):
|
|
171
|
+
"""
|
|
172
|
+
Assign obs columns from sub_adata back into final_adata for the matching obs_names.
|
|
173
|
+
Works for list/object columns too.
|
|
174
|
+
"""
|
|
175
|
+
idx = final_adata.obs_names.get_indexer(sub_adata.obs_names)
|
|
176
|
+
if (idx < 0).any():
|
|
177
|
+
raise ValueError("Some sub_adata.obs_names not found in final_adata.obs_names")
|
|
181
178
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
if isinstance(device, str):
|
|
185
|
-
device = torch.device(device)
|
|
186
|
-
model.to(device)
|
|
179
|
+
for c in cols:
|
|
180
|
+
final_adata.obs.iloc[idx, final_adata.obs.columns.get_loc(c)] = sub_adata.obs[c].values
|
|
187
181
|
|
|
188
|
-
# persist the config to the hmm class
|
|
189
|
-
cls.config = cfg
|
|
190
182
|
|
|
191
|
-
|
|
183
|
+
def _to_dense_np(x):
|
|
184
|
+
"""Convert sparse or array-like input to a dense NumPy array.
|
|
192
185
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
"""
|
|
196
|
-
Update existing model parameters from a config or dict (in-place).
|
|
197
|
-
This will normalize / reinitialize start/trans/emission using same logic as constructor.
|
|
198
|
-
"""
|
|
199
|
-
if cfg is None:
|
|
200
|
-
merged = {}
|
|
201
|
-
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
202
|
-
merged = dict(cfg.to_dict())
|
|
203
|
-
elif isinstance(cfg, dict):
|
|
204
|
-
merged = dict(cfg)
|
|
205
|
-
else:
|
|
206
|
-
try:
|
|
207
|
-
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
208
|
-
except Exception:
|
|
209
|
-
merged = {}
|
|
186
|
+
Args:
|
|
187
|
+
x: Input array or sparse matrix.
|
|
210
188
|
|
|
211
|
-
|
|
212
|
-
|
|
189
|
+
Returns:
|
|
190
|
+
Dense NumPy array or None.
|
|
191
|
+
"""
|
|
192
|
+
if x is None:
|
|
193
|
+
return None
|
|
194
|
+
if issparse(x):
|
|
195
|
+
return x.toarray()
|
|
196
|
+
return np.asarray(x)
|
|
213
197
|
|
|
214
|
-
# extract same keys as from_config
|
|
215
|
-
n_states = int(merged.get("hmm_n_states", self.n_states))
|
|
216
|
-
init_start = merged.get("hmm_init_start_probs", None)
|
|
217
|
-
init_trans = merged.get("hmm_init_transition_probs", None)
|
|
218
|
-
init_emission = merged.get("hmm_init_emission_probs", None)
|
|
219
|
-
eps = merged.get("hmm_eps", None)
|
|
220
|
-
dtype = merged.get("hmm_dtype", None)
|
|
221
|
-
|
|
222
|
-
# apply dtype/eps if present
|
|
223
|
-
if eps is not None:
|
|
224
|
-
self.eps = float(eps)
|
|
225
|
-
if dtype is not None:
|
|
226
|
-
self.dtype = self._resolve_dtype(dtype)
|
|
227
|
-
|
|
228
|
-
# if n_states changes we need a fresh re-init (easy approach: reconstruct)
|
|
229
|
-
if int(n_states) != int(self.n_states):
|
|
230
|
-
# rebuild self in-place: create a new model and copy tensors
|
|
231
|
-
new_model = HMM.from_config(merged)
|
|
232
|
-
# copy content
|
|
233
|
-
with torch.no_grad():
|
|
234
|
-
self.n_states = new_model.n_states
|
|
235
|
-
self.eps = new_model.eps
|
|
236
|
-
self.dtype = new_model.dtype
|
|
237
|
-
self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
|
|
238
|
-
self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
|
|
239
|
-
self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
|
|
240
|
-
return
|
|
241
198
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
if obj is None:
|
|
245
|
-
return None
|
|
246
|
-
arr = np.asarray(obj, dtype=float)
|
|
247
|
-
if shape_expected is not None:
|
|
248
|
-
try:
|
|
249
|
-
arr = arr.reshape(shape_expected)
|
|
250
|
-
except Exception:
|
|
251
|
-
# try to free-form slice/reshape (keep best-effort)
|
|
252
|
-
arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
|
|
253
|
-
return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
|
|
199
|
+
def _ensure_2d_np(x):
|
|
200
|
+
"""Ensure an array is 2D, reshaping 1D inputs.
|
|
254
201
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
t = _to_tensor(init_start, (self.n_states,))
|
|
258
|
-
if t.numel() == self.n_states:
|
|
259
|
-
self.start.data = t.clone()
|
|
260
|
-
if init_trans is not None:
|
|
261
|
-
t = _to_tensor(init_trans, (self.n_states, self.n_states))
|
|
262
|
-
if t.shape == (self.n_states, self.n_states):
|
|
263
|
-
self.trans.data = t.clone()
|
|
264
|
-
if init_emission is not None:
|
|
265
|
-
# attempt to extract P(obs==1) if shaped (K,2)
|
|
266
|
-
arr = np.asarray(init_emission, dtype=float)
|
|
267
|
-
if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
|
|
268
|
-
em = arr[: self.n_states, 1]
|
|
269
|
-
else:
|
|
270
|
-
em = arr.reshape(-1)[: self.n_states]
|
|
271
|
-
t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
|
|
272
|
-
if t.numel() == self.n_states:
|
|
273
|
-
self.emission.data = t.clone()
|
|
202
|
+
Args:
|
|
203
|
+
x: Input array-like.
|
|
274
204
|
|
|
275
|
-
|
|
276
|
-
|
|
205
|
+
Returns:
|
|
206
|
+
2D NumPy array.
|
|
207
|
+
"""
|
|
208
|
+
x = _to_dense_np(x)
|
|
209
|
+
if x.ndim == 1:
|
|
210
|
+
x = x.reshape(1, -1)
|
|
211
|
+
if x.ndim != 2:
|
|
212
|
+
raise ValueError(f"Expected 2D array; got shape {x.shape}")
|
|
213
|
+
return x
|
|
277
214
|
|
|
278
|
-
def _ensure_device_dtype(self, device: Optional[torch.device]):
|
|
279
|
-
if device is None:
|
|
280
|
-
device = next(self.parameters()).device
|
|
281
|
-
self.start.data = self.start.data.to(device=device, dtype=self.dtype)
|
|
282
|
-
self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
|
|
283
|
-
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
284
|
-
return device
|
|
285
215
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
216
|
+
# =============================================================================
|
|
217
|
+
# Feature-set normalization
|
|
218
|
+
# =============================================================================
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def normalize_hmm_feature_sets(raw: Any) -> Dict[str, Dict[str, Any]]:
|
|
222
|
+
"""
|
|
223
|
+
Canonical format:
|
|
224
|
+
{
|
|
225
|
+
"footprints": {"state": "Non-Modified", "features": {"small_bound_stretch": [0,50], ...}},
|
|
226
|
+
"accessible": {"state": "Modified", "features": {"all_accessible_features": [0, inf], ...}},
|
|
227
|
+
...
|
|
228
|
+
}
|
|
229
|
+
Each feature range is [lo, hi) in genomic bp (or index units if coords aren't ints).
|
|
230
|
+
"""
|
|
231
|
+
parsed = _try_json_or_literal(raw)
|
|
232
|
+
if not isinstance(parsed, dict):
|
|
233
|
+
return {}
|
|
234
|
+
|
|
235
|
+
def _coerce_bound(v):
|
|
236
|
+
"""Coerce a bound value into a float or sentinel.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
v: Bound value.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Float, np.inf, or None.
|
|
293
243
|
"""
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
244
|
+
if v is None:
|
|
245
|
+
return None
|
|
246
|
+
if isinstance(v, (int, float)):
|
|
247
|
+
return float(v)
|
|
248
|
+
s = str(v).strip().lower()
|
|
249
|
+
if s in ("inf", "infty", "np.inf", "infinite"):
|
|
250
|
+
return np.inf
|
|
251
|
+
if s in ("none", ""):
|
|
252
|
+
return None
|
|
253
|
+
try:
|
|
254
|
+
return float(v)
|
|
255
|
+
except Exception:
|
|
256
|
+
return None
|
|
257
|
+
|
|
258
|
+
def _coerce_map(feats):
|
|
259
|
+
"""Coerce feature ranges into (lo, hi) tuples.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
feats: Mapping of feature names to ranges.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Mapping of feature names to numeric bounds.
|
|
298
266
|
"""
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
warned_collapse = False
|
|
313
|
-
|
|
314
|
-
for i, seq in enumerate(data):
|
|
315
|
-
# seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
|
|
316
|
-
arr = np.asarray(seq, dtype=float)
|
|
317
|
-
|
|
318
|
-
# If arr is shape (L,1,1,...) squeeze trailing singletons
|
|
319
|
-
while arr.ndim > 1 and arr.shape[-1] == 1:
|
|
320
|
-
arr = np.squeeze(arr, axis=-1)
|
|
321
|
-
|
|
322
|
-
# If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
|
|
323
|
-
if arr.ndim > 1:
|
|
324
|
-
if not warned_collapse:
|
|
325
|
-
warnings.warn(
|
|
326
|
-
"HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
|
|
327
|
-
"(arr had shape {}). If you prefer a different reduction, "
|
|
328
|
-
"preprocess your data.".format(arr.shape),
|
|
329
|
-
stacklevel=2,
|
|
330
|
-
)
|
|
331
|
-
warned_collapse = True
|
|
332
|
-
# collapse features -> scalar per timestep
|
|
333
|
-
arr = np.asarray(arr, dtype=float).mean(axis=-1)
|
|
334
|
-
|
|
335
|
-
# now arr should be 1D (T,)
|
|
336
|
-
if arr.ndim == 0:
|
|
337
|
-
# single scalar: treat as length-1 sequence
|
|
338
|
-
arr = np.atleast_1d(arr)
|
|
339
|
-
|
|
340
|
-
nan_mask = np.isnan(arr)
|
|
341
|
-
if impute_strategy == "random" and nan_mask.any():
|
|
342
|
-
arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
|
|
343
|
-
local_mask = np.ones_like(arr, dtype=bool)
|
|
267
|
+
out = {}
|
|
268
|
+
if not isinstance(feats, dict):
|
|
269
|
+
return out
|
|
270
|
+
for name, rng in feats.items():
|
|
271
|
+
if rng is None:
|
|
272
|
+
out[name] = (0.0, np.inf)
|
|
273
|
+
continue
|
|
274
|
+
if isinstance(rng, (list, tuple)) and len(rng) >= 2:
|
|
275
|
+
lo = _coerce_bound(rng[0])
|
|
276
|
+
hi = _coerce_bound(rng[1])
|
|
277
|
+
lo = 0.0 if lo is None else float(lo)
|
|
278
|
+
hi = np.inf if hi is None else float(hi)
|
|
279
|
+
out[name] = (lo, hi)
|
|
344
280
|
else:
|
|
345
|
-
|
|
346
|
-
|
|
281
|
+
hi = _coerce_bound(rng)
|
|
282
|
+
hi = np.inf if hi is None else float(hi)
|
|
283
|
+
out[name] = (0.0, hi)
|
|
284
|
+
return out
|
|
285
|
+
|
|
286
|
+
out: Dict[str, Dict[str, Any]] = {}
|
|
287
|
+
for group, info in parsed.items():
|
|
288
|
+
if isinstance(info, dict):
|
|
289
|
+
feats = _coerce_map(info.get("features", info.get("ranges", {})))
|
|
290
|
+
state = info.get("state", info.get("label", "Modified"))
|
|
291
|
+
else:
|
|
292
|
+
feats = _coerce_map(info)
|
|
293
|
+
state = "Modified"
|
|
294
|
+
out[group] = {"features": feats, "state": state}
|
|
295
|
+
return out
|
|
347
296
|
|
|
348
|
-
L_i = arr.shape[0]
|
|
349
|
-
obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
|
|
350
|
-
mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
|
|
351
297
|
|
|
352
|
-
|
|
298
|
+
# =============================================================================
|
|
299
|
+
# BaseHMM: shared decoding + annotation pipeline
|
|
300
|
+
# =============================================================================
|
|
353
301
|
|
|
354
|
-
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
355
|
-
"""
|
|
356
|
-
obs: (B, L)
|
|
357
|
-
mask: (B, L) bool
|
|
358
|
-
returns logB (B, L, K)
|
|
359
|
-
"""
|
|
360
|
-
B, L = obs.shape
|
|
361
|
-
p = self.emission # (K,)
|
|
362
|
-
logp = torch.log(p + self.eps)
|
|
363
|
-
log1mp = torch.log1p(-p + self.eps)
|
|
364
|
-
obs_expand = obs.unsqueeze(-1) # (B, L, 1)
|
|
365
|
-
logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
|
|
366
|
-
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
367
|
-
return logB
|
|
368
302
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
"""
|
|
382
|
-
if device is None:
|
|
383
|
-
device = next(self.parameters()).device
|
|
384
|
-
elif isinstance(device, str):
|
|
385
|
-
device = torch.device(device)
|
|
386
|
-
device = self._ensure_device_dtype(device)
|
|
303
|
+
class BaseHMM(nn.Module):
|
|
304
|
+
"""
|
|
305
|
+
BaseHMM responsibilities:
|
|
306
|
+
- config resolution (from_config)
|
|
307
|
+
- EM fit wrapper (fit / fit_em)
|
|
308
|
+
- decoding (gamma / viterbi)
|
|
309
|
+
- AnnData annotation from provided arrays (X + coords)
|
|
310
|
+
- save/load registry aware
|
|
311
|
+
Subclasses implement:
|
|
312
|
+
- _log_emission(...) -> logB
|
|
313
|
+
- optional distance-aware transition handling
|
|
314
|
+
"""
|
|
387
315
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# rows are sequences: convert to list of 1D arrays
|
|
391
|
-
data = data.tolist()
|
|
392
|
-
elif data.ndim == 1:
|
|
393
|
-
# single sequence
|
|
394
|
-
data = [data.tolist()]
|
|
395
|
-
else:
|
|
396
|
-
raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
|
|
316
|
+
def __init__(self, n_states: int = 2, eps: float = 1e-8, dtype: torch.dtype = torch.float64):
|
|
317
|
+
"""Initialize the base HMM with shared parameters.
|
|
397
318
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
319
|
+
Args:
|
|
320
|
+
n_states: Number of hidden states.
|
|
321
|
+
eps: Smoothing epsilon for probabilities.
|
|
322
|
+
dtype: Torch dtype for parameters.
|
|
323
|
+
"""
|
|
324
|
+
super().__init__()
|
|
325
|
+
if n_states < 2:
|
|
326
|
+
raise ValueError("n_states must be >= 2")
|
|
327
|
+
self.n_states = int(n_states)
|
|
328
|
+
self.eps = float(eps)
|
|
329
|
+
self.dtype = dtype
|
|
402
330
|
|
|
403
|
-
|
|
404
|
-
|
|
331
|
+
# start probs + transitions (shared across backends)
|
|
332
|
+
start = np.full((self.n_states,), 1.0 / self.n_states, dtype=float)
|
|
333
|
+
trans = np.full((self.n_states, self.n_states), 1.0 / self.n_states, dtype=float)
|
|
405
334
|
|
|
406
|
-
|
|
335
|
+
self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
|
|
336
|
+
self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
|
|
337
|
+
self._normalize_params()
|
|
407
338
|
|
|
408
|
-
|
|
409
|
-
if verbose:
|
|
410
|
-
print(f"[HMM.fit] EM iter {it}")
|
|
339
|
+
# ------------------------- config -------------------------
|
|
411
340
|
|
|
412
|
-
|
|
413
|
-
|
|
341
|
+
@classmethod
|
|
342
|
+
def from_config(
|
|
343
|
+
cls, cfg: Union[dict, Any, None], *, override: Optional[dict] = None, device=None
|
|
344
|
+
):
|
|
345
|
+
"""Create a model from config with optional overrides.
|
|
414
346
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
347
|
+
Args:
|
|
348
|
+
cfg: Configuration mapping or object.
|
|
349
|
+
override: Override values to apply.
|
|
350
|
+
device: Device specifier.
|
|
418
351
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
352
|
+
Returns:
|
|
353
|
+
Initialized HMM instance.
|
|
354
|
+
"""
|
|
355
|
+
merged = cls._cfg_to_dict(cfg)
|
|
356
|
+
if override:
|
|
357
|
+
merged.update(override)
|
|
426
358
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
# temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
|
|
432
|
-
temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
|
|
433
|
-
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
359
|
+
n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
|
|
360
|
+
eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
|
|
361
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
|
|
362
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
434
363
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
total_loglike = float(seq_loglikes.sum().item())
|
|
364
|
+
model = cls(n_states=n_states, eps=eps, dtype=dtype)
|
|
365
|
+
if device is not None:
|
|
366
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
367
|
+
model._persisted_cfg = merged
|
|
368
|
+
return model
|
|
441
369
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
gamma = (log_gamma - logZ_time).exp() # (B, L, K)
|
|
370
|
+
@staticmethod
|
|
371
|
+
def _cfg_to_dict(cfg: Union[dict, Any, None]) -> dict:
|
|
372
|
+
"""Normalize a config object into a dictionary.
|
|
446
373
|
|
|
447
|
-
|
|
448
|
-
|
|
374
|
+
Args:
|
|
375
|
+
cfg: Config mapping or object.
|
|
449
376
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
377
|
+
Returns:
|
|
378
|
+
Dictionary of HMM-related config values.
|
|
379
|
+
"""
|
|
380
|
+
if cfg is None:
|
|
381
|
+
return {}
|
|
382
|
+
if isinstance(cfg, dict):
|
|
383
|
+
return dict(cfg)
|
|
384
|
+
if hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
385
|
+
return dict(cfg.to_dict())
|
|
386
|
+
out = {}
|
|
387
|
+
for k in dir(cfg):
|
|
388
|
+
if k.startswith("hmm_") or k in ("smf_modality", "cpg"):
|
|
389
|
+
try:
|
|
390
|
+
out[k] = getattr(cfg, k)
|
|
391
|
+
except Exception:
|
|
392
|
+
pass
|
|
393
|
+
return out
|
|
454
394
|
|
|
455
|
-
|
|
456
|
-
trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
457
|
-
if L >= 2:
|
|
458
|
-
time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
|
|
459
|
-
valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
|
|
460
|
-
for t in range(L - 1):
|
|
461
|
-
a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
|
|
462
|
-
b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
|
|
463
|
-
log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
|
|
464
|
-
log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
|
|
465
|
-
log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
466
|
-
xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
|
|
467
|
-
valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
468
|
-
xi_masked = xi * valid_t
|
|
469
|
-
trans_accum += xi_masked.sum(dim=0) # (i,j)
|
|
470
|
-
|
|
471
|
-
# M-step: update parameters with smoothing
|
|
472
|
-
with torch.no_grad():
|
|
473
|
-
new_start = gamma_start_accum + eps
|
|
474
|
-
new_start = new_start / new_start.sum()
|
|
395
|
+
# ------------------------- params -------------------------
|
|
475
396
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
397
|
+
def _normalize_params(self):
|
|
398
|
+
"""Normalize start and transition probabilities in-place."""
|
|
399
|
+
with torch.no_grad():
|
|
400
|
+
K = self.n_states
|
|
480
401
|
|
|
481
|
-
|
|
482
|
-
|
|
402
|
+
# start
|
|
403
|
+
self.start.data = self.start.data.reshape(-1)
|
|
404
|
+
if self.start.data.numel() != K:
|
|
405
|
+
self.start.data = torch.full(
|
|
406
|
+
(K,), 1.0 / K, dtype=self.dtype, device=self.start.device
|
|
407
|
+
)
|
|
408
|
+
self.start.data = self.start.data + self.eps
|
|
409
|
+
self.start.data = self.start.data / self.start.data.sum()
|
|
483
410
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
411
|
+
# trans
|
|
412
|
+
self.trans.data = self.trans.data.reshape(K, K)
|
|
413
|
+
self.trans.data = self.trans.data + self.eps
|
|
414
|
+
rs = self.trans.data.sum(dim=1, keepdim=True)
|
|
415
|
+
rs[rs == 0.0] = 1.0
|
|
416
|
+
self.trans.data = self.trans.data / rs
|
|
487
417
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
418
|
+
def _ensure_device_dtype(
|
|
419
|
+
self, device: Optional[Union[str, torch.device]] = None
|
|
420
|
+
) -> torch.device:
|
|
421
|
+
"""Move parameters to the requested device/dtype.
|
|
491
422
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
|
|
495
|
-
break
|
|
423
|
+
Args:
|
|
424
|
+
device: Device specifier or None to use current device.
|
|
496
425
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
def get_params(self) -> dict:
|
|
426
|
+
Returns:
|
|
427
|
+
Resolved torch device.
|
|
500
428
|
"""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
429
|
+
if device is None:
|
|
430
|
+
device = next(self.parameters()).device
|
|
431
|
+
device = torch.device(device) if isinstance(device, str) else device
|
|
432
|
+
self.start.data = self.start.data.to(device=device, dtype=self.dtype)
|
|
433
|
+
self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
|
|
434
|
+
return device
|
|
435
|
+
|
|
436
|
+
# ------------------------- state labeling -------------------------
|
|
437
|
+
|
|
438
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
439
|
+
"""Subclasses return (K,) score; higher => more “Modified/Accessible”."""
|
|
440
|
+
raise NotImplementedError
|
|
510
441
|
|
|
511
|
-
def
|
|
442
|
+
def modified_state_index(self) -> int:
|
|
443
|
+
"""Return the index of the most modified/accessible state."""
|
|
444
|
+
scores = self._state_modified_score()
|
|
445
|
+
return int(torch.argmax(scores).item())
|
|
446
|
+
|
|
447
|
+
def resolve_target_state_index(self, state_target: Any) -> int:
|
|
512
448
|
"""
|
|
513
|
-
|
|
449
|
+
Accept:
|
|
450
|
+
- int -> explicit state index
|
|
451
|
+
- "Modified" / "Non-Modified" and aliases
|
|
514
452
|
"""
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
453
|
+
if isinstance(state_target, (int, np.integer)):
|
|
454
|
+
idx = int(state_target)
|
|
455
|
+
return max(0, min(idx, self.n_states - 1))
|
|
456
|
+
|
|
457
|
+
s = str(state_target).strip().lower()
|
|
458
|
+
if s in ("modified", "open", "accessible", "1", "pos", "positive"):
|
|
459
|
+
return self.modified_state_index()
|
|
460
|
+
if s in ("non-modified", "closed", "inaccessible", "0", "neg", "negative"):
|
|
461
|
+
scores = self._state_modified_score()
|
|
462
|
+
return int(torch.argmin(scores).item())
|
|
463
|
+
return self.modified_state_index()
|
|
464
|
+
|
|
465
|
+
# ------------------------- emissions -------------------------
|
|
466
|
+
|
|
467
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
529
468
|
"""
|
|
530
|
-
Return
|
|
469
|
+
Return logB:
|
|
470
|
+
- single: obs (N,L), mask (N,L) -> logB (N,L,K)
|
|
471
|
+
- multi : obs (N,L,C), mask (N,L,C) -> logB (N,L,K)
|
|
531
472
|
"""
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
473
|
+
raise NotImplementedError
|
|
474
|
+
|
|
475
|
+
# ------------------------- decoding core -------------------------
|
|
476
|
+
|
|
477
|
+
def _forward_backward(
|
|
478
|
+
self,
|
|
479
|
+
obs: torch.Tensor,
|
|
480
|
+
mask: torch.Tensor,
|
|
481
|
+
*,
|
|
482
|
+
coords: Optional[np.ndarray] = None,
|
|
483
|
+
) -> torch.Tensor:
|
|
541
484
|
"""
|
|
542
|
-
|
|
485
|
+
Returns gamma (N,L,K) in probability space.
|
|
486
|
+
Subclasses can override for distance-aware transitions.
|
|
543
487
|
"""
|
|
544
|
-
|
|
545
|
-
device = next(self.parameters()).device
|
|
546
|
-
elif isinstance(device, str):
|
|
547
|
-
device = torch.device(device)
|
|
548
|
-
device = self._ensure_device_dtype(device)
|
|
549
|
-
|
|
550
|
-
obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
551
|
-
B, L = obs.shape
|
|
552
|
-
K = self.n_states
|
|
488
|
+
device = obs.device
|
|
553
489
|
eps = float(self.eps)
|
|
490
|
+
K = self.n_states
|
|
554
491
|
|
|
555
|
-
logB = self._log_emission(obs, mask) # (
|
|
556
|
-
logA = torch.log(self.trans + eps)
|
|
557
|
-
logstart = torch.log(self.start + eps)
|
|
492
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
493
|
+
logA = torch.log(self.trans + eps) # (K,K)
|
|
494
|
+
logstart = torch.log(self.start + eps) # (K,)
|
|
495
|
+
|
|
496
|
+
N, L, _ = logB.shape
|
|
558
497
|
|
|
559
|
-
|
|
560
|
-
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
498
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
561
499
|
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
500
|
+
|
|
562
501
|
for t in range(1, L):
|
|
563
|
-
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
502
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
|
|
564
503
|
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
565
504
|
|
|
566
|
-
|
|
567
|
-
beta
|
|
568
|
-
beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
|
|
505
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
506
|
+
beta[:, L - 1, :] = 0.0
|
|
569
507
|
for t in range(L - 2, -1, -1):
|
|
570
|
-
temp = logA.unsqueeze(0) + (logB[:, t + 1, :]
|
|
508
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
571
509
|
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
572
510
|
|
|
573
|
-
# gamma
|
|
574
511
|
log_gamma = alpha + beta
|
|
575
|
-
|
|
576
|
-
gamma = (log_gamma -
|
|
512
|
+
logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
|
|
513
|
+
gamma = (log_gamma - logZ).exp()
|
|
514
|
+
return gamma
|
|
577
515
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
516
|
+
def _viterbi(
|
|
517
|
+
self,
|
|
518
|
+
obs: torch.Tensor,
|
|
519
|
+
mask: torch.Tensor,
|
|
520
|
+
*,
|
|
521
|
+
coords: Optional[np.ndarray] = None,
|
|
522
|
+
) -> torch.Tensor:
|
|
585
523
|
"""
|
|
586
|
-
|
|
587
|
-
|
|
524
|
+
Returns states (N,L) int64. Missing positions (mask False for all channels)
|
|
525
|
+
are still decoded, but you’ll overwrite them to -1 during writing.
|
|
526
|
+
Subclasses can override for distance-aware transitions.
|
|
588
527
|
"""
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
single = True
|
|
593
|
-
else:
|
|
594
|
-
seqs = seq_or_list
|
|
528
|
+
device = obs.device
|
|
529
|
+
eps = float(self.eps)
|
|
530
|
+
K = self.n_states
|
|
595
531
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
device = torch.device(device)
|
|
600
|
-
device = self._ensure_device_dtype(device)
|
|
532
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
533
|
+
logA = torch.log(self.trans + eps) # (K,K)
|
|
534
|
+
logstart = torch.log(self.start + eps) # (K,)
|
|
601
535
|
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
K =
|
|
605
|
-
eps = float(self.eps)
|
|
536
|
+
N, L, _ = logB.shape
|
|
537
|
+
delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
538
|
+
psi = torch.empty((N, L, K), dtype=torch.long, device=device)
|
|
606
539
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
logstart = torch.log(self.start + eps)
|
|
540
|
+
delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
541
|
+
psi[:, 0, :] = -1
|
|
610
542
|
|
|
611
|
-
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
612
|
-
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
613
543
|
for t in range(1, L):
|
|
614
|
-
|
|
615
|
-
|
|
544
|
+
cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
|
|
545
|
+
best_val, best_idx = cand.max(dim=1)
|
|
546
|
+
delta[:, t, :] = best_val + logB[:, t, :]
|
|
547
|
+
psi[:, t, :] = best_idx
|
|
616
548
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
final_alpha = alpha[idx_range, last_idx, :] # (B, K)
|
|
620
|
-
seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
|
|
621
|
-
seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
|
|
622
|
-
return seq_loglikes[0] if single else seq_loglikes
|
|
549
|
+
last_state = torch.argmax(delta[:, L - 1, :], dim=1) # (N,)
|
|
550
|
+
states = torch.empty((N, L), dtype=torch.long, device=device)
|
|
623
551
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
552
|
+
states[:, L - 1] = last_state
|
|
553
|
+
for t in range(L - 2, -1, -1):
|
|
554
|
+
states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
|
|
555
|
+
|
|
556
|
+
return states
|
|
557
|
+
|
|
558
|
+
def decode(
|
|
559
|
+
self,
|
|
560
|
+
X: np.ndarray,
|
|
561
|
+
coords: Optional[np.ndarray] = None,
|
|
562
|
+
*,
|
|
563
|
+
decode: str = "marginal",
|
|
564
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
565
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
566
|
+
"""Decode observations into state calls and posterior probabilities.
|
|
630
567
|
|
|
631
|
-
|
|
568
|
+
Args:
|
|
569
|
+
X: Observations array (N, L) or (N, L, C).
|
|
570
|
+
coords: Optional coordinates aligned to L.
|
|
571
|
+
decode: Decoding strategy ("marginal" or "viterbi").
|
|
572
|
+
device: Device specifier.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Tuple of (states, posterior probabilities).
|
|
632
576
|
"""
|
|
633
|
-
|
|
634
|
-
|
|
577
|
+
device = self._ensure_device_dtype(device)
|
|
578
|
+
|
|
579
|
+
X = np.asarray(X, dtype=float)
|
|
580
|
+
if X.ndim == 2:
|
|
581
|
+
L = X.shape[1]
|
|
582
|
+
elif X.ndim == 3:
|
|
583
|
+
L = X.shape[1]
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError(f"X must be 2D or 3D; got shape {X.shape}")
|
|
586
|
+
|
|
587
|
+
if coords is None:
|
|
588
|
+
coords = np.arange(L, dtype=int)
|
|
589
|
+
coords = np.asarray(coords, dtype=int)
|
|
590
|
+
|
|
591
|
+
if X.ndim == 2:
|
|
592
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
593
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
594
|
+
else:
|
|
595
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
596
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
597
|
+
|
|
598
|
+
gamma = self._forward_backward(obs, mask, coords=coords)
|
|
599
|
+
|
|
600
|
+
if str(decode).lower() == "viterbi":
|
|
601
|
+
st = self._viterbi(obs, mask, coords=coords)
|
|
602
|
+
else:
|
|
603
|
+
st = torch.argmax(gamma, dim=2)
|
|
604
|
+
|
|
605
|
+
return st.detach().cpu().numpy(), gamma.detach().cpu().numpy()
|
|
606
|
+
|
|
607
|
+
# ------------------------- EM fit -------------------------
|
|
608
|
+
|
|
609
|
+
def fit(
|
|
610
|
+
self,
|
|
611
|
+
X: np.ndarray,
|
|
612
|
+
coords: Optional[np.ndarray] = None,
|
|
613
|
+
*,
|
|
614
|
+
max_iter: int = 50,
|
|
615
|
+
tol: float = 1e-4,
|
|
616
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
617
|
+
update_start: bool = True,
|
|
618
|
+
update_trans: bool = True,
|
|
619
|
+
update_emission: bool = True,
|
|
620
|
+
verbose: bool = False,
|
|
621
|
+
**kwargs,
|
|
622
|
+
) -> List[float]:
|
|
623
|
+
"""Fit HMM parameters using EM.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
X: Observations array.
|
|
627
|
+
coords: Optional coordinate array.
|
|
628
|
+
max_iter: Maximum EM iterations.
|
|
629
|
+
tol: Convergence tolerance.
|
|
630
|
+
device: Device specifier.
|
|
631
|
+
update_start: Whether to update start probabilities.
|
|
632
|
+
update_trans: Whether to update transition probabilities.
|
|
633
|
+
update_emission: Whether to update emission parameters.
|
|
634
|
+
verbose: Whether to log progress.
|
|
635
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
List of log-likelihood values across iterations.
|
|
635
639
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
+
X = np.asarray(X, dtype=float)
|
|
641
|
+
if X.ndim not in (2, 3):
|
|
642
|
+
raise ValueError(f"X must be 2D or 3D; got {X.shape}")
|
|
643
|
+
L = X.shape[1]
|
|
644
|
+
|
|
645
|
+
if coords is None:
|
|
646
|
+
coords = np.arange(L, dtype=int)
|
|
647
|
+
coords = np.asarray(coords, dtype=int)
|
|
648
|
+
|
|
640
649
|
device = self._ensure_device_dtype(device)
|
|
650
|
+
return self.fit_em(
|
|
651
|
+
X,
|
|
652
|
+
coords,
|
|
653
|
+
device=device,
|
|
654
|
+
max_iter=max_iter,
|
|
655
|
+
tol=tol,
|
|
656
|
+
update_start=update_start,
|
|
657
|
+
update_trans=update_trans,
|
|
658
|
+
update_emission=update_emission,
|
|
659
|
+
verbose=verbose,
|
|
660
|
+
**kwargs,
|
|
661
|
+
)
|
|
641
662
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
663
|
+
def adapt_emissions(
|
|
664
|
+
self,
|
|
665
|
+
X: np.ndarray,
|
|
666
|
+
coords: np.ndarray,
|
|
667
|
+
*,
|
|
668
|
+
iters: int = 5,
|
|
669
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
670
|
+
freeze_start: bool = True,
|
|
671
|
+
freeze_trans: bool = True,
|
|
672
|
+
verbose: bool = False,
|
|
673
|
+
**kwargs,
|
|
674
|
+
) -> List[float]:
|
|
675
|
+
"""Adapt emission parameters while keeping shared structure fixed.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
X: Observations array.
|
|
679
|
+
coords: Coordinate array aligned to X.
|
|
680
|
+
iters: Number of EM iterations.
|
|
681
|
+
device: Device specifier.
|
|
682
|
+
freeze_start: Whether to freeze start probabilities.
|
|
683
|
+
freeze_trans: Whether to freeze transitions.
|
|
684
|
+
verbose: Whether to log progress.
|
|
685
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
List of log-likelihood values across iterations.
|
|
689
|
+
"""
|
|
690
|
+
return self.fit(
|
|
691
|
+
X,
|
|
692
|
+
coords,
|
|
693
|
+
max_iter=int(iters),
|
|
694
|
+
tol=0.0,
|
|
695
|
+
device=device,
|
|
696
|
+
update_start=not freeze_start,
|
|
697
|
+
update_trans=not freeze_trans,
|
|
698
|
+
update_emission=True,
|
|
699
|
+
verbose=verbose,
|
|
700
|
+
**kwargs,
|
|
701
|
+
)
|
|
646
702
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
703
|
+
def fit_em(
|
|
704
|
+
self,
|
|
705
|
+
X: np.ndarray,
|
|
706
|
+
coords: np.ndarray,
|
|
707
|
+
*,
|
|
708
|
+
device: torch.device,
|
|
709
|
+
max_iter: int,
|
|
710
|
+
tol: float,
|
|
711
|
+
update_start: bool,
|
|
712
|
+
update_trans: bool,
|
|
713
|
+
update_emission: bool,
|
|
714
|
+
verbose: bool,
|
|
715
|
+
**kwargs,
|
|
716
|
+
) -> List[float]:
|
|
717
|
+
"""Run the core EM update loop (subclasses implement).
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
X: Observations array.
|
|
721
|
+
coords: Coordinate array aligned to X.
|
|
722
|
+
device: Torch device.
|
|
723
|
+
max_iter: Maximum iterations.
|
|
724
|
+
tol: Convergence tolerance.
|
|
725
|
+
update_start: Whether to update start probabilities.
|
|
726
|
+
update_trans: Whether to update transitions.
|
|
727
|
+
update_emission: Whether to update emission parameters.
|
|
728
|
+
verbose: Whether to log progress.
|
|
729
|
+
**kwargs: Additional subclass-specific kwargs.
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
List of log-likelihood values across iterations.
|
|
733
|
+
"""
|
|
734
|
+
raise NotImplementedError
|
|
652
735
|
|
|
653
|
-
|
|
654
|
-
logA = torch.log(self.trans + eps)
|
|
736
|
+
# ------------------------- save/load -------------------------
|
|
655
737
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
738
|
+
def _extra_save_payload(self) -> dict:
|
|
739
|
+
"""Return extra model state to include when saving."""
|
|
740
|
+
return {}
|
|
659
741
|
|
|
660
|
-
|
|
661
|
-
|
|
742
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
743
|
+
"""Load extra model state saved by subclasses.
|
|
662
744
|
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
|
|
667
|
-
delta[:, t, :] = best_val + logB[:, t, :]
|
|
668
|
-
psi[:, t, :] = best_idx # best previous state index for each (B, j)
|
|
669
|
-
|
|
670
|
-
# backtrack
|
|
671
|
-
last_idx = (lengths - 1).clamp(min=0)
|
|
672
|
-
idx_range = torch.arange(B, device=device)
|
|
673
|
-
final_delta = delta[idx_range, last_idx, :] # (B, K)
|
|
674
|
-
best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
|
|
675
|
-
paths = []
|
|
676
|
-
scores = []
|
|
677
|
-
for b in range(B):
|
|
678
|
-
Lb = int(lengths[b].item())
|
|
679
|
-
if Lb == 0:
|
|
680
|
-
paths.append([])
|
|
681
|
-
scores.append(float("-inf"))
|
|
682
|
-
continue
|
|
683
|
-
s = int(best_last_state[b].item())
|
|
684
|
-
path = [s]
|
|
685
|
-
for t in range(Lb - 1, 0, -1):
|
|
686
|
-
s = int(psi[b, t, s].item())
|
|
687
|
-
path.append(s)
|
|
688
|
-
path.reverse()
|
|
689
|
-
paths.append(path)
|
|
690
|
-
scores.append(float(best_last_val[b].item()))
|
|
691
|
-
return paths, scores
|
|
692
|
-
|
|
693
|
-
def save(self, path: str) -> None:
|
|
745
|
+
Args:
|
|
746
|
+
payload: Serialized model payload.
|
|
747
|
+
device: Torch device for tensors.
|
|
694
748
|
"""
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
752
|
+
"""Serialize the model to disk.
|
|
753
|
+
|
|
754
|
+
Args:
|
|
755
|
+
path: Output path for the serialized model.
|
|
698
756
|
"""
|
|
757
|
+
path = str(path)
|
|
699
758
|
payload = {
|
|
759
|
+
"hmm_name": getattr(self, "hmm_name", self.__class__.__name__),
|
|
760
|
+
"class": self.__class__.__name__,
|
|
700
761
|
"n_states": int(self.n_states),
|
|
701
762
|
"eps": float(self.eps),
|
|
702
|
-
# store dtype as a string like "torch.float64" (portable)
|
|
703
763
|
"dtype": str(self.dtype),
|
|
704
764
|
"start": self.start.detach().cpu(),
|
|
705
765
|
"trans": self.trans.detach().cpu(),
|
|
706
|
-
"emission": self.emission.detach().cpu(),
|
|
707
766
|
}
|
|
767
|
+
payload.update(self._extra_save_payload())
|
|
708
768
|
torch.save(payload, path)
|
|
709
769
|
|
|
710
770
|
@classmethod
|
|
711
|
-
def load(cls, path: str, device: Optional[Union[torch.device
|
|
712
|
-
"""
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
771
|
+
def load(cls, path: Union[str, Path], device: Optional[Union[str, torch.device]] = None):
|
|
772
|
+
"""Load a serialized model from disk.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
path: Path to the serialized model.
|
|
776
|
+
device: Optional device specifier.
|
|
777
|
+
|
|
778
|
+
Returns:
|
|
779
|
+
Loaded HMM instance.
|
|
716
780
|
"""
|
|
717
|
-
payload = torch.load(path, map_location="cpu")
|
|
781
|
+
payload = torch.load(str(path), map_location="cpu")
|
|
782
|
+
hmm_name = payload.get("hmm_name", None)
|
|
783
|
+
klass = _HMM_REGISTRY.get(hmm_name, cls)
|
|
784
|
+
|
|
785
|
+
dtype_str = str(payload.get("dtype", "torch.float64"))
|
|
786
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1], torch.float64)
|
|
787
|
+
torch_dtype = _coerce_dtype_for_device(torch_dtype, device) # <<< NEW
|
|
788
|
+
|
|
789
|
+
model = klass(
|
|
790
|
+
n_states=int(payload["n_states"]),
|
|
791
|
+
eps=float(payload.get("eps", 1e-8)),
|
|
792
|
+
dtype=torch_dtype,
|
|
793
|
+
)
|
|
794
|
+
dev = torch.device(device) if isinstance(device, str) else (device or torch.device("cpu"))
|
|
795
|
+
model.to(dev)
|
|
718
796
|
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
797
|
+
with torch.no_grad():
|
|
798
|
+
model.start.data = payload["start"].to(device=dev, dtype=model.dtype)
|
|
799
|
+
model.trans.data = payload["trans"].to(device=dev, dtype=model.dtype)
|
|
722
800
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
torch_dtype = dtype_entry
|
|
727
|
-
else:
|
|
728
|
-
# dtype_entry expected to be a string
|
|
729
|
-
dtype_str = str(dtype_entry)
|
|
730
|
-
# take last part after dot if present: "torch.float64" -> "float64"
|
|
731
|
-
name = dtype_str.split(".")[-1]
|
|
732
|
-
# map to torch dtype if available, else fallback mapping
|
|
733
|
-
if hasattr(torch, name):
|
|
734
|
-
torch_dtype = getattr(torch, name)
|
|
735
|
-
else:
|
|
736
|
-
fallback = {"float64": torch.float64, "float32": torch.float32, "float16": torch.float16}
|
|
737
|
-
torch_dtype = fallback.get(name, torch.float64)
|
|
801
|
+
model._load_extra_payload(payload, device=dev)
|
|
802
|
+
model._normalize_params()
|
|
803
|
+
return model
|
|
738
804
|
|
|
739
|
-
|
|
740
|
-
model = cls(n_states=n_states, dtype=torch_dtype, eps=eps)
|
|
805
|
+
# ------------------------- interval helpers -------------------------
|
|
741
806
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
807
|
+
@staticmethod
|
|
808
|
+
def _runs_from_bool(mask_1d: np.ndarray) -> List[Tuple[int, int]]:
|
|
809
|
+
"""
|
|
810
|
+
Return runs as (start_idx, end_idx_exclusive) for True segments.
|
|
811
|
+
"""
|
|
812
|
+
idx = np.nonzero(mask_1d)[0]
|
|
813
|
+
if idx.size == 0:
|
|
814
|
+
return []
|
|
815
|
+
breaks = np.where(np.diff(idx) > 1)[0]
|
|
816
|
+
starts = np.r_[idx[0], idx[breaks + 1]]
|
|
817
|
+
ends = np.r_[idx[breaks] + 1, idx[-1] + 1]
|
|
818
|
+
return list(zip(starts, ends))
|
|
747
819
|
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
820
|
+
@staticmethod
|
|
821
|
+
def _interval_length(coords: np.ndarray, s: int, e: int) -> int:
|
|
822
|
+
"""Genomic length for [s,e) on coords."""
|
|
823
|
+
if e <= s:
|
|
824
|
+
return 0
|
|
825
|
+
return int(coords[e - 1]) - int(coords[s]) + 1
|
|
753
826
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
827
|
+
@staticmethod
|
|
828
|
+
def _write_lengths_for_binary_layer(bin_mat: np.ndarray) -> np.ndarray:
|
|
829
|
+
"""
|
|
830
|
+
For each row, each True-run gets its run-length assigned across that run.
|
|
831
|
+
Output same shape as bin_mat, int32.
|
|
832
|
+
"""
|
|
833
|
+
n, L = bin_mat.shape
|
|
834
|
+
out = np.zeros((n, L), dtype=np.int32)
|
|
835
|
+
for i in range(n):
|
|
836
|
+
runs = BaseHMM._runs_from_bool(bin_mat[i].astype(bool))
|
|
837
|
+
for s, e in runs:
|
|
838
|
+
out[i, s:e] = e - s
|
|
839
|
+
return out
|
|
840
|
+
|
|
841
|
+
@staticmethod
|
|
842
|
+
def _write_lengths_for_state_layer(states: np.ndarray) -> np.ndarray:
|
|
843
|
+
"""
|
|
844
|
+
For each row, each constant-state run gets run-length assigned across run.
|
|
845
|
+
Missing values should be -1 and will get 0 length.
|
|
846
|
+
"""
|
|
847
|
+
n, L = states.shape
|
|
848
|
+
out = np.zeros((n, L), dtype=np.int32)
|
|
849
|
+
for i in range(n):
|
|
850
|
+
row = states[i]
|
|
851
|
+
valid = row >= 0
|
|
852
|
+
if not np.any(valid):
|
|
853
|
+
continue
|
|
854
|
+
# scan runs
|
|
855
|
+
s = 0
|
|
856
|
+
while s < L:
|
|
857
|
+
if row[s] < 0:
|
|
858
|
+
s += 1
|
|
859
|
+
continue
|
|
860
|
+
v = row[s]
|
|
861
|
+
e = s + 1
|
|
862
|
+
while e < L and row[e] == v:
|
|
863
|
+
e += 1
|
|
864
|
+
out[i, s:e] = e - s
|
|
865
|
+
s = e
|
|
866
|
+
return out
|
|
867
|
+
|
|
868
|
+
# ------------------------- merging -------------------------
|
|
869
|
+
|
|
870
|
+
def merge_intervals_to_new_layer(
|
|
759
871
|
self,
|
|
760
872
|
adata,
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
873
|
+
base_layer: str,
|
|
874
|
+
*,
|
|
875
|
+
distance_threshold: int,
|
|
876
|
+
suffix: str = "_merged",
|
|
877
|
+
overwrite: bool = True,
|
|
878
|
+
) -> str:
|
|
879
|
+
"""
|
|
880
|
+
Merge adjacent 1-intervals in a binary layer if gaps <= distance_threshold (in coords space),
|
|
881
|
+
writing:
|
|
882
|
+
- {base_layer}{suffix}
|
|
883
|
+
- {base_layer}{suffix}_lengths (run-length in index units)
|
|
884
|
+
"""
|
|
885
|
+
if base_layer not in adata.layers:
|
|
886
|
+
raise KeyError(f"Layer '{base_layer}' not found.")
|
|
887
|
+
|
|
888
|
+
coords, coords_are_ints = _safe_int_coords(adata.var_names)
|
|
889
|
+
arr = np.asarray(adata.layers[base_layer])
|
|
890
|
+
arr = (arr > 0).astype(np.uint8)
|
|
891
|
+
|
|
892
|
+
merged_name = f"{base_layer}{suffix}"
|
|
893
|
+
merged_len_name = f"{merged_name}_lengths"
|
|
894
|
+
|
|
895
|
+
if (merged_name in adata.layers or merged_len_name in adata.layers) and not overwrite:
|
|
896
|
+
raise KeyError(f"Merged outputs exist (use overwrite=True): {merged_name}")
|
|
897
|
+
|
|
898
|
+
n, L = arr.shape
|
|
899
|
+
out = np.zeros_like(arr, dtype=np.uint8)
|
|
900
|
+
|
|
901
|
+
dt = int(distance_threshold)
|
|
902
|
+
|
|
903
|
+
for i in range(n):
|
|
904
|
+
ones = np.nonzero(arr[i] != 0)[0]
|
|
905
|
+
runs = self._runs_from_bool(arr[i] != 0)
|
|
906
|
+
if not runs:
|
|
907
|
+
continue
|
|
908
|
+
ms, me = runs[0]
|
|
909
|
+
merged_runs = []
|
|
910
|
+
for s, e in runs[1:]:
|
|
911
|
+
if coords_are_ints:
|
|
912
|
+
gap = int(coords[s]) - int(coords[me - 1]) - 1
|
|
913
|
+
else:
|
|
914
|
+
gap = s - me
|
|
915
|
+
if gap <= dt:
|
|
916
|
+
me = e
|
|
917
|
+
else:
|
|
918
|
+
merged_runs.append((ms, me))
|
|
919
|
+
ms, me = s, e
|
|
920
|
+
merged_runs.append((ms, me))
|
|
921
|
+
|
|
922
|
+
for s, e in merged_runs:
|
|
923
|
+
out[i, s:e] = 1
|
|
924
|
+
|
|
925
|
+
adata.layers[merged_name] = out
|
|
926
|
+
adata.layers[merged_len_name] = self._write_lengths_for_binary_layer(out)
|
|
927
|
+
|
|
928
|
+
# bookkeeping
|
|
929
|
+
key = "hmm_appended_layers"
|
|
930
|
+
if adata.uns.get(key) is None:
|
|
931
|
+
adata.uns[key] = []
|
|
932
|
+
for nm in (merged_name, merged_len_name):
|
|
933
|
+
if nm not in adata.uns[key]:
|
|
934
|
+
adata.uns[key].append(nm)
|
|
935
|
+
|
|
936
|
+
return merged_name
|
|
937
|
+
|
|
938
|
+
def write_size_class_layers_from_binary(
|
|
939
|
+
self,
|
|
940
|
+
adata,
|
|
941
|
+
base_layer: str,
|
|
942
|
+
*,
|
|
943
|
+
out_prefix: str,
|
|
944
|
+
feature_ranges: Dict[str, Tuple[float, float]],
|
|
945
|
+
suffix: str = "",
|
|
946
|
+
overwrite: bool = True,
|
|
947
|
+
) -> List[str]:
|
|
948
|
+
"""
|
|
949
|
+
Take an existing binary layer (runs represent features) and write size-class layers:
|
|
950
|
+
- {out_prefix}_{feature}{suffix}
|
|
951
|
+
- plus lengths layers
|
|
952
|
+
|
|
953
|
+
feature_ranges: name -> (lo, hi) in genomic bp.
|
|
954
|
+
"""
|
|
955
|
+
if base_layer not in adata.layers:
|
|
956
|
+
raise KeyError(f"Layer '{base_layer}' not found.")
|
|
957
|
+
|
|
958
|
+
coords, coords_are_ints = _safe_int_coords(adata.var_names)
|
|
959
|
+
bin_arr = (np.asarray(adata.layers[base_layer]) > 0).astype(np.uint8)
|
|
960
|
+
n, L = bin_arr.shape
|
|
961
|
+
|
|
962
|
+
created: List[str] = []
|
|
963
|
+
for feat_name in feature_ranges.keys():
|
|
964
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
965
|
+
ln = f"{nm}_lengths"
|
|
966
|
+
if (nm in adata.layers or ln in adata.layers) and not overwrite:
|
|
967
|
+
continue
|
|
968
|
+
adata.layers[nm] = np.zeros((n, L), dtype=np.uint8)
|
|
969
|
+
adata.layers[ln] = np.zeros((n, L), dtype=np.int32)
|
|
970
|
+
created.extend([nm, ln])
|
|
971
|
+
|
|
972
|
+
for i in range(n):
|
|
973
|
+
runs = self._runs_from_bool(bin_arr[i] != 0)
|
|
974
|
+
for s, e in runs:
|
|
975
|
+
length_bp = self._interval_length(coords, s, e) if coords_are_ints else (e - s)
|
|
976
|
+
for feat_name, (lo, hi) in feature_ranges.items():
|
|
977
|
+
if float(lo) <= float(length_bp) < float(hi):
|
|
978
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
979
|
+
adata.layers[nm][i, s:e] = 1
|
|
980
|
+
adata.layers[f"{nm}_lengths"][i, s:e] = e - s
|
|
981
|
+
break
|
|
982
|
+
|
|
983
|
+
# fill lengths for each size layer (consistent, even if overlaps)
|
|
984
|
+
for feat_name in feature_ranges.keys():
|
|
985
|
+
nm = f"{out_prefix}_{feat_name}{suffix}"
|
|
986
|
+
adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
|
|
987
|
+
np.asarray(adata.layers[nm])
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
key = "hmm_appended_layers"
|
|
991
|
+
if adata.uns.get(key) is None:
|
|
992
|
+
adata.uns[key] = []
|
|
993
|
+
for nm in created:
|
|
994
|
+
if nm not in adata.uns[key]:
|
|
995
|
+
adata.uns[key].append(nm)
|
|
996
|
+
|
|
997
|
+
return created
|
|
998
|
+
|
|
999
|
+
# ------------------------- AnnData annotation -------------------------
|
|
1000
|
+
|
|
1001
|
+
@staticmethod
|
|
1002
|
+
def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
|
|
776
1003
|
"""
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
Parameters
|
|
780
|
-
----------
|
|
781
|
-
config : optional ExperimentConfig instance or plain dict
|
|
782
|
-
When provided, the following keys (if present) are used to override defaults:
|
|
783
|
-
- hmm_feature_sets : dict (canonical feature set structure) OR a JSON/string repr
|
|
784
|
-
- hmm_annotation_threshold : float
|
|
785
|
-
- hmm_batch_size : int
|
|
786
|
-
- hmm_use_viterbi : bool
|
|
787
|
-
- hmm_methbases : list
|
|
788
|
-
- footprints / accessible_patches / cpg (booleans)
|
|
789
|
-
Other keyword args override config values if explicitly provided.
|
|
1004
|
+
Local helper to resolve per-base masks from subset.var.* columns.
|
|
1005
|
+
Returns a boolean np.ndarray of length subset.n_vars or None.
|
|
790
1006
|
"""
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
1007
|
+
key = str(methbase).strip().lower()
|
|
1008
|
+
var = subset.var
|
|
1009
|
+
|
|
1010
|
+
def _has(col: str) -> bool:
|
|
1011
|
+
"""Return True when a column exists on subset.var."""
|
|
1012
|
+
return col in var.columns
|
|
1013
|
+
|
|
1014
|
+
if key in ("a",):
|
|
1015
|
+
col = f"{ref}_strand_FASTA_base"
|
|
1016
|
+
if not _has(col):
|
|
799
1017
|
return None
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
1018
|
+
return np.asarray(var[col] == "A")
|
|
1019
|
+
|
|
1020
|
+
if key in ("c", "any_c", "anyc", "any-c"):
|
|
1021
|
+
for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
|
|
1022
|
+
if _has(col):
|
|
1023
|
+
return np.asarray(var[col])
|
|
1024
|
+
return None
|
|
1025
|
+
|
|
1026
|
+
if key in ("gpc", "gpc_site", "gpc-site"):
|
|
1027
|
+
col = f"{ref}_GpC_site"
|
|
1028
|
+
if not _has(col):
|
|
804
1029
|
return None
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
return
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
1030
|
+
return np.asarray(var[col])
|
|
1031
|
+
|
|
1032
|
+
if key in ("cpg", "cpg_site", "cpg-site"):
|
|
1033
|
+
col = f"{ref}_CpG_site"
|
|
1034
|
+
if not _has(col):
|
|
1035
|
+
return None
|
|
1036
|
+
return np.asarray(var[col])
|
|
1037
|
+
|
|
1038
|
+
alt = f"{ref}_{methbase}_site"
|
|
1039
|
+
if not _has(alt):
|
|
1040
|
+
return None
|
|
1041
|
+
return np.asarray(var[alt])
|
|
1042
|
+
|
|
1043
|
+
def annotate_adata(
|
|
1044
|
+
self,
|
|
1045
|
+
adata,
|
|
1046
|
+
*,
|
|
1047
|
+
prefix: str,
|
|
1048
|
+
X: np.ndarray,
|
|
1049
|
+
coords: np.ndarray,
|
|
1050
|
+
var_mask: np.ndarray,
|
|
1051
|
+
span_fill: bool = True,
|
|
1052
|
+
config=None,
|
|
1053
|
+
decode: str = "marginal",
|
|
1054
|
+
write_posterior: bool = True,
|
|
1055
|
+
posterior_state: str = "Modified",
|
|
1056
|
+
feature_sets: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
1057
|
+
prob_threshold: float = 0.5,
|
|
1058
|
+
uns_key: str = "hmm_appended_layers",
|
|
1059
|
+
uns_flag: str = "hmm_annotated",
|
|
1060
|
+
force_redo: bool = False,
|
|
1061
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1062
|
+
**kwargs,
|
|
1063
|
+
):
|
|
1064
|
+
"""Decode and annotate an AnnData object with HMM-derived layers.
|
|
1065
|
+
|
|
1066
|
+
Args:
|
|
1067
|
+
adata: AnnData to annotate.
|
|
1068
|
+
prefix: Prefix for newly written layers.
|
|
1069
|
+
X: Observations array for decoding.
|
|
1070
|
+
coords: Coordinate array aligned to X.
|
|
1071
|
+
var_mask: Boolean mask for positions in adata.var.
|
|
1072
|
+
span_fill: Whether to fill missing spans.
|
|
1073
|
+
config: Optional config for naming and state selection.
|
|
1074
|
+
decode: Decode method ("marginal" or "viterbi").
|
|
1075
|
+
write_posterior: Whether to write posterior probabilities.
|
|
1076
|
+
posterior_state: State label to write posterior for.
|
|
1077
|
+
feature_sets: Optional feature set definition for size classes.
|
|
1078
|
+
prob_threshold: Posterior probability threshold for binary calls.
|
|
1079
|
+
uns_key: .uns key to track appended layers.
|
|
1080
|
+
uns_flag: .uns flag to mark annotations.
|
|
1081
|
+
force_redo: Whether to overwrite existing layers.
|
|
1082
|
+
device: Device specifier.
|
|
1083
|
+
**kwargs: Additional parameters for specialized workflows.
|
|
1084
|
+
|
|
1085
|
+
Returns:
|
|
1086
|
+
List of created layer names or None if skipped.
|
|
1087
|
+
"""
|
|
1088
|
+
# skip logic
|
|
1089
|
+
if bool(adata.uns.get(uns_flag, False)) and not force_redo:
|
|
1090
|
+
return None
|
|
1091
|
+
|
|
1092
|
+
if adata.uns.get(uns_key) is None:
|
|
1093
|
+
adata.uns[uns_key] = []
|
|
1094
|
+
appended = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
|
|
1095
|
+
|
|
1096
|
+
X = np.asarray(X, dtype=float)
|
|
1097
|
+
coords = np.asarray(coords, dtype=int)
|
|
1098
|
+
var_mask = np.asarray(var_mask, dtype=bool)
|
|
1099
|
+
if var_mask.shape[0] != adata.n_vars:
|
|
1100
|
+
raise ValueError(f"var_mask length {var_mask.shape[0]} != adata.n_vars {adata.n_vars}")
|
|
1101
|
+
|
|
1102
|
+
# decode
|
|
1103
|
+
states, gamma = self.decode(
|
|
1104
|
+
X, coords, decode=decode, device=device
|
|
1105
|
+
) # states (N,L), gamma (N,L,K)
|
|
1106
|
+
N, L = states.shape
|
|
1107
|
+
if N != adata.n_obs:
|
|
1108
|
+
raise ValueError(f"X has N={N} rows but adata.n_obs={adata.n_obs}")
|
|
1109
|
+
|
|
1110
|
+
# map coords -> full-var indices for span_fill
|
|
1111
|
+
full_coords, full_int = _safe_int_coords(adata.var_names)
|
|
1112
|
+
|
|
1113
|
+
# ---- write posterior + states on masked columns only ----
|
|
1114
|
+
masked_idx = np.nonzero(var_mask)[0]
|
|
1115
|
+
masked_coords, _ = _safe_int_coords(adata.var_names[var_mask])
|
|
1116
|
+
|
|
1117
|
+
# build mapping from coords order -> masked column order
|
|
1118
|
+
coord_to_pos_in_decoded = {int(c): i for i, c in enumerate(coords.tolist())}
|
|
1119
|
+
take = np.array(
|
|
1120
|
+
[coord_to_pos_in_decoded.get(int(c), -1) for c in masked_coords.tolist()], dtype=int
|
|
1121
|
+
)
|
|
1122
|
+
good = take >= 0
|
|
1123
|
+
masked_idx = masked_idx[good]
|
|
1124
|
+
take = take[good]
|
|
1125
|
+
|
|
1126
|
+
# states layer
|
|
1127
|
+
states_name = f"{prefix}_states"
|
|
1128
|
+
if states_name not in adata.layers:
|
|
1129
|
+
adata.layers[states_name] = np.full((adata.n_obs, adata.n_vars), -1, dtype=np.int8)
|
|
1130
|
+
adata.layers[states_name][:, masked_idx] = states[:, take].astype(np.int8)
|
|
1131
|
+
if states_name not in appended:
|
|
1132
|
+
appended.append(states_name)
|
|
1133
|
+
|
|
1134
|
+
# posterior layer (requested state)
|
|
1135
|
+
if write_posterior:
|
|
1136
|
+
t_idx = self.resolve_target_state_index(posterior_state)
|
|
1137
|
+
post = gamma[:, :, t_idx].astype(np.float32)
|
|
1138
|
+
post_name = f"{prefix}_posterior_{str(posterior_state).strip().lower().replace(' ', '_').replace('-', '_')}"
|
|
1139
|
+
if post_name not in adata.layers:
|
|
1140
|
+
adata.layers[post_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
|
|
1141
|
+
adata.layers[post_name][:, masked_idx] = post[:, take]
|
|
1142
|
+
if post_name not in appended:
|
|
1143
|
+
appended.append(post_name)
|
|
1144
|
+
|
|
1145
|
+
# ---- feature layers ----
|
|
1146
|
+
if feature_sets is None:
|
|
1147
|
+
cfgd = self._cfg_to_dict(config)
|
|
1148
|
+
feature_sets = normalize_hmm_feature_sets(cfgd.get("hmm_feature_sets", None))
|
|
1149
|
+
|
|
1150
|
+
if not feature_sets:
|
|
1151
|
+
adata.uns[uns_key] = appended
|
|
1152
|
+
adata.uns[uns_flag] = True
|
|
1153
|
+
return None
|
|
1154
|
+
|
|
1155
|
+
# allocate outputs
|
|
1156
|
+
for group, fs in feature_sets.items():
|
|
1157
|
+
fmap = fs.get("features", {}) or {}
|
|
1158
|
+
if not fmap:
|
|
1159
|
+
continue
|
|
1160
|
+
|
|
1161
|
+
all_layer = f"{prefix}_all_{group}_features"
|
|
1162
|
+
if all_layer not in adata.layers:
|
|
1163
|
+
adata.layers[all_layer] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint8)
|
|
1164
|
+
if f"{all_layer}_lengths" not in adata.layers:
|
|
1165
|
+
adata.layers[f"{all_layer}_lengths"] = np.zeros(
|
|
1166
|
+
(adata.n_obs, adata.n_vars), dtype=np.int32
|
|
1167
|
+
)
|
|
1168
|
+
for nm in (all_layer, f"{all_layer}_lengths"):
|
|
1169
|
+
if nm not in appended:
|
|
1170
|
+
appended.append(nm)
|
|
1171
|
+
|
|
1172
|
+
for feat in fmap.keys():
|
|
1173
|
+
nm = f"{prefix}_{feat}"
|
|
1174
|
+
if nm not in adata.layers:
|
|
1175
|
+
adata.layers[nm] = np.zeros(
|
|
1176
|
+
(adata.n_obs, adata.n_vars),
|
|
1177
|
+
dtype=np.int32 if nm.endswith("_lengths") else np.uint8,
|
|
1178
|
+
)
|
|
1179
|
+
if f"{nm}_lengths" not in adata.layers:
|
|
1180
|
+
adata.layers[f"{nm}_lengths"] = np.zeros(
|
|
1181
|
+
(adata.n_obs, adata.n_vars), dtype=np.int32
|
|
1182
|
+
)
|
|
1183
|
+
for outnm in (nm, f"{nm}_lengths"):
|
|
1184
|
+
if outnm not in appended:
|
|
1185
|
+
appended.append(outnm)
|
|
1186
|
+
|
|
1187
|
+
# classify runs per row
|
|
1188
|
+
target_idx = self.resolve_target_state_index(fs.get("state", "Modified"))
|
|
1189
|
+
membership = (
|
|
1190
|
+
(states == target_idx)
|
|
1191
|
+
if str(decode).lower() == "viterbi"
|
|
1192
|
+
else (gamma[:, :, target_idx] >= float(prob_threshold))
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
for i in range(N):
|
|
1196
|
+
runs = self._runs_from_bool(membership[i].astype(bool))
|
|
1197
|
+
for s, e in runs:
|
|
1198
|
+
# genomic length in coords space
|
|
1199
|
+
glen = int(coords[e - 1]) - int(coords[s]) + 1 if e > s else 0
|
|
1200
|
+
if glen <= 0:
|
|
855
1201
|
continue
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
1202
|
+
|
|
1203
|
+
# pick feature bin
|
|
1204
|
+
chosen = None
|
|
1205
|
+
for feat_name, (lo, hi) in fmap.items():
|
|
1206
|
+
if float(lo) <= float(glen) < float(hi):
|
|
1207
|
+
chosen = feat_name
|
|
1208
|
+
break
|
|
1209
|
+
if chosen is None:
|
|
1210
|
+
continue
|
|
1211
|
+
|
|
1212
|
+
# convert span to indices in full var grid
|
|
1213
|
+
if span_fill and full_int:
|
|
1214
|
+
left = int(np.searchsorted(full_coords, int(coords[s]), side="left"))
|
|
1215
|
+
right = int(np.searchsorted(full_coords, int(coords[e - 1]), side="right"))
|
|
1216
|
+
if left >= right:
|
|
1217
|
+
continue
|
|
1218
|
+
adata.layers[f"{prefix}_{chosen}"][i, left:right] = 1
|
|
1219
|
+
adata.layers[f"{prefix}_all_{group}_features"][i, left:right] = 1
|
|
862
1220
|
else:
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1221
|
+
# only fill at masked indices
|
|
1222
|
+
cols = masked_idx[
|
|
1223
|
+
(masked_coords >= coords[s]) & (masked_coords <= coords[e - 1])
|
|
1224
|
+
]
|
|
1225
|
+
if cols.size == 0:
|
|
1226
|
+
continue
|
|
1227
|
+
adata.layers[f"{prefix}_{chosen}"][i, cols] = 1
|
|
1228
|
+
adata.layers[f"{prefix}_all_{group}_features"][i, cols] = 1
|
|
1229
|
+
|
|
1230
|
+
# lengths derived from binary
|
|
1231
|
+
adata.layers[f"{prefix}_all_{group}_features_lengths"] = (
|
|
1232
|
+
self._write_lengths_for_binary_layer(
|
|
1233
|
+
np.asarray(adata.layers[f"{prefix}_all_{group}_features"])
|
|
1234
|
+
)
|
|
1235
|
+
)
|
|
1236
|
+
for feat in fmap.keys():
|
|
1237
|
+
nm = f"{prefix}_{feat}"
|
|
1238
|
+
adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
|
|
1239
|
+
np.asarray(adata.layers[nm])
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
adata.uns[uns_key] = appended
|
|
1243
|
+
adata.uns[uns_flag] = True
|
|
1244
|
+
return None
|
|
1245
|
+
|
|
1246
|
+
# ------------------------- row-copy helper (workflow uses it) -------------------------
|
|
1247
|
+
|
|
1248
|
+
def _ensure_final_layer_and_assign(
|
|
1249
|
+
self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data
|
|
1250
|
+
):
|
|
1251
|
+
"""
|
|
1252
|
+
Assign rows from sub_data into final_adata.layers[layer_name] for rows where subset_idx_mask is True.
|
|
1253
|
+
Handles dense arrays. If you want sparse support, add it here.
|
|
1254
|
+
"""
|
|
1255
|
+
n_final_obs, n_vars = final_adata.shape
|
|
1256
|
+
final_rows = np.nonzero(np.asarray(subset_idx_mask).astype(bool))[0]
|
|
1257
|
+
sub_arr = np.asarray(sub_data)
|
|
1258
|
+
|
|
1259
|
+
if layer_name not in final_adata.layers:
|
|
1260
|
+
final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
|
|
1261
|
+
|
|
1262
|
+
final_arr = np.asarray(final_adata.layers[layer_name])
|
|
1263
|
+
if sub_arr.shape[0] != final_rows.size:
|
|
1264
|
+
raise ValueError(f"Sub rows {sub_arr.shape[0]} != mask sum {final_rows.size}")
|
|
1265
|
+
final_arr[final_rows, :] = sub_arr
|
|
1266
|
+
final_adata.layers[layer_name] = final_arr
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
# =============================================================================
|
|
1270
|
+
# Single-channel Bernoulli HMM
|
|
1271
|
+
# =============================================================================
|
|
1272
|
+
|
|
1273
|
+
|
|
1274
|
+
@register_hmm("single")
|
|
1275
|
+
class SingleBernoulliHMM(BaseHMM):
|
|
1276
|
+
"""
|
|
1277
|
+
Bernoulli emission per state:
|
|
1278
|
+
emission[k] = P(obs==1 | state=k)
|
|
1279
|
+
"""
|
|
1280
|
+
|
|
1281
|
+
def __init__(
|
|
1282
|
+
self,
|
|
1283
|
+
n_states: int = 2,
|
|
1284
|
+
init_emission: Optional[Sequence[float]] = None,
|
|
1285
|
+
eps: float = 1e-8,
|
|
1286
|
+
dtype: torch.dtype = torch.float64,
|
|
1287
|
+
):
|
|
1288
|
+
"""Initialize a single-channel Bernoulli HMM.
|
|
1289
|
+
|
|
1290
|
+
Args:
|
|
1291
|
+
n_states: Number of hidden states.
|
|
1292
|
+
init_emission: Initial emission probabilities per state.
|
|
1293
|
+
eps: Smoothing epsilon for probabilities.
|
|
1294
|
+
dtype: Torch dtype for parameters.
|
|
1295
|
+
"""
|
|
1296
|
+
super().__init__(n_states=n_states, eps=eps, dtype=dtype)
|
|
1297
|
+
if init_emission is None:
|
|
1298
|
+
em = np.full((self.n_states,), 0.5, dtype=float)
|
|
1299
|
+
else:
|
|
1300
|
+
em = np.asarray(init_emission, dtype=float).reshape(-1)[: self.n_states]
|
|
1301
|
+
if em.size != self.n_states:
|
|
1302
|
+
em = np.full((self.n_states,), 0.5, dtype=float)
|
|
1303
|
+
|
|
1304
|
+
self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
|
|
1305
|
+
self._normalize_emission()
|
|
1306
|
+
|
|
1307
|
+
@classmethod
|
|
1308
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1309
|
+
"""Create a single-channel Bernoulli HMM from config.
|
|
1310
|
+
|
|
1311
|
+
Args:
|
|
1312
|
+
cfg: Configuration mapping or object.
|
|
1313
|
+
override: Override values to apply.
|
|
1314
|
+
device: Optional device specifier.
|
|
1315
|
+
|
|
1316
|
+
Returns:
|
|
1317
|
+
Initialized SingleBernoulliHMM instance.
|
|
1318
|
+
"""
|
|
1319
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1320
|
+
if override:
|
|
1321
|
+
merged.update(override)
|
|
1322
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1323
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1324
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1325
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1326
|
+
init_em = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
|
|
1327
|
+
model = cls(n_states=n_states, init_emission=init_em, eps=eps, dtype=dtype)
|
|
1328
|
+
if device is not None:
|
|
1329
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1330
|
+
model._persisted_cfg = merged
|
|
1331
|
+
return model
|
|
1332
|
+
|
|
1333
|
+
def _normalize_emission(self):
|
|
1334
|
+
"""Normalize and clamp emission probabilities in-place."""
|
|
1335
|
+
with torch.no_grad():
|
|
1336
|
+
self.emission.data = self.emission.data.reshape(-1)
|
|
1337
|
+
if self.emission.data.numel() != self.n_states:
|
|
1338
|
+
self.emission.data = torch.full(
|
|
1339
|
+
(self.n_states,), 0.5, dtype=self.dtype, device=self.emission.device
|
|
1340
|
+
)
|
|
1341
|
+
self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
|
|
1342
|
+
|
|
1343
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1344
|
+
"""Move emission parameters to the requested device/dtype."""
|
|
1345
|
+
device = super()._ensure_device_dtype(device)
|
|
1346
|
+
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
1347
|
+
return device
|
|
1348
|
+
|
|
1349
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
1350
|
+
"""Return per-state modified scores for ranking."""
|
|
1351
|
+
return self.emission.detach()
|
|
1352
|
+
|
|
1353
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
1354
|
+
"""
|
|
1355
|
+
obs: (N,L), mask: (N,L) -> logB: (N,L,K)
|
|
1356
|
+
"""
|
|
1357
|
+
p = self.emission # (K,)
|
|
1358
|
+
logp = torch.log(p + self.eps)
|
|
1359
|
+
log1mp = torch.log1p(-p + self.eps)
|
|
1360
|
+
|
|
1361
|
+
o = obs.unsqueeze(-1) # (N,L,1)
|
|
1362
|
+
logB = o * logp.view(1, 1, -1) + (1.0 - o) * log1mp.view(1, 1, -1)
|
|
1363
|
+
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
1364
|
+
return logB
|
|
1365
|
+
|
|
1366
|
+
def _extra_save_payload(self) -> dict:
|
|
1367
|
+
"""Return extra payload data for serialization."""
|
|
1368
|
+
return {"emission": self.emission.detach().cpu()}
|
|
1369
|
+
|
|
1370
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1371
|
+
"""Load serialized emission parameters.
|
|
1372
|
+
|
|
1373
|
+
Args:
|
|
1374
|
+
payload: Serialized payload dictionary.
|
|
1375
|
+
device: Target torch device.
|
|
1376
|
+
"""
|
|
1377
|
+
with torch.no_grad():
|
|
1378
|
+
self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
|
|
1379
|
+
self._normalize_emission()
|
|
1380
|
+
|
|
1381
|
+
def fit_em(
|
|
1382
|
+
self,
|
|
1383
|
+
X: np.ndarray,
|
|
1384
|
+
coords: np.ndarray,
|
|
1385
|
+
*,
|
|
1386
|
+
device: torch.device,
|
|
1387
|
+
max_iter: int,
|
|
1388
|
+
tol: float,
|
|
1389
|
+
update_start: bool,
|
|
1390
|
+
update_trans: bool,
|
|
1391
|
+
update_emission: bool,
|
|
1392
|
+
verbose: bool,
|
|
1393
|
+
**kwargs,
|
|
1394
|
+
) -> List[float]:
|
|
1395
|
+
"""Run EM updates for a single-channel Bernoulli HMM.
|
|
1396
|
+
|
|
1397
|
+
Args:
|
|
1398
|
+
X: Observations array (N, L).
|
|
1399
|
+
coords: Coordinate array aligned to X.
|
|
1400
|
+
device: Torch device.
|
|
1401
|
+
max_iter: Maximum iterations.
|
|
1402
|
+
tol: Convergence tolerance.
|
|
1403
|
+
update_start: Whether to update start probabilities.
|
|
1404
|
+
update_trans: Whether to update transitions.
|
|
1405
|
+
update_emission: Whether to update emission parameters.
|
|
1406
|
+
verbose: Whether to log progress.
|
|
1407
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
1408
|
+
|
|
1409
|
+
Returns:
|
|
1410
|
+
List of log-likelihood proxy values.
|
|
1411
|
+
"""
|
|
1412
|
+
X = np.asarray(X, dtype=float)
|
|
1413
|
+
if X.ndim != 2:
|
|
1414
|
+
raise ValueError("SingleBernoulliHMM expects X shape (N,L).")
|
|
1415
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
1416
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
1417
|
+
|
|
1418
|
+
eps = float(self.eps)
|
|
1419
|
+
K = self.n_states
|
|
1420
|
+
N, L = obs.shape
|
|
1421
|
+
|
|
1422
|
+
hist: List[float] = []
|
|
1423
|
+
for it in range(1, int(max_iter) + 1):
|
|
1424
|
+
gamma = self._forward_backward(obs, mask) # (N,L,K)
|
|
1425
|
+
|
|
1426
|
+
# log-likelihood proxy
|
|
1427
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
1428
|
+
hist.append(ll_proxy)
|
|
1429
|
+
|
|
1430
|
+
# expected start
|
|
1431
|
+
start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
|
|
1432
|
+
|
|
1433
|
+
# expected transitions xi
|
|
1434
|
+
logB = self._log_emission(obs, mask)
|
|
1435
|
+
logA = torch.log(self.trans + eps)
|
|
1436
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1437
|
+
alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
|
|
1438
|
+
for t in range(1, L):
|
|
1439
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
1440
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
1441
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1442
|
+
beta[:, L - 1, :] = 0.0
|
|
1443
|
+
for t in range(L - 2, -1, -1):
|
|
1444
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1445
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
1446
|
+
|
|
1447
|
+
trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
1448
|
+
for t in range(L - 1):
|
|
1449
|
+
valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
|
|
1450
|
+
log_xi = (
|
|
1451
|
+
alpha[:, t, :].unsqueeze(2)
|
|
1452
|
+
+ logA.unsqueeze(0)
|
|
1453
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1454
|
+
)
|
|
1455
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
1456
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
1457
|
+
trans_acc += xi.sum(dim=0)
|
|
1458
|
+
|
|
1459
|
+
# emission update
|
|
1460
|
+
mask_f = mask.float().unsqueeze(-1) # (N,L,1)
|
|
1461
|
+
emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
|
|
1462
|
+
emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
|
|
1463
|
+
|
|
1464
|
+
with torch.no_grad():
|
|
1465
|
+
if update_start:
|
|
1466
|
+
new_start = start_acc + eps
|
|
1467
|
+
self.start.data = new_start / new_start.sum()
|
|
1468
|
+
|
|
1469
|
+
if update_trans:
|
|
1470
|
+
new_trans = trans_acc + eps
|
|
1471
|
+
rs = new_trans.sum(dim=1, keepdim=True)
|
|
1472
|
+
rs[rs == 0.0] = 1.0
|
|
1473
|
+
self.trans.data = new_trans / rs
|
|
1474
|
+
|
|
1475
|
+
if update_emission:
|
|
1476
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
1477
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
1478
|
+
|
|
1479
|
+
self._normalize_params()
|
|
1480
|
+
self._normalize_emission()
|
|
923
1481
|
|
|
924
|
-
if not feature_sets:
|
|
925
1482
|
if verbose:
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
adata = adata.copy()
|
|
935
|
-
|
|
936
|
-
# prepare column names
|
|
937
|
-
all_features = []
|
|
938
|
-
combined_prefix = "Combined"
|
|
939
|
-
for key, fs in feature_sets.items():
|
|
940
|
-
feats = fs.get("features", {})
|
|
941
|
-
if key == "cpg":
|
|
942
|
-
all_features += [f"CpG_{f}" for f in feats]
|
|
943
|
-
all_features.append(f"CpG_all_{key}_features")
|
|
944
|
-
else:
|
|
945
|
-
for methbase in methbases:
|
|
946
|
-
all_features += [f"{methbase}_{f}" for f in feats]
|
|
947
|
-
all_features.append(f"{methbase}_all_{key}_features")
|
|
948
|
-
if len(methbases) > 1:
|
|
949
|
-
all_features += [f"{combined_prefix}_{f}" for f in feats]
|
|
950
|
-
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
951
|
-
|
|
952
|
-
# initialize obs columns (unique lists per row)
|
|
953
|
-
n_rows = adata.shape[0]
|
|
954
|
-
for feature in all_features:
|
|
955
|
-
if feature not in adata.obs.columns:
|
|
956
|
-
adata.obs[feature] = [[] for _ in range(n_rows)]
|
|
957
|
-
if f"{feature}_distances" not in adata.obs.columns:
|
|
958
|
-
adata.obs[f"{feature}_distances"] = [None] * n_rows
|
|
959
|
-
if f"n_{feature}" not in adata.obs.columns:
|
|
960
|
-
adata.obs[f"n_{feature}"] = -1
|
|
961
|
-
|
|
962
|
-
appended_layers: List[str] = []
|
|
963
|
-
|
|
964
|
-
# device management
|
|
965
|
-
if device is None:
|
|
966
|
-
device = next(self.parameters()).device
|
|
967
|
-
elif isinstance(device, str):
|
|
968
|
-
device = _torch.device(device)
|
|
969
|
-
self.to(device)
|
|
1483
|
+
logger.info(
|
|
1484
|
+
"[SingleBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
|
|
1485
|
+
it,
|
|
1486
|
+
hist[-1],
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
1490
|
+
break
|
|
970
1491
|
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
1492
|
+
return hist
|
|
1493
|
+
|
|
1494
|
+
def adapt_emissions(
|
|
1495
|
+
self,
|
|
1496
|
+
X: np.ndarray,
|
|
1497
|
+
coords: Optional[np.ndarray] = None,
|
|
1498
|
+
*,
|
|
1499
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1500
|
+
iters: Optional[int] = None,
|
|
1501
|
+
max_iter: Optional[int] = None, # alias for your trainer
|
|
1502
|
+
verbose: bool = False,
|
|
1503
|
+
**kwargs,
|
|
1504
|
+
):
|
|
1505
|
+
"""Adapt emissions with legacy parameter names.
|
|
1506
|
+
|
|
1507
|
+
Args:
|
|
1508
|
+
X: Observations array.
|
|
1509
|
+
coords: Optional coordinate array.
|
|
1510
|
+
device: Device specifier.
|
|
1511
|
+
iters: Number of iterations.
|
|
1512
|
+
max_iter: Alias for iters.
|
|
1513
|
+
verbose: Whether to log progress.
|
|
1514
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
1515
|
+
|
|
1516
|
+
Returns:
|
|
1517
|
+
List of log-likelihood values.
|
|
1518
|
+
"""
|
|
1519
|
+
if iters is None:
|
|
1520
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
1521
|
+
return super().adapt_emissions(
|
|
1522
|
+
np.asarray(X, dtype=float),
|
|
1523
|
+
coords if coords is not None else None,
|
|
1524
|
+
iters=int(iters),
|
|
1525
|
+
device=device,
|
|
1526
|
+
verbose=verbose,
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
|
|
1530
|
+
# =============================================================================
|
|
1531
|
+
# Multi-channel Bernoulli HMM (union coordinate grid)
|
|
1532
|
+
# =============================================================================
|
|
1533
|
+
|
|
1534
|
+
|
|
1535
|
+
@register_hmm("multi")
|
|
1536
|
+
class MultiBernoulliHMM(BaseHMM):
|
|
1537
|
+
"""
|
|
1538
|
+
Multi-channel independent Bernoulli:
|
|
1539
|
+
emission[k,c] = P(obs_c==1 | state=k)
|
|
1540
|
+
X must be (N,L,C) on a union coordinate grid; NaN per-channel allowed.
|
|
1541
|
+
"""
|
|
1542
|
+
|
|
1543
|
+
def __init__(
|
|
1544
|
+
self,
|
|
1545
|
+
n_states: int = 2,
|
|
1546
|
+
n_channels: int = 2,
|
|
1547
|
+
init_emission: Optional[Any] = None,
|
|
1548
|
+
eps: float = 1e-8,
|
|
1549
|
+
dtype: torch.dtype = torch.float64,
|
|
1550
|
+
):
|
|
1551
|
+
"""Initialize a multi-channel Bernoulli HMM.
|
|
1552
|
+
|
|
1553
|
+
Args:
|
|
1554
|
+
n_states: Number of hidden states.
|
|
1555
|
+
n_channels: Number of observed channels.
|
|
1556
|
+
init_emission: Initial emission probabilities.
|
|
1557
|
+
eps: Smoothing epsilon for probabilities.
|
|
1558
|
+
dtype: Torch dtype for parameters.
|
|
1559
|
+
"""
|
|
1560
|
+
super().__init__(n_states=n_states, eps=eps, dtype=dtype)
|
|
1561
|
+
self.n_channels = int(n_channels)
|
|
1562
|
+
if self.n_channels < 1:
|
|
1563
|
+
raise ValueError("n_channels must be >=1")
|
|
1564
|
+
|
|
1565
|
+
if init_emission is None:
|
|
1566
|
+
em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
|
|
1567
|
+
else:
|
|
1568
|
+
arr = np.asarray(init_emission, dtype=float)
|
|
974
1569
|
if arr.ndim == 1:
|
|
975
|
-
arr = arr
|
|
976
|
-
|
|
977
|
-
# squeeze trailing singletons
|
|
978
|
-
while arr.ndim > 2 and arr.shape[-1] == 1:
|
|
979
|
-
arr = _np.squeeze(arr, axis=-1)
|
|
980
|
-
if arr.ndim != 2:
|
|
981
|
-
raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
|
|
982
|
-
return arr
|
|
983
|
-
|
|
984
|
-
def calculate_batch_distances(intervals_list, threshold_local=0.9):
|
|
985
|
-
results_local = []
|
|
986
|
-
for intervals in intervals_list:
|
|
987
|
-
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
988
|
-
results_local.append([])
|
|
989
|
-
continue
|
|
990
|
-
valid = [iv for iv in intervals if iv[2] > threshold_local]
|
|
991
|
-
if len(valid) <= 1:
|
|
992
|
-
results_local.append([])
|
|
993
|
-
continue
|
|
994
|
-
valid = sorted(valid, key=lambda x: x[0])
|
|
995
|
-
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
996
|
-
results_local.append(dists)
|
|
997
|
-
return results_local
|
|
998
|
-
|
|
999
|
-
def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
|
|
1000
|
-
# Accept numpy arrays or torch tensors
|
|
1001
|
-
if isinstance(predicted_states_batch, _torch.Tensor):
|
|
1002
|
-
pred_np = predicted_states_batch.detach().cpu().numpy()
|
|
1570
|
+
arr = arr.reshape(-1, 1)
|
|
1571
|
+
em = np.repeat(arr[: self.n_states, :], self.n_channels, axis=1)
|
|
1003
1572
|
else:
|
|
1004
|
-
|
|
1005
|
-
if
|
|
1006
|
-
|
|
1007
|
-
else:
|
|
1008
|
-
probs_np = _np.asarray(probabilities_batch)
|
|
1009
|
-
|
|
1010
|
-
batch_size, L = pred_np.shape
|
|
1011
|
-
all_classifications_local = []
|
|
1012
|
-
# allow caller to pass arbitrary state labels mapping; default two-state mapping:
|
|
1013
|
-
state_labels = ["Non-Modified", "Modified"]
|
|
1014
|
-
try:
|
|
1015
|
-
target_idx = state_labels.index(target_state)
|
|
1016
|
-
except ValueError:
|
|
1017
|
-
target_idx = 1 # fallback
|
|
1018
|
-
|
|
1019
|
-
for b in range(batch_size):
|
|
1020
|
-
predicted_states = pred_np[b]
|
|
1021
|
-
probabilities = probs_np[b]
|
|
1022
|
-
regions = []
|
|
1023
|
-
current_start, current_length, current_probs = None, 0, []
|
|
1024
|
-
for i, state_index in enumerate(predicted_states):
|
|
1025
|
-
state_prob = float(probabilities[i][state_index])
|
|
1026
|
-
if state_index == target_idx:
|
|
1027
|
-
if current_start is None:
|
|
1028
|
-
current_start = i
|
|
1029
|
-
current_length += 1
|
|
1030
|
-
current_probs.append(state_prob)
|
|
1031
|
-
elif current_start is not None:
|
|
1032
|
-
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1033
|
-
current_start, current_length, current_probs = None, 0, []
|
|
1034
|
-
if current_start is not None:
|
|
1035
|
-
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1036
|
-
|
|
1037
|
-
final = []
|
|
1038
|
-
for start, length, prob in regions:
|
|
1039
|
-
# compute genomic length try/catch
|
|
1040
|
-
try:
|
|
1041
|
-
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
1042
|
-
except Exception:
|
|
1043
|
-
feature_length = int(length)
|
|
1044
|
-
|
|
1045
|
-
# classification_mapping values are (lo, hi) tuples or lists
|
|
1046
|
-
label = None
|
|
1047
|
-
for ftype, rng in classification_mapping.items():
|
|
1048
|
-
lo, hi = rng[0], rng[1]
|
|
1049
|
-
try:
|
|
1050
|
-
if lo <= feature_length < hi:
|
|
1051
|
-
label = ftype
|
|
1052
|
-
break
|
|
1053
|
-
except Exception:
|
|
1054
|
-
continue
|
|
1055
|
-
if label is None:
|
|
1056
|
-
# fallback to first mapping key or 'unknown'
|
|
1057
|
-
label = next(iter(classification_mapping.keys()), "feature")
|
|
1058
|
-
|
|
1059
|
-
# Store reported start coordinate in same coordinate system as `coordinates`.
|
|
1060
|
-
try:
|
|
1061
|
-
genomic_start = int(coordinates[start])
|
|
1062
|
-
except Exception:
|
|
1063
|
-
genomic_start = int(start)
|
|
1064
|
-
final.append((genomic_start, feature_length, label, prob))
|
|
1065
|
-
all_classifications_local.append(final)
|
|
1066
|
-
return all_classifications_local
|
|
1067
|
-
|
|
1068
|
-
# -----------------------------------------------------------------------
|
|
1069
|
-
|
|
1070
|
-
# Ensure obs_column is categorical-like for iteration
|
|
1071
|
-
sseries = adata.obs[obs_column]
|
|
1072
|
-
if not pd.api.types.is_categorical_dtype(sseries):
|
|
1073
|
-
sseries = sseries.astype("category")
|
|
1074
|
-
references = list(sseries.cat.categories)
|
|
1075
|
-
|
|
1076
|
-
ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
|
|
1077
|
-
for ref in ref_iter:
|
|
1078
|
-
# subset reads with this obs_column value
|
|
1079
|
-
ref_mask = adata.obs[obs_column] == ref
|
|
1080
|
-
ref_subset = adata[ref_mask].copy()
|
|
1081
|
-
combined_mask = None
|
|
1082
|
-
|
|
1083
|
-
# per-methbase processing
|
|
1084
|
-
for methbase in methbases:
|
|
1085
|
-
key_lower = methbase.strip().lower()
|
|
1086
|
-
|
|
1087
|
-
# map several common synonyms -> canonical lookup
|
|
1088
|
-
if key_lower in ("a",):
|
|
1089
|
-
pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
|
|
1090
|
-
elif key_lower in ("c", "any_c", "anyc", "any-c"):
|
|
1091
|
-
# unify 'C' or 'any_C' names to the any_C var column
|
|
1092
|
-
pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
|
|
1093
|
-
elif key_lower in ("gpc", "gpc_site", "gpc-site"):
|
|
1094
|
-
pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
|
|
1095
|
-
elif key_lower in ("cpg", "cpg_site", "cpg-site"):
|
|
1096
|
-
pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
|
|
1097
|
-
else:
|
|
1098
|
-
# try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
|
|
1099
|
-
alt_col = f"{ref}_{methbase}_site"
|
|
1100
|
-
pos_mask = ref_subset.var.get(alt_col, None)
|
|
1573
|
+
em = arr[: self.n_states, : self.n_channels]
|
|
1574
|
+
if em.shape != (self.n_states, self.n_channels):
|
|
1575
|
+
em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
|
|
1101
1576
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
|
|
1577
|
+
self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
|
|
1578
|
+
self._normalize_emission()
|
|
1105
1579
|
|
|
1106
|
-
|
|
1107
|
-
|
|
1580
|
+
@classmethod
|
|
1581
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1582
|
+
"""Create a multi-channel Bernoulli HMM from config.
|
|
1108
1583
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
n_reads = matrix.shape[0]
|
|
1584
|
+
Args:
|
|
1585
|
+
cfg: Configuration mapping or object.
|
|
1586
|
+
override: Override values to apply.
|
|
1587
|
+
device: Optional device specifier.
|
|
1114
1588
|
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
if use_viterbi:
|
|
1135
|
-
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1136
|
-
pred_states = _np.asarray(paths)
|
|
1137
|
-
else:
|
|
1138
|
-
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1589
|
+
Returns:
|
|
1590
|
+
Initialized MultiBernoulliHMM instance.
|
|
1591
|
+
"""
|
|
1592
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1593
|
+
if override:
|
|
1594
|
+
merged.update(override)
|
|
1595
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1596
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1597
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1598
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1599
|
+
n_channels = int(merged.get("hmm_n_channels", merged.get("n_channels", 2)))
|
|
1600
|
+
init_em = merged.get("hmm_init_emission_probs", None)
|
|
1601
|
+
model = cls(
|
|
1602
|
+
n_states=n_states, n_channels=n_channels, init_emission=init_em, eps=eps, dtype=dtype
|
|
1603
|
+
)
|
|
1604
|
+
if device is not None:
|
|
1605
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1606
|
+
model._persisted_cfg = merged
|
|
1607
|
+
return model
|
|
1139
1608
|
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
row_indices = list(comb.obs.index[start_idx:stop_idx])
|
|
1193
|
-
for i_local, idx in enumerate(row_indices):
|
|
1194
|
-
for start, length, label, prob in classifications[i_local]:
|
|
1195
|
-
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
1196
|
-
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
1197
|
-
|
|
1198
|
-
# CpG special handling
|
|
1199
|
-
if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
|
|
1200
|
-
cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
|
|
1201
|
-
for ref in cpg_iter:
|
|
1202
|
-
ref_mask = adata.obs[obs_column] == ref
|
|
1203
|
-
ref_subset = adata[ref_mask].copy()
|
|
1204
|
-
pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
|
|
1205
|
-
if pos_mask.sum() == 0:
|
|
1206
|
-
continue
|
|
1207
|
-
cpg_sub = ref_subset[:, pos_mask]
|
|
1208
|
-
matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
|
|
1209
|
-
matrix = _ensure_2d_array_like(matrix)
|
|
1210
|
-
n_reads = matrix.shape[0]
|
|
1211
|
-
try:
|
|
1212
|
-
coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
|
|
1213
|
-
except Exception:
|
|
1214
|
-
coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
|
|
1215
|
-
|
|
1216
|
-
chunk_iter = range(0, n_reads, batch_size)
|
|
1217
|
-
if verbose:
|
|
1218
|
-
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
|
|
1219
|
-
for start_idx in chunk_iter:
|
|
1220
|
-
stop_idx = min(n_reads, start_idx + batch_size)
|
|
1221
|
-
chunk = matrix[start_idx:stop_idx]
|
|
1222
|
-
seqs = chunk.tolist()
|
|
1223
|
-
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1224
|
-
if len(gammas) == 0:
|
|
1225
|
-
continue
|
|
1226
|
-
probs_batch = _np.stack(gammas, axis=0)
|
|
1227
|
-
if use_viterbi:
|
|
1228
|
-
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1229
|
-
pred_states = _np.asarray(paths)
|
|
1230
|
-
else:
|
|
1231
|
-
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1232
|
-
|
|
1233
|
-
fs = feature_sets["cpg"]
|
|
1234
|
-
state_target = fs.get("state", "Modified")
|
|
1235
|
-
feature_map = fs.get("features", {})
|
|
1236
|
-
classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
|
|
1237
|
-
row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
|
|
1238
|
-
for i_local, idx in enumerate(row_indices):
|
|
1239
|
-
for start, length, label, prob in classifications[i_local]:
|
|
1240
|
-
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
1241
|
-
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
1242
|
-
|
|
1243
|
-
# finalize: convert intervals into binary layers and distances
|
|
1244
|
-
try:
|
|
1245
|
-
coordinates = _np.asarray(adata.var_names, dtype=int)
|
|
1246
|
-
coords_are_ints = True
|
|
1247
|
-
except Exception:
|
|
1248
|
-
coordinates = _np.arange(adata.shape[1], dtype=int)
|
|
1249
|
-
coords_are_ints = False
|
|
1250
|
-
|
|
1251
|
-
features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
|
|
1252
|
-
for feature in features_iter:
|
|
1253
|
-
bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1254
|
-
counts = _np.zeros(adata.shape[0], dtype=int)
|
|
1255
|
-
|
|
1256
|
-
# new: integer-length layer (0 where not inside a feature)
|
|
1257
|
-
len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1258
|
-
|
|
1259
|
-
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
1260
|
-
if not isinstance(intervals, list):
|
|
1261
|
-
intervals = []
|
|
1262
|
-
for start, length, prob in intervals:
|
|
1263
|
-
if prob > threshold:
|
|
1264
|
-
if coords_are_ints:
|
|
1265
|
-
# map genomic start/length into index interval [start_idx, end_idx)
|
|
1266
|
-
start_idx = _np.searchsorted(coordinates, int(start), side="left")
|
|
1267
|
-
end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
|
|
1268
|
-
else:
|
|
1269
|
-
start_idx = int(start)
|
|
1270
|
-
end_idx = start_idx + int(length)
|
|
1271
|
-
|
|
1272
|
-
start_idx = max(0, min(start_idx, adata.shape[1]))
|
|
1273
|
-
end_idx = max(0, min(end_idx, adata.shape[1]))
|
|
1274
|
-
|
|
1275
|
-
if start_idx < end_idx:
|
|
1276
|
-
span = end_idx - start_idx # number of positions covered
|
|
1277
|
-
# set binary mask
|
|
1278
|
-
bin_matrix[row_idx, start_idx:end_idx] = 1
|
|
1279
|
-
# set length mask: use maximum in case of overlaps
|
|
1280
|
-
existing = len_matrix[row_idx, start_idx:end_idx]
|
|
1281
|
-
len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
|
|
1282
|
-
counts[row_idx] += 1
|
|
1283
|
-
|
|
1284
|
-
# write binary layer and length layer, track appended names
|
|
1285
|
-
adata.layers[feature] = bin_matrix
|
|
1286
|
-
appended_layers.append(feature)
|
|
1287
|
-
|
|
1288
|
-
# name the integer-length layer (choose suffix you like)
|
|
1289
|
-
length_layer_name = f"{feature}_lengths"
|
|
1290
|
-
adata.layers[length_layer_name] = len_matrix
|
|
1291
|
-
appended_layers.append(length_layer_name)
|
|
1292
|
-
|
|
1293
|
-
adata.obs[f"n_{feature}"] = counts
|
|
1294
|
-
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
1295
|
-
|
|
1296
|
-
# Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
|
|
1297
|
-
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
|
|
1298
|
-
new_list = existing + [l for l in appended_layers if l not in existing]
|
|
1299
|
-
adata.uns[uns_key] = new_list
|
|
1300
|
-
|
|
1301
|
-
return None if in_place else adata
|
|
1302
|
-
|
|
1303
|
-
def merge_intervals_in_layer(
|
|
1609
|
+
def _normalize_emission(self):
|
|
1610
|
+
"""Normalize and clamp emission probabilities in-place."""
|
|
1611
|
+
with torch.no_grad():
|
|
1612
|
+
self.emission.data = self.emission.data.reshape(self.n_states, self.n_channels)
|
|
1613
|
+
self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
|
|
1614
|
+
|
|
1615
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1616
|
+
"""Move emission parameters to the requested device/dtype."""
|
|
1617
|
+
device = super()._ensure_device_dtype(device)
|
|
1618
|
+
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
1619
|
+
return device
|
|
1620
|
+
|
|
1621
|
+
def _state_modified_score(self) -> torch.Tensor:
|
|
1622
|
+
"""Return per-state modified scores for ranking."""
|
|
1623
|
+
# more “modified” = higher mean P(1) across channels
|
|
1624
|
+
return self.emission.detach().mean(dim=1)
|
|
1625
|
+
|
|
1626
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
1627
|
+
"""
|
|
1628
|
+
obs: (N,L,C), mask: (N,L,C) -> logB: (N,L,K)
|
|
1629
|
+
"""
|
|
1630
|
+
N, L, C = obs.shape
|
|
1631
|
+
K = self.n_states
|
|
1632
|
+
|
|
1633
|
+
p = self.emission # (K,C)
|
|
1634
|
+
logp = torch.log(p + self.eps).view(1, 1, K, C)
|
|
1635
|
+
log1mp = torch.log1p(-p + self.eps).view(1, 1, K, C)
|
|
1636
|
+
|
|
1637
|
+
o = obs.unsqueeze(2) # (N,L,1,C)
|
|
1638
|
+
m = mask.unsqueeze(2) # (N,L,1,C)
|
|
1639
|
+
|
|
1640
|
+
logBC = o * logp + (1.0 - o) * log1mp
|
|
1641
|
+
logBC = torch.where(m, logBC, torch.zeros_like(logBC))
|
|
1642
|
+
return logBC.sum(dim=3) # sum channels -> (N,L,K)
|
|
1643
|
+
|
|
1644
|
+
def _extra_save_payload(self) -> dict:
|
|
1645
|
+
"""Return extra payload data for serialization."""
|
|
1646
|
+
return {"n_channels": int(self.n_channels), "emission": self.emission.detach().cpu()}
|
|
1647
|
+
|
|
1648
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1649
|
+
"""Load serialized emission parameters.
|
|
1650
|
+
|
|
1651
|
+
Args:
|
|
1652
|
+
payload: Serialized payload dictionary.
|
|
1653
|
+
device: Target torch device.
|
|
1654
|
+
"""
|
|
1655
|
+
self.n_channels = int(payload.get("n_channels", self.n_channels))
|
|
1656
|
+
with torch.no_grad():
|
|
1657
|
+
self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
|
|
1658
|
+
self._normalize_emission()
|
|
1659
|
+
|
|
1660
|
+
def fit_em(
|
|
1304
1661
|
self,
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1662
|
+
X: np.ndarray,
|
|
1663
|
+
coords: np.ndarray,
|
|
1664
|
+
*,
|
|
1665
|
+
device: torch.device,
|
|
1666
|
+
max_iter: int,
|
|
1667
|
+
tol: float,
|
|
1668
|
+
update_start: bool,
|
|
1669
|
+
update_trans: bool,
|
|
1670
|
+
update_emission: bool,
|
|
1671
|
+
verbose: bool,
|
|
1672
|
+
**kwargs,
|
|
1673
|
+
) -> List[float]:
|
|
1674
|
+
"""Run EM updates for a multi-channel Bernoulli HMM.
|
|
1675
|
+
|
|
1676
|
+
Args:
|
|
1677
|
+
X: Observations array (N, L, C).
|
|
1678
|
+
coords: Coordinate array aligned to X.
|
|
1679
|
+
device: Torch device.
|
|
1680
|
+
max_iter: Maximum iterations.
|
|
1681
|
+
tol: Convergence tolerance.
|
|
1682
|
+
update_start: Whether to update start probabilities.
|
|
1683
|
+
update_trans: Whether to update transitions.
|
|
1684
|
+
update_emission: Whether to update emission parameters.
|
|
1685
|
+
verbose: Whether to log progress.
|
|
1686
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
1687
|
+
|
|
1688
|
+
Returns:
|
|
1689
|
+
List of log-likelihood proxy values.
|
|
1690
|
+
"""
|
|
1691
|
+
X = np.asarray(X, dtype=float)
|
|
1692
|
+
if X.ndim != 3:
|
|
1693
|
+
raise ValueError("MultiBernoulliHMM expects X shape (N,L,C).")
|
|
1694
|
+
|
|
1695
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
1696
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
1697
|
+
|
|
1698
|
+
eps = float(self.eps)
|
|
1699
|
+
K = self.n_states
|
|
1700
|
+
N, L, C = obs.shape
|
|
1701
|
+
|
|
1702
|
+
self._ensure_n_channels(C, device)
|
|
1703
|
+
|
|
1704
|
+
hist: List[float] = []
|
|
1705
|
+
for it in range(1, int(max_iter) + 1):
|
|
1706
|
+
gamma = self._forward_backward(obs, mask) # (N,L,K)
|
|
1707
|
+
|
|
1708
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
1709
|
+
hist.append(ll_proxy)
|
|
1710
|
+
|
|
1711
|
+
# expected start
|
|
1712
|
+
start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
|
|
1713
|
+
|
|
1714
|
+
# transitions xi
|
|
1715
|
+
logB = self._log_emission(obs, mask)
|
|
1716
|
+
logA = torch.log(self.trans + eps)
|
|
1717
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1718
|
+
alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
|
|
1719
|
+
for t in range(1, L):
|
|
1720
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
1721
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
1722
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
1723
|
+
beta[:, L - 1, :] = 0.0
|
|
1724
|
+
for t in range(L - 2, -1, -1):
|
|
1725
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1726
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
1727
|
+
|
|
1728
|
+
trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
1729
|
+
# valid timestep if at least one channel observed at both positions
|
|
1730
|
+
valid_pos = mask.any(dim=2) # (N,L)
|
|
1731
|
+
for t in range(L - 1):
|
|
1732
|
+
valid_t = (valid_pos[:, t] & valid_pos[:, t + 1]).float().view(N, 1, 1)
|
|
1733
|
+
log_xi = (
|
|
1734
|
+
alpha[:, t, :].unsqueeze(2)
|
|
1735
|
+
+ logA.unsqueeze(0)
|
|
1736
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
1737
|
+
)
|
|
1738
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
1739
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
1740
|
+
trans_acc += xi.sum(dim=0)
|
|
1741
|
+
|
|
1742
|
+
# emission update per channel
|
|
1743
|
+
gamma_k = gamma.unsqueeze(-1) # (N,L,K,1)
|
|
1744
|
+
obs_c = obs.unsqueeze(2) # (N,L,1,C)
|
|
1745
|
+
mask_c = mask.unsqueeze(2).float() # (N,L,1,C)
|
|
1746
|
+
|
|
1747
|
+
emit_num = (gamma_k * obs_c * mask_c).sum(dim=(0, 1)) # (K,C)
|
|
1748
|
+
emit_den = (gamma_k * mask_c).sum(dim=(0, 1)) # (K,C)
|
|
1749
|
+
|
|
1750
|
+
with torch.no_grad():
|
|
1751
|
+
if update_start:
|
|
1752
|
+
new_start = start_acc + eps
|
|
1753
|
+
self.start.data = new_start / new_start.sum()
|
|
1754
|
+
|
|
1755
|
+
if update_trans:
|
|
1756
|
+
new_trans = trans_acc + eps
|
|
1757
|
+
rs = new_trans.sum(dim=1, keepdim=True)
|
|
1758
|
+
rs[rs == 0.0] = 1.0
|
|
1759
|
+
self.trans.data = new_trans / rs
|
|
1760
|
+
|
|
1761
|
+
if update_emission:
|
|
1762
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
1763
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
1764
|
+
|
|
1765
|
+
self._normalize_params()
|
|
1766
|
+
self._normalize_emission()
|
|
1767
|
+
|
|
1768
|
+
if verbose:
|
|
1769
|
+
logger.info(
|
|
1770
|
+
"[MultiBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
|
|
1771
|
+
it,
|
|
1772
|
+
hist[-1],
|
|
1773
|
+
)
|
|
1774
|
+
|
|
1775
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
1776
|
+
break
|
|
1777
|
+
|
|
1778
|
+
return hist
|
|
1779
|
+
|
|
1780
|
+
def adapt_emissions(
|
|
1781
|
+
self,
|
|
1782
|
+
X: np.ndarray,
|
|
1783
|
+
coords: Optional[np.ndarray] = None,
|
|
1784
|
+
*,
|
|
1785
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1786
|
+
iters: Optional[int] = None,
|
|
1787
|
+
max_iter: Optional[int] = None, # alias for your trainer
|
|
1314
1788
|
verbose: bool = False,
|
|
1789
|
+
**kwargs,
|
|
1315
1790
|
):
|
|
1791
|
+
"""Adapt emissions with legacy parameter names.
|
|
1792
|
+
|
|
1793
|
+
Args:
|
|
1794
|
+
X: Observations array.
|
|
1795
|
+
coords: Optional coordinate array.
|
|
1796
|
+
device: Device specifier.
|
|
1797
|
+
iters: Number of iterations.
|
|
1798
|
+
max_iter: Alias for iters.
|
|
1799
|
+
verbose: Whether to log progress.
|
|
1800
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
1801
|
+
|
|
1802
|
+
Returns:
|
|
1803
|
+
List of log-likelihood values.
|
|
1316
1804
|
"""
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
|
|
1327
|
-
merged_suffix : str
|
|
1328
|
-
Suffix appended to original layer for the merged binary layer (default "_merged").
|
|
1329
|
-
length_layer_suffix : str
|
|
1330
|
-
Suffix appended after merged suffix for the lengths layer (default "_lengths").
|
|
1331
|
-
update_obs : bool
|
|
1332
|
-
If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
|
|
1333
|
-
prob_strategy : str
|
|
1334
|
-
How to combine probs when merging ('mean', 'max', 'orig_first').
|
|
1335
|
-
inplace : bool
|
|
1336
|
-
If False, returns a new AnnData with changes (original untouched).
|
|
1337
|
-
overwrite : bool
|
|
1338
|
-
If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
|
|
1339
|
-
"""
|
|
1340
|
-
import numpy as _np
|
|
1341
|
-
from scipy.sparse import issparse
|
|
1805
|
+
if iters is None:
|
|
1806
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
1807
|
+
return super().adapt_emissions(
|
|
1808
|
+
np.asarray(X, dtype=float),
|
|
1809
|
+
coords if coords is not None else None,
|
|
1810
|
+
iters=int(iters),
|
|
1811
|
+
device=device,
|
|
1812
|
+
verbose=verbose,
|
|
1813
|
+
)
|
|
1342
1814
|
|
|
1343
|
-
|
|
1344
|
-
|
|
1815
|
+
def _ensure_n_channels(self, C: int, device: torch.device):
|
|
1816
|
+
"""Expand emission parameters when channel count changes.
|
|
1345
1817
|
|
|
1346
|
-
|
|
1347
|
-
|
|
1818
|
+
Args:
|
|
1819
|
+
C: Target channel count.
|
|
1820
|
+
device: Torch device for the new parameters.
|
|
1821
|
+
"""
|
|
1822
|
+
C = int(C)
|
|
1823
|
+
if C == self.n_channels:
|
|
1824
|
+
return
|
|
1825
|
+
with torch.no_grad():
|
|
1826
|
+
old = self.emission.detach().cpu().numpy() # (K, Cold)
|
|
1827
|
+
K = old.shape[0]
|
|
1828
|
+
new = np.full((K, C), 0.5, dtype=float)
|
|
1829
|
+
m = min(old.shape[1], C)
|
|
1830
|
+
new[:, :m] = old[:, :m]
|
|
1831
|
+
if C > old.shape[1]:
|
|
1832
|
+
fill = old.mean(axis=1, keepdims=True)
|
|
1833
|
+
new[:, m:] = fill
|
|
1834
|
+
self.n_channels = C
|
|
1835
|
+
self.emission = nn.Parameter(
|
|
1836
|
+
torch.tensor(new, dtype=self.dtype, device=device), requires_grad=False
|
|
1837
|
+
)
|
|
1838
|
+
self._normalize_emission()
|
|
1839
|
+
|
|
1840
|
+
|
|
1841
|
+
# =============================================================================
|
|
1842
|
+
# Distance-binned transitions (single-channel only)
|
|
1843
|
+
# =============================================================================
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
@register_hmm("single_distance_binned")
|
|
1847
|
+
class DistanceBinnedSingleBernoulliHMM(SingleBernoulliHMM):
|
|
1848
|
+
"""
|
|
1849
|
+
Transition matrix depends on binned distances between consecutive coords.
|
|
1348
1850
|
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1851
|
+
Config keys:
|
|
1852
|
+
hmm_distance_bins: list[int] edges (bp)
|
|
1853
|
+
hmm_init_transitions_by_bin: optional (n_bins,K,K)
|
|
1854
|
+
"""
|
|
1352
1855
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1856
|
+
def __init__(
|
|
1857
|
+
self,
|
|
1858
|
+
n_states: int = 2,
|
|
1859
|
+
init_emission: Optional[Sequence[float]] = None,
|
|
1860
|
+
distance_bins: Optional[Sequence[int]] = None,
|
|
1861
|
+
init_trans_by_bin: Optional[Any] = None,
|
|
1862
|
+
eps: float = 1e-8,
|
|
1863
|
+
dtype: torch.dtype = torch.float64,
|
|
1864
|
+
):
|
|
1865
|
+
"""Initialize a distance-binned transition HMM.
|
|
1866
|
+
|
|
1867
|
+
Args:
|
|
1868
|
+
n_states: Number of hidden states.
|
|
1869
|
+
init_emission: Initial emission probabilities per state.
|
|
1870
|
+
distance_bins: Distance bin edges in base pairs.
|
|
1871
|
+
init_trans_by_bin: Initial transition matrices per bin.
|
|
1872
|
+
eps: Smoothing epsilon for probabilities.
|
|
1873
|
+
dtype: Torch dtype for parameters.
|
|
1874
|
+
"""
|
|
1875
|
+
super().__init__(n_states=n_states, init_emission=init_emission, eps=eps, dtype=dtype)
|
|
1355
1876
|
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1877
|
+
self.distance_bins = np.asarray(
|
|
1878
|
+
distance_bins if distance_bins is not None else [1, 5, 10, 25, 50, 100], dtype=int
|
|
1879
|
+
)
|
|
1880
|
+
self.n_bins = int(len(self.distance_bins) + 1)
|
|
1881
|
+
|
|
1882
|
+
if init_trans_by_bin is None:
|
|
1883
|
+
base = self.trans.detach().cpu().numpy()
|
|
1884
|
+
tb = np.stack([base for _ in range(self.n_bins)], axis=0)
|
|
1359
1885
|
else:
|
|
1360
|
-
|
|
1886
|
+
tb = np.asarray(init_trans_by_bin, dtype=float)
|
|
1887
|
+
if tb.shape != (self.n_bins, self.n_states, self.n_states):
|
|
1888
|
+
base = self.trans.detach().cpu().numpy()
|
|
1889
|
+
tb = np.stack([base for _ in range(self.n_bins)], axis=0)
|
|
1361
1890
|
|
|
1362
|
-
|
|
1891
|
+
self.trans_by_bin = nn.Parameter(torch.tensor(tb, dtype=self.dtype), requires_grad=False)
|
|
1892
|
+
self._normalize_trans_by_bin()
|
|
1363
1893
|
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
coords_are_ints = True
|
|
1368
|
-
except Exception:
|
|
1369
|
-
coords = _np.arange(n_cols, dtype=int)
|
|
1370
|
-
coords_are_ints = False
|
|
1371
|
-
|
|
1372
|
-
# helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
|
|
1373
|
-
def _runs_from_mask(mask_1d):
|
|
1374
|
-
idx = _np.nonzero(mask_1d)[0]
|
|
1375
|
-
if idx.size == 0:
|
|
1376
|
-
return []
|
|
1377
|
-
runs = []
|
|
1378
|
-
start = idx[0]
|
|
1379
|
-
prev = idx[0]
|
|
1380
|
-
for i in idx[1:]:
|
|
1381
|
-
if i == prev + 1:
|
|
1382
|
-
prev = i
|
|
1383
|
-
continue
|
|
1384
|
-
runs.append((start, prev + 1))
|
|
1385
|
-
start = i
|
|
1386
|
-
prev = i
|
|
1387
|
-
runs.append((start, prev + 1))
|
|
1388
|
-
return runs
|
|
1389
|
-
|
|
1390
|
-
# read original obs intervals/probs if available (for combining probs)
|
|
1391
|
-
orig_obs = None
|
|
1392
|
-
if update_obs and (layer in adata.obs.columns):
|
|
1393
|
-
orig_obs = list(adata.obs[layer]) # might be non-list entries
|
|
1394
|
-
|
|
1395
|
-
# prepare outputs
|
|
1396
|
-
merged_bin = _np.zeros_like(bin_arr, dtype=int)
|
|
1397
|
-
merged_len = _np.zeros_like(bin_arr, dtype=int)
|
|
1398
|
-
merged_obs_col = [[] for _ in range(n_rows)]
|
|
1399
|
-
merged_counts = _np.zeros(n_rows, dtype=int)
|
|
1400
|
-
|
|
1401
|
-
for r in range(n_rows):
|
|
1402
|
-
mask = bin_arr[r, :] != 0
|
|
1403
|
-
runs = _runs_from_mask(mask)
|
|
1404
|
-
if not runs:
|
|
1405
|
-
merged_obs_col[r] = []
|
|
1406
|
-
continue
|
|
1894
|
+
@classmethod
|
|
1895
|
+
def from_config(cls, cfg, *, override=None, device=None):
|
|
1896
|
+
"""Create a distance-binned HMM from config.
|
|
1407
1897
|
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
|
|
1476
|
-
row_entries.append((start_coord, int(length_val), float(prob_val)))
|
|
1477
|
-
|
|
1478
|
-
merged_obs_col[r] = row_entries
|
|
1479
|
-
merged_counts[r] = len(row_entries)
|
|
1480
|
-
|
|
1481
|
-
# write merged layers (do not overwrite originals unless overwrite=True was set above)
|
|
1482
|
-
adata.layers[merged_bin_name] = merged_bin
|
|
1483
|
-
adata.layers[merged_len_name] = merged_len
|
|
1484
|
-
|
|
1485
|
-
if update_obs:
|
|
1486
|
-
adata.obs[merged_bin_name] = merged_obs_col
|
|
1487
|
-
adata.obs[f"n_{merged_bin_name}"] = merged_counts
|
|
1488
|
-
|
|
1489
|
-
# recompute distances list per-row (gaps between adjacent merged intervals)
|
|
1490
|
-
def _calc_distances(obs_list):
|
|
1491
|
-
out = []
|
|
1492
|
-
for intervals in obs_list:
|
|
1493
|
-
if not intervals:
|
|
1494
|
-
out.append([])
|
|
1495
|
-
continue
|
|
1496
|
-
iv = sorted(intervals, key=lambda x: int(x[0]))
|
|
1497
|
-
if len(iv) <= 1:
|
|
1498
|
-
out.append([])
|
|
1499
|
-
continue
|
|
1500
|
-
dlist = []
|
|
1501
|
-
for i in range(len(iv) - 1):
|
|
1502
|
-
endi = int(iv[i][0]) + int(iv[i][1]) - 1
|
|
1503
|
-
startn = int(iv[i + 1][0])
|
|
1504
|
-
dlist.append(startn - endi - 1)
|
|
1505
|
-
out.append(dlist)
|
|
1506
|
-
return out
|
|
1507
|
-
|
|
1508
|
-
adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
|
|
1509
|
-
|
|
1510
|
-
# update uns appended list
|
|
1511
|
-
uns_key = "hmm_appended_layers"
|
|
1512
|
-
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
|
|
1513
|
-
for nm in (merged_bin_name, merged_len_name):
|
|
1514
|
-
if nm not in existing:
|
|
1515
|
-
existing.append(nm)
|
|
1516
|
-
adata.uns[uns_key] = existing
|
|
1517
|
-
|
|
1518
|
-
if verbose:
|
|
1519
|
-
print(f"Created merged binary layer: {merged_bin_name}")
|
|
1520
|
-
print(f"Created merged length layer: {merged_len_name}")
|
|
1521
|
-
if update_obs:
|
|
1522
|
-
print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
|
|
1523
|
-
|
|
1524
|
-
return None if inplace else adata
|
|
1525
|
-
|
|
1526
|
-
def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
|
|
1898
|
+
Args:
|
|
1899
|
+
cfg: Configuration mapping or object.
|
|
1900
|
+
override: Override values to apply.
|
|
1901
|
+
device: Optional device specifier.
|
|
1902
|
+
|
|
1903
|
+
Returns:
|
|
1904
|
+
Initialized DistanceBinnedSingleBernoulliHMM instance.
|
|
1905
|
+
"""
|
|
1906
|
+
merged = cls._cfg_to_dict(cfg)
|
|
1907
|
+
if override:
|
|
1908
|
+
merged.update(override)
|
|
1909
|
+
|
|
1910
|
+
n_states = int(merged.get("hmm_n_states", 2))
|
|
1911
|
+
eps = float(merged.get("hmm_eps", 1e-8))
|
|
1912
|
+
dtype = _resolve_dtype(merged.get("hmm_dtype", None))
|
|
1913
|
+
dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
|
|
1914
|
+
init_em = merged.get("hmm_init_emission_probs", None)
|
|
1915
|
+
|
|
1916
|
+
bins = merged.get("hmm_distance_bins", [1, 5, 10, 25, 50, 100])
|
|
1917
|
+
init_tb = merged.get("hmm_init_transitions_by_bin", None)
|
|
1918
|
+
|
|
1919
|
+
model = cls(
|
|
1920
|
+
n_states=n_states,
|
|
1921
|
+
init_emission=init_em,
|
|
1922
|
+
distance_bins=bins,
|
|
1923
|
+
init_trans_by_bin=init_tb,
|
|
1924
|
+
eps=eps,
|
|
1925
|
+
dtype=dtype,
|
|
1926
|
+
)
|
|
1927
|
+
if device is not None:
|
|
1928
|
+
model.to(torch.device(device) if isinstance(device, str) else device)
|
|
1929
|
+
model._persisted_cfg = merged
|
|
1930
|
+
return model
|
|
1931
|
+
|
|
1932
|
+
def _ensure_device_dtype(self, device=None) -> torch.device:
|
|
1933
|
+
"""Move transition-by-bin parameters to the requested device/dtype."""
|
|
1934
|
+
device = super()._ensure_device_dtype(device)
|
|
1935
|
+
self.trans_by_bin.data = self.trans_by_bin.data.to(device=device, dtype=self.dtype)
|
|
1936
|
+
return device
|
|
1937
|
+
|
|
1938
|
+
def _normalize_trans_by_bin(self):
|
|
1939
|
+
"""Normalize transition matrices per distance bin in-place."""
|
|
1940
|
+
with torch.no_grad():
|
|
1941
|
+
tb = self.trans_by_bin.data.reshape(self.n_bins, self.n_states, self.n_states)
|
|
1942
|
+
tb = tb + self.eps
|
|
1943
|
+
rs = tb.sum(dim=2, keepdim=True)
|
|
1944
|
+
rs[rs == 0.0] = 1.0
|
|
1945
|
+
self.trans_by_bin.data = tb / rs
|
|
1946
|
+
|
|
1947
|
+
def _extra_save_payload(self) -> dict:
|
|
1948
|
+
"""Return extra payload data for serialization."""
|
|
1949
|
+
p = super()._extra_save_payload()
|
|
1950
|
+
p.update(
|
|
1951
|
+
{
|
|
1952
|
+
"distance_bins": torch.tensor(self.distance_bins, dtype=torch.long),
|
|
1953
|
+
"trans_by_bin": self.trans_by_bin.detach().cpu(),
|
|
1954
|
+
}
|
|
1955
|
+
)
|
|
1956
|
+
return p
|
|
1957
|
+
|
|
1958
|
+
def _load_extra_payload(self, payload: dict, *, device: torch.device):
|
|
1959
|
+
"""Load serialized distance-bin parameters.
|
|
1960
|
+
|
|
1961
|
+
Args:
|
|
1962
|
+
payload: Serialized payload dictionary.
|
|
1963
|
+
device: Target torch device.
|
|
1527
1964
|
"""
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1965
|
+
super()._load_extra_payload(payload, device=device)
|
|
1966
|
+
self.distance_bins = (
|
|
1967
|
+
payload.get("distance_bins", torch.tensor([1, 5, 10, 25, 50, 100]))
|
|
1968
|
+
.cpu()
|
|
1969
|
+
.numpy()
|
|
1970
|
+
.astype(int)
|
|
1971
|
+
)
|
|
1972
|
+
self.n_bins = int(len(self.distance_bins) + 1)
|
|
1973
|
+
with torch.no_grad():
|
|
1974
|
+
self.trans_by_bin.data = payload["trans_by_bin"].to(device=device, dtype=self.dtype)
|
|
1975
|
+
self._normalize_trans_by_bin()
|
|
1976
|
+
|
|
1977
|
+
def _bin_index(self, coords: np.ndarray) -> np.ndarray:
|
|
1978
|
+
"""Return per-step distance bin indices for coordinates.
|
|
1979
|
+
|
|
1980
|
+
Args:
|
|
1981
|
+
coords: Coordinate array.
|
|
1982
|
+
|
|
1983
|
+
Returns:
|
|
1984
|
+
Array of bin indices (length L-1).
|
|
1531
1985
|
"""
|
|
1532
|
-
|
|
1533
|
-
|
|
1986
|
+
d = np.diff(np.asarray(coords, dtype=int))
|
|
1987
|
+
return np.digitize(d, self.distance_bins, right=True) # length L-1
|
|
1534
1988
|
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1989
|
+
def _forward_backward(
|
|
1990
|
+
self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
|
|
1991
|
+
) -> torch.Tensor:
|
|
1992
|
+
"""Run forward-backward using distance-binned transitions.
|
|
1993
|
+
|
|
1994
|
+
Args:
|
|
1995
|
+
obs: Observation tensor.
|
|
1996
|
+
mask: Observation mask.
|
|
1997
|
+
coords: Coordinate array.
|
|
1998
|
+
|
|
1999
|
+
Returns:
|
|
2000
|
+
Posterior probabilities (gamma).
|
|
2001
|
+
"""
|
|
2002
|
+
if coords is None:
|
|
2003
|
+
raise ValueError("Distance-binned HMM requires coords.")
|
|
2004
|
+
device = obs.device
|
|
2005
|
+
eps = float(self.eps)
|
|
2006
|
+
K = self.n_states
|
|
2007
|
+
|
|
2008
|
+
coords = np.asarray(coords, dtype=int)
|
|
2009
|
+
bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
|
|
2010
|
+
|
|
2011
|
+
logB = self._log_emission(obs, mask) # (N,L,K)
|
|
2012
|
+
logstart = torch.log(self.start + eps)
|
|
2013
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps) # (nb,K,K)
|
|
2014
|
+
|
|
2015
|
+
N, L, _ = logB.shape
|
|
2016
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2017
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2018
|
+
|
|
2019
|
+
for t in range(1, L):
|
|
2020
|
+
b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
|
|
2021
|
+
logA = logA_by_bin[b]
|
|
2022
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2023
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
2024
|
+
|
|
2025
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2026
|
+
beta[:, L - 1, :] = 0.0
|
|
2027
|
+
for t in range(L - 2, -1, -1):
|
|
2028
|
+
b = int(bins[t].item()) if t < bins.numel() else 0
|
|
2029
|
+
logA = logA_by_bin[b]
|
|
2030
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2031
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
2032
|
+
|
|
2033
|
+
log_gamma = alpha + beta
|
|
2034
|
+
logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
|
|
2035
|
+
return (log_gamma - logZ).exp()
|
|
2036
|
+
|
|
2037
|
+
def _viterbi(
|
|
2038
|
+
self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
|
|
2039
|
+
) -> torch.Tensor:
|
|
2040
|
+
"""Run Viterbi decoding using distance-binned transitions.
|
|
2041
|
+
|
|
2042
|
+
Args:
|
|
2043
|
+
obs: Observation tensor.
|
|
2044
|
+
mask: Observation mask.
|
|
2045
|
+
coords: Coordinate array.
|
|
2046
|
+
|
|
2047
|
+
Returns:
|
|
2048
|
+
Decoded state sequence tensor.
|
|
2049
|
+
"""
|
|
2050
|
+
if coords is None:
|
|
2051
|
+
raise ValueError("Distance-binned HMM requires coords.")
|
|
2052
|
+
device = obs.device
|
|
2053
|
+
eps = float(self.eps)
|
|
2054
|
+
K = self.n_states
|
|
2055
|
+
|
|
2056
|
+
coords = np.asarray(coords, dtype=int)
|
|
2057
|
+
bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
|
|
2058
|
+
|
|
2059
|
+
logB = self._log_emission(obs, mask)
|
|
2060
|
+
logstart = torch.log(self.start + eps)
|
|
2061
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps)
|
|
2062
|
+
|
|
2063
|
+
N, L, _ = logB.shape
|
|
2064
|
+
delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2065
|
+
psi = torch.empty((N, L, K), dtype=torch.long, device=device)
|
|
2066
|
+
|
|
2067
|
+
delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2068
|
+
psi[:, 0, :] = -1
|
|
2069
|
+
|
|
2070
|
+
for t in range(1, L):
|
|
2071
|
+
b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
|
|
2072
|
+
logA = logA_by_bin[b]
|
|
2073
|
+
cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2074
|
+
best_val, best_idx = cand.max(dim=1)
|
|
2075
|
+
delta[:, t, :] = best_val + logB[:, t, :]
|
|
2076
|
+
psi[:, t, :] = best_idx
|
|
2077
|
+
|
|
2078
|
+
last_state = torch.argmax(delta[:, L - 1, :], dim=1)
|
|
2079
|
+
states = torch.empty((N, L), dtype=torch.long, device=device)
|
|
2080
|
+
states[:, L - 1] = last_state
|
|
2081
|
+
for t in range(L - 2, -1, -1):
|
|
2082
|
+
states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
|
|
2083
|
+
return states
|
|
2084
|
+
|
|
2085
|
+
def fit_em(
|
|
2086
|
+
self,
|
|
2087
|
+
X: np.ndarray,
|
|
2088
|
+
coords: np.ndarray,
|
|
2089
|
+
*,
|
|
2090
|
+
device: torch.device,
|
|
2091
|
+
max_iter: int,
|
|
2092
|
+
tol: float,
|
|
2093
|
+
update_start: bool,
|
|
2094
|
+
update_trans: bool,
|
|
2095
|
+
update_emission: bool,
|
|
2096
|
+
verbose: bool,
|
|
2097
|
+
**kwargs,
|
|
2098
|
+
) -> List[float]:
|
|
2099
|
+
"""Run EM updates for distance-binned transitions.
|
|
2100
|
+
|
|
2101
|
+
Args:
|
|
2102
|
+
X: Observations array (N, L).
|
|
2103
|
+
coords: Coordinate array aligned to X.
|
|
2104
|
+
device: Torch device.
|
|
2105
|
+
max_iter: Maximum iterations.
|
|
2106
|
+
tol: Convergence tolerance.
|
|
2107
|
+
update_start: Whether to update start probabilities.
|
|
2108
|
+
update_trans: Whether to update transitions.
|
|
2109
|
+
update_emission: Whether to update emission parameters.
|
|
2110
|
+
verbose: Whether to log progress.
|
|
2111
|
+
**kwargs: Additional implementation-specific kwargs.
|
|
2112
|
+
|
|
2113
|
+
Returns:
|
|
2114
|
+
List of log-likelihood proxy values.
|
|
2115
|
+
"""
|
|
2116
|
+
# Keep this simple: use gamma for emissions; transitions-by-bin updated via xi (same pattern).
|
|
2117
|
+
X = np.asarray(X, dtype=float)
|
|
2118
|
+
if X.ndim != 2:
|
|
2119
|
+
raise ValueError("DistanceBinnedSingleBernoulliHMM expects X shape (N,L).")
|
|
2120
|
+
|
|
2121
|
+
coords = np.asarray(coords, dtype=int)
|
|
2122
|
+
bins_np = self._bin_index(coords) # (L-1,)
|
|
2123
|
+
|
|
2124
|
+
obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
|
|
2125
|
+
mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
|
|
2126
|
+
|
|
2127
|
+
eps = float(self.eps)
|
|
2128
|
+
K = self.n_states
|
|
2129
|
+
N, L = obs.shape
|
|
2130
|
+
|
|
2131
|
+
hist: List[float] = []
|
|
2132
|
+
for it in range(1, int(max_iter) + 1):
|
|
2133
|
+
gamma = self._forward_backward(obs, mask, coords=coords) # (N,L,K)
|
|
2134
|
+
ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
|
|
2135
|
+
hist.append(ll_proxy)
|
|
2136
|
+
|
|
2137
|
+
# expected start
|
|
2138
|
+
start_acc = gamma[:, 0, :].sum(dim=0)
|
|
2139
|
+
|
|
2140
|
+
# compute alpha/beta for xi
|
|
2141
|
+
logB = self._log_emission(obs, mask)
|
|
2142
|
+
logstart = torch.log(self.start + eps)
|
|
2143
|
+
logA_by_bin = torch.log(self.trans_by_bin + eps)
|
|
2144
|
+
|
|
2145
|
+
alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2146
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
2147
|
+
|
|
2148
|
+
for t in range(1, L):
|
|
2149
|
+
b = int(bins_np[t - 1])
|
|
2150
|
+
logA = logA_by_bin[b]
|
|
2151
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
2152
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
2153
|
+
|
|
2154
|
+
beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
|
|
2155
|
+
beta[:, L - 1, :] = 0.0
|
|
2156
|
+
for t in range(L - 2, -1, -1):
|
|
2157
|
+
b = int(bins_np[t])
|
|
2158
|
+
logA = logA_by_bin[b]
|
|
2159
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2160
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
2161
|
+
|
|
2162
|
+
trans_acc_by_bin = torch.zeros((self.n_bins, K, K), dtype=self.dtype, device=device)
|
|
2163
|
+
for t in range(L - 1):
|
|
2164
|
+
b = int(bins_np[t])
|
|
2165
|
+
logA = logA_by_bin[b]
|
|
2166
|
+
valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
|
|
2167
|
+
log_xi = (
|
|
2168
|
+
alpha[:, t, :].unsqueeze(2)
|
|
2169
|
+
+ logA.unsqueeze(0)
|
|
2170
|
+
+ (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
|
|
2171
|
+
)
|
|
2172
|
+
log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
|
|
2173
|
+
xi = (log_xi - log_norm).exp() * valid_t
|
|
2174
|
+
trans_acc_by_bin[b] += xi.sum(dim=0)
|
|
2175
|
+
|
|
2176
|
+
mask_f = mask.float().unsqueeze(-1)
|
|
2177
|
+
emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1))
|
|
2178
|
+
emit_den = (gamma * mask_f).sum(dim=(0, 1))
|
|
2179
|
+
|
|
2180
|
+
with torch.no_grad():
|
|
2181
|
+
if update_start:
|
|
2182
|
+
new_start = start_acc + eps
|
|
2183
|
+
self.start.data = new_start / new_start.sum()
|
|
2184
|
+
|
|
2185
|
+
if update_trans:
|
|
2186
|
+
tb = trans_acc_by_bin + eps
|
|
2187
|
+
rs = tb.sum(dim=2, keepdim=True)
|
|
2188
|
+
rs[rs == 0.0] = 1.0
|
|
2189
|
+
self.trans_by_bin.data = tb / rs
|
|
2190
|
+
|
|
2191
|
+
if update_emission:
|
|
2192
|
+
new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
2193
|
+
self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
|
|
2194
|
+
|
|
2195
|
+
self._normalize_params()
|
|
2196
|
+
self._normalize_emission()
|
|
2197
|
+
self._normalize_trans_by_bin()
|
|
2198
|
+
|
|
2199
|
+
if verbose:
|
|
2200
|
+
logger.info(
|
|
2201
|
+
"[DistanceBinnedSingle.fit] iter=%s ll_proxy=%.6f",
|
|
2202
|
+
it,
|
|
2203
|
+
hist[-1],
|
|
2204
|
+
)
|
|
2205
|
+
|
|
2206
|
+
if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
|
|
2207
|
+
break
|
|
2208
|
+
|
|
2209
|
+
return hist
|
|
2210
|
+
|
|
2211
|
+
def adapt_emissions(
|
|
2212
|
+
self,
|
|
2213
|
+
X: np.ndarray,
|
|
2214
|
+
coords: Optional[np.ndarray] = None,
|
|
2215
|
+
*,
|
|
2216
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
2217
|
+
iters: Optional[int] = None,
|
|
2218
|
+
max_iter: Optional[int] = None,
|
|
2219
|
+
verbose: bool = False,
|
|
2220
|
+
**kwargs,
|
|
2221
|
+
):
|
|
2222
|
+
"""Adapt emissions with legacy parameter names.
|
|
2223
|
+
|
|
2224
|
+
Args:
|
|
2225
|
+
X: Observations array.
|
|
2226
|
+
coords: Optional coordinate array.
|
|
2227
|
+
device: Device specifier.
|
|
2228
|
+
iters: Number of iterations.
|
|
2229
|
+
max_iter: Alias for iters.
|
|
2230
|
+
verbose: Whether to log progress.
|
|
2231
|
+
**kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
|
|
2232
|
+
|
|
2233
|
+
Returns:
|
|
2234
|
+
List of log-likelihood values.
|
|
2235
|
+
"""
|
|
2236
|
+
if iters is None:
|
|
2237
|
+
iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
|
|
2238
|
+
return super().adapt_emissions(
|
|
2239
|
+
np.asarray(X, dtype=float),
|
|
2240
|
+
coords if coords is not None else None,
|
|
2241
|
+
iters=int(iters),
|
|
2242
|
+
device=device,
|
|
2243
|
+
verbose=verbose,
|
|
2244
|
+
)
|
|
2245
|
+
|
|
2246
|
+
|
|
2247
|
+
# =============================================================================
|
|
2248
|
+
# Facade class to match workflow import style
|
|
2249
|
+
# =============================================================================
|
|
2250
|
+
|
|
2251
|
+
|
|
2252
|
+
class HMM:
|
|
2253
|
+
"""
|
|
2254
|
+
Facade so workflow can do:
|
|
2255
|
+
from ..hmm.HMM import HMM
|
|
2256
|
+
hmm = HMM.from_config(cfg, arch="single")
|
|
2257
|
+
hmm.save(...)
|
|
2258
|
+
hmm = HMM.load(...)
|
|
2259
|
+
"""
|
|
2260
|
+
|
|
2261
|
+
@staticmethod
|
|
2262
|
+
def from_config(cfg, arch: Optional[str] = None, **kwargs) -> BaseHMM:
|
|
2263
|
+
"""Create an HMM instance from configuration.
|
|
2264
|
+
|
|
2265
|
+
Args:
|
|
2266
|
+
cfg: Configuration mapping or object.
|
|
2267
|
+
arch: Optional HMM architecture name.
|
|
2268
|
+
**kwargs: Additional parameters passed to the factory.
|
|
2269
|
+
|
|
2270
|
+
Returns:
|
|
2271
|
+
Initialized HMM instance.
|
|
2272
|
+
"""
|
|
2273
|
+
return create_hmm(cfg, arch=arch, **kwargs)
|
|
2274
|
+
|
|
2275
|
+
@staticmethod
|
|
2276
|
+
def load(path: Union[str, Path], device: Optional[Union[str, torch.device]] = None) -> BaseHMM:
|
|
2277
|
+
"""Load an HMM instance from disk.
|
|
2278
|
+
|
|
2279
|
+
Args:
|
|
2280
|
+
path: Path to the serialized model.
|
|
2281
|
+
device: Optional device specifier.
|
|
2282
|
+
|
|
2283
|
+
Returns:
|
|
2284
|
+
Loaded HMM instance.
|
|
2285
|
+
"""
|
|
2286
|
+
return BaseHMM.load(path, device=device)
|