smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py CHANGED
@@ -1,1587 +1,2292 @@
1
- import math
2
- from typing import List, Optional, Tuple, Union, Any, Dict, Sequence
1
+ from __future__ import annotations
2
+
3
3
  import ast
4
4
  import json
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
5
7
 
6
8
  import numpy as np
7
- import pandas as pd
8
- import torch
9
- import torch.nn as nn
9
+ from scipy.sparse import issparse
10
10
 
11
- def _logsumexp(vec: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
12
- return torch.logsumexp(vec, dim=dim, keepdim=keepdim)
11
+ from smftools.logging_utils import get_logger
12
+ from smftools.optional_imports import require
13
13
 
14
- class HMM(nn.Module):
15
- """
16
- Vectorized HMM (Bernoulli emissions) implemented in PyTorch.
17
-
18
- Methods:
19
- - fit(data, ...) -> trains params in-place
20
- - predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
21
- - viterbi(seq, ...) -> (path_list, score)
22
- - batch_viterbi(data, ...) -> list of (path_list, score)
23
- - score(seq_or_list, ...) -> float or list of floats
24
-
25
- Notes:
26
- - data: list of sequences (each sequence is iterable of {0,1,np.nan}).
27
- - impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
28
- """
14
+ if TYPE_CHECKING:
15
+ import torch as torch_types
16
+ import torch.nn as nn_types
29
17
 
30
- def __init__(
31
- self,
32
- n_states: int = 2,
33
- init_start: Optional[List[float]] = None,
34
- init_trans: Optional[List[List[float]]] = None,
35
- init_emission: Optional[List[float]] = None,
36
- dtype: torch.dtype = torch.float64,
37
- eps: float = 1e-8,
38
- smf_modality: Optional[str] = None,
39
- ):
40
- super().__init__()
41
- if n_states < 2:
42
- raise ValueError("n_states must be >= 2")
43
- self.n_states = n_states
44
- self.eps = float(eps)
45
- self.dtype = dtype
46
- self.smf_modality = smf_modality
18
+ torch = require("torch", extra="torch", purpose="HMM modeling")
19
+ nn = torch.nn
47
20
 
48
- # initialize params (probabilities)
49
- if init_start is None:
50
- start = np.full((n_states,), 1.0 / n_states, dtype=float)
51
- else:
52
- start = np.asarray(init_start, dtype=float)
53
- if init_trans is None:
54
- trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
55
- else:
56
- trans = np.asarray(init_trans, dtype=float)
57
- # --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
58
- if init_emission is None:
59
- emission = np.full((n_states,), 0.5, dtype=float)
60
- else:
61
- em_arr = np.asarray(init_emission, dtype=float)
62
- # case: (K,2) -> pick P(obs==1) from second column
63
- if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
64
- emission = em_arr[:, 1].astype(float)
65
- # case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
66
- elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
67
- emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
68
- else:
69
- emission = em_arr.reshape(-1)[:n_states].astype(float)
21
+ logger = get_logger(__name__)
22
+ # =============================================================================
23
+ # Registry / Factory
24
+ # =============================================================================
70
25
 
71
- # store as parameters (not trainable via grad; EM updates .data in-place)
72
- self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
73
- self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
74
- self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
26
+ _HMM_REGISTRY: Dict[str, type] = {}
75
27
 
76
- self._normalize_params()
77
28
 
78
- def _normalize_params(self):
79
- with torch.no_grad():
80
- # coerce shapes
81
- K = self.n_states
82
- self.start.data = self.start.data.squeeze()
83
- if self.start.data.numel() != K:
84
- self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
29
+ def register_hmm(name: str):
30
+ """Decorator to register an HMM backend under a string key."""
85
31
 
86
- self.trans.data = self.trans.data.squeeze()
87
- if not (self.trans.data.ndim == 2 and self.trans.data.shape == (K, K)):
88
- if K == 2:
89
- self.trans.data = torch.tensor([[0.9,0.1],[0.1,0.9]], dtype=self.dtype)
90
- else:
91
- self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
32
+ def deco(cls):
33
+ """Register the provided class in the HMM registry."""
34
+ _HMM_REGISTRY[name] = cls
35
+ cls.hmm_name = name
36
+ return cls
92
37
 
93
- self.emission.data = self.emission.data.squeeze()
94
- if self.emission.data.numel() != K:
95
- self.emission.data = torch.full((K,), 0.5, dtype=self.dtype)
38
+ return deco
96
39
 
97
- # now perform smoothing/normalization
98
- self.start.data = (self.start.data + self.eps)
99
- self.start.data = self.start.data / self.start.data.sum()
100
40
 
101
- self.trans.data = (self.trans.data + self.eps)
102
- row_sums = self.trans.data.sum(dim=1, keepdim=True)
103
- row_sums[row_sums == 0.0] = 1.0
104
- self.trans.data = self.trans.data / row_sums
41
+ def create_hmm(cfg: Union[dict, Any, None], arch: Optional[str] = None, **kwargs):
42
+ """
43
+ Factory: creates an HMM from cfg + arch (override).
44
+ """
45
+ key = (
46
+ arch
47
+ or getattr(cfg, "hmm_arch", None)
48
+ or (cfg.get("hmm_arch") if isinstance(cfg, dict) else None)
49
+ or "single"
50
+ )
51
+ if key not in _HMM_REGISTRY:
52
+ raise KeyError(f"Unknown hmm_arch={key!r}. Known: {sorted(_HMM_REGISTRY.keys())}")
53
+ return _HMM_REGISTRY[key].from_config(cfg, **kwargs)
54
+
55
+
56
+ # =============================================================================
57
+ # Small utilities
58
+ # =============================================================================
59
+ def _coerce_dtype_for_device(
60
+ dtype: torch.dtype, device: Optional[Union[str, torch.device]]
61
+ ) -> torch.dtype:
62
+ """MPS does not support float64. When targeting MPS, coerce to float32."""
63
+ dev = torch.device(device) if isinstance(device, str) else device
64
+ if dev is not None and getattr(dev, "type", None) == "mps" and dtype == torch.float64:
65
+ return torch.float32
66
+ return dtype
67
+
68
+
69
+ def _try_json_or_literal(x: Any) -> Any:
70
+ """Parse a string value as JSON or a Python literal when possible.
71
+
72
+ Args:
73
+ x: Value to parse.
74
+
75
+ Returns:
76
+ The parsed value if possible, otherwise the original value.
77
+ """
78
+ if x is None:
79
+ return None
80
+ if not isinstance(x, str):
81
+ return x
82
+ s = x.strip()
83
+ if not s:
84
+ return None
85
+ try:
86
+ return json.loads(s)
87
+ except Exception:
88
+ pass
89
+ try:
90
+ return ast.literal_eval(s)
91
+ except Exception:
92
+ return x
93
+
94
+
95
+ def _coerce_bool(x: Any) -> bool:
96
+ """Coerce a value into a boolean using common truthy strings.
97
+
98
+ Args:
99
+ x: Value to coerce.
100
+
101
+ Returns:
102
+ Boolean interpretation of the input.
103
+ """
104
+ if x is None:
105
+ return False
106
+ if isinstance(x, bool):
107
+ return x
108
+ if isinstance(x, (int, float)):
109
+ return bool(x)
110
+ s = str(x).strip().lower()
111
+ return s in ("1", "true", "t", "yes", "y", "on")
112
+
105
113
 
106
- self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
114
+ def _resolve_dtype(dtype_entry: Any) -> torch.dtype:
115
+ """Resolve a torch dtype from a config entry.
107
116
 
108
- @staticmethod
109
- def _resolve_dtype(dtype_entry):
110
- """Accept torch.dtype, string ('float32'/'float64') or None -> torch.dtype."""
111
- if dtype_entry is None:
112
- return torch.float64
113
- if isinstance(dtype_entry, torch.dtype):
114
- return dtype_entry
115
- s = str(dtype_entry).lower()
116
- if "32" in s:
117
- return torch.float32
118
- if "16" in s:
119
- return torch.float16
117
+ Args:
118
+ dtype_entry: Config value (string or torch.dtype).
119
+
120
+ Returns:
121
+ Resolved torch dtype.
122
+ """
123
+ if dtype_entry is None:
120
124
  return torch.float64
125
+ if isinstance(dtype_entry, torch.dtype):
126
+ return dtype_entry
127
+ s = str(dtype_entry).lower()
128
+ if "16" in s:
129
+ return torch.float16
130
+ if "32" in s:
131
+ return torch.float32
132
+ return torch.float64
121
133
 
122
- @classmethod
123
- def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
124
- override: Optional[dict] = None,
125
- device: Optional[Union[str, torch.device]] = None) -> "HMM":
126
- """
127
- Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
128
134
 
129
- cfg may be:
130
- - an ExperimentConfig (your dataclass instance)
131
- - a dict (e.g. loader.var_dict or merged defaults)
132
- - None (uses internal defaults)
135
+ def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
136
+ """
137
+ Try to cast var_names to int coordinates. If not possible,
138
+ fall back to 0..L-1 index coordinates.
139
+ """
140
+ try:
141
+ coords = np.asarray(var_names, dtype=int)
142
+ return coords, True
143
+ except Exception:
144
+ return np.arange(len(var_names), dtype=int), False
133
145
 
134
- override: optional dict to override resolved keys (handy for tests).
135
- device: optional device string or torch.device to move model to.
136
- """
137
- # Accept ExperimentConfig dataclass
138
- if cfg is None:
139
- merged = {}
140
- elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
141
- merged = dict(cfg.to_dict())
142
- elif isinstance(cfg, dict):
143
- merged = dict(cfg)
144
- else:
145
- # try attr access as fallback
146
- try:
147
- merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
148
- except Exception:
149
- merged = {}
150
146
 
151
- if override:
152
- merged.update(override)
147
+ def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
148
+ """Compute log-sum-exp in a numerically stable way.
153
149
 
154
- # basic resolution with fallback
155
- n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
156
- init_start = merged.get("hmm_init_start_probs", merged.get("hmm_init_start", None))
157
- init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
158
- init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
159
- eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
160
- dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
150
+ Args:
151
+ x: Input tensor.
152
+ dim: Dimension to reduce.
161
153
 
162
- # coerce lists (if present) -> numpy arrays (the HMM constructor already sanitizes)
163
- def _coerce_np(x):
164
- if x is None:
165
- return None
166
- return np.asarray(x, dtype=float)
154
+ Returns:
155
+ Reduced tensor.
156
+ """
157
+ return torch.logsumexp(x, dim=dim)
167
158
 
168
- init_start = _coerce_np(init_start)
169
- init_trans = _coerce_np(init_trans)
170
- init_emission = _coerce_np(init_emission)
171
159
 
172
- model = cls(
173
- n_states=n_states,
174
- init_start=init_start,
175
- init_trans=init_trans,
176
- init_emission=init_emission,
177
- dtype=dtype,
178
- eps=eps,
179
- smf_modality=merged.get("smf_modality", None),
180
- )
160
+ def _ensure_layer_full_shape(adata, name: str, dtype, fill_value=0):
161
+ """
162
+ Ensure adata.layers[name] exists with shape (n_obs, n_vars).
163
+ """
164
+ if name not in adata.layers:
165
+ arr = np.full((adata.n_obs, adata.n_vars), fill_value=fill_value, dtype=dtype)
166
+ adata.layers[name] = arr
167
+ else:
168
+ arr = _to_dense_np(adata.layers[name])
169
+ if arr.shape != (adata.n_obs, adata.n_vars):
170
+ raise ValueError(
171
+ f"Layer '{name}' exists but has shape {arr.shape}; expected {(adata.n_obs, adata.n_vars)}"
172
+ )
173
+ return adata.layers[name]
174
+
175
+
176
+ def _assign_back_obs(final_adata, sub_adata, cols: List[str]):
177
+ """
178
+ Assign obs columns from sub_adata back into final_adata for the matching obs_names.
179
+ Works for list/object columns too.
180
+ """
181
+ idx = final_adata.obs_names.get_indexer(sub_adata.obs_names)
182
+ if (idx < 0).any():
183
+ raise ValueError("Some sub_adata.obs_names not found in final_adata.obs_names")
181
184
 
182
- # move to device if requested
183
- if device is not None:
184
- if isinstance(device, str):
185
- device = torch.device(device)
186
- model.to(device)
185
+ for c in cols:
186
+ final_adata.obs.iloc[idx, final_adata.obs.columns.get_loc(c)] = sub_adata.obs[c].values
187
187
 
188
- # persist the config to the hmm class
189
- cls.config = cfg
190
188
 
191
- return model
189
+ def _to_dense_np(x):
190
+ """Convert sparse or array-like input to a dense NumPy array.
192
191
 
193
- def update_from_config(self, cfg: Union[dict, "ExperimentConfig", None], *,
194
- override: Optional[dict] = None):
195
- """
196
- Update existing model parameters from a config or dict (in-place).
197
- This will normalize / reinitialize start/trans/emission using same logic as constructor.
198
- """
199
- if cfg is None:
200
- merged = {}
201
- elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
202
- merged = dict(cfg.to_dict())
203
- elif isinstance(cfg, dict):
204
- merged = dict(cfg)
205
- else:
206
- try:
207
- merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
208
- except Exception:
209
- merged = {}
192
+ Args:
193
+ x: Input array or sparse matrix.
210
194
 
211
- if override:
212
- merged.update(override)
195
+ Returns:
196
+ Dense NumPy array or None.
197
+ """
198
+ if x is None:
199
+ return None
200
+ if issparse(x):
201
+ return x.toarray()
202
+ return np.asarray(x)
213
203
 
214
- # extract same keys as from_config
215
- n_states = int(merged.get("hmm_n_states", self.n_states))
216
- init_start = merged.get("hmm_init_start_probs", None)
217
- init_trans = merged.get("hmm_init_transition_probs", None)
218
- init_emission = merged.get("hmm_init_emission_probs", None)
219
- eps = merged.get("hmm_eps", None)
220
- dtype = merged.get("hmm_dtype", None)
221
-
222
- # apply dtype/eps if present
223
- if eps is not None:
224
- self.eps = float(eps)
225
- if dtype is not None:
226
- self.dtype = self._resolve_dtype(dtype)
227
-
228
- # if n_states changes we need a fresh re-init (easy approach: reconstruct)
229
- if int(n_states) != int(self.n_states):
230
- # rebuild self in-place: create a new model and copy tensors
231
- new_model = HMM.from_config(merged)
232
- # copy content
233
- with torch.no_grad():
234
- self.n_states = new_model.n_states
235
- self.eps = new_model.eps
236
- self.dtype = new_model.dtype
237
- self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
238
- self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
239
- self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
240
- return
241
204
 
242
- # else only update provided tensors
243
- def _to_tensor(obj, shape_expected=None):
244
- if obj is None:
245
- return None
246
- arr = np.asarray(obj, dtype=float)
247
- if shape_expected is not None:
248
- try:
249
- arr = arr.reshape(shape_expected)
250
- except Exception:
251
- # try to free-form slice/reshape (keep best-effort)
252
- arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
253
- return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
205
+ def _ensure_2d_np(x):
206
+ """Ensure an array is 2D, reshaping 1D inputs.
254
207
 
255
- with torch.no_grad():
256
- if init_start is not None:
257
- t = _to_tensor(init_start, (self.n_states,))
258
- if t.numel() == self.n_states:
259
- self.start.data = t.clone()
260
- if init_trans is not None:
261
- t = _to_tensor(init_trans, (self.n_states, self.n_states))
262
- if t.shape == (self.n_states, self.n_states):
263
- self.trans.data = t.clone()
264
- if init_emission is not None:
265
- # attempt to extract P(obs==1) if shaped (K,2)
266
- arr = np.asarray(init_emission, dtype=float)
267
- if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
268
- em = arr[: self.n_states, 1]
269
- else:
270
- em = arr.reshape(-1)[: self.n_states]
271
- t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
272
- if t.numel() == self.n_states:
273
- self.emission.data = t.clone()
208
+ Args:
209
+ x: Input array-like.
274
210
 
275
- # finally normalize
276
- self._normalize_params()
211
+ Returns:
212
+ 2D NumPy array.
213
+ """
214
+ x = _to_dense_np(x)
215
+ if x.ndim == 1:
216
+ x = x.reshape(1, -1)
217
+ if x.ndim != 2:
218
+ raise ValueError(f"Expected 2D array; got shape {x.shape}")
219
+ return x
277
220
 
278
- def _ensure_device_dtype(self, device: Optional[torch.device]):
279
- if device is None:
280
- device = next(self.parameters()).device
281
- self.start.data = self.start.data.to(device=device, dtype=self.dtype)
282
- self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
283
- self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
284
- return device
285
221
 
286
- @staticmethod
287
- def _pad_and_mask(
288
- data: List[List],
289
- device: torch.device,
290
- dtype: torch.dtype,
291
- impute_strategy: str = "ignore",
292
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
222
+ # =============================================================================
223
+ # Feature-set normalization
224
+ # =============================================================================
225
+
226
+
227
+ def normalize_hmm_feature_sets(raw: Any) -> Dict[str, Dict[str, Any]]:
228
+ """
229
+ Canonical format:
230
+ {
231
+ "footprints": {"state": "Non-Modified", "features": {"small_bound_stretch": [0,50], ...}},
232
+ "accessible": {"state": "Modified", "features": {"all_accessible_features": [0, inf], ...}},
233
+ ...
234
+ }
235
+ Each feature range is [lo, hi) in genomic bp (or index units if coords aren't ints).
236
+ """
237
+ parsed = _try_json_or_literal(raw)
238
+ if not isinstance(parsed, dict):
239
+ return {}
240
+
241
+ def _coerce_bound(v):
242
+ """Coerce a bound value into a float or sentinel.
243
+
244
+ Args:
245
+ v: Bound value.
246
+
247
+ Returns:
248
+ Float, np.inf, or None.
293
249
  """
294
- Pads sequences to shape (B, L). Returns (obs, mask, lengths)
295
- - Accepts: list-of-seqs, or 2D ndarray (B, L).
296
- - If a sequence element is itself an array (per-timestep feature vector),
297
- collapse the last axis by mean (warns once).
250
+ if v is None:
251
+ return None
252
+ if isinstance(v, (int, float)):
253
+ return float(v)
254
+ s = str(v).strip().lower()
255
+ if s in ("inf", "infty", "np.inf", "infinite"):
256
+ return np.inf
257
+ if s in ("none", ""):
258
+ return None
259
+ try:
260
+ return float(v)
261
+ except Exception:
262
+ return None
263
+
264
+ def _coerce_map(feats):
265
+ """Coerce feature ranges into (lo, hi) tuples.
266
+
267
+ Args:
268
+ feats: Mapping of feature names to ranges.
269
+
270
+ Returns:
271
+ Mapping of feature names to numeric bounds.
298
272
  """
299
- import warnings
300
-
301
- # If somebody passed a 2-D ndarray directly, convert to list-of-rows
302
- if isinstance(data, np.ndarray) and data.ndim == 2:
303
- # convert rows -> python lists (scalars per timestep)
304
- data = data.tolist()
305
-
306
- B = len(data)
307
- lengths = torch.tensor([len(s) for s in data], dtype=torch.long, device=device)
308
- L = int(lengths.max().item()) if B > 0 else 0
309
- obs = torch.zeros((B, L), dtype=dtype, device=device)
310
- mask = torch.zeros((B, L), dtype=torch.bool, device=device)
311
-
312
- warned_collapse = False
313
-
314
- for i, seq in enumerate(data):
315
- # seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
316
- arr = np.asarray(seq, dtype=float)
317
-
318
- # If arr is shape (L,1,1,...) squeeze trailing singletons
319
- while arr.ndim > 1 and arr.shape[-1] == 1:
320
- arr = np.squeeze(arr, axis=-1)
321
-
322
- # If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
323
- if arr.ndim > 1:
324
- if not warned_collapse:
325
- warnings.warn(
326
- "HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
327
- "(arr had shape {}). If you prefer a different reduction, "
328
- "preprocess your data.".format(arr.shape),
329
- stacklevel=2,
330
- )
331
- warned_collapse = True
332
- # collapse features -> scalar per timestep
333
- arr = np.asarray(arr, dtype=float).mean(axis=-1)
334
-
335
- # now arr should be 1D (T,)
336
- if arr.ndim == 0:
337
- # single scalar: treat as length-1 sequence
338
- arr = np.atleast_1d(arr)
339
-
340
- nan_mask = np.isnan(arr)
341
- if impute_strategy == "random" and nan_mask.any():
342
- arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
343
- local_mask = np.ones_like(arr, dtype=bool)
273
+ out = {}
274
+ if not isinstance(feats, dict):
275
+ return out
276
+ for name, rng in feats.items():
277
+ if rng is None:
278
+ out[name] = (0.0, np.inf)
279
+ continue
280
+ if isinstance(rng, (list, tuple)) and len(rng) >= 2:
281
+ lo = _coerce_bound(rng[0])
282
+ hi = _coerce_bound(rng[1])
283
+ lo = 0.0 if lo is None else float(lo)
284
+ hi = np.inf if hi is None else float(hi)
285
+ out[name] = (lo, hi)
344
286
  else:
345
- local_mask = ~nan_mask
346
- arr = np.where(local_mask, arr, 0.0)
287
+ hi = _coerce_bound(rng)
288
+ hi = np.inf if hi is None else float(hi)
289
+ out[name] = (0.0, hi)
290
+ return out
291
+
292
+ out: Dict[str, Dict[str, Any]] = {}
293
+ for group, info in parsed.items():
294
+ if isinstance(info, dict):
295
+ feats = _coerce_map(info.get("features", info.get("ranges", {})))
296
+ state = info.get("state", info.get("label", "Modified"))
297
+ else:
298
+ feats = _coerce_map(info)
299
+ state = "Modified"
300
+ out[group] = {"features": feats, "state": state}
301
+ return out
347
302
 
348
- L_i = arr.shape[0]
349
- obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
350
- mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
351
303
 
352
- return obs, mask, lengths
304
+ # =============================================================================
305
+ # BaseHMM: shared decoding + annotation pipeline
306
+ # =============================================================================
353
307
 
354
- def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
355
- """
356
- obs: (B, L)
357
- mask: (B, L) bool
358
- returns logB (B, L, K)
359
- """
360
- B, L = obs.shape
361
- p = self.emission # (K,)
362
- logp = torch.log(p + self.eps)
363
- log1mp = torch.log1p(-p + self.eps)
364
- obs_expand = obs.unsqueeze(-1) # (B, L, 1)
365
- logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
366
- logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
367
- return logB
368
308
 
369
- def fit(
370
- self,
371
- data: List[List],
372
- max_iter: int = 100,
373
- tol: float = 1e-4,
374
- impute_strategy: str = "ignore",
375
- verbose: bool = True,
376
- return_history: bool = False,
377
- device: Optional[Union[torch.device, str]] = None,
378
- ):
379
- """
380
- Vectorized Baum-Welch EM across a batch of sequences (padded).
381
- """
382
- if device is None:
383
- device = next(self.parameters()).device
384
- elif isinstance(device, str):
385
- device = torch.device(device)
386
- device = self._ensure_device_dtype(device)
309
+ class BaseHMM(nn.Module):
310
+ """
311
+ BaseHMM responsibilities:
312
+ - config resolution (from_config)
313
+ - EM fit wrapper (fit / fit_em)
314
+ - decoding (gamma / viterbi)
315
+ - AnnData annotation from provided arrays (X + coords)
316
+ - save/load registry aware
317
+ Subclasses implement:
318
+ - _log_emission(...) -> logB
319
+ - optional distance-aware transition handling
320
+ """
387
321
 
388
- if isinstance(data, np.ndarray):
389
- if data.ndim == 2:
390
- # rows are sequences: convert to list of 1D arrays
391
- data = data.tolist()
392
- elif data.ndim == 1:
393
- # single sequence
394
- data = [data.tolist()]
395
- else:
396
- raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
322
+ def __init__(self, n_states: int = 2, eps: float = 1e-8, dtype: torch.dtype = torch.float64):
323
+ """Initialize the base HMM with shared parameters.
397
324
 
398
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
399
- B, L = obs.shape
400
- K = self.n_states
401
- eps = float(self.eps)
325
+ Args:
326
+ n_states: Number of hidden states.
327
+ eps: Smoothing epsilon for probabilities.
328
+ dtype: Torch dtype for parameters.
329
+ """
330
+ super().__init__()
331
+ if n_states < 2:
332
+ raise ValueError("n_states must be >= 2")
333
+ self.n_states = int(n_states)
334
+ self.eps = float(eps)
335
+ self.dtype = dtype
402
336
 
403
- if verbose:
404
- print(f"[HMM.fit] device={device}, batch={B}, max_len={L}, states={K}")
337
+ # start probs + transitions (shared across backends)
338
+ start = np.full((self.n_states,), 1.0 / self.n_states, dtype=float)
339
+ trans = np.full((self.n_states, self.n_states), 1.0 / self.n_states, dtype=float)
405
340
 
406
- loglik_history = []
341
+ self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
342
+ self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
343
+ self._normalize_params()
407
344
 
408
- for it in range(1, max_iter + 1):
409
- if verbose:
410
- print(f"[HMM.fit] EM iter {it}")
345
+ # ------------------------- config -------------------------
411
346
 
412
- # compute batched emission logs
413
- logB = self._log_emission(obs, mask) # (B, L, K)
347
+ @classmethod
348
+ def from_config(
349
+ cls, cfg: Union[dict, Any, None], *, override: Optional[dict] = None, device=None
350
+ ):
351
+ """Create a model from config with optional overrides.
414
352
 
415
- # logs for start and transition
416
- logA = torch.log(self.trans + eps) # (K, K)
417
- logstart = torch.log(self.start + eps) # (K,)
353
+ Args:
354
+ cfg: Configuration mapping or object.
355
+ override: Override values to apply.
356
+ device: Device specifier.
418
357
 
419
- # Forward (batched)
420
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
421
- alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :] # (B,K)
422
- for t in range(1, L):
423
- # prev: (B, i, 1) + (1, i, j) broadcast => (B, i, j)
424
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
425
- alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
358
+ Returns:
359
+ Initialized HMM instance.
360
+ """
361
+ merged = cls._cfg_to_dict(cfg)
362
+ if override:
363
+ merged.update(override)
426
364
 
427
- # Backward (batched)
428
- beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
429
- beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
430
- for t in range(L - 2, -1, -1):
431
- # temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
432
- temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
433
- beta[:, t, :] = _logsumexp(temp, dim=2)
365
+ n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
366
+ eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
367
+ dtype = _resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
368
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
434
369
 
435
- # sequence log-likelihoods (use last real index)
436
- last_idx = (lengths - 1).clamp(min=0)
437
- idx_range = torch.arange(B, device=device)
438
- final_alpha = alpha[idx_range, last_idx, :] # (B, K)
439
- seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
440
- total_loglike = float(seq_loglikes.sum().item())
370
+ model = cls(n_states=n_states, eps=eps, dtype=dtype)
371
+ if device is not None:
372
+ model.to(torch.device(device) if isinstance(device, str) else device)
373
+ model._persisted_cfg = merged
374
+ return model
441
375
 
442
- # posterior gamma (B, L, K)
443
- log_gamma = alpha + beta # (B, L, K)
444
- logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True) # (B, L, 1)
445
- gamma = (log_gamma - logZ_time).exp() # (B, L, K)
376
+ @staticmethod
377
+ def _cfg_to_dict(cfg: Union[dict, Any, None]) -> dict:
378
+ """Normalize a config object into a dictionary.
446
379
 
447
- # accumulators: starts, transitions, emissions
448
- gamma_start_accum = gamma[:, 0, :].sum(dim=0) # (K,)
380
+ Args:
381
+ cfg: Config mapping or object.
449
382
 
450
- # emission accumulators: sum over observed positions only
451
- mask_f = mask.unsqueeze(-1) # (B, L, 1)
452
- emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
453
- emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
383
+ Returns:
384
+ Dictionary of HMM-related config values.
385
+ """
386
+ if cfg is None:
387
+ return {}
388
+ if isinstance(cfg, dict):
389
+ return dict(cfg)
390
+ if hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
391
+ return dict(cfg.to_dict())
392
+ out = {}
393
+ for k in dir(cfg):
394
+ if k.startswith("hmm_") or k in ("smf_modality", "cpg"):
395
+ try:
396
+ out[k] = getattr(cfg, k)
397
+ except Exception:
398
+ pass
399
+ return out
454
400
 
455
- # transitions: accumulate xi across t for valid positions
456
- trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
457
- if L >= 2:
458
- time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
459
- valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
460
- for t in range(L - 1):
461
- a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
462
- b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
463
- log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
464
- log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
465
- log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
466
- xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
467
- valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
468
- xi_masked = xi * valid_t
469
- trans_accum += xi_masked.sum(dim=0) # (i,j)
470
-
471
- # M-step: update parameters with smoothing
472
- with torch.no_grad():
473
- new_start = gamma_start_accum + eps
474
- new_start = new_start / new_start.sum()
401
+ # ------------------------- params -------------------------
475
402
 
476
- new_trans = trans_accum + eps
477
- row_sums = new_trans.sum(dim=1, keepdim=True)
478
- row_sums[row_sums == 0.0] = 1.0
479
- new_trans = new_trans / row_sums
403
+ def _normalize_params(self):
404
+ """Normalize start and transition probabilities in-place."""
405
+ with torch.no_grad():
406
+ K = self.n_states
480
407
 
481
- new_emission = (emit_num + eps) / (emit_den + 2.0 * eps)
482
- new_emission = new_emission.clamp(min=eps, max=1.0 - eps)
408
+ # start
409
+ self.start.data = self.start.data.reshape(-1)
410
+ if self.start.data.numel() != K:
411
+ self.start.data = torch.full(
412
+ (K,), 1.0 / K, dtype=self.dtype, device=self.start.device
413
+ )
414
+ self.start.data = self.start.data + self.eps
415
+ self.start.data = self.start.data / self.start.data.sum()
483
416
 
484
- self.start.data = new_start
485
- self.trans.data = new_trans
486
- self.emission.data = new_emission
417
+ # trans
418
+ self.trans.data = self.trans.data.reshape(K, K)
419
+ self.trans.data = self.trans.data + self.eps
420
+ rs = self.trans.data.sum(dim=1, keepdim=True)
421
+ rs[rs == 0.0] = 1.0
422
+ self.trans.data = self.trans.data / rs
487
423
 
488
- loglik_history.append(total_loglike)
489
- if verbose:
490
- print(f" total loglik = {total_loglike:.6f}")
424
+ def _ensure_device_dtype(
425
+ self, device: Optional[Union[str, torch.device]] = None
426
+ ) -> torch.device:
427
+ """Move parameters to the requested device/dtype.
491
428
 
492
- if len(loglik_history) > 1 and abs(loglik_history[-1] - loglik_history[-2]) < tol:
493
- if verbose:
494
- print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
495
- break
429
+ Args:
430
+ device: Device specifier or None to use current device.
496
431
 
497
- return loglik_history if return_history else None
498
-
499
- def get_params(self) -> dict:
432
+ Returns:
433
+ Resolved torch device.
500
434
  """
501
- Return model parameters as numpy arrays on CPU.
502
- """
503
- with torch.no_grad():
504
- return {
505
- "n_states": int(self.n_states),
506
- "start": self.start.detach().cpu().numpy().astype(float).reshape(-1),
507
- "trans": self.trans.detach().cpu().numpy().astype(float),
508
- "emission": self.emission.detach().cpu().numpy().astype(float).reshape(-1),
509
- }
435
+ if device is None:
436
+ device = next(self.parameters()).device
437
+ device = torch.device(device) if isinstance(device, str) else device
438
+ self.start.data = self.start.data.to(device=device, dtype=self.dtype)
439
+ self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
440
+ return device
510
441
 
511
- def print_params(self, decimals: int = 4):
442
+ # ------------------------- state labeling -------------------------
443
+
444
+ def _state_modified_score(self) -> torch.Tensor:
445
+ """Subclasses return (K,) score; higher => more “Modified/Accessible”."""
446
+ raise NotImplementedError
447
+
448
+ def modified_state_index(self) -> int:
449
+ """Return the index of the most modified/accessible state."""
450
+ scores = self._state_modified_score()
451
+ return int(torch.argmax(scores).item())
452
+
453
+ def resolve_target_state_index(self, state_target: Any) -> int:
512
454
  """
513
- Nicely print start, transition, and emission probabilities.
455
+ Accept:
456
+ - int -> explicit state index
457
+ - "Modified" / "Non-Modified" and aliases
514
458
  """
515
- params = self.get_params()
516
- K = params["n_states"]
517
- fmt = f"{{:.{decimals}f}}"
518
- print(f"HMM params (K={K} states):")
519
- print(" start probs:")
520
- print(" [" + ", ".join(fmt.format(v) for v in params["start"]) + "]")
521
- print(" transition matrix (rows = from-state, cols = to-state):")
522
- for i, row in enumerate(params["trans"]):
523
- print(" s{:d}: [".format(i) + ", ".join(fmt.format(v) for v in row) + "]")
524
- print(" emission P(obs==1 | state):")
525
- for i, v in enumerate(params["emission"]):
526
- print(f" s{i}: {fmt.format(v)}")
527
-
528
- def to_dataframes(self) -> dict:
459
+ if isinstance(state_target, (int, np.integer)):
460
+ idx = int(state_target)
461
+ return max(0, min(idx, self.n_states - 1))
462
+
463
+ s = str(state_target).strip().lower()
464
+ if s in ("modified", "open", "accessible", "1", "pos", "positive"):
465
+ return self.modified_state_index()
466
+ if s in ("non-modified", "closed", "inaccessible", "0", "neg", "negative"):
467
+ scores = self._state_modified_score()
468
+ return int(torch.argmin(scores).item())
469
+ return self.modified_state_index()
470
+
471
+ # ------------------------- emissions -------------------------
472
+
473
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
529
474
  """
530
- Return pandas DataFrames for start (Series), trans (DataFrame), emission (Series).
475
+ Return logB:
476
+ - single: obs (N,L), mask (N,L) -> logB (N,L,K)
477
+ - multi : obs (N,L,C), mask (N,L,C) -> logB (N,L,K)
531
478
  """
532
- p = self.get_params()
533
- K = p["n_states"]
534
- state_names = [f"state_{i}" for i in range(K)]
535
- start_s = pd.Series(p["start"], index=state_names, name="start_prob")
536
- trans_df = pd.DataFrame(p["trans"], index=state_names, columns=state_names)
537
- emission_s = pd.Series(p["emission"], index=state_names, name="p_obs1")
538
- return {"start": start_s, "trans": trans_df, "emission": emission_s}
539
-
540
- def predict(self, data: List[List], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> List[np.ndarray]:
479
+ raise NotImplementedError
480
+
481
+ # ------------------------- decoding core -------------------------
482
+
483
+ def _forward_backward(
484
+ self,
485
+ obs: torch.Tensor,
486
+ mask: torch.Tensor,
487
+ *,
488
+ coords: Optional[np.ndarray] = None,
489
+ ) -> torch.Tensor:
541
490
  """
542
- Return posterior marginals gamma_t(k) for each sequence as list of (L, K) numpy arrays.
491
+ Returns gamma (N,L,K) in probability space.
492
+ Subclasses can override for distance-aware transitions.
543
493
  """
544
- if device is None:
545
- device = next(self.parameters()).device
546
- elif isinstance(device, str):
547
- device = torch.device(device)
548
- device = self._ensure_device_dtype(device)
549
-
550
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
551
- B, L = obs.shape
552
- K = self.n_states
494
+ device = obs.device
553
495
  eps = float(self.eps)
496
+ K = self.n_states
554
497
 
555
- logB = self._log_emission(obs, mask) # (B, L, K)
556
- logA = torch.log(self.trans + eps)
557
- logstart = torch.log(self.start + eps)
498
+ logB = self._log_emission(obs, mask) # (N,L,K)
499
+ logA = torch.log(self.trans + eps) # (K,K)
500
+ logstart = torch.log(self.start + eps) # (K,)
501
+
502
+ N, L, _ = logB.shape
558
503
 
559
- # Forward
560
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
504
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
561
505
  alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
506
+
562
507
  for t in range(1, L):
563
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
508
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
564
509
  alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
565
510
 
566
- # Backward
567
- beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
568
- beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
511
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
512
+ beta[:, L - 1, :] = 0.0
569
513
  for t in range(L - 2, -1, -1):
570
- temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
514
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
571
515
  beta[:, t, :] = _logsumexp(temp, dim=2)
572
516
 
573
- # gamma
574
517
  log_gamma = alpha + beta
575
- logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True)
576
- gamma = (log_gamma - logZ_time).exp() # (B, L, K)
518
+ logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
519
+ gamma = (log_gamma - logZ).exp()
520
+ return gamma
577
521
 
578
- results = []
579
- for i in range(B):
580
- L_i = int(lengths[i].item())
581
- results.append(gamma[i, :L_i, :].detach().cpu().numpy())
582
- return results
583
-
584
- def score(self, seq_or_list: Union[List[float], List[List[float]]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Union[float, List[float]]:
522
+ def _viterbi(
523
+ self,
524
+ obs: torch.Tensor,
525
+ mask: torch.Tensor,
526
+ *,
527
+ coords: Optional[np.ndarray] = None,
528
+ ) -> torch.Tensor:
585
529
  """
586
- Compute log-likelihood of a single sequence or list of sequences under current params.
587
- Returns float (single) or list of floats (batch).
530
+ Returns states (N,L) int64. Missing positions (mask False for all channels)
531
+ are still decoded, but you’ll overwrite them to -1 during writing.
532
+ Subclasses can override for distance-aware transitions.
588
533
  """
589
- single = False
590
- if not isinstance(seq_or_list[0], (list, tuple, np.ndarray)):
591
- seqs = [seq_or_list]
592
- single = True
593
- else:
594
- seqs = seq_or_list
534
+ device = obs.device
535
+ eps = float(self.eps)
536
+ K = self.n_states
595
537
 
596
- if device is None:
597
- device = next(self.parameters()).device
598
- elif isinstance(device, str):
599
- device = torch.device(device)
600
- device = self._ensure_device_dtype(device)
538
+ logB = self._log_emission(obs, mask) # (N,L,K)
539
+ logA = torch.log(self.trans + eps) # (K,K)
540
+ logstart = torch.log(self.start + eps) # (K,)
601
541
 
602
- obs, mask, lengths = self._pad_and_mask(seqs, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
603
- B, L = obs.shape
604
- K = self.n_states
605
- eps = float(self.eps)
542
+ N, L, _ = logB.shape
543
+ delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
544
+ psi = torch.empty((N, L, K), dtype=torch.long, device=device)
606
545
 
607
- logB = self._log_emission(obs, mask)
608
- logA = torch.log(self.trans + eps)
609
- logstart = torch.log(self.start + eps)
546
+ delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
547
+ psi[:, 0, :] = -1
610
548
 
611
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
612
- alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
613
549
  for t in range(1, L):
614
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
615
- alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
550
+ cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
551
+ best_val, best_idx = cand.max(dim=1)
552
+ delta[:, t, :] = best_val + logB[:, t, :]
553
+ psi[:, t, :] = best_idx
616
554
 
617
- last_idx = (lengths - 1).clamp(min=0)
618
- idx_range = torch.arange(B, device=device)
619
- final_alpha = alpha[idx_range, last_idx, :] # (B, K)
620
- seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
621
- seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
622
- return seq_loglikes[0] if single else seq_loglikes
555
+ last_state = torch.argmax(delta[:, L - 1, :], dim=1) # (N,)
556
+ states = torch.empty((N, L), dtype=torch.long, device=device)
623
557
 
624
- def viterbi(self, seq: List[float], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[int], float]:
625
- """
626
- Viterbi decode a single sequence. Returns (state_path, log_probability_of_path).
627
- """
628
- paths, scores = self.batch_viterbi([seq], impute_strategy=impute_strategy, device=device)
629
- return paths[0], scores[0]
558
+ states[:, L - 1] = last_state
559
+ for t in range(L - 2, -1, -1):
560
+ states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
630
561
 
631
- def batch_viterbi(self, data: List[List[float]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[List[int]], List[float]]:
562
+ return states
563
+
564
+ def decode(
565
+ self,
566
+ X: np.ndarray,
567
+ coords: Optional[np.ndarray] = None,
568
+ *,
569
+ decode: str = "marginal",
570
+ device: Optional[Union[str, torch.device]] = None,
571
+ ) -> Tuple[np.ndarray, np.ndarray]:
572
+ """Decode observations into state calls and posterior probabilities.
573
+
574
+ Args:
575
+ X: Observations array (N, L) or (N, L, C).
576
+ coords: Optional coordinates aligned to L.
577
+ decode: Decoding strategy ("marginal" or "viterbi").
578
+ device: Device specifier.
579
+
580
+ Returns:
581
+ Tuple of (states, posterior probabilities).
632
582
  """
633
- Batched Viterbi decoding on padded sequences. Returns (list_of_paths, list_of_scores).
634
- Each path is the length of the original sequence.
583
+ device = self._ensure_device_dtype(device)
584
+
585
+ X = np.asarray(X, dtype=float)
586
+ if X.ndim == 2:
587
+ L = X.shape[1]
588
+ elif X.ndim == 3:
589
+ L = X.shape[1]
590
+ else:
591
+ raise ValueError(f"X must be 2D or 3D; got shape {X.shape}")
592
+
593
+ if coords is None:
594
+ coords = np.arange(L, dtype=int)
595
+ coords = np.asarray(coords, dtype=int)
596
+
597
+ if X.ndim == 2:
598
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
599
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
600
+ else:
601
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
602
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
603
+
604
+ gamma = self._forward_backward(obs, mask, coords=coords)
605
+
606
+ if str(decode).lower() == "viterbi":
607
+ st = self._viterbi(obs, mask, coords=coords)
608
+ else:
609
+ st = torch.argmax(gamma, dim=2)
610
+
611
+ return st.detach().cpu().numpy(), gamma.detach().cpu().numpy()
612
+
613
+ # ------------------------- EM fit -------------------------
614
+
615
+ def fit(
616
+ self,
617
+ X: np.ndarray,
618
+ coords: Optional[np.ndarray] = None,
619
+ *,
620
+ max_iter: int = 50,
621
+ tol: float = 1e-4,
622
+ device: Optional[Union[str, torch.device]] = None,
623
+ update_start: bool = True,
624
+ update_trans: bool = True,
625
+ update_emission: bool = True,
626
+ verbose: bool = False,
627
+ **kwargs,
628
+ ) -> List[float]:
629
+ """Fit HMM parameters using EM.
630
+
631
+ Args:
632
+ X: Observations array.
633
+ coords: Optional coordinate array.
634
+ max_iter: Maximum EM iterations.
635
+ tol: Convergence tolerance.
636
+ device: Device specifier.
637
+ update_start: Whether to update start probabilities.
638
+ update_trans: Whether to update transition probabilities.
639
+ update_emission: Whether to update emission parameters.
640
+ verbose: Whether to log progress.
641
+ **kwargs: Additional implementation-specific kwargs.
642
+
643
+ Returns:
644
+ List of log-likelihood values across iterations.
635
645
  """
636
- if device is None:
637
- device = next(self.parameters()).device
638
- elif isinstance(device, str):
639
- device = torch.device(device)
646
+ X = np.asarray(X, dtype=float)
647
+ if X.ndim not in (2, 3):
648
+ raise ValueError(f"X must be 2D or 3D; got {X.shape}")
649
+ L = X.shape[1]
650
+
651
+ if coords is None:
652
+ coords = np.arange(L, dtype=int)
653
+ coords = np.asarray(coords, dtype=int)
654
+
640
655
  device = self._ensure_device_dtype(device)
656
+ return self.fit_em(
657
+ X,
658
+ coords,
659
+ device=device,
660
+ max_iter=max_iter,
661
+ tol=tol,
662
+ update_start=update_start,
663
+ update_trans=update_trans,
664
+ update_emission=update_emission,
665
+ verbose=verbose,
666
+ **kwargs,
667
+ )
641
668
 
642
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
643
- B, L = obs.shape
644
- K = self.n_states
645
- eps = float(self.eps)
669
+ def adapt_emissions(
670
+ self,
671
+ X: np.ndarray,
672
+ coords: np.ndarray,
673
+ *,
674
+ iters: int = 5,
675
+ device: Optional[Union[str, torch.device]] = None,
676
+ freeze_start: bool = True,
677
+ freeze_trans: bool = True,
678
+ verbose: bool = False,
679
+ **kwargs,
680
+ ) -> List[float]:
681
+ """Adapt emission parameters while keeping shared structure fixed.
682
+
683
+ Args:
684
+ X: Observations array.
685
+ coords: Coordinate array aligned to X.
686
+ iters: Number of EM iterations.
687
+ device: Device specifier.
688
+ freeze_start: Whether to freeze start probabilities.
689
+ freeze_trans: Whether to freeze transitions.
690
+ verbose: Whether to log progress.
691
+ **kwargs: Additional implementation-specific kwargs.
692
+
693
+ Returns:
694
+ List of log-likelihood values across iterations.
695
+ """
696
+ return self.fit(
697
+ X,
698
+ coords,
699
+ max_iter=int(iters),
700
+ tol=0.0,
701
+ device=device,
702
+ update_start=not freeze_start,
703
+ update_trans=not freeze_trans,
704
+ update_emission=True,
705
+ verbose=verbose,
706
+ **kwargs,
707
+ )
646
708
 
647
- p = self.emission
648
- logp = torch.log(p + eps)
649
- log1mp = torch.log1p(-p + eps)
650
- logB = obs.unsqueeze(-1) * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs.unsqueeze(-1)) * log1mp.unsqueeze(0).unsqueeze(0)
651
- logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
709
+ def fit_em(
710
+ self,
711
+ X: np.ndarray,
712
+ coords: np.ndarray,
713
+ *,
714
+ device: torch.device,
715
+ max_iter: int,
716
+ tol: float,
717
+ update_start: bool,
718
+ update_trans: bool,
719
+ update_emission: bool,
720
+ verbose: bool,
721
+ **kwargs,
722
+ ) -> List[float]:
723
+ """Run the core EM update loop (subclasses implement).
724
+
725
+ Args:
726
+ X: Observations array.
727
+ coords: Coordinate array aligned to X.
728
+ device: Torch device.
729
+ max_iter: Maximum iterations.
730
+ tol: Convergence tolerance.
731
+ update_start: Whether to update start probabilities.
732
+ update_trans: Whether to update transitions.
733
+ update_emission: Whether to update emission parameters.
734
+ verbose: Whether to log progress.
735
+ **kwargs: Additional subclass-specific kwargs.
736
+
737
+ Returns:
738
+ List of log-likelihood values across iterations.
739
+ """
740
+ raise NotImplementedError
652
741
 
653
- logstart = torch.log(self.start + eps)
654
- logA = torch.log(self.trans + eps)
742
+ # ------------------------- save/load -------------------------
655
743
 
656
- # delta (score) and psi (argmax pointers)
657
- delta = torch.empty((B, L, K), dtype=self.dtype, device=device)
658
- psi = torch.zeros((B, L, K), dtype=torch.long, device=device)
744
+ def _extra_save_payload(self) -> dict:
745
+ """Return extra model state to include when saving."""
746
+ return {}
659
747
 
660
- delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
661
- psi[:, 0, :] = -1 # sentinel
748
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
749
+ """Load extra model state saved by subclasses.
662
750
 
663
- for t in range(1, L):
664
- # cand shape (B, i, j)
665
- cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
666
- best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
667
- delta[:, t, :] = best_val + logB[:, t, :]
668
- psi[:, t, :] = best_idx # best previous state index for each (B, j)
669
-
670
- # backtrack
671
- last_idx = (lengths - 1).clamp(min=0)
672
- idx_range = torch.arange(B, device=device)
673
- final_delta = delta[idx_range, last_idx, :] # (B, K)
674
- best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
675
- paths = []
676
- scores = []
677
- for b in range(B):
678
- Lb = int(lengths[b].item())
679
- if Lb == 0:
680
- paths.append([])
681
- scores.append(float("-inf"))
682
- continue
683
- s = int(best_last_state[b].item())
684
- path = [s]
685
- for t in range(Lb - 1, 0, -1):
686
- s = int(psi[b, t, s].item())
687
- path.append(s)
688
- path.reverse()
689
- paths.append(path)
690
- scores.append(float(best_last_val[b].item()))
691
- return paths, scores
692
-
693
- def save(self, path: str) -> None:
751
+ Args:
752
+ payload: Serialized model payload.
753
+ device: Torch device for tensors.
694
754
  """
695
- Save HMM to `path` using torch.save. Stores:
696
- - n_states, eps, dtype (string)
697
- - start, trans, emission (CPU tensors)
755
+ return
756
+
757
+ def save(self, path: Union[str, Path]) -> None:
758
+ """Serialize the model to disk.
759
+
760
+ Args:
761
+ path: Output path for the serialized model.
698
762
  """
763
+ path = str(path)
699
764
  payload = {
765
+ "hmm_name": getattr(self, "hmm_name", self.__class__.__name__),
766
+ "class": self.__class__.__name__,
700
767
  "n_states": int(self.n_states),
701
768
  "eps": float(self.eps),
702
- # store dtype as a string like "torch.float64" (portable)
703
769
  "dtype": str(self.dtype),
704
770
  "start": self.start.detach().cpu(),
705
771
  "trans": self.trans.detach().cpu(),
706
- "emission": self.emission.detach().cpu(),
707
772
  }
773
+ payload.update(self._extra_save_payload())
708
774
  torch.save(payload, path)
709
775
 
710
776
  @classmethod
711
- def load(cls, path: str, device: Optional[Union[torch.device, str]] = None) -> "HMM":
712
- """
713
- Load model from `path`. If `device` is provided (str or torch.device),
714
- parameters will be moved to that device; otherwise they remain on CPU.
715
- Example: model = HMM.load('hmm.pt', device='cuda')
716
- """
717
- payload = torch.load(path, map_location="cpu")
777
+ def load(cls, path: Union[str, Path], device: Optional[Union[str, torch.device]] = None):
778
+ """Load a serialized model from disk.
718
779
 
719
- n_states = int(payload.get("n_states"))
720
- eps = float(payload.get("eps", 1e-8))
721
- dtype_entry = payload.get("dtype", "torch.float64")
780
+ Args:
781
+ path: Path to the serialized model.
782
+ device: Optional device specifier.
722
783
 
723
- # Resolve dtype string robustly:
724
- # Accept "torch.float64" or "float64" or actual torch.dtype (older payloads)
725
- if isinstance(dtype_entry, torch.dtype):
726
- torch_dtype = dtype_entry
727
- else:
728
- # dtype_entry expected to be a string
729
- dtype_str = str(dtype_entry)
730
- # take last part after dot if present: "torch.float64" -> "float64"
731
- name = dtype_str.split(".")[-1]
732
- # map to torch dtype if available, else fallback mapping
733
- if hasattr(torch, name):
734
- torch_dtype = getattr(torch, name)
735
- else:
736
- fallback = {"float64": torch.float64, "float32": torch.float32, "float16": torch.float16}
737
- torch_dtype = fallback.get(name, torch.float64)
738
-
739
- # Build instance (use resolved dtype)
740
- model = cls(n_states=n_states, dtype=torch_dtype, eps=eps)
741
-
742
- # Determine target device
743
- if device is None:
744
- device = torch.device("cpu")
745
- elif isinstance(device, str):
746
- device = torch.device(device)
784
+ Returns:
785
+ Loaded HMM instance.
786
+ """
787
+ payload = torch.load(str(path), map_location="cpu")
788
+ hmm_name = payload.get("hmm_name", None)
789
+ klass = _HMM_REGISTRY.get(hmm_name, cls)
790
+
791
+ dtype_str = str(payload.get("dtype", "torch.float64"))
792
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1], torch.float64)
793
+ torch_dtype = _coerce_dtype_for_device(torch_dtype, device) # <<< NEW
794
+
795
+ model = klass(
796
+ n_states=int(payload["n_states"]),
797
+ eps=float(payload.get("eps", 1e-8)),
798
+ dtype=torch_dtype,
799
+ )
800
+ dev = torch.device(device) if isinstance(device, str) else (device or torch.device("cpu"))
801
+ model.to(dev)
747
802
 
748
- # Load params (they were saved on CPU) and cast to model dtype/device
749
803
  with torch.no_grad():
750
- model.start.data = payload["start"].to(device=device, dtype=model.dtype)
751
- model.trans.data = payload["trans"].to(device=device, dtype=model.dtype)
752
- model.emission.data = payload["emission"].to(device=device, dtype=model.dtype)
804
+ model.start.data = payload["start"].to(device=dev, dtype=model.dtype)
805
+ model.trans.data = payload["trans"].to(device=dev, dtype=model.dtype)
753
806
 
754
- # Normalize / coerce shapes just in case
807
+ model._load_extra_payload(payload, device=dev)
755
808
  model._normalize_params()
756
809
  return model
757
-
810
+
811
+ # ------------------------- interval helpers -------------------------
812
+
813
+ @staticmethod
814
+ def _runs_from_bool(mask_1d: np.ndarray) -> List[Tuple[int, int]]:
815
+ """
816
+ Return runs as (start_idx, end_idx_exclusive) for True segments.
817
+ """
818
+ idx = np.nonzero(mask_1d)[0]
819
+ if idx.size == 0:
820
+ return []
821
+ breaks = np.where(np.diff(idx) > 1)[0]
822
+ starts = np.r_[idx[0], idx[breaks + 1]]
823
+ ends = np.r_[idx[breaks] + 1, idx[-1] + 1]
824
+ return list(zip(starts, ends))
825
+
826
+ @staticmethod
827
+ def _interval_length(coords: np.ndarray, s: int, e: int) -> int:
828
+ """Genomic length for [s,e) on coords."""
829
+ if e <= s:
830
+ return 0
831
+ return int(coords[e - 1]) - int(coords[s]) + 1
832
+
833
+ @staticmethod
834
+ def _write_lengths_for_binary_layer(bin_mat: np.ndarray) -> np.ndarray:
835
+ """
836
+ For each row, each True-run gets its run-length assigned across that run.
837
+ Output same shape as bin_mat, int32.
838
+ """
839
+ n, L = bin_mat.shape
840
+ out = np.zeros((n, L), dtype=np.int32)
841
+ for i in range(n):
842
+ runs = BaseHMM._runs_from_bool(bin_mat[i].astype(bool))
843
+ for s, e in runs:
844
+ out[i, s:e] = e - s
845
+ return out
846
+
847
+ @staticmethod
848
+ def _write_lengths_for_state_layer(states: np.ndarray) -> np.ndarray:
849
+ """
850
+ For each row, each constant-state run gets run-length assigned across run.
851
+ Missing values should be -1 and will get 0 length.
852
+ """
853
+ n, L = states.shape
854
+ out = np.zeros((n, L), dtype=np.int32)
855
+ for i in range(n):
856
+ row = states[i]
857
+ valid = row >= 0
858
+ if not np.any(valid):
859
+ continue
860
+ # scan runs
861
+ s = 0
862
+ while s < L:
863
+ if row[s] < 0:
864
+ s += 1
865
+ continue
866
+ v = row[s]
867
+ e = s + 1
868
+ while e < L and row[e] == v:
869
+ e += 1
870
+ out[i, s:e] = e - s
871
+ s = e
872
+ return out
873
+
874
+ # ------------------------- merging -------------------------
875
+
876
+ def merge_intervals_to_new_layer(
877
+ self,
878
+ adata,
879
+ base_layer: str,
880
+ *,
881
+ distance_threshold: int,
882
+ suffix: str = "_merged",
883
+ overwrite: bool = True,
884
+ ) -> str:
885
+ """
886
+ Merge adjacent 1-intervals in a binary layer if gaps <= distance_threshold (in coords space),
887
+ writing:
888
+ - {base_layer}{suffix}
889
+ - {base_layer}{suffix}_lengths (run-length in index units)
890
+ """
891
+ if base_layer not in adata.layers:
892
+ raise KeyError(f"Layer '{base_layer}' not found.")
893
+
894
+ coords, coords_are_ints = _safe_int_coords(adata.var_names)
895
+ arr = np.asarray(adata.layers[base_layer])
896
+ arr = (arr > 0).astype(np.uint8)
897
+
898
+ merged_name = f"{base_layer}{suffix}"
899
+ merged_len_name = f"{merged_name}_lengths"
900
+
901
+ if (merged_name in adata.layers or merged_len_name in adata.layers) and not overwrite:
902
+ raise KeyError(f"Merged outputs exist (use overwrite=True): {merged_name}")
903
+
904
+ n, L = arr.shape
905
+ out = np.zeros_like(arr, dtype=np.uint8)
906
+
907
+ dt = int(distance_threshold)
908
+
909
+ for i in range(n):
910
+ ones = np.nonzero(arr[i] != 0)[0]
911
+ runs = self._runs_from_bool(arr[i] != 0)
912
+ if not runs:
913
+ continue
914
+ ms, me = runs[0]
915
+ merged_runs = []
916
+ for s, e in runs[1:]:
917
+ if coords_are_ints:
918
+ gap = int(coords[s]) - int(coords[me - 1]) - 1
919
+ else:
920
+ gap = s - me
921
+ if gap <= dt:
922
+ me = e
923
+ else:
924
+ merged_runs.append((ms, me))
925
+ ms, me = s, e
926
+ merged_runs.append((ms, me))
927
+
928
+ for s, e in merged_runs:
929
+ out[i, s:e] = 1
930
+
931
+ adata.layers[merged_name] = out
932
+ adata.layers[merged_len_name] = self._write_lengths_for_binary_layer(out)
933
+
934
+ # bookkeeping
935
+ key = "hmm_appended_layers"
936
+ if adata.uns.get(key) is None:
937
+ adata.uns[key] = []
938
+ for nm in (merged_name, merged_len_name):
939
+ if nm not in adata.uns[key]:
940
+ adata.uns[key].append(nm)
941
+
942
+ return merged_name
943
+
944
+ def write_size_class_layers_from_binary(
945
+ self,
946
+ adata,
947
+ base_layer: str,
948
+ *,
949
+ out_prefix: str,
950
+ feature_ranges: Dict[str, Tuple[float, float]],
951
+ suffix: str = "",
952
+ overwrite: bool = True,
953
+ ) -> List[str]:
954
+ """
955
+ Take an existing binary layer (runs represent features) and write size-class layers:
956
+ - {out_prefix}_{feature}{suffix}
957
+ - plus lengths layers
958
+
959
+ feature_ranges: name -> (lo, hi) in genomic bp.
960
+ """
961
+ if base_layer not in adata.layers:
962
+ raise KeyError(f"Layer '{base_layer}' not found.")
963
+
964
+ coords, coords_are_ints = _safe_int_coords(adata.var_names)
965
+ bin_arr = (np.asarray(adata.layers[base_layer]) > 0).astype(np.uint8)
966
+ n, L = bin_arr.shape
967
+
968
+ created: List[str] = []
969
+ for feat_name in feature_ranges.keys():
970
+ nm = f"{out_prefix}_{feat_name}{suffix}"
971
+ ln = f"{nm}_lengths"
972
+ if (nm in adata.layers or ln in adata.layers) and not overwrite:
973
+ continue
974
+ adata.layers[nm] = np.zeros((n, L), dtype=np.uint8)
975
+ adata.layers[ln] = np.zeros((n, L), dtype=np.int32)
976
+ created.extend([nm, ln])
977
+
978
+ for i in range(n):
979
+ runs = self._runs_from_bool(bin_arr[i] != 0)
980
+ for s, e in runs:
981
+ length_bp = self._interval_length(coords, s, e) if coords_are_ints else (e - s)
982
+ for feat_name, (lo, hi) in feature_ranges.items():
983
+ if float(lo) <= float(length_bp) < float(hi):
984
+ nm = f"{out_prefix}_{feat_name}{suffix}"
985
+ adata.layers[nm][i, s:e] = 1
986
+ adata.layers[f"{nm}_lengths"][i, s:e] = e - s
987
+ break
988
+
989
+ # fill lengths for each size layer (consistent, even if overlaps)
990
+ for feat_name in feature_ranges.keys():
991
+ nm = f"{out_prefix}_{feat_name}{suffix}"
992
+ adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
993
+ np.asarray(adata.layers[nm])
994
+ )
995
+
996
+ key = "hmm_appended_layers"
997
+ if adata.uns.get(key) is None:
998
+ adata.uns[key] = []
999
+ for nm in created:
1000
+ if nm not in adata.uns[key]:
1001
+ adata.uns[key].append(nm)
1002
+
1003
+ return created
1004
+
1005
+ # ------------------------- AnnData annotation -------------------------
1006
+
1007
+ @staticmethod
1008
+ def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
1009
+ """
1010
+ Local helper to resolve per-base masks from subset.var.* columns.
1011
+ Returns a boolean np.ndarray of length subset.n_vars or None.
1012
+ """
1013
+ key = str(methbase).strip().lower()
1014
+ var = subset.var
1015
+
1016
+ def _has(col: str) -> bool:
1017
+ """Return True when a column exists on subset.var."""
1018
+ return col in var.columns
1019
+
1020
+ if key in ("a",):
1021
+ col = f"{ref}_strand_FASTA_base"
1022
+ if not _has(col):
1023
+ return None
1024
+ return np.asarray(var[col] == "A")
1025
+
1026
+ if key in ("c", "any_c", "anyc", "any-c"):
1027
+ for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
1028
+ if _has(col):
1029
+ return np.asarray(var[col])
1030
+ return None
1031
+
1032
+ if key in ("gpc", "gpc_site", "gpc-site"):
1033
+ col = f"{ref}_GpC_site"
1034
+ if not _has(col):
1035
+ return None
1036
+ return np.asarray(var[col])
1037
+
1038
+ if key in ("cpg", "cpg_site", "cpg-site"):
1039
+ col = f"{ref}_CpG_site"
1040
+ if not _has(col):
1041
+ return None
1042
+ return np.asarray(var[col])
1043
+
1044
+ alt = f"{ref}_{methbase}_site"
1045
+ if not _has(alt):
1046
+ return None
1047
+ return np.asarray(var[alt])
1048
+
758
1049
  def annotate_adata(
759
1050
  self,
760
1051
  adata,
761
- obs_column: str,
762
- layer: Optional[str] = None,
763
- footprints: Optional[bool] = None,
764
- accessible_patches: Optional[bool] = None,
765
- cpg: Optional[bool] = None,
766
- methbases: Optional[List[str]] = None,
767
- threshold: Optional[float] = None,
768
- device: Optional[Union[str, torch.device]] = None,
769
- batch_size: Optional[int] = None,
770
- use_viterbi: Optional[bool] = None,
771
- in_place: bool = True,
772
- verbose: bool = True,
1052
+ *,
1053
+ prefix: str,
1054
+ X: np.ndarray,
1055
+ coords: np.ndarray,
1056
+ var_mask: np.ndarray,
1057
+ span_fill: bool = True,
1058
+ config=None,
1059
+ decode: str = "marginal",
1060
+ write_posterior: bool = True,
1061
+ posterior_state: str = "Modified",
1062
+ feature_sets: Optional[Dict[str, Dict[str, Any]]] = None,
1063
+ prob_threshold: float = 0.5,
773
1064
  uns_key: str = "hmm_appended_layers",
774
- config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
775
1065
  uns_flag: str = "hmm_annotated",
776
- force_redo: bool = False
1066
+ force_redo: bool = False,
1067
+ device: Optional[Union[str, torch.device]] = None,
1068
+ **kwargs,
777
1069
  ):
1070
+ """Decode and annotate an AnnData object with HMM-derived layers.
1071
+
1072
+ Args:
1073
+ adata: AnnData to annotate.
1074
+ prefix: Prefix for newly written layers.
1075
+ X: Observations array for decoding.
1076
+ coords: Coordinate array aligned to X.
1077
+ var_mask: Boolean mask for positions in adata.var.
1078
+ span_fill: Whether to fill missing spans.
1079
+ config: Optional config for naming and state selection.
1080
+ decode: Decode method ("marginal" or "viterbi").
1081
+ write_posterior: Whether to write posterior probabilities.
1082
+ posterior_state: State label to write posterior for.
1083
+ feature_sets: Optional feature set definition for size classes.
1084
+ prob_threshold: Posterior probability threshold for binary calls.
1085
+ uns_key: .uns key to track appended layers.
1086
+ uns_flag: .uns flag to mark annotations.
1087
+ force_redo: Whether to overwrite existing layers.
1088
+ device: Device specifier.
1089
+ **kwargs: Additional parameters for specialized workflows.
1090
+
1091
+ Returns:
1092
+ List of created layer names or None if skipped.
778
1093
  """
779
- Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
780
-
781
- Parameters
782
- ----------
783
- config : optional ExperimentConfig instance or plain dict
784
- When provided, the following keys (if present) are used to override defaults:
785
- - hmm_feature_sets : dict (canonical feature set structure) OR a JSON/string repr
786
- - hmm_annotation_threshold : float
787
- - hmm_batch_size : int
788
- - hmm_use_viterbi : bool
789
- - hmm_methbases : list
790
- - footprints / accessible_patches / cpg (booleans)
791
- Other keyword args override config values if explicitly provided.
792
- """
793
- import json, ast, warnings
794
- import numpy as _np
795
- import torch as _torch
796
- from tqdm import trange, tqdm as _tqdm
797
-
798
- # Only run if not already performed
799
- already = bool(adata.uns.get(uns_flag, False))
800
- if (already and not force_redo):
801
- # QC already performed; nothing to do
802
- return None if in_place else adata
803
-
804
- # small helpers
805
- def _try_json_or_literal(s):
806
- if s is None:
807
- return None
808
- if not isinstance(s, str):
809
- return s
810
- s0 = s.strip()
811
- if s0 == "":
812
- return None
813
- try:
814
- return json.loads(s0)
815
- except Exception:
816
- pass
817
- try:
818
- return ast.literal_eval(s0)
819
- except Exception:
820
- return s
821
-
822
- def _coerce_bool(x):
823
- if x is None:
824
- return False
825
- if isinstance(x, bool):
826
- return x
827
- if isinstance(x, (int, float)):
828
- return bool(x)
829
- s = str(x).strip().lower()
830
- return s in ("1", "true", "t", "yes", "y", "on")
831
-
832
- def normalize_hmm_feature_sets(raw):
833
- if raw is None:
834
- return {}
835
- parsed = raw
836
- if isinstance(raw, str):
837
- parsed = _try_json_or_literal(raw)
838
- if not isinstance(parsed, dict):
839
- return {}
840
-
841
- def _coerce_bound(x):
842
- if x is None:
843
- return None
844
- if isinstance(x, (int, float)):
845
- return float(x)
846
- s = str(x).strip().lower()
847
- if s in ("inf", "infty", "infinite", "np.inf"):
848
- return _np.inf
849
- if s in ("none", ""):
850
- return None
851
- try:
852
- return float(x)
853
- except Exception:
854
- return None
855
-
856
- def _coerce_feature_map(feats):
857
- out = {}
858
- if not isinstance(feats, dict):
859
- return out
860
- for fname, rng in feats.items():
861
- if rng is None:
862
- out[fname] = (0.0, _np.inf)
1094
+ # skip logic
1095
+ if bool(adata.uns.get(uns_flag, False)) and not force_redo:
1096
+ return None
1097
+
1098
+ if adata.uns.get(uns_key) is None:
1099
+ adata.uns[uns_key] = []
1100
+ appended = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
1101
+
1102
+ X = np.asarray(X, dtype=float)
1103
+ coords = np.asarray(coords, dtype=int)
1104
+ var_mask = np.asarray(var_mask, dtype=bool)
1105
+ if var_mask.shape[0] != adata.n_vars:
1106
+ raise ValueError(f"var_mask length {var_mask.shape[0]} != adata.n_vars {adata.n_vars}")
1107
+
1108
+ # decode
1109
+ states, gamma = self.decode(
1110
+ X, coords, decode=decode, device=device
1111
+ ) # states (N,L), gamma (N,L,K)
1112
+ N, L = states.shape
1113
+ if N != adata.n_obs:
1114
+ raise ValueError(f"X has N={N} rows but adata.n_obs={adata.n_obs}")
1115
+
1116
+ # map coords -> full-var indices for span_fill
1117
+ full_coords, full_int = _safe_int_coords(adata.var_names)
1118
+
1119
+ # ---- write posterior + states on masked columns only ----
1120
+ masked_idx = np.nonzero(var_mask)[0]
1121
+ masked_coords, _ = _safe_int_coords(adata.var_names[var_mask])
1122
+
1123
+ # build mapping from coords order -> masked column order
1124
+ coord_to_pos_in_decoded = {int(c): i for i, c in enumerate(coords.tolist())}
1125
+ take = np.array(
1126
+ [coord_to_pos_in_decoded.get(int(c), -1) for c in masked_coords.tolist()], dtype=int
1127
+ )
1128
+ good = take >= 0
1129
+ masked_idx = masked_idx[good]
1130
+ take = take[good]
1131
+
1132
+ # states layer
1133
+ states_name = f"{prefix}_states"
1134
+ if states_name not in adata.layers:
1135
+ adata.layers[states_name] = np.full((adata.n_obs, adata.n_vars), -1, dtype=np.int8)
1136
+ adata.layers[states_name][:, masked_idx] = states[:, take].astype(np.int8)
1137
+ if states_name not in appended:
1138
+ appended.append(states_name)
1139
+
1140
+ # posterior layer (requested state)
1141
+ if write_posterior:
1142
+ t_idx = self.resolve_target_state_index(posterior_state)
1143
+ post = gamma[:, :, t_idx].astype(np.float32)
1144
+ post_name = f"{prefix}_posterior_{str(posterior_state).strip().lower().replace(' ', '_').replace('-', '_')}"
1145
+ if post_name not in adata.layers:
1146
+ adata.layers[post_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
1147
+ adata.layers[post_name][:, masked_idx] = post[:, take]
1148
+ if post_name not in appended:
1149
+ appended.append(post_name)
1150
+
1151
+ # ---- feature layers ----
1152
+ if feature_sets is None:
1153
+ cfgd = self._cfg_to_dict(config)
1154
+ feature_sets = normalize_hmm_feature_sets(cfgd.get("hmm_feature_sets", None))
1155
+
1156
+ if not feature_sets:
1157
+ adata.uns[uns_key] = appended
1158
+ adata.uns[uns_flag] = True
1159
+ return None
1160
+
1161
+ # allocate outputs
1162
+ for group, fs in feature_sets.items():
1163
+ fmap = fs.get("features", {}) or {}
1164
+ if not fmap:
1165
+ continue
1166
+
1167
+ all_layer = f"{prefix}_all_{group}_features"
1168
+ if all_layer not in adata.layers:
1169
+ adata.layers[all_layer] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint8)
1170
+ if f"{all_layer}_lengths" not in adata.layers:
1171
+ adata.layers[f"{all_layer}_lengths"] = np.zeros(
1172
+ (adata.n_obs, adata.n_vars), dtype=np.int32
1173
+ )
1174
+ for nm in (all_layer, f"{all_layer}_lengths"):
1175
+ if nm not in appended:
1176
+ appended.append(nm)
1177
+
1178
+ for feat in fmap.keys():
1179
+ nm = f"{prefix}_{feat}"
1180
+ if nm not in adata.layers:
1181
+ adata.layers[nm] = np.zeros(
1182
+ (adata.n_obs, adata.n_vars),
1183
+ dtype=np.int32 if nm.endswith("_lengths") else np.uint8,
1184
+ )
1185
+ if f"{nm}_lengths" not in adata.layers:
1186
+ adata.layers[f"{nm}_lengths"] = np.zeros(
1187
+ (adata.n_obs, adata.n_vars), dtype=np.int32
1188
+ )
1189
+ for outnm in (nm, f"{nm}_lengths"):
1190
+ if outnm not in appended:
1191
+ appended.append(outnm)
1192
+
1193
+ # classify runs per row
1194
+ target_idx = self.resolve_target_state_index(fs.get("state", "Modified"))
1195
+ membership = (
1196
+ (states == target_idx)
1197
+ if str(decode).lower() == "viterbi"
1198
+ else (gamma[:, :, target_idx] >= float(prob_threshold))
1199
+ )
1200
+
1201
+ for i in range(N):
1202
+ runs = self._runs_from_bool(membership[i].astype(bool))
1203
+ for s, e in runs:
1204
+ # genomic length in coords space
1205
+ glen = int(coords[e - 1]) - int(coords[s]) + 1 if e > s else 0
1206
+ if glen <= 0:
1207
+ continue
1208
+
1209
+ # pick feature bin
1210
+ chosen = None
1211
+ for feat_name, (lo, hi) in fmap.items():
1212
+ if float(lo) <= float(glen) < float(hi):
1213
+ chosen = feat_name
1214
+ break
1215
+ if chosen is None:
863
1216
  continue
864
- if isinstance(rng, (list, tuple)) and len(rng) >= 2:
865
- lo = _coerce_bound(rng[0]) or 0.0
866
- hi = _coerce_bound(rng[1])
867
- if hi is None:
868
- hi = _np.inf
869
- out[fname] = (float(lo), float(hi) if not _np.isinf(hi) else _np.inf)
1217
+
1218
+ # convert span to indices in full var grid
1219
+ if span_fill and full_int:
1220
+ left = int(np.searchsorted(full_coords, int(coords[s]), side="left"))
1221
+ right = int(np.searchsorted(full_coords, int(coords[e - 1]), side="right"))
1222
+ if left >= right:
1223
+ continue
1224
+ adata.layers[f"{prefix}_{chosen}"][i, left:right] = 1
1225
+ adata.layers[f"{prefix}_all_{group}_features"][i, left:right] = 1
870
1226
  else:
871
- val = _coerce_bound(rng)
872
- out[fname] = (0.0, float(val) if val is not None else _np.inf)
873
- return out
874
-
875
- canonical = {}
876
- for grp, info in parsed.items():
877
- if not isinstance(info, dict):
878
- feats = _coerce_feature_map(info)
879
- canonical[grp] = {"features": feats, "state": "Modified"}
880
- continue
881
- feats = _coerce_feature_map(info.get("features", info.get("ranges", {})))
882
- state = info.get("state", info.get("label", "Modified"))
883
- canonical[grp] = {"features": feats, "state": state}
884
- return canonical
885
-
886
- # ---------- resolve config dict ----------
887
- merged_cfg = {}
888
- if config is not None:
889
- if hasattr(config, "to_dict") and callable(getattr(config, "to_dict")):
890
- merged_cfg = dict(config.to_dict())
891
- elif isinstance(config, dict):
892
- merged_cfg = dict(config)
893
- else:
894
- try:
895
- merged_cfg = {k: getattr(config, k) for k in dir(config) if k.startswith("hmm_")}
896
- except Exception:
897
- merged_cfg = {}
898
-
899
- def _pick(key, local_val, fallback=None):
900
- if local_val is not None:
901
- return local_val
902
- if key in merged_cfg and merged_cfg[key] is not None:
903
- return merged_cfg[key]
904
- alt = f"hmm_{key}"
905
- if alt in merged_cfg and merged_cfg[alt] is not None:
906
- return merged_cfg[alt]
907
- return fallback
908
-
909
- # coerce booleans robustly
910
- footprints = _coerce_bool(_pick("footprints", footprints, merged_cfg.get("footprints", False)))
911
- accessible_patches = _coerce_bool(_pick("accessible_patches", accessible_patches, merged_cfg.get("accessible_patches", False)))
912
- cpg = _coerce_bool(_pick("cpg", cpg, merged_cfg.get("cpg", False)))
913
-
914
- threshold = float(_pick("threshold", threshold, merged_cfg.get("hmm_annotation_threshold", 0.5)))
915
- batch_size = int(_pick("batch_size", batch_size, merged_cfg.get("hmm_batch_size", 1024)))
916
- use_viterbi = _coerce_bool(_pick("use_viterbi", use_viterbi, merged_cfg.get("hmm_use_viterbi", False)))
917
-
918
- methbases = merged_cfg.get("hmm_methbases", None)
919
-
920
- # normalize whitespace/case for human-friendly inputs (but keep original tokens as given)
921
- methbases = [str(m).strip() for m in methbases if m is not None]
922
- if verbose:
923
- print("DEBUG: final methbases list =", methbases)
924
-
925
- # resolve feature sets: prefer canonical if it yields non-empty mapping, otherwise fall back to boolean defaults
926
- feature_sets = {}
927
- if "hmm_feature_sets" in merged_cfg and merged_cfg.get("hmm_feature_sets") is not None:
928
- cand = normalize_hmm_feature_sets(merged_cfg.get("hmm_feature_sets"))
929
- if isinstance(cand, dict) and len(cand) > 0:
930
- feature_sets = cand
1227
+ # only fill at masked indices
1228
+ cols = masked_idx[
1229
+ (masked_coords >= coords[s]) & (masked_coords <= coords[e - 1])
1230
+ ]
1231
+ if cols.size == 0:
1232
+ continue
1233
+ adata.layers[f"{prefix}_{chosen}"][i, cols] = 1
1234
+ adata.layers[f"{prefix}_all_{group}_features"][i, cols] = 1
1235
+
1236
+ # lengths derived from binary
1237
+ adata.layers[f"{prefix}_all_{group}_features_lengths"] = (
1238
+ self._write_lengths_for_binary_layer(
1239
+ np.asarray(adata.layers[f"{prefix}_all_{group}_features"])
1240
+ )
1241
+ )
1242
+ for feat in fmap.keys():
1243
+ nm = f"{prefix}_{feat}"
1244
+ adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
1245
+ np.asarray(adata.layers[nm])
1246
+ )
1247
+
1248
+ adata.uns[uns_key] = appended
1249
+ adata.uns[uns_flag] = True
1250
+ return None
1251
+
1252
+ # ------------------------- row-copy helper (workflow uses it) -------------------------
1253
+
1254
+ def _ensure_final_layer_and_assign(
1255
+ self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data
1256
+ ):
1257
+ """
1258
+ Assign rows from sub_data into final_adata.layers[layer_name] for rows where subset_idx_mask is True.
1259
+ Handles dense arrays. If you want sparse support, add it here.
1260
+ """
1261
+ n_final_obs, n_vars = final_adata.shape
1262
+ final_rows = np.nonzero(np.asarray(subset_idx_mask).astype(bool))[0]
1263
+ sub_arr = np.asarray(sub_data)
1264
+
1265
+ if layer_name not in final_adata.layers:
1266
+ final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
1267
+
1268
+ final_arr = np.asarray(final_adata.layers[layer_name])
1269
+ if sub_arr.shape[0] != final_rows.size:
1270
+ raise ValueError(f"Sub rows {sub_arr.shape[0]} != mask sum {final_rows.size}")
1271
+ final_arr[final_rows, :] = sub_arr
1272
+ final_adata.layers[layer_name] = final_arr
1273
+
1274
+
1275
+ # =============================================================================
1276
+ # Single-channel Bernoulli HMM
1277
+ # =============================================================================
1278
+
1279
+
1280
+ @register_hmm("single")
1281
+ class SingleBernoulliHMM(BaseHMM):
1282
+ """
1283
+ Bernoulli emission per state:
1284
+ emission[k] = P(obs==1 | state=k)
1285
+ """
1286
+
1287
+ def __init__(
1288
+ self,
1289
+ n_states: int = 2,
1290
+ init_emission: Optional[Sequence[float]] = None,
1291
+ eps: float = 1e-8,
1292
+ dtype: torch.dtype = torch.float64,
1293
+ ):
1294
+ """Initialize a single-channel Bernoulli HMM.
1295
+
1296
+ Args:
1297
+ n_states: Number of hidden states.
1298
+ init_emission: Initial emission probabilities per state.
1299
+ eps: Smoothing epsilon for probabilities.
1300
+ dtype: Torch dtype for parameters.
1301
+ """
1302
+ super().__init__(n_states=n_states, eps=eps, dtype=dtype)
1303
+ if init_emission is None:
1304
+ em = np.full((self.n_states,), 0.5, dtype=float)
1305
+ else:
1306
+ em = np.asarray(init_emission, dtype=float).reshape(-1)[: self.n_states]
1307
+ if em.size != self.n_states:
1308
+ em = np.full((self.n_states,), 0.5, dtype=float)
1309
+
1310
+ self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
1311
+ self._normalize_emission()
1312
+
1313
+ @classmethod
1314
+ def from_config(cls, cfg, *, override=None, device=None):
1315
+ """Create a single-channel Bernoulli HMM from config.
1316
+
1317
+ Args:
1318
+ cfg: Configuration mapping or object.
1319
+ override: Override values to apply.
1320
+ device: Optional device specifier.
1321
+
1322
+ Returns:
1323
+ Initialized SingleBernoulliHMM instance.
1324
+ """
1325
+ merged = cls._cfg_to_dict(cfg)
1326
+ if override:
1327
+ merged.update(override)
1328
+ n_states = int(merged.get("hmm_n_states", 2))
1329
+ eps = float(merged.get("hmm_eps", 1e-8))
1330
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1331
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1332
+ init_em = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
1333
+ model = cls(n_states=n_states, init_emission=init_em, eps=eps, dtype=dtype)
1334
+ if device is not None:
1335
+ model.to(torch.device(device) if isinstance(device, str) else device)
1336
+ model._persisted_cfg = merged
1337
+ return model
1338
+
1339
+ def _normalize_emission(self):
1340
+ """Normalize and clamp emission probabilities in-place."""
1341
+ with torch.no_grad():
1342
+ self.emission.data = self.emission.data.reshape(-1)
1343
+ if self.emission.data.numel() != self.n_states:
1344
+ self.emission.data = torch.full(
1345
+ (self.n_states,), 0.5, dtype=self.dtype, device=self.emission.device
1346
+ )
1347
+ self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
1348
+
1349
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1350
+ """Move emission parameters to the requested device/dtype."""
1351
+ device = super()._ensure_device_dtype(device)
1352
+ self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
1353
+ return device
1354
+
1355
+ def _state_modified_score(self) -> torch.Tensor:
1356
+ """Return per-state modified scores for ranking."""
1357
+ return self.emission.detach()
1358
+
1359
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1360
+ """
1361
+ obs: (N,L), mask: (N,L) -> logB: (N,L,K)
1362
+ """
1363
+ p = self.emission # (K,)
1364
+ logp = torch.log(p + self.eps)
1365
+ log1mp = torch.log1p(-p + self.eps)
1366
+
1367
+ o = obs.unsqueeze(-1) # (N,L,1)
1368
+ logB = o * logp.view(1, 1, -1) + (1.0 - o) * log1mp.view(1, 1, -1)
1369
+ logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
1370
+ return logB
1371
+
1372
+ def _extra_save_payload(self) -> dict:
1373
+ """Return extra payload data for serialization."""
1374
+ return {"emission": self.emission.detach().cpu()}
1375
+
1376
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1377
+ """Load serialized emission parameters.
1378
+
1379
+ Args:
1380
+ payload: Serialized payload dictionary.
1381
+ device: Target torch device.
1382
+ """
1383
+ with torch.no_grad():
1384
+ self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
1385
+ self._normalize_emission()
1386
+
1387
+ def fit_em(
1388
+ self,
1389
+ X: np.ndarray,
1390
+ coords: np.ndarray,
1391
+ *,
1392
+ device: torch.device,
1393
+ max_iter: int,
1394
+ tol: float,
1395
+ update_start: bool,
1396
+ update_trans: bool,
1397
+ update_emission: bool,
1398
+ verbose: bool,
1399
+ **kwargs,
1400
+ ) -> List[float]:
1401
+ """Run EM updates for a single-channel Bernoulli HMM.
1402
+
1403
+ Args:
1404
+ X: Observations array (N, L).
1405
+ coords: Coordinate array aligned to X.
1406
+ device: Torch device.
1407
+ max_iter: Maximum iterations.
1408
+ tol: Convergence tolerance.
1409
+ update_start: Whether to update start probabilities.
1410
+ update_trans: Whether to update transitions.
1411
+ update_emission: Whether to update emission parameters.
1412
+ verbose: Whether to log progress.
1413
+ **kwargs: Additional implementation-specific kwargs.
1414
+
1415
+ Returns:
1416
+ List of log-likelihood proxy values.
1417
+ """
1418
+ X = np.asarray(X, dtype=float)
1419
+ if X.ndim != 2:
1420
+ raise ValueError("SingleBernoulliHMM expects X shape (N,L).")
1421
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
1422
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
1423
+
1424
+ eps = float(self.eps)
1425
+ K = self.n_states
1426
+ N, L = obs.shape
1427
+
1428
+ hist: List[float] = []
1429
+ for it in range(1, int(max_iter) + 1):
1430
+ gamma = self._forward_backward(obs, mask) # (N,L,K)
1431
+
1432
+ # log-likelihood proxy
1433
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
1434
+ hist.append(ll_proxy)
1435
+
1436
+ # expected start
1437
+ start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
1438
+
1439
+ # expected transitions xi
1440
+ logB = self._log_emission(obs, mask)
1441
+ logA = torch.log(self.trans + eps)
1442
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
1443
+ alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
1444
+ for t in range(1, L):
1445
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
1446
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
1447
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
1448
+ beta[:, L - 1, :] = 0.0
1449
+ for t in range(L - 2, -1, -1):
1450
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1451
+ beta[:, t, :] = _logsumexp(temp, dim=2)
1452
+
1453
+ trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
1454
+ for t in range(L - 1):
1455
+ valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
1456
+ log_xi = (
1457
+ alpha[:, t, :].unsqueeze(2)
1458
+ + logA.unsqueeze(0)
1459
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1460
+ )
1461
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
1462
+ xi = (log_xi - log_norm).exp() * valid_t
1463
+ trans_acc += xi.sum(dim=0)
1464
+
1465
+ # emission update
1466
+ mask_f = mask.float().unsqueeze(-1) # (N,L,1)
1467
+ emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
1468
+ emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
1469
+
1470
+ with torch.no_grad():
1471
+ if update_start:
1472
+ new_start = start_acc + eps
1473
+ self.start.data = new_start / new_start.sum()
1474
+
1475
+ if update_trans:
1476
+ new_trans = trans_acc + eps
1477
+ rs = new_trans.sum(dim=1, keepdim=True)
1478
+ rs[rs == 0.0] = 1.0
1479
+ self.trans.data = new_trans / rs
1480
+
1481
+ if update_emission:
1482
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
1483
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
1484
+
1485
+ self._normalize_params()
1486
+ self._normalize_emission()
931
1487
 
932
- if not feature_sets:
933
1488
  if verbose:
934
- print("[HMM.annotate_adata] no feature sets configured; nothing to append.")
935
- return None if in_place else adata
936
-
937
- if verbose:
938
- print("[HMM.annotate_adata] resolved feature sets:", list(feature_sets.keys()))
939
-
940
- # copy vs in-place
941
- if not in_place:
942
- adata = adata.copy()
943
-
944
- # prepare column names
945
- all_features = []
946
- combined_prefix = "Combined"
947
- for key, fs in feature_sets.items():
948
- feats = fs.get("features", {})
949
- if key == "cpg":
950
- all_features += [f"CpG_{f}" for f in feats]
951
- all_features.append(f"CpG_all_{key}_features")
952
- else:
953
- for methbase in methbases:
954
- all_features += [f"{methbase}_{f}" for f in feats]
955
- all_features.append(f"{methbase}_all_{key}_features")
956
- if len(methbases) > 1:
957
- all_features += [f"{combined_prefix}_{f}" for f in feats]
958
- all_features.append(f"{combined_prefix}_all_{key}_features")
959
-
960
- # initialize obs columns (unique lists per row)
961
- n_rows = adata.shape[0]
962
- for feature in all_features:
963
- if feature not in adata.obs.columns:
964
- adata.obs[feature] = [[] for _ in range(n_rows)]
965
- if f"{feature}_distances" not in adata.obs.columns:
966
- adata.obs[f"{feature}_distances"] = [None] * n_rows
967
- if f"n_{feature}" not in adata.obs.columns:
968
- adata.obs[f"n_{feature}"] = -1
969
-
970
- appended_layers: List[str] = []
971
-
972
- # device management
973
- if device is None:
974
- device = next(self.parameters()).device
975
- elif isinstance(device, str):
976
- device = _torch.device(device)
977
- self.to(device)
1489
+ logger.info(
1490
+ "[SingleBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
1491
+ it,
1492
+ hist[-1],
1493
+ )
1494
+
1495
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
1496
+ break
1497
+
1498
+ return hist
1499
+
1500
+ def adapt_emissions(
1501
+ self,
1502
+ X: np.ndarray,
1503
+ coords: Optional[np.ndarray] = None,
1504
+ *,
1505
+ device: Optional[Union[str, torch.device]] = None,
1506
+ iters: Optional[int] = None,
1507
+ max_iter: Optional[int] = None, # alias for your trainer
1508
+ verbose: bool = False,
1509
+ **kwargs,
1510
+ ):
1511
+ """Adapt emissions with legacy parameter names.
1512
+
1513
+ Args:
1514
+ X: Observations array.
1515
+ coords: Optional coordinate array.
1516
+ device: Device specifier.
1517
+ iters: Number of iterations.
1518
+ max_iter: Alias for iters.
1519
+ verbose: Whether to log progress.
1520
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
1521
+
1522
+ Returns:
1523
+ List of log-likelihood values.
1524
+ """
1525
+ if iters is None:
1526
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
1527
+ return super().adapt_emissions(
1528
+ np.asarray(X, dtype=float),
1529
+ coords if coords is not None else None,
1530
+ iters=int(iters),
1531
+ device=device,
1532
+ verbose=verbose,
1533
+ )
1534
+
1535
+
1536
+ # =============================================================================
1537
+ # Multi-channel Bernoulli HMM (union coordinate grid)
1538
+ # =============================================================================
1539
+
1540
+
1541
+ @register_hmm("multi")
1542
+ class MultiBernoulliHMM(BaseHMM):
1543
+ """
1544
+ Multi-channel independent Bernoulli:
1545
+ emission[k,c] = P(obs_c==1 | state=k)
1546
+ X must be (N,L,C) on a union coordinate grid; NaN per-channel allowed.
1547
+ """
1548
+
1549
+ def __init__(
1550
+ self,
1551
+ n_states: int = 2,
1552
+ n_channels: int = 2,
1553
+ init_emission: Optional[Any] = None,
1554
+ eps: float = 1e-8,
1555
+ dtype: torch.dtype = torch.float64,
1556
+ ):
1557
+ """Initialize a multi-channel Bernoulli HMM.
1558
+
1559
+ Args:
1560
+ n_states: Number of hidden states.
1561
+ n_channels: Number of observed channels.
1562
+ init_emission: Initial emission probabilities.
1563
+ eps: Smoothing epsilon for probabilities.
1564
+ dtype: Torch dtype for parameters.
1565
+ """
1566
+ super().__init__(n_states=n_states, eps=eps, dtype=dtype)
1567
+ self.n_channels = int(n_channels)
1568
+ if self.n_channels < 1:
1569
+ raise ValueError("n_channels must be >=1")
978
1570
 
979
- # helpers ---------------------------------------------------------------
980
- def _ensure_2d_array_like(matrix):
981
- arr = _np.asarray(matrix)
1571
+ if init_emission is None:
1572
+ em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
1573
+ else:
1574
+ arr = np.asarray(init_emission, dtype=float)
982
1575
  if arr.ndim == 1:
983
- arr = arr[_np.newaxis, :]
984
- elif arr.ndim > 2:
985
- # squeeze trailing singletons
986
- while arr.ndim > 2 and arr.shape[-1] == 1:
987
- arr = _np.squeeze(arr, axis=-1)
988
- if arr.ndim != 2:
989
- raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
990
- return arr
991
-
992
- def calculate_batch_distances(intervals_list, threshold_local=0.9):
993
- results_local = []
994
- for intervals in intervals_list:
995
- if not isinstance(intervals, list) or len(intervals) == 0:
996
- results_local.append([])
997
- continue
998
- valid = [iv for iv in intervals if iv[2] > threshold_local]
999
- if len(valid) <= 1:
1000
- results_local.append([])
1001
- continue
1002
- valid = sorted(valid, key=lambda x: x[0])
1003
- dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
1004
- results_local.append(dists)
1005
- return results_local
1006
-
1007
- def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
1008
- # Accept numpy arrays or torch tensors
1009
- if isinstance(predicted_states_batch, _torch.Tensor):
1010
- pred_np = predicted_states_batch.detach().cpu().numpy()
1576
+ arr = arr.reshape(-1, 1)
1577
+ em = np.repeat(arr[: self.n_states, :], self.n_channels, axis=1)
1011
1578
  else:
1012
- pred_np = _np.asarray(predicted_states_batch)
1013
- if isinstance(probabilities_batch, _torch.Tensor):
1014
- probs_np = probabilities_batch.detach().cpu().numpy()
1015
- else:
1016
- probs_np = _np.asarray(probabilities_batch)
1017
-
1018
- batch_size, L = pred_np.shape
1019
- all_classifications_local = []
1020
- # allow caller to pass arbitrary state labels mapping; default two-state mapping:
1021
- state_labels = ["Non-Modified", "Modified"]
1022
- try:
1023
- target_idx = state_labels.index(target_state)
1024
- except ValueError:
1025
- target_idx = 1 # fallback
1026
-
1027
- for b in range(batch_size):
1028
- predicted_states = pred_np[b]
1029
- probabilities = probs_np[b]
1030
- regions = []
1031
- current_start, current_length, current_probs = None, 0, []
1032
- for i, state_index in enumerate(predicted_states):
1033
- state_prob = float(probabilities[i][state_index])
1034
- if state_index == target_idx:
1035
- if current_start is None:
1036
- current_start = i
1037
- current_length += 1
1038
- current_probs.append(state_prob)
1039
- elif current_start is not None:
1040
- regions.append((current_start, current_length, float(_np.mean(current_probs))))
1041
- current_start, current_length, current_probs = None, 0, []
1042
- if current_start is not None:
1043
- regions.append((current_start, current_length, float(_np.mean(current_probs))))
1044
-
1045
- final = []
1046
- for start, length, prob in regions:
1047
- # compute genomic length try/catch
1048
- try:
1049
- feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
1050
- except Exception:
1051
- feature_length = int(length)
1052
-
1053
- # classification_mapping values are (lo, hi) tuples or lists
1054
- label = None
1055
- for ftype, rng in classification_mapping.items():
1056
- lo, hi = rng[0], rng[1]
1057
- try:
1058
- if lo <= feature_length < hi:
1059
- label = ftype
1060
- break
1061
- except Exception:
1062
- continue
1063
- if label is None:
1064
- # fallback to first mapping key or 'unknown'
1065
- label = next(iter(classification_mapping.keys()), "feature")
1066
-
1067
- # Store reported start coordinate in same coordinate system as `coordinates`.
1068
- try:
1069
- genomic_start = int(coordinates[start])
1070
- except Exception:
1071
- genomic_start = int(start)
1072
- final.append((genomic_start, feature_length, label, prob))
1073
- all_classifications_local.append(final)
1074
- return all_classifications_local
1075
-
1076
- # -----------------------------------------------------------------------
1077
-
1078
- # Ensure obs_column is categorical-like for iteration
1079
- sseries = adata.obs[obs_column]
1080
- if not pd.api.types.is_categorical_dtype(sseries):
1081
- sseries = sseries.astype("category")
1082
- references = list(sseries.cat.categories)
1083
-
1084
- ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
1085
- for ref in ref_iter:
1086
- # subset reads with this obs_column value
1087
- ref_mask = adata.obs[obs_column] == ref
1088
- ref_subset = adata[ref_mask].copy()
1089
- combined_mask = None
1090
-
1091
- # per-methbase processing
1092
- for methbase in methbases:
1093
- key_lower = methbase.strip().lower()
1094
-
1095
- # map several common synonyms -> canonical lookup
1096
- if key_lower in ("a",):
1097
- pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
1098
- elif key_lower in ("c", "any_c", "anyc", "any-c"):
1099
- # unify 'C' or 'any_C' names to the any_C var column
1100
- pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
1101
- elif key_lower in ("gpc", "gpc_site", "gpc-site"):
1102
- pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
1103
- elif key_lower in ("cpg", "cpg_site", "cpg-site"):
1104
- pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
1105
- else:
1106
- # try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
1107
- alt_col = f"{ref}_{methbase}_site"
1108
- pos_mask = ref_subset.var.get(alt_col, None)
1579
+ em = arr[: self.n_states, : self.n_channels]
1580
+ if em.shape != (self.n_states, self.n_channels):
1581
+ em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
1109
1582
 
1110
- if pos_mask is None:
1111
- continue
1112
- combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
1583
+ self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
1584
+ self._normalize_emission()
1113
1585
 
1114
- if pos_mask.sum() == 0:
1115
- continue
1586
+ @classmethod
1587
+ def from_config(cls, cfg, *, override=None, device=None):
1588
+ """Create a multi-channel Bernoulli HMM from config.
1116
1589
 
1117
- sub = ref_subset[:, pos_mask]
1118
- # choose matrix
1119
- matrix = sub.layers[layer] if (layer and layer in sub.layers) else sub.X
1120
- matrix = _ensure_2d_array_like(matrix)
1121
- n_reads = matrix.shape[0]
1590
+ Args:
1591
+ cfg: Configuration mapping or object.
1592
+ override: Override values to apply.
1593
+ device: Optional device specifier.
1122
1594
 
1123
- # coordinates for this sub (try to convert to ints, else fallback to indices)
1124
- try:
1125
- coords = _np.asarray(sub.var_names, dtype=int)
1126
- except Exception:
1127
- coords = _np.arange(sub.shape[1], dtype=int)
1128
-
1129
- # chunked processing
1130
- chunk_iter = range(0, n_reads, batch_size)
1131
- if verbose:
1132
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:{methbase} chunks")
1133
- for start_idx in chunk_iter:
1134
- stop_idx = min(n_reads, start_idx + batch_size)
1135
- chunk = matrix[start_idx:stop_idx]
1136
- seqs = chunk.tolist()
1137
- # posterior marginals
1138
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1139
- if len(gammas) == 0:
1140
- continue
1141
- probs_batch = _np.stack(gammas, axis=0) # (B, L, K)
1142
- if use_viterbi:
1143
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1144
- pred_states = _np.asarray(paths)
1145
- else:
1146
- pred_states = _np.argmax(probs_batch, axis=2)
1595
+ Returns:
1596
+ Initialized MultiBernoulliHMM instance.
1597
+ """
1598
+ merged = cls._cfg_to_dict(cfg)
1599
+ if override:
1600
+ merged.update(override)
1601
+ n_states = int(merged.get("hmm_n_states", 2))
1602
+ eps = float(merged.get("hmm_eps", 1e-8))
1603
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1604
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1605
+ n_channels = int(merged.get("hmm_n_channels", merged.get("n_channels", 2)))
1606
+ init_em = merged.get("hmm_init_emission_probs", None)
1607
+ model = cls(
1608
+ n_states=n_states, n_channels=n_channels, init_emission=init_em, eps=eps, dtype=dtype
1609
+ )
1610
+ if device is not None:
1611
+ model.to(torch.device(device) if isinstance(device, str) else device)
1612
+ model._persisted_cfg = merged
1613
+ return model
1147
1614
 
1148
- # For each feature group, classify separately and write back
1149
- for key, fs in feature_sets.items():
1150
- if key == "cpg":
1151
- continue
1152
- state_target = fs.get("state", "Modified")
1153
- feature_map = fs.get("features", {})
1154
- classifications = classify_batch_local(pred_states, probs_batch, coords, feature_map, target_state=state_target)
1155
-
1156
- # write results to adata.obs rows (use original index names)
1157
- row_indices = list(sub.obs.index[start_idx:stop_idx])
1158
- for i_local, idx in enumerate(row_indices):
1159
- for start, length, label, prob in classifications[i_local]:
1160
- col_name = f"{methbase}_{label}"
1161
- all_col = f"{methbase}_all_{key}_features"
1162
- adata.obs.at[idx, col_name].append([start, length, prob])
1163
- adata.obs.at[idx, all_col].append([start, length, prob])
1164
-
1165
- # Combined subset (if multiple methbases)
1166
- if len(methbases) > 1 and (combined_mask is not None) and (combined_mask.sum() > 0):
1167
- comb = ref_subset[:, combined_mask]
1168
- if comb.shape[1] > 0:
1169
- matrix = comb.layers[layer] if (layer and layer in comb.layers) else comb.X
1170
- matrix = _ensure_2d_array_like(matrix)
1171
- n_reads_comb = matrix.shape[0]
1172
- try:
1173
- coords_comb = _np.asarray(comb.var_names, dtype=int)
1174
- except Exception:
1175
- coords_comb = _np.arange(comb.shape[1], dtype=int)
1176
-
1177
- chunk_iter = range(0, n_reads_comb, batch_size)
1178
- if verbose:
1179
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:Combined chunks")
1180
- for start_idx in chunk_iter:
1181
- stop_idx = min(n_reads_comb, start_idx + batch_size)
1182
- chunk = matrix[start_idx:stop_idx]
1183
- seqs = chunk.tolist()
1184
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1185
- if len(gammas) == 0:
1186
- continue
1187
- probs_batch = _np.stack(gammas, axis=0)
1188
- if use_viterbi:
1189
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1190
- pred_states = _np.asarray(paths)
1191
- else:
1192
- pred_states = _np.argmax(probs_batch, axis=2)
1193
-
1194
- for key, fs in feature_sets.items():
1195
- if key == "cpg":
1196
- continue
1197
- state_target = fs.get("state", "Modified")
1198
- feature_map = fs.get("features", {})
1199
- classifications = classify_batch_local(pred_states, probs_batch, coords_comb, feature_map, target_state=state_target)
1200
- row_indices = list(comb.obs.index[start_idx:stop_idx])
1201
- for i_local, idx in enumerate(row_indices):
1202
- for start, length, label, prob in classifications[i_local]:
1203
- adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
1204
- adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
1205
-
1206
- # CpG special handling
1207
- if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
1208
- cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
1209
- for ref in cpg_iter:
1210
- ref_mask = adata.obs[obs_column] == ref
1211
- ref_subset = adata[ref_mask].copy()
1212
- pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
1213
- if pos_mask.sum() == 0:
1214
- continue
1215
- cpg_sub = ref_subset[:, pos_mask]
1216
- matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
1217
- matrix = _ensure_2d_array_like(matrix)
1218
- n_reads = matrix.shape[0]
1219
- try:
1220
- coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
1221
- except Exception:
1222
- coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
1223
-
1224
- chunk_iter = range(0, n_reads, batch_size)
1225
- if verbose:
1226
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
1227
- for start_idx in chunk_iter:
1228
- stop_idx = min(n_reads, start_idx + batch_size)
1229
- chunk = matrix[start_idx:stop_idx]
1230
- seqs = chunk.tolist()
1231
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1232
- if len(gammas) == 0:
1233
- continue
1234
- probs_batch = _np.stack(gammas, axis=0)
1235
- if use_viterbi:
1236
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1237
- pred_states = _np.asarray(paths)
1238
- else:
1239
- pred_states = _np.argmax(probs_batch, axis=2)
1240
-
1241
- fs = feature_sets["cpg"]
1242
- state_target = fs.get("state", "Modified")
1243
- feature_map = fs.get("features", {})
1244
- classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
1245
- row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
1246
- for i_local, idx in enumerate(row_indices):
1247
- for start, length, label, prob in classifications[i_local]:
1248
- adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
1249
- adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
1250
-
1251
- # finalize: convert intervals into binary layers and distances
1252
- try:
1253
- coordinates = _np.asarray(adata.var_names, dtype=int)
1254
- coords_are_ints = True
1255
- except Exception:
1256
- coordinates = _np.arange(adata.shape[1], dtype=int)
1257
- coords_are_ints = False
1258
-
1259
- features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
1260
- for feature in features_iter:
1261
- bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1262
- counts = _np.zeros(adata.shape[0], dtype=int)
1263
-
1264
- # new: integer-length layer (0 where not inside a feature)
1265
- len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1266
-
1267
- for row_idx, intervals in enumerate(adata.obs[feature]):
1268
- if not isinstance(intervals, list):
1269
- intervals = []
1270
- for start, length, prob in intervals:
1271
- if prob > threshold:
1272
- if coords_are_ints:
1273
- # map genomic start/length into index interval [start_idx, end_idx)
1274
- start_idx = _np.searchsorted(coordinates, int(start), side="left")
1275
- end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
1276
- else:
1277
- start_idx = int(start)
1278
- end_idx = start_idx + int(length)
1279
-
1280
- start_idx = max(0, min(start_idx, adata.shape[1]))
1281
- end_idx = max(0, min(end_idx, adata.shape[1]))
1282
-
1283
- if start_idx < end_idx:
1284
- span = end_idx - start_idx # number of positions covered
1285
- # set binary mask
1286
- bin_matrix[row_idx, start_idx:end_idx] = 1
1287
- # set length mask: use maximum in case of overlaps
1288
- existing = len_matrix[row_idx, start_idx:end_idx]
1289
- len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
1290
- counts[row_idx] += 1
1291
-
1292
- # write binary layer and length layer, track appended names
1293
- adata.layers[feature] = bin_matrix
1294
- appended_layers.append(feature)
1295
-
1296
- # name the integer-length layer (choose suffix you like)
1297
- length_layer_name = f"{feature}_lengths"
1298
- adata.layers[length_layer_name] = len_matrix
1299
- appended_layers.append(length_layer_name)
1300
-
1301
- adata.obs[f"n_{feature}"] = counts
1302
- adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
1303
-
1304
- # Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
1305
- existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
1306
- new_list = existing + [l for l in appended_layers if l not in existing]
1307
- adata.uns[uns_key] = new_list
1308
-
1309
- # Mark that the annotation has been completed
1310
- adata.uns[uns_flag] = True
1615
+ def _normalize_emission(self):
1616
+ """Normalize and clamp emission probabilities in-place."""
1617
+ with torch.no_grad():
1618
+ self.emission.data = self.emission.data.reshape(self.n_states, self.n_channels)
1619
+ self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
1620
+
1621
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1622
+ """Move emission parameters to the requested device/dtype."""
1623
+ device = super()._ensure_device_dtype(device)
1624
+ self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
1625
+ return device
1626
+
1627
+ def _state_modified_score(self) -> torch.Tensor:
1628
+ """Return per-state modified scores for ranking."""
1629
+ # more “modified” = higher mean P(1) across channels
1630
+ return self.emission.detach().mean(dim=1)
1631
+
1632
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1633
+ """
1634
+ obs: (N,L,C), mask: (N,L,C) -> logB: (N,L,K)
1635
+ """
1636
+ N, L, C = obs.shape
1637
+ K = self.n_states
1638
+
1639
+ p = self.emission # (K,C)
1640
+ logp = torch.log(p + self.eps).view(1, 1, K, C)
1641
+ log1mp = torch.log1p(-p + self.eps).view(1, 1, K, C)
1642
+
1643
+ o = obs.unsqueeze(2) # (N,L,1,C)
1644
+ m = mask.unsqueeze(2) # (N,L,1,C)
1311
1645
 
1312
- return None if in_place else adata
1646
+ logBC = o * logp + (1.0 - o) * log1mp
1647
+ logBC = torch.where(m, logBC, torch.zeros_like(logBC))
1648
+ return logBC.sum(dim=3) # sum channels -> (N,L,K)
1313
1649
 
1314
- def merge_intervals_in_layer(
1650
+ def _extra_save_payload(self) -> dict:
1651
+ """Return extra payload data for serialization."""
1652
+ return {"n_channels": int(self.n_channels), "emission": self.emission.detach().cpu()}
1653
+
1654
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1655
+ """Load serialized emission parameters.
1656
+
1657
+ Args:
1658
+ payload: Serialized payload dictionary.
1659
+ device: Target torch device.
1660
+ """
1661
+ self.n_channels = int(payload.get("n_channels", self.n_channels))
1662
+ with torch.no_grad():
1663
+ self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
1664
+ self._normalize_emission()
1665
+
1666
+ def fit_em(
1315
1667
  self,
1316
- adata,
1317
- layer: str,
1318
- distance_threshold: int = 0,
1319
- merged_suffix: str = "_merged",
1320
- length_layer_suffix: str = "_lengths",
1321
- update_obs: bool = True,
1322
- prob_strategy: str = "mean", # 'mean'|'max'|'orig_first'
1323
- inplace: bool = True,
1324
- overwrite: bool = False,
1668
+ X: np.ndarray,
1669
+ coords: np.ndarray,
1670
+ *,
1671
+ device: torch.device,
1672
+ max_iter: int,
1673
+ tol: float,
1674
+ update_start: bool,
1675
+ update_trans: bool,
1676
+ update_emission: bool,
1677
+ verbose: bool,
1678
+ **kwargs,
1679
+ ) -> List[float]:
1680
+ """Run EM updates for a multi-channel Bernoulli HMM.
1681
+
1682
+ Args:
1683
+ X: Observations array (N, L, C).
1684
+ coords: Coordinate array aligned to X.
1685
+ device: Torch device.
1686
+ max_iter: Maximum iterations.
1687
+ tol: Convergence tolerance.
1688
+ update_start: Whether to update start probabilities.
1689
+ update_trans: Whether to update transitions.
1690
+ update_emission: Whether to update emission parameters.
1691
+ verbose: Whether to log progress.
1692
+ **kwargs: Additional implementation-specific kwargs.
1693
+
1694
+ Returns:
1695
+ List of log-likelihood proxy values.
1696
+ """
1697
+ X = np.asarray(X, dtype=float)
1698
+ if X.ndim != 3:
1699
+ raise ValueError("MultiBernoulliHMM expects X shape (N,L,C).")
1700
+
1701
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
1702
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
1703
+
1704
+ eps = float(self.eps)
1705
+ K = self.n_states
1706
+ N, L, C = obs.shape
1707
+
1708
+ self._ensure_n_channels(C, device)
1709
+
1710
+ hist: List[float] = []
1711
+ for it in range(1, int(max_iter) + 1):
1712
+ gamma = self._forward_backward(obs, mask) # (N,L,K)
1713
+
1714
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
1715
+ hist.append(ll_proxy)
1716
+
1717
+ # expected start
1718
+ start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
1719
+
1720
+ # transitions xi
1721
+ logB = self._log_emission(obs, mask)
1722
+ logA = torch.log(self.trans + eps)
1723
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
1724
+ alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
1725
+ for t in range(1, L):
1726
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
1727
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
1728
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
1729
+ beta[:, L - 1, :] = 0.0
1730
+ for t in range(L - 2, -1, -1):
1731
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1732
+ beta[:, t, :] = _logsumexp(temp, dim=2)
1733
+
1734
+ trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
1735
+ # valid timestep if at least one channel observed at both positions
1736
+ valid_pos = mask.any(dim=2) # (N,L)
1737
+ for t in range(L - 1):
1738
+ valid_t = (valid_pos[:, t] & valid_pos[:, t + 1]).float().view(N, 1, 1)
1739
+ log_xi = (
1740
+ alpha[:, t, :].unsqueeze(2)
1741
+ + logA.unsqueeze(0)
1742
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1743
+ )
1744
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
1745
+ xi = (log_xi - log_norm).exp() * valid_t
1746
+ trans_acc += xi.sum(dim=0)
1747
+
1748
+ # emission update per channel
1749
+ gamma_k = gamma.unsqueeze(-1) # (N,L,K,1)
1750
+ obs_c = obs.unsqueeze(2) # (N,L,1,C)
1751
+ mask_c = mask.unsqueeze(2).float() # (N,L,1,C)
1752
+
1753
+ emit_num = (gamma_k * obs_c * mask_c).sum(dim=(0, 1)) # (K,C)
1754
+ emit_den = (gamma_k * mask_c).sum(dim=(0, 1)) # (K,C)
1755
+
1756
+ with torch.no_grad():
1757
+ if update_start:
1758
+ new_start = start_acc + eps
1759
+ self.start.data = new_start / new_start.sum()
1760
+
1761
+ if update_trans:
1762
+ new_trans = trans_acc + eps
1763
+ rs = new_trans.sum(dim=1, keepdim=True)
1764
+ rs[rs == 0.0] = 1.0
1765
+ self.trans.data = new_trans / rs
1766
+
1767
+ if update_emission:
1768
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
1769
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
1770
+
1771
+ self._normalize_params()
1772
+ self._normalize_emission()
1773
+
1774
+ if verbose:
1775
+ logger.info(
1776
+ "[MultiBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
1777
+ it,
1778
+ hist[-1],
1779
+ )
1780
+
1781
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
1782
+ break
1783
+
1784
+ return hist
1785
+
1786
+ def adapt_emissions(
1787
+ self,
1788
+ X: np.ndarray,
1789
+ coords: Optional[np.ndarray] = None,
1790
+ *,
1791
+ device: Optional[Union[str, torch.device]] = None,
1792
+ iters: Optional[int] = None,
1793
+ max_iter: Optional[int] = None, # alias for your trainer
1325
1794
  verbose: bool = False,
1795
+ **kwargs,
1326
1796
  ):
1797
+ """Adapt emissions with legacy parameter names.
1798
+
1799
+ Args:
1800
+ X: Observations array.
1801
+ coords: Optional coordinate array.
1802
+ device: Device specifier.
1803
+ iters: Number of iterations.
1804
+ max_iter: Alias for iters.
1805
+ verbose: Whether to log progress.
1806
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
1807
+
1808
+ Returns:
1809
+ List of log-likelihood values.
1327
1810
  """
1328
- Merge intervals in `adata.layers[layer]` that are within `distance_threshold`.
1329
- Writes new merged binary layer named f"{layer}{merged_suffix}" and length layer
1330
- f"{layer}{merged_suffix}{length_layer_suffix}". Optionally updates adata.obs for merged intervals.
1331
-
1332
- Parameters
1333
- ----------
1334
- layer : str
1335
- Name of original binary layer (0/1 mask).
1336
- distance_threshold : int
1337
- Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
1338
- merged_suffix : str
1339
- Suffix appended to original layer for the merged binary layer (default "_merged").
1340
- length_layer_suffix : str
1341
- Suffix appended after merged suffix for the lengths layer (default "_lengths").
1342
- update_obs : bool
1343
- If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
1344
- prob_strategy : str
1345
- How to combine probs when merging ('mean', 'max', 'orig_first').
1346
- inplace : bool
1347
- If False, returns a new AnnData with changes (original untouched).
1348
- overwrite : bool
1349
- If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
1350
- """
1351
- import numpy as _np
1352
- from scipy.sparse import issparse
1811
+ if iters is None:
1812
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
1813
+ return super().adapt_emissions(
1814
+ np.asarray(X, dtype=float),
1815
+ coords if coords is not None else None,
1816
+ iters=int(iters),
1817
+ device=device,
1818
+ verbose=verbose,
1819
+ )
1353
1820
 
1354
- if not inplace:
1355
- adata = adata.copy()
1821
+ def _ensure_n_channels(self, C: int, device: torch.device):
1822
+ """Expand emission parameters when channel count changes.
1356
1823
 
1357
- merged_bin_name = f"{layer}{merged_suffix}"
1358
- merged_len_name = f"{layer}{merged_suffix}{length_layer_suffix}"
1824
+ Args:
1825
+ C: Target channel count.
1826
+ device: Torch device for the new parameters.
1827
+ """
1828
+ C = int(C)
1829
+ if C == self.n_channels:
1830
+ return
1831
+ with torch.no_grad():
1832
+ old = self.emission.detach().cpu().numpy() # (K, Cold)
1833
+ K = old.shape[0]
1834
+ new = np.full((K, C), 0.5, dtype=float)
1835
+ m = min(old.shape[1], C)
1836
+ new[:, :m] = old[:, :m]
1837
+ if C > old.shape[1]:
1838
+ fill = old.mean(axis=1, keepdims=True)
1839
+ new[:, m:] = fill
1840
+ self.n_channels = C
1841
+ self.emission = nn.Parameter(
1842
+ torch.tensor(new, dtype=self.dtype, device=device), requires_grad=False
1843
+ )
1844
+ self._normalize_emission()
1845
+
1846
+
1847
+ # =============================================================================
1848
+ # Distance-binned transitions (single-channel only)
1849
+ # =============================================================================
1850
+
1851
+
1852
+ @register_hmm("single_distance_binned")
1853
+ class DistanceBinnedSingleBernoulliHMM(SingleBernoulliHMM):
1854
+ """
1855
+ Transition matrix depends on binned distances between consecutive coords.
1359
1856
 
1360
- if (merged_bin_name in adata.layers or merged_len_name in adata.layers or
1361
- (update_obs and merged_bin_name in adata.obs.columns)) and not overwrite:
1362
- raise KeyError(f"Merged outputs exist (use overwrite=True to replace): {merged_bin_name} / {merged_len_name}")
1857
+ Config keys:
1858
+ hmm_distance_bins: list[int] edges (bp)
1859
+ hmm_init_transitions_by_bin: optional (n_bins,K,K)
1860
+ """
1363
1861
 
1364
- if layer not in adata.layers:
1365
- raise KeyError(f"Layer '{layer}' not found in adata.layers")
1862
+ def __init__(
1863
+ self,
1864
+ n_states: int = 2,
1865
+ init_emission: Optional[Sequence[float]] = None,
1866
+ distance_bins: Optional[Sequence[int]] = None,
1867
+ init_trans_by_bin: Optional[Any] = None,
1868
+ eps: float = 1e-8,
1869
+ dtype: torch.dtype = torch.float64,
1870
+ ):
1871
+ """Initialize a distance-binned transition HMM.
1872
+
1873
+ Args:
1874
+ n_states: Number of hidden states.
1875
+ init_emission: Initial emission probabilities per state.
1876
+ distance_bins: Distance bin edges in base pairs.
1877
+ init_trans_by_bin: Initial transition matrices per bin.
1878
+ eps: Smoothing epsilon for probabilities.
1879
+ dtype: Torch dtype for parameters.
1880
+ """
1881
+ super().__init__(n_states=n_states, init_emission=init_emission, eps=eps, dtype=dtype)
1882
+
1883
+ self.distance_bins = np.asarray(
1884
+ distance_bins if distance_bins is not None else [1, 5, 10, 25, 50, 100], dtype=int
1885
+ )
1886
+ self.n_bins = int(len(self.distance_bins) + 1)
1366
1887
 
1367
- bin_layer = adata.layers[layer]
1368
- if issparse(bin_layer):
1369
- bin_arr = bin_layer.toarray().astype(int)
1888
+ if init_trans_by_bin is None:
1889
+ base = self.trans.detach().cpu().numpy()
1890
+ tb = np.stack([base for _ in range(self.n_bins)], axis=0)
1370
1891
  else:
1371
- bin_arr = _np.asarray(bin_layer, dtype=int)
1892
+ tb = np.asarray(init_trans_by_bin, dtype=float)
1893
+ if tb.shape != (self.n_bins, self.n_states, self.n_states):
1894
+ base = self.trans.detach().cpu().numpy()
1895
+ tb = np.stack([base for _ in range(self.n_bins)], axis=0)
1372
1896
 
1373
- n_rows, n_cols = bin_arr.shape
1897
+ self.trans_by_bin = nn.Parameter(torch.tensor(tb, dtype=self.dtype), requires_grad=False)
1898
+ self._normalize_trans_by_bin()
1374
1899
 
1375
- # coordinates in genomic units if possible
1376
- try:
1377
- coords = _np.asarray(adata.var_names, dtype=int)
1378
- coords_are_ints = True
1379
- except Exception:
1380
- coords = _np.arange(n_cols, dtype=int)
1381
- coords_are_ints = False
1382
-
1383
- # helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
1384
- def _runs_from_mask(mask_1d):
1385
- idx = _np.nonzero(mask_1d)[0]
1386
- if idx.size == 0:
1387
- return []
1388
- runs = []
1389
- start = idx[0]
1390
- prev = idx[0]
1391
- for i in idx[1:]:
1392
- if i == prev + 1:
1393
- prev = i
1394
- continue
1395
- runs.append((start, prev + 1))
1396
- start = i
1397
- prev = i
1398
- runs.append((start, prev + 1))
1399
- return runs
1400
-
1401
- # read original obs intervals/probs if available (for combining probs)
1402
- orig_obs = None
1403
- if update_obs and (layer in adata.obs.columns):
1404
- orig_obs = list(adata.obs[layer]) # might be non-list entries
1405
-
1406
- # prepare outputs
1407
- merged_bin = _np.zeros_like(bin_arr, dtype=int)
1408
- merged_len = _np.zeros_like(bin_arr, dtype=int)
1409
- merged_obs_col = [[] for _ in range(n_rows)]
1410
- merged_counts = _np.zeros(n_rows, dtype=int)
1411
-
1412
- for r in range(n_rows):
1413
- mask = bin_arr[r, :] != 0
1414
- runs = _runs_from_mask(mask)
1415
- if not runs:
1416
- merged_obs_col[r] = []
1417
- continue
1900
+ @classmethod
1901
+ def from_config(cls, cfg, *, override=None, device=None):
1902
+ """Create a distance-binned HMM from config.
1418
1903
 
1419
- # merge runs where gap <= distance_threshold (gap in genomic coords when possible)
1420
- merged_runs = []
1421
- cur_s, cur_e = runs[0]
1422
- for (s, e) in runs[1:]:
1423
- if coords_are_ints:
1424
- end_coord = int(coords[cur_e - 1])
1425
- next_start_coord = int(coords[s])
1426
- gap = next_start_coord - end_coord - 1
1427
- else:
1428
- gap = s - cur_e
1429
- if gap <= distance_threshold:
1430
- # extend
1431
- cur_e = e
1432
- else:
1433
- merged_runs.append((cur_s, cur_e))
1434
- cur_s, cur_e = s, e
1435
- merged_runs.append((cur_s, cur_e))
1436
-
1437
- # assemble merged mask/lengths and obs entries
1438
- row_entries = []
1439
- for (s_idx, e_idx) in merged_runs:
1440
- if e_idx <= s_idx:
1441
- continue
1442
- span_positions = e_idx - s_idx
1443
- if coords_are_ints:
1444
- try:
1445
- length_val = int(coords[e_idx - 1]) - int(coords[s_idx]) + 1
1446
- except Exception:
1447
- length_val = span_positions
1448
- else:
1449
- length_val = span_positions
1450
-
1451
- # set binary and length masks
1452
- merged_bin[r, s_idx:e_idx] = 1
1453
- existing_segment = merged_len[r, s_idx:e_idx]
1454
- # set to max(existing, length_val)
1455
- if existing_segment.size > 0:
1456
- merged_len[r, s_idx:e_idx] = _np.maximum(existing_segment, length_val)
1457
- else:
1458
- merged_len[r, s_idx:e_idx] = length_val
1459
-
1460
- # determine prob from overlapping original obs (if present)
1461
- prob_val = 1.0
1462
- if update_obs and orig_obs is not None:
1463
- overlaps = []
1464
- for orig in (orig_obs[r] or []):
1465
- try:
1466
- ostart, olen, opro = orig[0], int(orig[1]), float(orig[2])
1467
- except Exception:
1468
- continue
1469
- if coords_are_ints:
1470
- ostart_idx = _np.searchsorted(coords, int(ostart), side="left")
1471
- oend_idx = ostart_idx + olen
1472
- else:
1473
- ostart_idx = int(ostart)
1474
- oend_idx = ostart_idx + olen
1475
- # overlap test in index space
1476
- if not (oend_idx <= s_idx or ostart_idx >= e_idx):
1477
- overlaps.append(opro)
1478
- if overlaps:
1479
- if prob_strategy == "mean":
1480
- prob_val = float(_np.mean(overlaps))
1481
- elif prob_strategy == "max":
1482
- prob_val = float(_np.max(overlaps))
1483
- else:
1484
- prob_val = float(overlaps[0])
1485
-
1486
- start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
1487
- row_entries.append((start_coord, int(length_val), float(prob_val)))
1488
-
1489
- merged_obs_col[r] = row_entries
1490
- merged_counts[r] = len(row_entries)
1491
-
1492
- # write merged layers (do not overwrite originals unless overwrite=True was set above)
1493
- adata.layers[merged_bin_name] = merged_bin
1494
- adata.layers[merged_len_name] = merged_len
1495
-
1496
- if update_obs:
1497
- adata.obs[merged_bin_name] = merged_obs_col
1498
- adata.obs[f"n_{merged_bin_name}"] = merged_counts
1499
-
1500
- # recompute distances list per-row (gaps between adjacent merged intervals)
1501
- def _calc_distances(obs_list):
1502
- out = []
1503
- for intervals in obs_list:
1504
- if not intervals:
1505
- out.append([])
1506
- continue
1507
- iv = sorted(intervals, key=lambda x: int(x[0]))
1508
- if len(iv) <= 1:
1509
- out.append([])
1510
- continue
1511
- dlist = []
1512
- for i in range(len(iv) - 1):
1513
- endi = int(iv[i][0]) + int(iv[i][1]) - 1
1514
- startn = int(iv[i + 1][0])
1515
- dlist.append(startn - endi - 1)
1516
- out.append(dlist)
1517
- return out
1518
-
1519
- adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
1520
-
1521
- # update uns appended list
1522
- uns_key = "hmm_appended_layers"
1523
- existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
1524
- for nm in (merged_bin_name, merged_len_name):
1525
- if nm not in existing:
1526
- existing.append(nm)
1527
- adata.uns[uns_key] = existing
1528
-
1529
- if verbose:
1530
- print(f"Created merged binary layer: {merged_bin_name}")
1531
- print(f"Created merged length layer: {merged_len_name}")
1532
- if update_obs:
1533
- print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
1534
-
1535
- return None if inplace else adata
1536
-
1537
- def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
1904
+ Args:
1905
+ cfg: Configuration mapping or object.
1906
+ override: Override values to apply.
1907
+ device: Optional device specifier.
1908
+
1909
+ Returns:
1910
+ Initialized DistanceBinnedSingleBernoulliHMM instance.
1911
+ """
1912
+ merged = cls._cfg_to_dict(cfg)
1913
+ if override:
1914
+ merged.update(override)
1915
+
1916
+ n_states = int(merged.get("hmm_n_states", 2))
1917
+ eps = float(merged.get("hmm_eps", 1e-8))
1918
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1919
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1920
+ init_em = merged.get("hmm_init_emission_probs", None)
1921
+
1922
+ bins = merged.get("hmm_distance_bins", [1, 5, 10, 25, 50, 100])
1923
+ init_tb = merged.get("hmm_init_transitions_by_bin", None)
1924
+
1925
+ model = cls(
1926
+ n_states=n_states,
1927
+ init_emission=init_em,
1928
+ distance_bins=bins,
1929
+ init_trans_by_bin=init_tb,
1930
+ eps=eps,
1931
+ dtype=dtype,
1932
+ )
1933
+ if device is not None:
1934
+ model.to(torch.device(device) if isinstance(device, str) else device)
1935
+ model._persisted_cfg = merged
1936
+ return model
1937
+
1938
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1939
+ """Move transition-by-bin parameters to the requested device/dtype."""
1940
+ device = super()._ensure_device_dtype(device)
1941
+ self.trans_by_bin.data = self.trans_by_bin.data.to(device=device, dtype=self.dtype)
1942
+ return device
1943
+
1944
+ def _normalize_trans_by_bin(self):
1945
+ """Normalize transition matrices per distance bin in-place."""
1946
+ with torch.no_grad():
1947
+ tb = self.trans_by_bin.data.reshape(self.n_bins, self.n_states, self.n_states)
1948
+ tb = tb + self.eps
1949
+ rs = tb.sum(dim=2, keepdim=True)
1950
+ rs[rs == 0.0] = 1.0
1951
+ self.trans_by_bin.data = tb / rs
1952
+
1953
+ def _extra_save_payload(self) -> dict:
1954
+ """Return extra payload data for serialization."""
1955
+ p = super()._extra_save_payload()
1956
+ p.update(
1957
+ {
1958
+ "distance_bins": torch.tensor(self.distance_bins, dtype=torch.long),
1959
+ "trans_by_bin": self.trans_by_bin.detach().cpu(),
1960
+ }
1961
+ )
1962
+ return p
1963
+
1964
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1965
+ """Load serialized distance-bin parameters.
1966
+
1967
+ Args:
1968
+ payload: Serialized payload dictionary.
1969
+ device: Target torch device.
1538
1970
  """
1539
- Ensure final_adata.layers[layer_name] exists and assign rows corresponding to subset_idx_mask
1540
- sub_data has shape (n_subset_rows, n_vars).
1541
- subset_idx_mask: boolean array of length final_adata.n_obs with True where rows belong to subset.
1971
+ super()._load_extra_payload(payload, device=device)
1972
+ self.distance_bins = (
1973
+ payload.get("distance_bins", torch.tensor([1, 5, 10, 25, 50, 100]))
1974
+ .cpu()
1975
+ .numpy()
1976
+ .astype(int)
1977
+ )
1978
+ self.n_bins = int(len(self.distance_bins) + 1)
1979
+ with torch.no_grad():
1980
+ self.trans_by_bin.data = payload["trans_by_bin"].to(device=device, dtype=self.dtype)
1981
+ self._normalize_trans_by_bin()
1982
+
1983
+ def _bin_index(self, coords: np.ndarray) -> np.ndarray:
1984
+ """Return per-step distance bin indices for coordinates.
1985
+
1986
+ Args:
1987
+ coords: Coordinate array.
1988
+
1989
+ Returns:
1990
+ Array of bin indices (length L-1).
1542
1991
  """
1543
- from scipy.sparse import issparse, csr_matrix
1544
- import warnings
1992
+ d = np.diff(np.asarray(coords, dtype=int))
1993
+ return np.digitize(d, self.distance_bins, right=True) # length L-1
1545
1994
 
1546
- n_final_obs, n_vars = final_adata.shape
1547
- n_sub_rows = int(subset_idx_mask.sum())
1548
-
1549
- # prepare row indices in final_adata
1550
- final_row_indices = np.nonzero(subset_idx_mask)[0]
1551
-
1552
- # if sub_data is sparse, work with sparse
1553
- if issparse(sub_data):
1554
- sub_csr = sub_data.tocsr()
1555
- # if final layer not present, create sparse CSR with zero rows and same n_vars
1556
- if layer_name not in final_adata.layers:
1557
- # create an empty CSR of shape (n_final_obs, n_vars)
1558
- final_adata.layers[layer_name] = csr_matrix((n_final_obs, n_vars), dtype=sub_csr.dtype)
1559
- final_csr = final_adata.layers[layer_name]
1560
- if not issparse(final_csr):
1561
- # convert dense final to sparse first
1562
- final_csr = csr_matrix(final_csr)
1563
- # replace the block of rows: easiest is to build a new csr by stacking pieces
1564
- # (efficient for moderate sizes; for huge data you might want an in-place approach)
1565
- # Build list of blocks: rows before, the subset rows (from final where mask False -> zeros), rows after
1566
- # We'll convert final to LIL for row assignment (mutable), then back to CSR.
1567
- final_lil = final_csr.tolil()
1568
- for i_local, r in enumerate(final_row_indices):
1569
- final_lil.rows[r] = sub_csr.getrow(i_local).indices.tolist()
1570
- final_lil.data[r] = sub_csr.getrow(i_local).data.tolist()
1571
- final_csr = final_lil.tocsr()
1572
- final_adata.layers[layer_name] = final_csr
1573
- else:
1574
- # dense numpy array
1575
- sub_arr = np.asarray(sub_data)
1576
- if sub_arr.shape[0] != n_sub_rows:
1577
- raise ValueError(f"Sub data rows ({sub_arr.shape[0]}) != mask selected rows ({n_sub_rows})")
1578
- if layer_name not in final_adata.layers:
1579
- # create zero array with small dtype
1580
- final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
1581
- final_arr = final_adata.layers[layer_name]
1582
- if issparse(final_arr):
1583
- # convert sparse final to dense (or convert sub to sparse); we'll convert final to dense here
1584
- final_arr = final_arr.toarray()
1585
- # assign
1586
- final_arr[final_row_indices, :] = sub_arr
1587
- final_adata.layers[layer_name] = final_arr
1995
+ def _forward_backward(
1996
+ self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
1997
+ ) -> torch.Tensor:
1998
+ """Run forward-backward using distance-binned transitions.
1999
+
2000
+ Args:
2001
+ obs: Observation tensor.
2002
+ mask: Observation mask.
2003
+ coords: Coordinate array.
2004
+
2005
+ Returns:
2006
+ Posterior probabilities (gamma).
2007
+ """
2008
+ if coords is None:
2009
+ raise ValueError("Distance-binned HMM requires coords.")
2010
+ device = obs.device
2011
+ eps = float(self.eps)
2012
+ K = self.n_states
2013
+
2014
+ coords = np.asarray(coords, dtype=int)
2015
+ bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
2016
+
2017
+ logB = self._log_emission(obs, mask) # (N,L,K)
2018
+ logstart = torch.log(self.start + eps)
2019
+ logA_by_bin = torch.log(self.trans_by_bin + eps) # (nb,K,K)
2020
+
2021
+ N, L, _ = logB.shape
2022
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
2023
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2024
+
2025
+ for t in range(1, L):
2026
+ b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
2027
+ logA = logA_by_bin[b]
2028
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2029
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
2030
+
2031
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2032
+ beta[:, L - 1, :] = 0.0
2033
+ for t in range(L - 2, -1, -1):
2034
+ b = int(bins[t].item()) if t < bins.numel() else 0
2035
+ logA = logA_by_bin[b]
2036
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2037
+ beta[:, t, :] = _logsumexp(temp, dim=2)
2038
+
2039
+ log_gamma = alpha + beta
2040
+ logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
2041
+ return (log_gamma - logZ).exp()
2042
+
2043
+ def _viterbi(
2044
+ self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
2045
+ ) -> torch.Tensor:
2046
+ """Run Viterbi decoding using distance-binned transitions.
2047
+
2048
+ Args:
2049
+ obs: Observation tensor.
2050
+ mask: Observation mask.
2051
+ coords: Coordinate array.
2052
+
2053
+ Returns:
2054
+ Decoded state sequence tensor.
2055
+ """
2056
+ if coords is None:
2057
+ raise ValueError("Distance-binned HMM requires coords.")
2058
+ device = obs.device
2059
+ eps = float(self.eps)
2060
+ K = self.n_states
2061
+
2062
+ coords = np.asarray(coords, dtype=int)
2063
+ bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
2064
+
2065
+ logB = self._log_emission(obs, mask)
2066
+ logstart = torch.log(self.start + eps)
2067
+ logA_by_bin = torch.log(self.trans_by_bin + eps)
2068
+
2069
+ N, L, _ = logB.shape
2070
+ delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2071
+ psi = torch.empty((N, L, K), dtype=torch.long, device=device)
2072
+
2073
+ delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2074
+ psi[:, 0, :] = -1
2075
+
2076
+ for t in range(1, L):
2077
+ b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
2078
+ logA = logA_by_bin[b]
2079
+ cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2080
+ best_val, best_idx = cand.max(dim=1)
2081
+ delta[:, t, :] = best_val + logB[:, t, :]
2082
+ psi[:, t, :] = best_idx
2083
+
2084
+ last_state = torch.argmax(delta[:, L - 1, :], dim=1)
2085
+ states = torch.empty((N, L), dtype=torch.long, device=device)
2086
+ states[:, L - 1] = last_state
2087
+ for t in range(L - 2, -1, -1):
2088
+ states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
2089
+ return states
2090
+
2091
+ def fit_em(
2092
+ self,
2093
+ X: np.ndarray,
2094
+ coords: np.ndarray,
2095
+ *,
2096
+ device: torch.device,
2097
+ max_iter: int,
2098
+ tol: float,
2099
+ update_start: bool,
2100
+ update_trans: bool,
2101
+ update_emission: bool,
2102
+ verbose: bool,
2103
+ **kwargs,
2104
+ ) -> List[float]:
2105
+ """Run EM updates for distance-binned transitions.
2106
+
2107
+ Args:
2108
+ X: Observations array (N, L).
2109
+ coords: Coordinate array aligned to X.
2110
+ device: Torch device.
2111
+ max_iter: Maximum iterations.
2112
+ tol: Convergence tolerance.
2113
+ update_start: Whether to update start probabilities.
2114
+ update_trans: Whether to update transitions.
2115
+ update_emission: Whether to update emission parameters.
2116
+ verbose: Whether to log progress.
2117
+ **kwargs: Additional implementation-specific kwargs.
2118
+
2119
+ Returns:
2120
+ List of log-likelihood proxy values.
2121
+ """
2122
+ # Keep this simple: use gamma for emissions; transitions-by-bin updated via xi (same pattern).
2123
+ X = np.asarray(X, dtype=float)
2124
+ if X.ndim != 2:
2125
+ raise ValueError("DistanceBinnedSingleBernoulliHMM expects X shape (N,L).")
2126
+
2127
+ coords = np.asarray(coords, dtype=int)
2128
+ bins_np = self._bin_index(coords) # (L-1,)
2129
+
2130
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
2131
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
2132
+
2133
+ eps = float(self.eps)
2134
+ K = self.n_states
2135
+ N, L = obs.shape
2136
+
2137
+ hist: List[float] = []
2138
+ for it in range(1, int(max_iter) + 1):
2139
+ gamma = self._forward_backward(obs, mask, coords=coords) # (N,L,K)
2140
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
2141
+ hist.append(ll_proxy)
2142
+
2143
+ # expected start
2144
+ start_acc = gamma[:, 0, :].sum(dim=0)
2145
+
2146
+ # compute alpha/beta for xi
2147
+ logB = self._log_emission(obs, mask)
2148
+ logstart = torch.log(self.start + eps)
2149
+ logA_by_bin = torch.log(self.trans_by_bin + eps)
2150
+
2151
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
2152
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2153
+
2154
+ for t in range(1, L):
2155
+ b = int(bins_np[t - 1])
2156
+ logA = logA_by_bin[b]
2157
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2158
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
2159
+
2160
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2161
+ beta[:, L - 1, :] = 0.0
2162
+ for t in range(L - 2, -1, -1):
2163
+ b = int(bins_np[t])
2164
+ logA = logA_by_bin[b]
2165
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2166
+ beta[:, t, :] = _logsumexp(temp, dim=2)
2167
+
2168
+ trans_acc_by_bin = torch.zeros((self.n_bins, K, K), dtype=self.dtype, device=device)
2169
+ for t in range(L - 1):
2170
+ b = int(bins_np[t])
2171
+ logA = logA_by_bin[b]
2172
+ valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
2173
+ log_xi = (
2174
+ alpha[:, t, :].unsqueeze(2)
2175
+ + logA.unsqueeze(0)
2176
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2177
+ )
2178
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
2179
+ xi = (log_xi - log_norm).exp() * valid_t
2180
+ trans_acc_by_bin[b] += xi.sum(dim=0)
2181
+
2182
+ mask_f = mask.float().unsqueeze(-1)
2183
+ emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1))
2184
+ emit_den = (gamma * mask_f).sum(dim=(0, 1))
2185
+
2186
+ with torch.no_grad():
2187
+ if update_start:
2188
+ new_start = start_acc + eps
2189
+ self.start.data = new_start / new_start.sum()
2190
+
2191
+ if update_trans:
2192
+ tb = trans_acc_by_bin + eps
2193
+ rs = tb.sum(dim=2, keepdim=True)
2194
+ rs[rs == 0.0] = 1.0
2195
+ self.trans_by_bin.data = tb / rs
2196
+
2197
+ if update_emission:
2198
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
2199
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
2200
+
2201
+ self._normalize_params()
2202
+ self._normalize_emission()
2203
+ self._normalize_trans_by_bin()
2204
+
2205
+ if verbose:
2206
+ logger.info(
2207
+ "[DistanceBinnedSingle.fit] iter=%s ll_proxy=%.6f",
2208
+ it,
2209
+ hist[-1],
2210
+ )
2211
+
2212
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
2213
+ break
2214
+
2215
+ return hist
2216
+
2217
+ def adapt_emissions(
2218
+ self,
2219
+ X: np.ndarray,
2220
+ coords: Optional[np.ndarray] = None,
2221
+ *,
2222
+ device: Optional[Union[str, torch.device]] = None,
2223
+ iters: Optional[int] = None,
2224
+ max_iter: Optional[int] = None,
2225
+ verbose: bool = False,
2226
+ **kwargs,
2227
+ ):
2228
+ """Adapt emissions with legacy parameter names.
2229
+
2230
+ Args:
2231
+ X: Observations array.
2232
+ coords: Optional coordinate array.
2233
+ device: Device specifier.
2234
+ iters: Number of iterations.
2235
+ max_iter: Alias for iters.
2236
+ verbose: Whether to log progress.
2237
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
2238
+
2239
+ Returns:
2240
+ List of log-likelihood values.
2241
+ """
2242
+ if iters is None:
2243
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
2244
+ return super().adapt_emissions(
2245
+ np.asarray(X, dtype=float),
2246
+ coords if coords is not None else None,
2247
+ iters=int(iters),
2248
+ device=device,
2249
+ verbose=verbose,
2250
+ )
2251
+
2252
+
2253
+ # =============================================================================
2254
+ # Facade class to match workflow import style
2255
+ # =============================================================================
2256
+
2257
+
2258
+ class HMM:
2259
+ """
2260
+ Facade so workflow can do:
2261
+ from ..hmm.HMM import HMM
2262
+ hmm = HMM.from_config(cfg, arch="single")
2263
+ hmm.save(...)
2264
+ hmm = HMM.load(...)
2265
+ """
2266
+
2267
+ @staticmethod
2268
+ def from_config(cfg, arch: Optional[str] = None, **kwargs) -> BaseHMM:
2269
+ """Create an HMM instance from configuration.
2270
+
2271
+ Args:
2272
+ cfg: Configuration mapping or object.
2273
+ arch: Optional HMM architecture name.
2274
+ **kwargs: Additional parameters passed to the factory.
2275
+
2276
+ Returns:
2277
+ Initialized HMM instance.
2278
+ """
2279
+ return create_hmm(cfg, arch=arch, **kwargs)
2280
+
2281
+ @staticmethod
2282
+ def load(path: Union[str, Path], device: Optional[Union[str, torch.device]] = None) -> BaseHMM:
2283
+ """Load an HMM instance from disk.
2284
+
2285
+ Args:
2286
+ path: Path to the serialized model.
2287
+ device: Optional device specifier.
2288
+
2289
+ Returns:
2290
+ Loaded HMM instance.
2291
+ """
2292
+ return BaseHMM.load(path, device=device)