smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py CHANGED
@@ -1,1576 +1,2286 @@
1
- import math
2
- from typing import List, Optional, Tuple, Union, Any, Dict
1
+ from __future__ import annotations
2
+
3
3
  import ast
4
4
  import json
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
7
 
6
8
  import numpy as np
7
- import pandas as pd
8
9
  import torch
9
10
  import torch.nn as nn
11
+ from scipy.sparse import issparse
10
12
 
11
- def _logsumexp(vec: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
12
- return torch.logsumexp(vec, dim=dim, keepdim=keepdim)
13
+ from smftools.logging_utils import get_logger
13
14
 
14
- class HMM(nn.Module):
15
- """
16
- Vectorized HMM (Bernoulli emissions) implemented in PyTorch.
17
-
18
- Methods:
19
- - fit(data, ...) -> trains params in-place
20
- - predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
21
- - viterbi(seq, ...) -> (path_list, score)
22
- - batch_viterbi(data, ...) -> list of (path_list, score)
23
- - score(seq_or_list, ...) -> float or list of floats
24
-
25
- Notes:
26
- - data: list of sequences (each sequence is iterable of {0,1,np.nan}).
27
- - impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
28
- """
15
+ logger = get_logger(__name__)
16
+ # =============================================================================
17
+ # Registry / Factory
18
+ # =============================================================================
29
19
 
30
- def __init__(
31
- self,
32
- n_states: int = 2,
33
- init_start: Optional[List[float]] = None,
34
- init_trans: Optional[List[List[float]]] = None,
35
- init_emission: Optional[List[float]] = None,
36
- dtype: torch.dtype = torch.float64,
37
- eps: float = 1e-8,
38
- smf_modality: Optional[str] = None,
39
- ):
40
- super().__init__()
41
- if n_states < 2:
42
- raise ValueError("n_states must be >= 2")
43
- self.n_states = n_states
44
- self.eps = float(eps)
45
- self.dtype = dtype
46
- self.smf_modality = smf_modality
20
+ _HMM_REGISTRY: Dict[str, type] = {}
47
21
 
48
- # initialize params (probabilities)
49
- if init_start is None:
50
- start = np.full((n_states,), 1.0 / n_states, dtype=float)
51
- else:
52
- start = np.asarray(init_start, dtype=float)
53
- if init_trans is None:
54
- trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
55
- else:
56
- trans = np.asarray(init_trans, dtype=float)
57
- # --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
58
- if init_emission is None:
59
- emission = np.full((n_states,), 0.5, dtype=float)
60
- else:
61
- em_arr = np.asarray(init_emission, dtype=float)
62
- # case: (K,2) -> pick P(obs==1) from second column
63
- if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
64
- emission = em_arr[:, 1].astype(float)
65
- # case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
66
- elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
67
- emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
68
- else:
69
- emission = em_arr.reshape(-1)[:n_states].astype(float)
70
22
 
71
- # store as parameters (not trainable via grad; EM updates .data in-place)
72
- self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
73
- self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
74
- self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
23
+ def register_hmm(name: str):
24
+ """Decorator to register an HMM backend under a string key."""
75
25
 
76
- self._normalize_params()
26
+ def deco(cls):
27
+ """Register the provided class in the HMM registry."""
28
+ _HMM_REGISTRY[name] = cls
29
+ cls.hmm_name = name
30
+ return cls
77
31
 
78
- def _normalize_params(self):
79
- with torch.no_grad():
80
- # coerce shapes
81
- K = self.n_states
82
- self.start.data = self.start.data.squeeze()
83
- if self.start.data.numel() != K:
84
- self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
32
+ return deco
85
33
 
86
- self.trans.data = self.trans.data.squeeze()
87
- if not (self.trans.data.ndim == 2 and self.trans.data.shape == (K, K)):
88
- if K == 2:
89
- self.trans.data = torch.tensor([[0.9,0.1],[0.1,0.9]], dtype=self.dtype)
90
- else:
91
- self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
92
34
 
93
- self.emission.data = self.emission.data.squeeze()
94
- if self.emission.data.numel() != K:
95
- self.emission.data = torch.full((K,), 0.5, dtype=self.dtype)
35
+ def create_hmm(cfg: Union[dict, Any, None], arch: Optional[str] = None, **kwargs):
36
+ """
37
+ Factory: creates an HMM from cfg + arch (override).
38
+ """
39
+ key = (
40
+ arch
41
+ or getattr(cfg, "hmm_arch", None)
42
+ or (cfg.get("hmm_arch") if isinstance(cfg, dict) else None)
43
+ or "single"
44
+ )
45
+ if key not in _HMM_REGISTRY:
46
+ raise KeyError(f"Unknown hmm_arch={key!r}. Known: {sorted(_HMM_REGISTRY.keys())}")
47
+ return _HMM_REGISTRY[key].from_config(cfg, **kwargs)
48
+
49
+
50
+ # =============================================================================
51
+ # Small utilities
52
+ # =============================================================================
53
+ def _coerce_dtype_for_device(
54
+ dtype: torch.dtype, device: Optional[Union[str, torch.device]]
55
+ ) -> torch.dtype:
56
+ """MPS does not support float64. When targeting MPS, coerce to float32."""
57
+ dev = torch.device(device) if isinstance(device, str) else device
58
+ if dev is not None and getattr(dev, "type", None) == "mps" and dtype == torch.float64:
59
+ return torch.float32
60
+ return dtype
61
+
62
+
63
+ def _try_json_or_literal(x: Any) -> Any:
64
+ """Parse a string value as JSON or a Python literal when possible.
65
+
66
+ Args:
67
+ x: Value to parse.
68
+
69
+ Returns:
70
+ The parsed value if possible, otherwise the original value.
71
+ """
72
+ if x is None:
73
+ return None
74
+ if not isinstance(x, str):
75
+ return x
76
+ s = x.strip()
77
+ if not s:
78
+ return None
79
+ try:
80
+ return json.loads(s)
81
+ except Exception:
82
+ pass
83
+ try:
84
+ return ast.literal_eval(s)
85
+ except Exception:
86
+ return x
87
+
88
+
89
+ def _coerce_bool(x: Any) -> bool:
90
+ """Coerce a value into a boolean using common truthy strings.
91
+
92
+ Args:
93
+ x: Value to coerce.
94
+
95
+ Returns:
96
+ Boolean interpretation of the input.
97
+ """
98
+ if x is None:
99
+ return False
100
+ if isinstance(x, bool):
101
+ return x
102
+ if isinstance(x, (int, float)):
103
+ return bool(x)
104
+ s = str(x).strip().lower()
105
+ return s in ("1", "true", "t", "yes", "y", "on")
96
106
 
97
- # now perform smoothing/normalization
98
- self.start.data = (self.start.data + self.eps)
99
- self.start.data = self.start.data / self.start.data.sum()
100
107
 
101
- self.trans.data = (self.trans.data + self.eps)
102
- row_sums = self.trans.data.sum(dim=1, keepdim=True)
103
- row_sums[row_sums == 0.0] = 1.0
104
- self.trans.data = self.trans.data / row_sums
108
+ def _resolve_dtype(dtype_entry: Any) -> torch.dtype:
109
+ """Resolve a torch dtype from a config entry.
105
110
 
106
- self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
111
+ Args:
112
+ dtype_entry: Config value (string or torch.dtype).
107
113
 
108
- @staticmethod
109
- def _resolve_dtype(dtype_entry):
110
- """Accept torch.dtype, string ('float32'/'float64') or None -> torch.dtype."""
111
- if dtype_entry is None:
112
- return torch.float64
113
- if isinstance(dtype_entry, torch.dtype):
114
- return dtype_entry
115
- s = str(dtype_entry).lower()
116
- if "32" in s:
117
- return torch.float32
118
- if "16" in s:
119
- return torch.float16
114
+ Returns:
115
+ Resolved torch dtype.
116
+ """
117
+ if dtype_entry is None:
120
118
  return torch.float64
119
+ if isinstance(dtype_entry, torch.dtype):
120
+ return dtype_entry
121
+ s = str(dtype_entry).lower()
122
+ if "16" in s:
123
+ return torch.float16
124
+ if "32" in s:
125
+ return torch.float32
126
+ return torch.float64
121
127
 
122
- @classmethod
123
- def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
124
- override: Optional[dict] = None,
125
- device: Optional[Union[str, torch.device]] = None) -> "HMM":
126
- """
127
- Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
128
128
 
129
- cfg may be:
130
- - an ExperimentConfig (your dataclass instance)
131
- - a dict (e.g. loader.var_dict or merged defaults)
132
- - None (uses internal defaults)
129
+ def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
130
+ """
131
+ Try to cast var_names to int coordinates. If not possible,
132
+ fall back to 0..L-1 index coordinates.
133
+ """
134
+ try:
135
+ coords = np.asarray(var_names, dtype=int)
136
+ return coords, True
137
+ except Exception:
138
+ return np.arange(len(var_names), dtype=int), False
133
139
 
134
- override: optional dict to override resolved keys (handy for tests).
135
- device: optional device string or torch.device to move model to.
136
- """
137
- # Accept ExperimentConfig dataclass
138
- if cfg is None:
139
- merged = {}
140
- elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
141
- merged = dict(cfg.to_dict())
142
- elif isinstance(cfg, dict):
143
- merged = dict(cfg)
144
- else:
145
- # try attr access as fallback
146
- try:
147
- merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
148
- except Exception:
149
- merged = {}
150
140
 
151
- if override:
152
- merged.update(override)
141
+ def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
142
+ """Compute log-sum-exp in a numerically stable way.
153
143
 
154
- # basic resolution with fallback
155
- n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
156
- init_start = merged.get("hmm_init_start_probs", merged.get("hmm_init_start", None))
157
- init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
158
- init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
159
- eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
160
- dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
144
+ Args:
145
+ x: Input tensor.
146
+ dim: Dimension to reduce.
161
147
 
162
- # coerce lists (if present) -> numpy arrays (the HMM constructor already sanitizes)
163
- def _coerce_np(x):
164
- if x is None:
165
- return None
166
- return np.asarray(x, dtype=float)
148
+ Returns:
149
+ Reduced tensor.
150
+ """
151
+ return torch.logsumexp(x, dim=dim)
167
152
 
168
- init_start = _coerce_np(init_start)
169
- init_trans = _coerce_np(init_trans)
170
- init_emission = _coerce_np(init_emission)
171
153
 
172
- model = cls(
173
- n_states=n_states,
174
- init_start=init_start,
175
- init_trans=init_trans,
176
- init_emission=init_emission,
177
- dtype=dtype,
178
- eps=eps,
179
- smf_modality=merged.get("smf_modality", None),
180
- )
154
+ def _ensure_layer_full_shape(adata, name: str, dtype, fill_value=0):
155
+ """
156
+ Ensure adata.layers[name] exists with shape (n_obs, n_vars).
157
+ """
158
+ if name not in adata.layers:
159
+ arr = np.full((adata.n_obs, adata.n_vars), fill_value=fill_value, dtype=dtype)
160
+ adata.layers[name] = arr
161
+ else:
162
+ arr = _to_dense_np(adata.layers[name])
163
+ if arr.shape != (adata.n_obs, adata.n_vars):
164
+ raise ValueError(
165
+ f"Layer '{name}' exists but has shape {arr.shape}; expected {(adata.n_obs, adata.n_vars)}"
166
+ )
167
+ return adata.layers[name]
168
+
169
+
170
+ def _assign_back_obs(final_adata, sub_adata, cols: List[str]):
171
+ """
172
+ Assign obs columns from sub_adata back into final_adata for the matching obs_names.
173
+ Works for list/object columns too.
174
+ """
175
+ idx = final_adata.obs_names.get_indexer(sub_adata.obs_names)
176
+ if (idx < 0).any():
177
+ raise ValueError("Some sub_adata.obs_names not found in final_adata.obs_names")
181
178
 
182
- # move to device if requested
183
- if device is not None:
184
- if isinstance(device, str):
185
- device = torch.device(device)
186
- model.to(device)
179
+ for c in cols:
180
+ final_adata.obs.iloc[idx, final_adata.obs.columns.get_loc(c)] = sub_adata.obs[c].values
187
181
 
188
- # persist the config to the hmm class
189
- cls.config = cfg
190
182
 
191
- return model
183
+ def _to_dense_np(x):
184
+ """Convert sparse or array-like input to a dense NumPy array.
192
185
 
193
- def update_from_config(self, cfg: Union[dict, "ExperimentConfig", None], *,
194
- override: Optional[dict] = None):
195
- """
196
- Update existing model parameters from a config or dict (in-place).
197
- This will normalize / reinitialize start/trans/emission using same logic as constructor.
198
- """
199
- if cfg is None:
200
- merged = {}
201
- elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
202
- merged = dict(cfg.to_dict())
203
- elif isinstance(cfg, dict):
204
- merged = dict(cfg)
205
- else:
206
- try:
207
- merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
208
- except Exception:
209
- merged = {}
186
+ Args:
187
+ x: Input array or sparse matrix.
210
188
 
211
- if override:
212
- merged.update(override)
189
+ Returns:
190
+ Dense NumPy array or None.
191
+ """
192
+ if x is None:
193
+ return None
194
+ if issparse(x):
195
+ return x.toarray()
196
+ return np.asarray(x)
213
197
 
214
- # extract same keys as from_config
215
- n_states = int(merged.get("hmm_n_states", self.n_states))
216
- init_start = merged.get("hmm_init_start_probs", None)
217
- init_trans = merged.get("hmm_init_transition_probs", None)
218
- init_emission = merged.get("hmm_init_emission_probs", None)
219
- eps = merged.get("hmm_eps", None)
220
- dtype = merged.get("hmm_dtype", None)
221
-
222
- # apply dtype/eps if present
223
- if eps is not None:
224
- self.eps = float(eps)
225
- if dtype is not None:
226
- self.dtype = self._resolve_dtype(dtype)
227
-
228
- # if n_states changes we need a fresh re-init (easy approach: reconstruct)
229
- if int(n_states) != int(self.n_states):
230
- # rebuild self in-place: create a new model and copy tensors
231
- new_model = HMM.from_config(merged)
232
- # copy content
233
- with torch.no_grad():
234
- self.n_states = new_model.n_states
235
- self.eps = new_model.eps
236
- self.dtype = new_model.dtype
237
- self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
238
- self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
239
- self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
240
- return
241
198
 
242
- # else only update provided tensors
243
- def _to_tensor(obj, shape_expected=None):
244
- if obj is None:
245
- return None
246
- arr = np.asarray(obj, dtype=float)
247
- if shape_expected is not None:
248
- try:
249
- arr = arr.reshape(shape_expected)
250
- except Exception:
251
- # try to free-form slice/reshape (keep best-effort)
252
- arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
253
- return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
199
+ def _ensure_2d_np(x):
200
+ """Ensure an array is 2D, reshaping 1D inputs.
254
201
 
255
- with torch.no_grad():
256
- if init_start is not None:
257
- t = _to_tensor(init_start, (self.n_states,))
258
- if t.numel() == self.n_states:
259
- self.start.data = t.clone()
260
- if init_trans is not None:
261
- t = _to_tensor(init_trans, (self.n_states, self.n_states))
262
- if t.shape == (self.n_states, self.n_states):
263
- self.trans.data = t.clone()
264
- if init_emission is not None:
265
- # attempt to extract P(obs==1) if shaped (K,2)
266
- arr = np.asarray(init_emission, dtype=float)
267
- if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
268
- em = arr[: self.n_states, 1]
269
- else:
270
- em = arr.reshape(-1)[: self.n_states]
271
- t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
272
- if t.numel() == self.n_states:
273
- self.emission.data = t.clone()
202
+ Args:
203
+ x: Input array-like.
274
204
 
275
- # finally normalize
276
- self._normalize_params()
205
+ Returns:
206
+ 2D NumPy array.
207
+ """
208
+ x = _to_dense_np(x)
209
+ if x.ndim == 1:
210
+ x = x.reshape(1, -1)
211
+ if x.ndim != 2:
212
+ raise ValueError(f"Expected 2D array; got shape {x.shape}")
213
+ return x
277
214
 
278
- def _ensure_device_dtype(self, device: Optional[torch.device]):
279
- if device is None:
280
- device = next(self.parameters()).device
281
- self.start.data = self.start.data.to(device=device, dtype=self.dtype)
282
- self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
283
- self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
284
- return device
285
215
 
286
- @staticmethod
287
- def _pad_and_mask(
288
- data: List[List],
289
- device: torch.device,
290
- dtype: torch.dtype,
291
- impute_strategy: str = "ignore",
292
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
216
+ # =============================================================================
217
+ # Feature-set normalization
218
+ # =============================================================================
219
+
220
+
221
+ def normalize_hmm_feature_sets(raw: Any) -> Dict[str, Dict[str, Any]]:
222
+ """
223
+ Canonical format:
224
+ {
225
+ "footprints": {"state": "Non-Modified", "features": {"small_bound_stretch": [0,50], ...}},
226
+ "accessible": {"state": "Modified", "features": {"all_accessible_features": [0, inf], ...}},
227
+ ...
228
+ }
229
+ Each feature range is [lo, hi) in genomic bp (or index units if coords aren't ints).
230
+ """
231
+ parsed = _try_json_or_literal(raw)
232
+ if not isinstance(parsed, dict):
233
+ return {}
234
+
235
+ def _coerce_bound(v):
236
+ """Coerce a bound value into a float or sentinel.
237
+
238
+ Args:
239
+ v: Bound value.
240
+
241
+ Returns:
242
+ Float, np.inf, or None.
293
243
  """
294
- Pads sequences to shape (B, L). Returns (obs, mask, lengths)
295
- - Accepts: list-of-seqs, or 2D ndarray (B, L).
296
- - If a sequence element is itself an array (per-timestep feature vector),
297
- collapse the last axis by mean (warns once).
244
+ if v is None:
245
+ return None
246
+ if isinstance(v, (int, float)):
247
+ return float(v)
248
+ s = str(v).strip().lower()
249
+ if s in ("inf", "infty", "np.inf", "infinite"):
250
+ return np.inf
251
+ if s in ("none", ""):
252
+ return None
253
+ try:
254
+ return float(v)
255
+ except Exception:
256
+ return None
257
+
258
+ def _coerce_map(feats):
259
+ """Coerce feature ranges into (lo, hi) tuples.
260
+
261
+ Args:
262
+ feats: Mapping of feature names to ranges.
263
+
264
+ Returns:
265
+ Mapping of feature names to numeric bounds.
298
266
  """
299
- import warnings
300
-
301
- # If somebody passed a 2-D ndarray directly, convert to list-of-rows
302
- if isinstance(data, np.ndarray) and data.ndim == 2:
303
- # convert rows -> python lists (scalars per timestep)
304
- data = data.tolist()
305
-
306
- B = len(data)
307
- lengths = torch.tensor([len(s) for s in data], dtype=torch.long, device=device)
308
- L = int(lengths.max().item()) if B > 0 else 0
309
- obs = torch.zeros((B, L), dtype=dtype, device=device)
310
- mask = torch.zeros((B, L), dtype=torch.bool, device=device)
311
-
312
- warned_collapse = False
313
-
314
- for i, seq in enumerate(data):
315
- # seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
316
- arr = np.asarray(seq, dtype=float)
317
-
318
- # If arr is shape (L,1,1,...) squeeze trailing singletons
319
- while arr.ndim > 1 and arr.shape[-1] == 1:
320
- arr = np.squeeze(arr, axis=-1)
321
-
322
- # If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
323
- if arr.ndim > 1:
324
- if not warned_collapse:
325
- warnings.warn(
326
- "HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
327
- "(arr had shape {}). If you prefer a different reduction, "
328
- "preprocess your data.".format(arr.shape),
329
- stacklevel=2,
330
- )
331
- warned_collapse = True
332
- # collapse features -> scalar per timestep
333
- arr = np.asarray(arr, dtype=float).mean(axis=-1)
334
-
335
- # now arr should be 1D (T,)
336
- if arr.ndim == 0:
337
- # single scalar: treat as length-1 sequence
338
- arr = np.atleast_1d(arr)
339
-
340
- nan_mask = np.isnan(arr)
341
- if impute_strategy == "random" and nan_mask.any():
342
- arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
343
- local_mask = np.ones_like(arr, dtype=bool)
267
+ out = {}
268
+ if not isinstance(feats, dict):
269
+ return out
270
+ for name, rng in feats.items():
271
+ if rng is None:
272
+ out[name] = (0.0, np.inf)
273
+ continue
274
+ if isinstance(rng, (list, tuple)) and len(rng) >= 2:
275
+ lo = _coerce_bound(rng[0])
276
+ hi = _coerce_bound(rng[1])
277
+ lo = 0.0 if lo is None else float(lo)
278
+ hi = np.inf if hi is None else float(hi)
279
+ out[name] = (lo, hi)
344
280
  else:
345
- local_mask = ~nan_mask
346
- arr = np.where(local_mask, arr, 0.0)
281
+ hi = _coerce_bound(rng)
282
+ hi = np.inf if hi is None else float(hi)
283
+ out[name] = (0.0, hi)
284
+ return out
285
+
286
+ out: Dict[str, Dict[str, Any]] = {}
287
+ for group, info in parsed.items():
288
+ if isinstance(info, dict):
289
+ feats = _coerce_map(info.get("features", info.get("ranges", {})))
290
+ state = info.get("state", info.get("label", "Modified"))
291
+ else:
292
+ feats = _coerce_map(info)
293
+ state = "Modified"
294
+ out[group] = {"features": feats, "state": state}
295
+ return out
347
296
 
348
- L_i = arr.shape[0]
349
- obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
350
- mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
351
297
 
352
- return obs, mask, lengths
298
+ # =============================================================================
299
+ # BaseHMM: shared decoding + annotation pipeline
300
+ # =============================================================================
353
301
 
354
- def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
355
- """
356
- obs: (B, L)
357
- mask: (B, L) bool
358
- returns logB (B, L, K)
359
- """
360
- B, L = obs.shape
361
- p = self.emission # (K,)
362
- logp = torch.log(p + self.eps)
363
- log1mp = torch.log1p(-p + self.eps)
364
- obs_expand = obs.unsqueeze(-1) # (B, L, 1)
365
- logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
366
- logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
367
- return logB
368
302
 
369
- def fit(
370
- self,
371
- data: List[List],
372
- max_iter: int = 100,
373
- tol: float = 1e-4,
374
- impute_strategy: str = "ignore",
375
- verbose: bool = True,
376
- return_history: bool = False,
377
- device: Optional[Union[torch.device, str]] = None,
378
- ):
379
- """
380
- Vectorized Baum-Welch EM across a batch of sequences (padded).
381
- """
382
- if device is None:
383
- device = next(self.parameters()).device
384
- elif isinstance(device, str):
385
- device = torch.device(device)
386
- device = self._ensure_device_dtype(device)
303
+ class BaseHMM(nn.Module):
304
+ """
305
+ BaseHMM responsibilities:
306
+ - config resolution (from_config)
307
+ - EM fit wrapper (fit / fit_em)
308
+ - decoding (gamma / viterbi)
309
+ - AnnData annotation from provided arrays (X + coords)
310
+ - save/load registry aware
311
+ Subclasses implement:
312
+ - _log_emission(...) -> logB
313
+ - optional distance-aware transition handling
314
+ """
387
315
 
388
- if isinstance(data, np.ndarray):
389
- if data.ndim == 2:
390
- # rows are sequences: convert to list of 1D arrays
391
- data = data.tolist()
392
- elif data.ndim == 1:
393
- # single sequence
394
- data = [data.tolist()]
395
- else:
396
- raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
316
+ def __init__(self, n_states: int = 2, eps: float = 1e-8, dtype: torch.dtype = torch.float64):
317
+ """Initialize the base HMM with shared parameters.
397
318
 
398
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
399
- B, L = obs.shape
400
- K = self.n_states
401
- eps = float(self.eps)
319
+ Args:
320
+ n_states: Number of hidden states.
321
+ eps: Smoothing epsilon for probabilities.
322
+ dtype: Torch dtype for parameters.
323
+ """
324
+ super().__init__()
325
+ if n_states < 2:
326
+ raise ValueError("n_states must be >= 2")
327
+ self.n_states = int(n_states)
328
+ self.eps = float(eps)
329
+ self.dtype = dtype
402
330
 
403
- if verbose:
404
- print(f"[HMM.fit] device={device}, batch={B}, max_len={L}, states={K}")
331
+ # start probs + transitions (shared across backends)
332
+ start = np.full((self.n_states,), 1.0 / self.n_states, dtype=float)
333
+ trans = np.full((self.n_states, self.n_states), 1.0 / self.n_states, dtype=float)
405
334
 
406
- loglik_history = []
335
+ self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
336
+ self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
337
+ self._normalize_params()
407
338
 
408
- for it in range(1, max_iter + 1):
409
- if verbose:
410
- print(f"[HMM.fit] EM iter {it}")
339
+ # ------------------------- config -------------------------
411
340
 
412
- # compute batched emission logs
413
- logB = self._log_emission(obs, mask) # (B, L, K)
341
+ @classmethod
342
+ def from_config(
343
+ cls, cfg: Union[dict, Any, None], *, override: Optional[dict] = None, device=None
344
+ ):
345
+ """Create a model from config with optional overrides.
414
346
 
415
- # logs for start and transition
416
- logA = torch.log(self.trans + eps) # (K, K)
417
- logstart = torch.log(self.start + eps) # (K,)
347
+ Args:
348
+ cfg: Configuration mapping or object.
349
+ override: Override values to apply.
350
+ device: Device specifier.
418
351
 
419
- # Forward (batched)
420
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
421
- alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :] # (B,K)
422
- for t in range(1, L):
423
- # prev: (B, i, 1) + (1, i, j) broadcast => (B, i, j)
424
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
425
- alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
352
+ Returns:
353
+ Initialized HMM instance.
354
+ """
355
+ merged = cls._cfg_to_dict(cfg)
356
+ if override:
357
+ merged.update(override)
426
358
 
427
- # Backward (batched)
428
- beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
429
- beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
430
- for t in range(L - 2, -1, -1):
431
- # temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
432
- temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
433
- beta[:, t, :] = _logsumexp(temp, dim=2)
359
+ n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
360
+ eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
361
+ dtype = _resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
362
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
434
363
 
435
- # sequence log-likelihoods (use last real index)
436
- last_idx = (lengths - 1).clamp(min=0)
437
- idx_range = torch.arange(B, device=device)
438
- final_alpha = alpha[idx_range, last_idx, :] # (B, K)
439
- seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
440
- total_loglike = float(seq_loglikes.sum().item())
364
+ model = cls(n_states=n_states, eps=eps, dtype=dtype)
365
+ if device is not None:
366
+ model.to(torch.device(device) if isinstance(device, str) else device)
367
+ model._persisted_cfg = merged
368
+ return model
441
369
 
442
- # posterior gamma (B, L, K)
443
- log_gamma = alpha + beta # (B, L, K)
444
- logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True) # (B, L, 1)
445
- gamma = (log_gamma - logZ_time).exp() # (B, L, K)
370
+ @staticmethod
371
+ def _cfg_to_dict(cfg: Union[dict, Any, None]) -> dict:
372
+ """Normalize a config object into a dictionary.
446
373
 
447
- # accumulators: starts, transitions, emissions
448
- gamma_start_accum = gamma[:, 0, :].sum(dim=0) # (K,)
374
+ Args:
375
+ cfg: Config mapping or object.
449
376
 
450
- # emission accumulators: sum over observed positions only
451
- mask_f = mask.unsqueeze(-1) # (B, L, 1)
452
- emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
453
- emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
377
+ Returns:
378
+ Dictionary of HMM-related config values.
379
+ """
380
+ if cfg is None:
381
+ return {}
382
+ if isinstance(cfg, dict):
383
+ return dict(cfg)
384
+ if hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
385
+ return dict(cfg.to_dict())
386
+ out = {}
387
+ for k in dir(cfg):
388
+ if k.startswith("hmm_") or k in ("smf_modality", "cpg"):
389
+ try:
390
+ out[k] = getattr(cfg, k)
391
+ except Exception:
392
+ pass
393
+ return out
454
394
 
455
- # transitions: accumulate xi across t for valid positions
456
- trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
457
- if L >= 2:
458
- time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
459
- valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
460
- for t in range(L - 1):
461
- a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
462
- b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
463
- log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
464
- log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
465
- log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
466
- xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
467
- valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
468
- xi_masked = xi * valid_t
469
- trans_accum += xi_masked.sum(dim=0) # (i,j)
470
-
471
- # M-step: update parameters with smoothing
472
- with torch.no_grad():
473
- new_start = gamma_start_accum + eps
474
- new_start = new_start / new_start.sum()
395
+ # ------------------------- params -------------------------
475
396
 
476
- new_trans = trans_accum + eps
477
- row_sums = new_trans.sum(dim=1, keepdim=True)
478
- row_sums[row_sums == 0.0] = 1.0
479
- new_trans = new_trans / row_sums
397
+ def _normalize_params(self):
398
+ """Normalize start and transition probabilities in-place."""
399
+ with torch.no_grad():
400
+ K = self.n_states
480
401
 
481
- new_emission = (emit_num + eps) / (emit_den + 2.0 * eps)
482
- new_emission = new_emission.clamp(min=eps, max=1.0 - eps)
402
+ # start
403
+ self.start.data = self.start.data.reshape(-1)
404
+ if self.start.data.numel() != K:
405
+ self.start.data = torch.full(
406
+ (K,), 1.0 / K, dtype=self.dtype, device=self.start.device
407
+ )
408
+ self.start.data = self.start.data + self.eps
409
+ self.start.data = self.start.data / self.start.data.sum()
483
410
 
484
- self.start.data = new_start
485
- self.trans.data = new_trans
486
- self.emission.data = new_emission
411
+ # trans
412
+ self.trans.data = self.trans.data.reshape(K, K)
413
+ self.trans.data = self.trans.data + self.eps
414
+ rs = self.trans.data.sum(dim=1, keepdim=True)
415
+ rs[rs == 0.0] = 1.0
416
+ self.trans.data = self.trans.data / rs
487
417
 
488
- loglik_history.append(total_loglike)
489
- if verbose:
490
- print(f" total loglik = {total_loglike:.6f}")
418
+ def _ensure_device_dtype(
419
+ self, device: Optional[Union[str, torch.device]] = None
420
+ ) -> torch.device:
421
+ """Move parameters to the requested device/dtype.
491
422
 
492
- if len(loglik_history) > 1 and abs(loglik_history[-1] - loglik_history[-2]) < tol:
493
- if verbose:
494
- print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
495
- break
423
+ Args:
424
+ device: Device specifier or None to use current device.
496
425
 
497
- return loglik_history if return_history else None
498
-
499
- def get_params(self) -> dict:
426
+ Returns:
427
+ Resolved torch device.
500
428
  """
501
- Return model parameters as numpy arrays on CPU.
502
- """
503
- with torch.no_grad():
504
- return {
505
- "n_states": int(self.n_states),
506
- "start": self.start.detach().cpu().numpy().astype(float).reshape(-1),
507
- "trans": self.trans.detach().cpu().numpy().astype(float),
508
- "emission": self.emission.detach().cpu().numpy().astype(float).reshape(-1),
509
- }
429
+ if device is None:
430
+ device = next(self.parameters()).device
431
+ device = torch.device(device) if isinstance(device, str) else device
432
+ self.start.data = self.start.data.to(device=device, dtype=self.dtype)
433
+ self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
434
+ return device
435
+
436
+ # ------------------------- state labeling -------------------------
437
+
438
+ def _state_modified_score(self) -> torch.Tensor:
439
+ """Subclasses return (K,) score; higher => more “Modified/Accessible”."""
440
+ raise NotImplementedError
510
441
 
511
- def print_params(self, decimals: int = 4):
442
+ def modified_state_index(self) -> int:
443
+ """Return the index of the most modified/accessible state."""
444
+ scores = self._state_modified_score()
445
+ return int(torch.argmax(scores).item())
446
+
447
+ def resolve_target_state_index(self, state_target: Any) -> int:
512
448
  """
513
- Nicely print start, transition, and emission probabilities.
449
+ Accept:
450
+ - int -> explicit state index
451
+ - "Modified" / "Non-Modified" and aliases
514
452
  """
515
- params = self.get_params()
516
- K = params["n_states"]
517
- fmt = f"{{:.{decimals}f}}"
518
- print(f"HMM params (K={K} states):")
519
- print(" start probs:")
520
- print(" [" + ", ".join(fmt.format(v) for v in params["start"]) + "]")
521
- print(" transition matrix (rows = from-state, cols = to-state):")
522
- for i, row in enumerate(params["trans"]):
523
- print(" s{:d}: [".format(i) + ", ".join(fmt.format(v) for v in row) + "]")
524
- print(" emission P(obs==1 | state):")
525
- for i, v in enumerate(params["emission"]):
526
- print(f" s{i}: {fmt.format(v)}")
527
-
528
- def to_dataframes(self) -> dict:
453
+ if isinstance(state_target, (int, np.integer)):
454
+ idx = int(state_target)
455
+ return max(0, min(idx, self.n_states - 1))
456
+
457
+ s = str(state_target).strip().lower()
458
+ if s in ("modified", "open", "accessible", "1", "pos", "positive"):
459
+ return self.modified_state_index()
460
+ if s in ("non-modified", "closed", "inaccessible", "0", "neg", "negative"):
461
+ scores = self._state_modified_score()
462
+ return int(torch.argmin(scores).item())
463
+ return self.modified_state_index()
464
+
465
+ # ------------------------- emissions -------------------------
466
+
467
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
529
468
  """
530
- Return pandas DataFrames for start (Series), trans (DataFrame), emission (Series).
469
+ Return logB:
470
+ - single: obs (N,L), mask (N,L) -> logB (N,L,K)
471
+ - multi : obs (N,L,C), mask (N,L,C) -> logB (N,L,K)
531
472
  """
532
- p = self.get_params()
533
- K = p["n_states"]
534
- state_names = [f"state_{i}" for i in range(K)]
535
- start_s = pd.Series(p["start"], index=state_names, name="start_prob")
536
- trans_df = pd.DataFrame(p["trans"], index=state_names, columns=state_names)
537
- emission_s = pd.Series(p["emission"], index=state_names, name="p_obs1")
538
- return {"start": start_s, "trans": trans_df, "emission": emission_s}
539
-
540
- def predict(self, data: List[List], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> List[np.ndarray]:
473
+ raise NotImplementedError
474
+
475
+ # ------------------------- decoding core -------------------------
476
+
477
+ def _forward_backward(
478
+ self,
479
+ obs: torch.Tensor,
480
+ mask: torch.Tensor,
481
+ *,
482
+ coords: Optional[np.ndarray] = None,
483
+ ) -> torch.Tensor:
541
484
  """
542
- Return posterior marginals gamma_t(k) for each sequence as list of (L, K) numpy arrays.
485
+ Returns gamma (N,L,K) in probability space.
486
+ Subclasses can override for distance-aware transitions.
543
487
  """
544
- if device is None:
545
- device = next(self.parameters()).device
546
- elif isinstance(device, str):
547
- device = torch.device(device)
548
- device = self._ensure_device_dtype(device)
549
-
550
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
551
- B, L = obs.shape
552
- K = self.n_states
488
+ device = obs.device
553
489
  eps = float(self.eps)
490
+ K = self.n_states
554
491
 
555
- logB = self._log_emission(obs, mask) # (B, L, K)
556
- logA = torch.log(self.trans + eps)
557
- logstart = torch.log(self.start + eps)
492
+ logB = self._log_emission(obs, mask) # (N,L,K)
493
+ logA = torch.log(self.trans + eps) # (K,K)
494
+ logstart = torch.log(self.start + eps) # (K,)
495
+
496
+ N, L, _ = logB.shape
558
497
 
559
- # Forward
560
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
498
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
561
499
  alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
500
+
562
501
  for t in range(1, L):
563
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
502
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
564
503
  alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
565
504
 
566
- # Backward
567
- beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
568
- beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
505
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
506
+ beta[:, L - 1, :] = 0.0
569
507
  for t in range(L - 2, -1, -1):
570
- temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
508
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
571
509
  beta[:, t, :] = _logsumexp(temp, dim=2)
572
510
 
573
- # gamma
574
511
  log_gamma = alpha + beta
575
- logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True)
576
- gamma = (log_gamma - logZ_time).exp() # (B, L, K)
512
+ logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
513
+ gamma = (log_gamma - logZ).exp()
514
+ return gamma
577
515
 
578
- results = []
579
- for i in range(B):
580
- L_i = int(lengths[i].item())
581
- results.append(gamma[i, :L_i, :].detach().cpu().numpy())
582
- return results
583
-
584
- def score(self, seq_or_list: Union[List[float], List[List[float]]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Union[float, List[float]]:
516
+ def _viterbi(
517
+ self,
518
+ obs: torch.Tensor,
519
+ mask: torch.Tensor,
520
+ *,
521
+ coords: Optional[np.ndarray] = None,
522
+ ) -> torch.Tensor:
585
523
  """
586
- Compute log-likelihood of a single sequence or list of sequences under current params.
587
- Returns float (single) or list of floats (batch).
524
+ Returns states (N,L) int64. Missing positions (mask False for all channels)
525
+ are still decoded, but you’ll overwrite them to -1 during writing.
526
+ Subclasses can override for distance-aware transitions.
588
527
  """
589
- single = False
590
- if not isinstance(seq_or_list[0], (list, tuple, np.ndarray)):
591
- seqs = [seq_or_list]
592
- single = True
593
- else:
594
- seqs = seq_or_list
528
+ device = obs.device
529
+ eps = float(self.eps)
530
+ K = self.n_states
595
531
 
596
- if device is None:
597
- device = next(self.parameters()).device
598
- elif isinstance(device, str):
599
- device = torch.device(device)
600
- device = self._ensure_device_dtype(device)
532
+ logB = self._log_emission(obs, mask) # (N,L,K)
533
+ logA = torch.log(self.trans + eps) # (K,K)
534
+ logstart = torch.log(self.start + eps) # (K,)
601
535
 
602
- obs, mask, lengths = self._pad_and_mask(seqs, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
603
- B, L = obs.shape
604
- K = self.n_states
605
- eps = float(self.eps)
536
+ N, L, _ = logB.shape
537
+ delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
538
+ psi = torch.empty((N, L, K), dtype=torch.long, device=device)
606
539
 
607
- logB = self._log_emission(obs, mask)
608
- logA = torch.log(self.trans + eps)
609
- logstart = torch.log(self.start + eps)
540
+ delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
541
+ psi[:, 0, :] = -1
610
542
 
611
- alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
612
- alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
613
543
  for t in range(1, L):
614
- prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
615
- alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
544
+ cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (N,K,K)
545
+ best_val, best_idx = cand.max(dim=1)
546
+ delta[:, t, :] = best_val + logB[:, t, :]
547
+ psi[:, t, :] = best_idx
616
548
 
617
- last_idx = (lengths - 1).clamp(min=0)
618
- idx_range = torch.arange(B, device=device)
619
- final_alpha = alpha[idx_range, last_idx, :] # (B, K)
620
- seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
621
- seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
622
- return seq_loglikes[0] if single else seq_loglikes
549
+ last_state = torch.argmax(delta[:, L - 1, :], dim=1) # (N,)
550
+ states = torch.empty((N, L), dtype=torch.long, device=device)
623
551
 
624
- def viterbi(self, seq: List[float], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[int], float]:
625
- """
626
- Viterbi decode a single sequence. Returns (state_path, log_probability_of_path).
627
- """
628
- paths, scores = self.batch_viterbi([seq], impute_strategy=impute_strategy, device=device)
629
- return paths[0], scores[0]
552
+ states[:, L - 1] = last_state
553
+ for t in range(L - 2, -1, -1):
554
+ states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
555
+
556
+ return states
557
+
558
+ def decode(
559
+ self,
560
+ X: np.ndarray,
561
+ coords: Optional[np.ndarray] = None,
562
+ *,
563
+ decode: str = "marginal",
564
+ device: Optional[Union[str, torch.device]] = None,
565
+ ) -> Tuple[np.ndarray, np.ndarray]:
566
+ """Decode observations into state calls and posterior probabilities.
630
567
 
631
- def batch_viterbi(self, data: List[List[float]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[List[int]], List[float]]:
568
+ Args:
569
+ X: Observations array (N, L) or (N, L, C).
570
+ coords: Optional coordinates aligned to L.
571
+ decode: Decoding strategy ("marginal" or "viterbi").
572
+ device: Device specifier.
573
+
574
+ Returns:
575
+ Tuple of (states, posterior probabilities).
632
576
  """
633
- Batched Viterbi decoding on padded sequences. Returns (list_of_paths, list_of_scores).
634
- Each path is the length of the original sequence.
577
+ device = self._ensure_device_dtype(device)
578
+
579
+ X = np.asarray(X, dtype=float)
580
+ if X.ndim == 2:
581
+ L = X.shape[1]
582
+ elif X.ndim == 3:
583
+ L = X.shape[1]
584
+ else:
585
+ raise ValueError(f"X must be 2D or 3D; got shape {X.shape}")
586
+
587
+ if coords is None:
588
+ coords = np.arange(L, dtype=int)
589
+ coords = np.asarray(coords, dtype=int)
590
+
591
+ if X.ndim == 2:
592
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
593
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
594
+ else:
595
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
596
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
597
+
598
+ gamma = self._forward_backward(obs, mask, coords=coords)
599
+
600
+ if str(decode).lower() == "viterbi":
601
+ st = self._viterbi(obs, mask, coords=coords)
602
+ else:
603
+ st = torch.argmax(gamma, dim=2)
604
+
605
+ return st.detach().cpu().numpy(), gamma.detach().cpu().numpy()
606
+
607
+ # ------------------------- EM fit -------------------------
608
+
609
+ def fit(
610
+ self,
611
+ X: np.ndarray,
612
+ coords: Optional[np.ndarray] = None,
613
+ *,
614
+ max_iter: int = 50,
615
+ tol: float = 1e-4,
616
+ device: Optional[Union[str, torch.device]] = None,
617
+ update_start: bool = True,
618
+ update_trans: bool = True,
619
+ update_emission: bool = True,
620
+ verbose: bool = False,
621
+ **kwargs,
622
+ ) -> List[float]:
623
+ """Fit HMM parameters using EM.
624
+
625
+ Args:
626
+ X: Observations array.
627
+ coords: Optional coordinate array.
628
+ max_iter: Maximum EM iterations.
629
+ tol: Convergence tolerance.
630
+ device: Device specifier.
631
+ update_start: Whether to update start probabilities.
632
+ update_trans: Whether to update transition probabilities.
633
+ update_emission: Whether to update emission parameters.
634
+ verbose: Whether to log progress.
635
+ **kwargs: Additional implementation-specific kwargs.
636
+
637
+ Returns:
638
+ List of log-likelihood values across iterations.
635
639
  """
636
- if device is None:
637
- device = next(self.parameters()).device
638
- elif isinstance(device, str):
639
- device = torch.device(device)
640
+ X = np.asarray(X, dtype=float)
641
+ if X.ndim not in (2, 3):
642
+ raise ValueError(f"X must be 2D or 3D; got {X.shape}")
643
+ L = X.shape[1]
644
+
645
+ if coords is None:
646
+ coords = np.arange(L, dtype=int)
647
+ coords = np.asarray(coords, dtype=int)
648
+
640
649
  device = self._ensure_device_dtype(device)
650
+ return self.fit_em(
651
+ X,
652
+ coords,
653
+ device=device,
654
+ max_iter=max_iter,
655
+ tol=tol,
656
+ update_start=update_start,
657
+ update_trans=update_trans,
658
+ update_emission=update_emission,
659
+ verbose=verbose,
660
+ **kwargs,
661
+ )
641
662
 
642
- obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
643
- B, L = obs.shape
644
- K = self.n_states
645
- eps = float(self.eps)
663
+ def adapt_emissions(
664
+ self,
665
+ X: np.ndarray,
666
+ coords: np.ndarray,
667
+ *,
668
+ iters: int = 5,
669
+ device: Optional[Union[str, torch.device]] = None,
670
+ freeze_start: bool = True,
671
+ freeze_trans: bool = True,
672
+ verbose: bool = False,
673
+ **kwargs,
674
+ ) -> List[float]:
675
+ """Adapt emission parameters while keeping shared structure fixed.
676
+
677
+ Args:
678
+ X: Observations array.
679
+ coords: Coordinate array aligned to X.
680
+ iters: Number of EM iterations.
681
+ device: Device specifier.
682
+ freeze_start: Whether to freeze start probabilities.
683
+ freeze_trans: Whether to freeze transitions.
684
+ verbose: Whether to log progress.
685
+ **kwargs: Additional implementation-specific kwargs.
686
+
687
+ Returns:
688
+ List of log-likelihood values across iterations.
689
+ """
690
+ return self.fit(
691
+ X,
692
+ coords,
693
+ max_iter=int(iters),
694
+ tol=0.0,
695
+ device=device,
696
+ update_start=not freeze_start,
697
+ update_trans=not freeze_trans,
698
+ update_emission=True,
699
+ verbose=verbose,
700
+ **kwargs,
701
+ )
646
702
 
647
- p = self.emission
648
- logp = torch.log(p + eps)
649
- log1mp = torch.log1p(-p + eps)
650
- logB = obs.unsqueeze(-1) * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs.unsqueeze(-1)) * log1mp.unsqueeze(0).unsqueeze(0)
651
- logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
703
+ def fit_em(
704
+ self,
705
+ X: np.ndarray,
706
+ coords: np.ndarray,
707
+ *,
708
+ device: torch.device,
709
+ max_iter: int,
710
+ tol: float,
711
+ update_start: bool,
712
+ update_trans: bool,
713
+ update_emission: bool,
714
+ verbose: bool,
715
+ **kwargs,
716
+ ) -> List[float]:
717
+ """Run the core EM update loop (subclasses implement).
718
+
719
+ Args:
720
+ X: Observations array.
721
+ coords: Coordinate array aligned to X.
722
+ device: Torch device.
723
+ max_iter: Maximum iterations.
724
+ tol: Convergence tolerance.
725
+ update_start: Whether to update start probabilities.
726
+ update_trans: Whether to update transitions.
727
+ update_emission: Whether to update emission parameters.
728
+ verbose: Whether to log progress.
729
+ **kwargs: Additional subclass-specific kwargs.
730
+
731
+ Returns:
732
+ List of log-likelihood values across iterations.
733
+ """
734
+ raise NotImplementedError
652
735
 
653
- logstart = torch.log(self.start + eps)
654
- logA = torch.log(self.trans + eps)
736
+ # ------------------------- save/load -------------------------
655
737
 
656
- # delta (score) and psi (argmax pointers)
657
- delta = torch.empty((B, L, K), dtype=self.dtype, device=device)
658
- psi = torch.zeros((B, L, K), dtype=torch.long, device=device)
738
+ def _extra_save_payload(self) -> dict:
739
+ """Return extra model state to include when saving."""
740
+ return {}
659
741
 
660
- delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
661
- psi[:, 0, :] = -1 # sentinel
742
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
743
+ """Load extra model state saved by subclasses.
662
744
 
663
- for t in range(1, L):
664
- # cand shape (B, i, j)
665
- cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
666
- best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
667
- delta[:, t, :] = best_val + logB[:, t, :]
668
- psi[:, t, :] = best_idx # best previous state index for each (B, j)
669
-
670
- # backtrack
671
- last_idx = (lengths - 1).clamp(min=0)
672
- idx_range = torch.arange(B, device=device)
673
- final_delta = delta[idx_range, last_idx, :] # (B, K)
674
- best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
675
- paths = []
676
- scores = []
677
- for b in range(B):
678
- Lb = int(lengths[b].item())
679
- if Lb == 0:
680
- paths.append([])
681
- scores.append(float("-inf"))
682
- continue
683
- s = int(best_last_state[b].item())
684
- path = [s]
685
- for t in range(Lb - 1, 0, -1):
686
- s = int(psi[b, t, s].item())
687
- path.append(s)
688
- path.reverse()
689
- paths.append(path)
690
- scores.append(float(best_last_val[b].item()))
691
- return paths, scores
692
-
693
- def save(self, path: str) -> None:
745
+ Args:
746
+ payload: Serialized model payload.
747
+ device: Torch device for tensors.
694
748
  """
695
- Save HMM to `path` using torch.save. Stores:
696
- - n_states, eps, dtype (string)
697
- - start, trans, emission (CPU tensors)
749
+ return
750
+
751
+ def save(self, path: Union[str, Path]) -> None:
752
+ """Serialize the model to disk.
753
+
754
+ Args:
755
+ path: Output path for the serialized model.
698
756
  """
757
+ path = str(path)
699
758
  payload = {
759
+ "hmm_name": getattr(self, "hmm_name", self.__class__.__name__),
760
+ "class": self.__class__.__name__,
700
761
  "n_states": int(self.n_states),
701
762
  "eps": float(self.eps),
702
- # store dtype as a string like "torch.float64" (portable)
703
763
  "dtype": str(self.dtype),
704
764
  "start": self.start.detach().cpu(),
705
765
  "trans": self.trans.detach().cpu(),
706
- "emission": self.emission.detach().cpu(),
707
766
  }
767
+ payload.update(self._extra_save_payload())
708
768
  torch.save(payload, path)
709
769
 
710
770
  @classmethod
711
- def load(cls, path: str, device: Optional[Union[torch.device, str]] = None) -> "HMM":
712
- """
713
- Load model from `path`. If `device` is provided (str or torch.device),
714
- parameters will be moved to that device; otherwise they remain on CPU.
715
- Example: model = HMM.load('hmm.pt', device='cuda')
771
+ def load(cls, path: Union[str, Path], device: Optional[Union[str, torch.device]] = None):
772
+ """Load a serialized model from disk.
773
+
774
+ Args:
775
+ path: Path to the serialized model.
776
+ device: Optional device specifier.
777
+
778
+ Returns:
779
+ Loaded HMM instance.
716
780
  """
717
- payload = torch.load(path, map_location="cpu")
781
+ payload = torch.load(str(path), map_location="cpu")
782
+ hmm_name = payload.get("hmm_name", None)
783
+ klass = _HMM_REGISTRY.get(hmm_name, cls)
784
+
785
+ dtype_str = str(payload.get("dtype", "torch.float64"))
786
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1], torch.float64)
787
+ torch_dtype = _coerce_dtype_for_device(torch_dtype, device) # <<< NEW
788
+
789
+ model = klass(
790
+ n_states=int(payload["n_states"]),
791
+ eps=float(payload.get("eps", 1e-8)),
792
+ dtype=torch_dtype,
793
+ )
794
+ dev = torch.device(device) if isinstance(device, str) else (device or torch.device("cpu"))
795
+ model.to(dev)
718
796
 
719
- n_states = int(payload.get("n_states"))
720
- eps = float(payload.get("eps", 1e-8))
721
- dtype_entry = payload.get("dtype", "torch.float64")
797
+ with torch.no_grad():
798
+ model.start.data = payload["start"].to(device=dev, dtype=model.dtype)
799
+ model.trans.data = payload["trans"].to(device=dev, dtype=model.dtype)
722
800
 
723
- # Resolve dtype string robustly:
724
- # Accept "torch.float64" or "float64" or actual torch.dtype (older payloads)
725
- if isinstance(dtype_entry, torch.dtype):
726
- torch_dtype = dtype_entry
727
- else:
728
- # dtype_entry expected to be a string
729
- dtype_str = str(dtype_entry)
730
- # take last part after dot if present: "torch.float64" -> "float64"
731
- name = dtype_str.split(".")[-1]
732
- # map to torch dtype if available, else fallback mapping
733
- if hasattr(torch, name):
734
- torch_dtype = getattr(torch, name)
735
- else:
736
- fallback = {"float64": torch.float64, "float32": torch.float32, "float16": torch.float16}
737
- torch_dtype = fallback.get(name, torch.float64)
801
+ model._load_extra_payload(payload, device=dev)
802
+ model._normalize_params()
803
+ return model
738
804
 
739
- # Build instance (use resolved dtype)
740
- model = cls(n_states=n_states, dtype=torch_dtype, eps=eps)
805
+ # ------------------------- interval helpers -------------------------
741
806
 
742
- # Determine target device
743
- if device is None:
744
- device = torch.device("cpu")
745
- elif isinstance(device, str):
746
- device = torch.device(device)
807
+ @staticmethod
808
+ def _runs_from_bool(mask_1d: np.ndarray) -> List[Tuple[int, int]]:
809
+ """
810
+ Return runs as (start_idx, end_idx_exclusive) for True segments.
811
+ """
812
+ idx = np.nonzero(mask_1d)[0]
813
+ if idx.size == 0:
814
+ return []
815
+ breaks = np.where(np.diff(idx) > 1)[0]
816
+ starts = np.r_[idx[0], idx[breaks + 1]]
817
+ ends = np.r_[idx[breaks] + 1, idx[-1] + 1]
818
+ return list(zip(starts, ends))
747
819
 
748
- # Load params (they were saved on CPU) and cast to model dtype/device
749
- with torch.no_grad():
750
- model.start.data = payload["start"].to(device=device, dtype=model.dtype)
751
- model.trans.data = payload["trans"].to(device=device, dtype=model.dtype)
752
- model.emission.data = payload["emission"].to(device=device, dtype=model.dtype)
820
+ @staticmethod
821
+ def _interval_length(coords: np.ndarray, s: int, e: int) -> int:
822
+ """Genomic length for [s,e) on coords."""
823
+ if e <= s:
824
+ return 0
825
+ return int(coords[e - 1]) - int(coords[s]) + 1
753
826
 
754
- # Normalize / coerce shapes just in case
755
- model._normalize_params()
756
- return model
757
-
758
- def annotate_adata(
827
+ @staticmethod
828
+ def _write_lengths_for_binary_layer(bin_mat: np.ndarray) -> np.ndarray:
829
+ """
830
+ For each row, each True-run gets its run-length assigned across that run.
831
+ Output same shape as bin_mat, int32.
832
+ """
833
+ n, L = bin_mat.shape
834
+ out = np.zeros((n, L), dtype=np.int32)
835
+ for i in range(n):
836
+ runs = BaseHMM._runs_from_bool(bin_mat[i].astype(bool))
837
+ for s, e in runs:
838
+ out[i, s:e] = e - s
839
+ return out
840
+
841
+ @staticmethod
842
+ def _write_lengths_for_state_layer(states: np.ndarray) -> np.ndarray:
843
+ """
844
+ For each row, each constant-state run gets run-length assigned across run.
845
+ Missing values should be -1 and will get 0 length.
846
+ """
847
+ n, L = states.shape
848
+ out = np.zeros((n, L), dtype=np.int32)
849
+ for i in range(n):
850
+ row = states[i]
851
+ valid = row >= 0
852
+ if not np.any(valid):
853
+ continue
854
+ # scan runs
855
+ s = 0
856
+ while s < L:
857
+ if row[s] < 0:
858
+ s += 1
859
+ continue
860
+ v = row[s]
861
+ e = s + 1
862
+ while e < L and row[e] == v:
863
+ e += 1
864
+ out[i, s:e] = e - s
865
+ s = e
866
+ return out
867
+
868
+ # ------------------------- merging -------------------------
869
+
870
+ def merge_intervals_to_new_layer(
759
871
  self,
760
872
  adata,
761
- obs_column: str,
762
- layer: Optional[str] = None,
763
- footprints: Optional[bool] = None,
764
- accessible_patches: Optional[bool] = None,
765
- cpg: Optional[bool] = None,
766
- methbases: Optional[List[str]] = None,
767
- threshold: Optional[float] = None,
768
- device: Optional[Union[str, torch.device]] = None,
769
- batch_size: Optional[int] = None,
770
- use_viterbi: Optional[bool] = None,
771
- in_place: bool = True,
772
- verbose: bool = True,
773
- uns_key: str = "hmm_appended_layers",
774
- config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
775
- ):
873
+ base_layer: str,
874
+ *,
875
+ distance_threshold: int,
876
+ suffix: str = "_merged",
877
+ overwrite: bool = True,
878
+ ) -> str:
879
+ """
880
+ Merge adjacent 1-intervals in a binary layer if gaps <= distance_threshold (in coords space),
881
+ writing:
882
+ - {base_layer}{suffix}
883
+ - {base_layer}{suffix}_lengths (run-length in index units)
884
+ """
885
+ if base_layer not in adata.layers:
886
+ raise KeyError(f"Layer '{base_layer}' not found.")
887
+
888
+ coords, coords_are_ints = _safe_int_coords(adata.var_names)
889
+ arr = np.asarray(adata.layers[base_layer])
890
+ arr = (arr > 0).astype(np.uint8)
891
+
892
+ merged_name = f"{base_layer}{suffix}"
893
+ merged_len_name = f"{merged_name}_lengths"
894
+
895
+ if (merged_name in adata.layers or merged_len_name in adata.layers) and not overwrite:
896
+ raise KeyError(f"Merged outputs exist (use overwrite=True): {merged_name}")
897
+
898
+ n, L = arr.shape
899
+ out = np.zeros_like(arr, dtype=np.uint8)
900
+
901
+ dt = int(distance_threshold)
902
+
903
+ for i in range(n):
904
+ ones = np.nonzero(arr[i] != 0)[0]
905
+ runs = self._runs_from_bool(arr[i] != 0)
906
+ if not runs:
907
+ continue
908
+ ms, me = runs[0]
909
+ merged_runs = []
910
+ for s, e in runs[1:]:
911
+ if coords_are_ints:
912
+ gap = int(coords[s]) - int(coords[me - 1]) - 1
913
+ else:
914
+ gap = s - me
915
+ if gap <= dt:
916
+ me = e
917
+ else:
918
+ merged_runs.append((ms, me))
919
+ ms, me = s, e
920
+ merged_runs.append((ms, me))
921
+
922
+ for s, e in merged_runs:
923
+ out[i, s:e] = 1
924
+
925
+ adata.layers[merged_name] = out
926
+ adata.layers[merged_len_name] = self._write_lengths_for_binary_layer(out)
927
+
928
+ # bookkeeping
929
+ key = "hmm_appended_layers"
930
+ if adata.uns.get(key) is None:
931
+ adata.uns[key] = []
932
+ for nm in (merged_name, merged_len_name):
933
+ if nm not in adata.uns[key]:
934
+ adata.uns[key].append(nm)
935
+
936
+ return merged_name
937
+
938
+ def write_size_class_layers_from_binary(
939
+ self,
940
+ adata,
941
+ base_layer: str,
942
+ *,
943
+ out_prefix: str,
944
+ feature_ranges: Dict[str, Tuple[float, float]],
945
+ suffix: str = "",
946
+ overwrite: bool = True,
947
+ ) -> List[str]:
948
+ """
949
+ Take an existing binary layer (runs represent features) and write size-class layers:
950
+ - {out_prefix}_{feature}{suffix}
951
+ - plus lengths layers
952
+
953
+ feature_ranges: name -> (lo, hi) in genomic bp.
954
+ """
955
+ if base_layer not in adata.layers:
956
+ raise KeyError(f"Layer '{base_layer}' not found.")
957
+
958
+ coords, coords_are_ints = _safe_int_coords(adata.var_names)
959
+ bin_arr = (np.asarray(adata.layers[base_layer]) > 0).astype(np.uint8)
960
+ n, L = bin_arr.shape
961
+
962
+ created: List[str] = []
963
+ for feat_name in feature_ranges.keys():
964
+ nm = f"{out_prefix}_{feat_name}{suffix}"
965
+ ln = f"{nm}_lengths"
966
+ if (nm in adata.layers or ln in adata.layers) and not overwrite:
967
+ continue
968
+ adata.layers[nm] = np.zeros((n, L), dtype=np.uint8)
969
+ adata.layers[ln] = np.zeros((n, L), dtype=np.int32)
970
+ created.extend([nm, ln])
971
+
972
+ for i in range(n):
973
+ runs = self._runs_from_bool(bin_arr[i] != 0)
974
+ for s, e in runs:
975
+ length_bp = self._interval_length(coords, s, e) if coords_are_ints else (e - s)
976
+ for feat_name, (lo, hi) in feature_ranges.items():
977
+ if float(lo) <= float(length_bp) < float(hi):
978
+ nm = f"{out_prefix}_{feat_name}{suffix}"
979
+ adata.layers[nm][i, s:e] = 1
980
+ adata.layers[f"{nm}_lengths"][i, s:e] = e - s
981
+ break
982
+
983
+ # fill lengths for each size layer (consistent, even if overlaps)
984
+ for feat_name in feature_ranges.keys():
985
+ nm = f"{out_prefix}_{feat_name}{suffix}"
986
+ adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
987
+ np.asarray(adata.layers[nm])
988
+ )
989
+
990
+ key = "hmm_appended_layers"
991
+ if adata.uns.get(key) is None:
992
+ adata.uns[key] = []
993
+ for nm in created:
994
+ if nm not in adata.uns[key]:
995
+ adata.uns[key].append(nm)
996
+
997
+ return created
998
+
999
+ # ------------------------- AnnData annotation -------------------------
1000
+
1001
+ @staticmethod
1002
+ def _resolve_pos_mask_for_methbase(subset, ref: str, methbase: str) -> Optional[np.ndarray]:
776
1003
  """
777
- Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
778
-
779
- Parameters
780
- ----------
781
- config : optional ExperimentConfig instance or plain dict
782
- When provided, the following keys (if present) are used to override defaults:
783
- - hmm_feature_sets : dict (canonical feature set structure) OR a JSON/string repr
784
- - hmm_annotation_threshold : float
785
- - hmm_batch_size : int
786
- - hmm_use_viterbi : bool
787
- - hmm_methbases : list
788
- - footprints / accessible_patches / cpg (booleans)
789
- Other keyword args override config values if explicitly provided.
1004
+ Local helper to resolve per-base masks from subset.var.* columns.
1005
+ Returns a boolean np.ndarray of length subset.n_vars or None.
790
1006
  """
791
- import json, ast, warnings
792
- import numpy as _np
793
- import torch as _torch
794
- from tqdm import trange, tqdm as _tqdm
795
-
796
- # small helpers
797
- def _try_json_or_literal(s):
798
- if s is None:
1007
+ key = str(methbase).strip().lower()
1008
+ var = subset.var
1009
+
1010
+ def _has(col: str) -> bool:
1011
+ """Return True when a column exists on subset.var."""
1012
+ return col in var.columns
1013
+
1014
+ if key in ("a",):
1015
+ col = f"{ref}_strand_FASTA_base"
1016
+ if not _has(col):
799
1017
  return None
800
- if not isinstance(s, str):
801
- return s
802
- s0 = s.strip()
803
- if s0 == "":
1018
+ return np.asarray(var[col] == "A")
1019
+
1020
+ if key in ("c", "any_c", "anyc", "any-c"):
1021
+ for col in (f"{ref}_any_C_site", f"{ref}_C_site"):
1022
+ if _has(col):
1023
+ return np.asarray(var[col])
1024
+ return None
1025
+
1026
+ if key in ("gpc", "gpc_site", "gpc-site"):
1027
+ col = f"{ref}_GpC_site"
1028
+ if not _has(col):
804
1029
  return None
805
- try:
806
- return json.loads(s0)
807
- except Exception:
808
- pass
809
- try:
810
- return ast.literal_eval(s0)
811
- except Exception:
812
- return s
813
-
814
- def _coerce_bool(x):
815
- if x is None:
816
- return False
817
- if isinstance(x, bool):
818
- return x
819
- if isinstance(x, (int, float)):
820
- return bool(x)
821
- s = str(x).strip().lower()
822
- return s in ("1", "true", "t", "yes", "y", "on")
823
-
824
- def normalize_hmm_feature_sets(raw):
825
- if raw is None:
826
- return {}
827
- parsed = raw
828
- if isinstance(raw, str):
829
- parsed = _try_json_or_literal(raw)
830
- if not isinstance(parsed, dict):
831
- return {}
832
-
833
- def _coerce_bound(x):
834
- if x is None:
835
- return None
836
- if isinstance(x, (int, float)):
837
- return float(x)
838
- s = str(x).strip().lower()
839
- if s in ("inf", "infty", "infinite", "np.inf"):
840
- return _np.inf
841
- if s in ("none", ""):
842
- return None
843
- try:
844
- return float(x)
845
- except Exception:
846
- return None
847
-
848
- def _coerce_feature_map(feats):
849
- out = {}
850
- if not isinstance(feats, dict):
851
- return out
852
- for fname, rng in feats.items():
853
- if rng is None:
854
- out[fname] = (0.0, _np.inf)
1030
+ return np.asarray(var[col])
1031
+
1032
+ if key in ("cpg", "cpg_site", "cpg-site"):
1033
+ col = f"{ref}_CpG_site"
1034
+ if not _has(col):
1035
+ return None
1036
+ return np.asarray(var[col])
1037
+
1038
+ alt = f"{ref}_{methbase}_site"
1039
+ if not _has(alt):
1040
+ return None
1041
+ return np.asarray(var[alt])
1042
+
1043
+ def annotate_adata(
1044
+ self,
1045
+ adata,
1046
+ *,
1047
+ prefix: str,
1048
+ X: np.ndarray,
1049
+ coords: np.ndarray,
1050
+ var_mask: np.ndarray,
1051
+ span_fill: bool = True,
1052
+ config=None,
1053
+ decode: str = "marginal",
1054
+ write_posterior: bool = True,
1055
+ posterior_state: str = "Modified",
1056
+ feature_sets: Optional[Dict[str, Dict[str, Any]]] = None,
1057
+ prob_threshold: float = 0.5,
1058
+ uns_key: str = "hmm_appended_layers",
1059
+ uns_flag: str = "hmm_annotated",
1060
+ force_redo: bool = False,
1061
+ device: Optional[Union[str, torch.device]] = None,
1062
+ **kwargs,
1063
+ ):
1064
+ """Decode and annotate an AnnData object with HMM-derived layers.
1065
+
1066
+ Args:
1067
+ adata: AnnData to annotate.
1068
+ prefix: Prefix for newly written layers.
1069
+ X: Observations array for decoding.
1070
+ coords: Coordinate array aligned to X.
1071
+ var_mask: Boolean mask for positions in adata.var.
1072
+ span_fill: Whether to fill missing spans.
1073
+ config: Optional config for naming and state selection.
1074
+ decode: Decode method ("marginal" or "viterbi").
1075
+ write_posterior: Whether to write posterior probabilities.
1076
+ posterior_state: State label to write posterior for.
1077
+ feature_sets: Optional feature set definition for size classes.
1078
+ prob_threshold: Posterior probability threshold for binary calls.
1079
+ uns_key: .uns key to track appended layers.
1080
+ uns_flag: .uns flag to mark annotations.
1081
+ force_redo: Whether to overwrite existing layers.
1082
+ device: Device specifier.
1083
+ **kwargs: Additional parameters for specialized workflows.
1084
+
1085
+ Returns:
1086
+ List of created layer names or None if skipped.
1087
+ """
1088
+ # skip logic
1089
+ if bool(adata.uns.get(uns_flag, False)) and not force_redo:
1090
+ return None
1091
+
1092
+ if adata.uns.get(uns_key) is None:
1093
+ adata.uns[uns_key] = []
1094
+ appended = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
1095
+
1096
+ X = np.asarray(X, dtype=float)
1097
+ coords = np.asarray(coords, dtype=int)
1098
+ var_mask = np.asarray(var_mask, dtype=bool)
1099
+ if var_mask.shape[0] != adata.n_vars:
1100
+ raise ValueError(f"var_mask length {var_mask.shape[0]} != adata.n_vars {adata.n_vars}")
1101
+
1102
+ # decode
1103
+ states, gamma = self.decode(
1104
+ X, coords, decode=decode, device=device
1105
+ ) # states (N,L), gamma (N,L,K)
1106
+ N, L = states.shape
1107
+ if N != adata.n_obs:
1108
+ raise ValueError(f"X has N={N} rows but adata.n_obs={adata.n_obs}")
1109
+
1110
+ # map coords -> full-var indices for span_fill
1111
+ full_coords, full_int = _safe_int_coords(adata.var_names)
1112
+
1113
+ # ---- write posterior + states on masked columns only ----
1114
+ masked_idx = np.nonzero(var_mask)[0]
1115
+ masked_coords, _ = _safe_int_coords(adata.var_names[var_mask])
1116
+
1117
+ # build mapping from coords order -> masked column order
1118
+ coord_to_pos_in_decoded = {int(c): i for i, c in enumerate(coords.tolist())}
1119
+ take = np.array(
1120
+ [coord_to_pos_in_decoded.get(int(c), -1) for c in masked_coords.tolist()], dtype=int
1121
+ )
1122
+ good = take >= 0
1123
+ masked_idx = masked_idx[good]
1124
+ take = take[good]
1125
+
1126
+ # states layer
1127
+ states_name = f"{prefix}_states"
1128
+ if states_name not in adata.layers:
1129
+ adata.layers[states_name] = np.full((adata.n_obs, adata.n_vars), -1, dtype=np.int8)
1130
+ adata.layers[states_name][:, masked_idx] = states[:, take].astype(np.int8)
1131
+ if states_name not in appended:
1132
+ appended.append(states_name)
1133
+
1134
+ # posterior layer (requested state)
1135
+ if write_posterior:
1136
+ t_idx = self.resolve_target_state_index(posterior_state)
1137
+ post = gamma[:, :, t_idx].astype(np.float32)
1138
+ post_name = f"{prefix}_posterior_{str(posterior_state).strip().lower().replace(' ', '_').replace('-', '_')}"
1139
+ if post_name not in adata.layers:
1140
+ adata.layers[post_name] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
1141
+ adata.layers[post_name][:, masked_idx] = post[:, take]
1142
+ if post_name not in appended:
1143
+ appended.append(post_name)
1144
+
1145
+ # ---- feature layers ----
1146
+ if feature_sets is None:
1147
+ cfgd = self._cfg_to_dict(config)
1148
+ feature_sets = normalize_hmm_feature_sets(cfgd.get("hmm_feature_sets", None))
1149
+
1150
+ if not feature_sets:
1151
+ adata.uns[uns_key] = appended
1152
+ adata.uns[uns_flag] = True
1153
+ return None
1154
+
1155
+ # allocate outputs
1156
+ for group, fs in feature_sets.items():
1157
+ fmap = fs.get("features", {}) or {}
1158
+ if not fmap:
1159
+ continue
1160
+
1161
+ all_layer = f"{prefix}_all_{group}_features"
1162
+ if all_layer not in adata.layers:
1163
+ adata.layers[all_layer] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint8)
1164
+ if f"{all_layer}_lengths" not in adata.layers:
1165
+ adata.layers[f"{all_layer}_lengths"] = np.zeros(
1166
+ (adata.n_obs, adata.n_vars), dtype=np.int32
1167
+ )
1168
+ for nm in (all_layer, f"{all_layer}_lengths"):
1169
+ if nm not in appended:
1170
+ appended.append(nm)
1171
+
1172
+ for feat in fmap.keys():
1173
+ nm = f"{prefix}_{feat}"
1174
+ if nm not in adata.layers:
1175
+ adata.layers[nm] = np.zeros(
1176
+ (adata.n_obs, adata.n_vars),
1177
+ dtype=np.int32 if nm.endswith("_lengths") else np.uint8,
1178
+ )
1179
+ if f"{nm}_lengths" not in adata.layers:
1180
+ adata.layers[f"{nm}_lengths"] = np.zeros(
1181
+ (adata.n_obs, adata.n_vars), dtype=np.int32
1182
+ )
1183
+ for outnm in (nm, f"{nm}_lengths"):
1184
+ if outnm not in appended:
1185
+ appended.append(outnm)
1186
+
1187
+ # classify runs per row
1188
+ target_idx = self.resolve_target_state_index(fs.get("state", "Modified"))
1189
+ membership = (
1190
+ (states == target_idx)
1191
+ if str(decode).lower() == "viterbi"
1192
+ else (gamma[:, :, target_idx] >= float(prob_threshold))
1193
+ )
1194
+
1195
+ for i in range(N):
1196
+ runs = self._runs_from_bool(membership[i].astype(bool))
1197
+ for s, e in runs:
1198
+ # genomic length in coords space
1199
+ glen = int(coords[e - 1]) - int(coords[s]) + 1 if e > s else 0
1200
+ if glen <= 0:
855
1201
  continue
856
- if isinstance(rng, (list, tuple)) and len(rng) >= 2:
857
- lo = _coerce_bound(rng[0]) or 0.0
858
- hi = _coerce_bound(rng[1])
859
- if hi is None:
860
- hi = _np.inf
861
- out[fname] = (float(lo), float(hi) if not _np.isinf(hi) else _np.inf)
1202
+
1203
+ # pick feature bin
1204
+ chosen = None
1205
+ for feat_name, (lo, hi) in fmap.items():
1206
+ if float(lo) <= float(glen) < float(hi):
1207
+ chosen = feat_name
1208
+ break
1209
+ if chosen is None:
1210
+ continue
1211
+
1212
+ # convert span to indices in full var grid
1213
+ if span_fill and full_int:
1214
+ left = int(np.searchsorted(full_coords, int(coords[s]), side="left"))
1215
+ right = int(np.searchsorted(full_coords, int(coords[e - 1]), side="right"))
1216
+ if left >= right:
1217
+ continue
1218
+ adata.layers[f"{prefix}_{chosen}"][i, left:right] = 1
1219
+ adata.layers[f"{prefix}_all_{group}_features"][i, left:right] = 1
862
1220
  else:
863
- val = _coerce_bound(rng)
864
- out[fname] = (0.0, float(val) if val is not None else _np.inf)
865
- return out
866
-
867
- canonical = {}
868
- for grp, info in parsed.items():
869
- if not isinstance(info, dict):
870
- feats = _coerce_feature_map(info)
871
- canonical[grp] = {"features": feats, "state": "Modified"}
872
- continue
873
- feats = _coerce_feature_map(info.get("features", info.get("ranges", {})))
874
- state = info.get("state", info.get("label", "Modified"))
875
- canonical[grp] = {"features": feats, "state": state}
876
- return canonical
877
-
878
- # ---------- resolve config dict ----------
879
- merged_cfg = {}
880
- if config is not None:
881
- if hasattr(config, "to_dict") and callable(getattr(config, "to_dict")):
882
- merged_cfg = dict(config.to_dict())
883
- elif isinstance(config, dict):
884
- merged_cfg = dict(config)
885
- else:
886
- try:
887
- merged_cfg = {k: getattr(config, k) for k in dir(config) if k.startswith("hmm_")}
888
- except Exception:
889
- merged_cfg = {}
890
-
891
- def _pick(key, local_val, fallback=None):
892
- if local_val is not None:
893
- return local_val
894
- if key in merged_cfg and merged_cfg[key] is not None:
895
- return merged_cfg[key]
896
- alt = f"hmm_{key}"
897
- if alt in merged_cfg and merged_cfg[alt] is not None:
898
- return merged_cfg[alt]
899
- return fallback
900
-
901
- # coerce booleans robustly
902
- footprints = _coerce_bool(_pick("footprints", footprints, merged_cfg.get("footprints", False)))
903
- accessible_patches = _coerce_bool(_pick("accessible_patches", accessible_patches, merged_cfg.get("accessible_patches", False)))
904
- cpg = _coerce_bool(_pick("cpg", cpg, merged_cfg.get("cpg", False)))
905
-
906
- threshold = float(_pick("threshold", threshold, merged_cfg.get("hmm_annotation_threshold", 0.5)))
907
- batch_size = int(_pick("batch_size", batch_size, merged_cfg.get("hmm_batch_size", 1024)))
908
- use_viterbi = _coerce_bool(_pick("use_viterbi", use_viterbi, merged_cfg.get("hmm_use_viterbi", False)))
909
-
910
- methbases = merged_cfg.get("hmm_methbases", None)
911
-
912
- # normalize whitespace/case for human-friendly inputs (but keep original tokens as given)
913
- methbases = [str(m).strip() for m in methbases if m is not None]
914
- if verbose:
915
- print("DEBUG: final methbases list =", methbases)
916
-
917
- # resolve feature sets: prefer canonical if it yields non-empty mapping, otherwise fall back to boolean defaults
918
- feature_sets = {}
919
- if "hmm_feature_sets" in merged_cfg and merged_cfg.get("hmm_feature_sets") is not None:
920
- cand = normalize_hmm_feature_sets(merged_cfg.get("hmm_feature_sets"))
921
- if isinstance(cand, dict) and len(cand) > 0:
922
- feature_sets = cand
1221
+ # only fill at masked indices
1222
+ cols = masked_idx[
1223
+ (masked_coords >= coords[s]) & (masked_coords <= coords[e - 1])
1224
+ ]
1225
+ if cols.size == 0:
1226
+ continue
1227
+ adata.layers[f"{prefix}_{chosen}"][i, cols] = 1
1228
+ adata.layers[f"{prefix}_all_{group}_features"][i, cols] = 1
1229
+
1230
+ # lengths derived from binary
1231
+ adata.layers[f"{prefix}_all_{group}_features_lengths"] = (
1232
+ self._write_lengths_for_binary_layer(
1233
+ np.asarray(adata.layers[f"{prefix}_all_{group}_features"])
1234
+ )
1235
+ )
1236
+ for feat in fmap.keys():
1237
+ nm = f"{prefix}_{feat}"
1238
+ adata.layers[f"{nm}_lengths"] = self._write_lengths_for_binary_layer(
1239
+ np.asarray(adata.layers[nm])
1240
+ )
1241
+
1242
+ adata.uns[uns_key] = appended
1243
+ adata.uns[uns_flag] = True
1244
+ return None
1245
+
1246
+ # ------------------------- row-copy helper (workflow uses it) -------------------------
1247
+
1248
+ def _ensure_final_layer_and_assign(
1249
+ self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data
1250
+ ):
1251
+ """
1252
+ Assign rows from sub_data into final_adata.layers[layer_name] for rows where subset_idx_mask is True.
1253
+ Handles dense arrays. If you want sparse support, add it here.
1254
+ """
1255
+ n_final_obs, n_vars = final_adata.shape
1256
+ final_rows = np.nonzero(np.asarray(subset_idx_mask).astype(bool))[0]
1257
+ sub_arr = np.asarray(sub_data)
1258
+
1259
+ if layer_name not in final_adata.layers:
1260
+ final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
1261
+
1262
+ final_arr = np.asarray(final_adata.layers[layer_name])
1263
+ if sub_arr.shape[0] != final_rows.size:
1264
+ raise ValueError(f"Sub rows {sub_arr.shape[0]} != mask sum {final_rows.size}")
1265
+ final_arr[final_rows, :] = sub_arr
1266
+ final_adata.layers[layer_name] = final_arr
1267
+
1268
+
1269
+ # =============================================================================
1270
+ # Single-channel Bernoulli HMM
1271
+ # =============================================================================
1272
+
1273
+
1274
+ @register_hmm("single")
1275
+ class SingleBernoulliHMM(BaseHMM):
1276
+ """
1277
+ Bernoulli emission per state:
1278
+ emission[k] = P(obs==1 | state=k)
1279
+ """
1280
+
1281
+ def __init__(
1282
+ self,
1283
+ n_states: int = 2,
1284
+ init_emission: Optional[Sequence[float]] = None,
1285
+ eps: float = 1e-8,
1286
+ dtype: torch.dtype = torch.float64,
1287
+ ):
1288
+ """Initialize a single-channel Bernoulli HMM.
1289
+
1290
+ Args:
1291
+ n_states: Number of hidden states.
1292
+ init_emission: Initial emission probabilities per state.
1293
+ eps: Smoothing epsilon for probabilities.
1294
+ dtype: Torch dtype for parameters.
1295
+ """
1296
+ super().__init__(n_states=n_states, eps=eps, dtype=dtype)
1297
+ if init_emission is None:
1298
+ em = np.full((self.n_states,), 0.5, dtype=float)
1299
+ else:
1300
+ em = np.asarray(init_emission, dtype=float).reshape(-1)[: self.n_states]
1301
+ if em.size != self.n_states:
1302
+ em = np.full((self.n_states,), 0.5, dtype=float)
1303
+
1304
+ self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
1305
+ self._normalize_emission()
1306
+
1307
+ @classmethod
1308
+ def from_config(cls, cfg, *, override=None, device=None):
1309
+ """Create a single-channel Bernoulli HMM from config.
1310
+
1311
+ Args:
1312
+ cfg: Configuration mapping or object.
1313
+ override: Override values to apply.
1314
+ device: Optional device specifier.
1315
+
1316
+ Returns:
1317
+ Initialized SingleBernoulliHMM instance.
1318
+ """
1319
+ merged = cls._cfg_to_dict(cfg)
1320
+ if override:
1321
+ merged.update(override)
1322
+ n_states = int(merged.get("hmm_n_states", 2))
1323
+ eps = float(merged.get("hmm_eps", 1e-8))
1324
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1325
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1326
+ init_em = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
1327
+ model = cls(n_states=n_states, init_emission=init_em, eps=eps, dtype=dtype)
1328
+ if device is not None:
1329
+ model.to(torch.device(device) if isinstance(device, str) else device)
1330
+ model._persisted_cfg = merged
1331
+ return model
1332
+
1333
+ def _normalize_emission(self):
1334
+ """Normalize and clamp emission probabilities in-place."""
1335
+ with torch.no_grad():
1336
+ self.emission.data = self.emission.data.reshape(-1)
1337
+ if self.emission.data.numel() != self.n_states:
1338
+ self.emission.data = torch.full(
1339
+ (self.n_states,), 0.5, dtype=self.dtype, device=self.emission.device
1340
+ )
1341
+ self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
1342
+
1343
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1344
+ """Move emission parameters to the requested device/dtype."""
1345
+ device = super()._ensure_device_dtype(device)
1346
+ self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
1347
+ return device
1348
+
1349
+ def _state_modified_score(self) -> torch.Tensor:
1350
+ """Return per-state modified scores for ranking."""
1351
+ return self.emission.detach()
1352
+
1353
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1354
+ """
1355
+ obs: (N,L), mask: (N,L) -> logB: (N,L,K)
1356
+ """
1357
+ p = self.emission # (K,)
1358
+ logp = torch.log(p + self.eps)
1359
+ log1mp = torch.log1p(-p + self.eps)
1360
+
1361
+ o = obs.unsqueeze(-1) # (N,L,1)
1362
+ logB = o * logp.view(1, 1, -1) + (1.0 - o) * log1mp.view(1, 1, -1)
1363
+ logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
1364
+ return logB
1365
+
1366
+ def _extra_save_payload(self) -> dict:
1367
+ """Return extra payload data for serialization."""
1368
+ return {"emission": self.emission.detach().cpu()}
1369
+
1370
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1371
+ """Load serialized emission parameters.
1372
+
1373
+ Args:
1374
+ payload: Serialized payload dictionary.
1375
+ device: Target torch device.
1376
+ """
1377
+ with torch.no_grad():
1378
+ self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
1379
+ self._normalize_emission()
1380
+
1381
+ def fit_em(
1382
+ self,
1383
+ X: np.ndarray,
1384
+ coords: np.ndarray,
1385
+ *,
1386
+ device: torch.device,
1387
+ max_iter: int,
1388
+ tol: float,
1389
+ update_start: bool,
1390
+ update_trans: bool,
1391
+ update_emission: bool,
1392
+ verbose: bool,
1393
+ **kwargs,
1394
+ ) -> List[float]:
1395
+ """Run EM updates for a single-channel Bernoulli HMM.
1396
+
1397
+ Args:
1398
+ X: Observations array (N, L).
1399
+ coords: Coordinate array aligned to X.
1400
+ device: Torch device.
1401
+ max_iter: Maximum iterations.
1402
+ tol: Convergence tolerance.
1403
+ update_start: Whether to update start probabilities.
1404
+ update_trans: Whether to update transitions.
1405
+ update_emission: Whether to update emission parameters.
1406
+ verbose: Whether to log progress.
1407
+ **kwargs: Additional implementation-specific kwargs.
1408
+
1409
+ Returns:
1410
+ List of log-likelihood proxy values.
1411
+ """
1412
+ X = np.asarray(X, dtype=float)
1413
+ if X.ndim != 2:
1414
+ raise ValueError("SingleBernoulliHMM expects X shape (N,L).")
1415
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
1416
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
1417
+
1418
+ eps = float(self.eps)
1419
+ K = self.n_states
1420
+ N, L = obs.shape
1421
+
1422
+ hist: List[float] = []
1423
+ for it in range(1, int(max_iter) + 1):
1424
+ gamma = self._forward_backward(obs, mask) # (N,L,K)
1425
+
1426
+ # log-likelihood proxy
1427
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
1428
+ hist.append(ll_proxy)
1429
+
1430
+ # expected start
1431
+ start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
1432
+
1433
+ # expected transitions xi
1434
+ logB = self._log_emission(obs, mask)
1435
+ logA = torch.log(self.trans + eps)
1436
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
1437
+ alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
1438
+ for t in range(1, L):
1439
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
1440
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
1441
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
1442
+ beta[:, L - 1, :] = 0.0
1443
+ for t in range(L - 2, -1, -1):
1444
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1445
+ beta[:, t, :] = _logsumexp(temp, dim=2)
1446
+
1447
+ trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
1448
+ for t in range(L - 1):
1449
+ valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
1450
+ log_xi = (
1451
+ alpha[:, t, :].unsqueeze(2)
1452
+ + logA.unsqueeze(0)
1453
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1454
+ )
1455
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
1456
+ xi = (log_xi - log_norm).exp() * valid_t
1457
+ trans_acc += xi.sum(dim=0)
1458
+
1459
+ # emission update
1460
+ mask_f = mask.float().unsqueeze(-1) # (N,L,1)
1461
+ emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
1462
+ emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
1463
+
1464
+ with torch.no_grad():
1465
+ if update_start:
1466
+ new_start = start_acc + eps
1467
+ self.start.data = new_start / new_start.sum()
1468
+
1469
+ if update_trans:
1470
+ new_trans = trans_acc + eps
1471
+ rs = new_trans.sum(dim=1, keepdim=True)
1472
+ rs[rs == 0.0] = 1.0
1473
+ self.trans.data = new_trans / rs
1474
+
1475
+ if update_emission:
1476
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
1477
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
1478
+
1479
+ self._normalize_params()
1480
+ self._normalize_emission()
923
1481
 
924
- if not feature_sets:
925
1482
  if verbose:
926
- print("[HMM.annotate_adata] no feature sets configured; nothing to append.")
927
- return None if in_place else adata
928
-
929
- if verbose:
930
- print("[HMM.annotate_adata] resolved feature sets:", list(feature_sets.keys()))
931
-
932
- # copy vs in-place
933
- if not in_place:
934
- adata = adata.copy()
935
-
936
- # prepare column names
937
- all_features = []
938
- combined_prefix = "Combined"
939
- for key, fs in feature_sets.items():
940
- feats = fs.get("features", {})
941
- if key == "cpg":
942
- all_features += [f"CpG_{f}" for f in feats]
943
- all_features.append(f"CpG_all_{key}_features")
944
- else:
945
- for methbase in methbases:
946
- all_features += [f"{methbase}_{f}" for f in feats]
947
- all_features.append(f"{methbase}_all_{key}_features")
948
- if len(methbases) > 1:
949
- all_features += [f"{combined_prefix}_{f}" for f in feats]
950
- all_features.append(f"{combined_prefix}_all_{key}_features")
951
-
952
- # initialize obs columns (unique lists per row)
953
- n_rows = adata.shape[0]
954
- for feature in all_features:
955
- if feature not in adata.obs.columns:
956
- adata.obs[feature] = [[] for _ in range(n_rows)]
957
- if f"{feature}_distances" not in adata.obs.columns:
958
- adata.obs[f"{feature}_distances"] = [None] * n_rows
959
- if f"n_{feature}" not in adata.obs.columns:
960
- adata.obs[f"n_{feature}"] = -1
961
-
962
- appended_layers: List[str] = []
963
-
964
- # device management
965
- if device is None:
966
- device = next(self.parameters()).device
967
- elif isinstance(device, str):
968
- device = _torch.device(device)
969
- self.to(device)
1483
+ logger.info(
1484
+ "[SingleBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
1485
+ it,
1486
+ hist[-1],
1487
+ )
1488
+
1489
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
1490
+ break
970
1491
 
971
- # helpers ---------------------------------------------------------------
972
- def _ensure_2d_array_like(matrix):
973
- arr = _np.asarray(matrix)
1492
+ return hist
1493
+
1494
+ def adapt_emissions(
1495
+ self,
1496
+ X: np.ndarray,
1497
+ coords: Optional[np.ndarray] = None,
1498
+ *,
1499
+ device: Optional[Union[str, torch.device]] = None,
1500
+ iters: Optional[int] = None,
1501
+ max_iter: Optional[int] = None, # alias for your trainer
1502
+ verbose: bool = False,
1503
+ **kwargs,
1504
+ ):
1505
+ """Adapt emissions with legacy parameter names.
1506
+
1507
+ Args:
1508
+ X: Observations array.
1509
+ coords: Optional coordinate array.
1510
+ device: Device specifier.
1511
+ iters: Number of iterations.
1512
+ max_iter: Alias for iters.
1513
+ verbose: Whether to log progress.
1514
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
1515
+
1516
+ Returns:
1517
+ List of log-likelihood values.
1518
+ """
1519
+ if iters is None:
1520
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
1521
+ return super().adapt_emissions(
1522
+ np.asarray(X, dtype=float),
1523
+ coords if coords is not None else None,
1524
+ iters=int(iters),
1525
+ device=device,
1526
+ verbose=verbose,
1527
+ )
1528
+
1529
+
1530
+ # =============================================================================
1531
+ # Multi-channel Bernoulli HMM (union coordinate grid)
1532
+ # =============================================================================
1533
+
1534
+
1535
+ @register_hmm("multi")
1536
+ class MultiBernoulliHMM(BaseHMM):
1537
+ """
1538
+ Multi-channel independent Bernoulli:
1539
+ emission[k,c] = P(obs_c==1 | state=k)
1540
+ X must be (N,L,C) on a union coordinate grid; NaN per-channel allowed.
1541
+ """
1542
+
1543
+ def __init__(
1544
+ self,
1545
+ n_states: int = 2,
1546
+ n_channels: int = 2,
1547
+ init_emission: Optional[Any] = None,
1548
+ eps: float = 1e-8,
1549
+ dtype: torch.dtype = torch.float64,
1550
+ ):
1551
+ """Initialize a multi-channel Bernoulli HMM.
1552
+
1553
+ Args:
1554
+ n_states: Number of hidden states.
1555
+ n_channels: Number of observed channels.
1556
+ init_emission: Initial emission probabilities.
1557
+ eps: Smoothing epsilon for probabilities.
1558
+ dtype: Torch dtype for parameters.
1559
+ """
1560
+ super().__init__(n_states=n_states, eps=eps, dtype=dtype)
1561
+ self.n_channels = int(n_channels)
1562
+ if self.n_channels < 1:
1563
+ raise ValueError("n_channels must be >=1")
1564
+
1565
+ if init_emission is None:
1566
+ em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
1567
+ else:
1568
+ arr = np.asarray(init_emission, dtype=float)
974
1569
  if arr.ndim == 1:
975
- arr = arr[_np.newaxis, :]
976
- elif arr.ndim > 2:
977
- # squeeze trailing singletons
978
- while arr.ndim > 2 and arr.shape[-1] == 1:
979
- arr = _np.squeeze(arr, axis=-1)
980
- if arr.ndim != 2:
981
- raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
982
- return arr
983
-
984
- def calculate_batch_distances(intervals_list, threshold_local=0.9):
985
- results_local = []
986
- for intervals in intervals_list:
987
- if not isinstance(intervals, list) or len(intervals) == 0:
988
- results_local.append([])
989
- continue
990
- valid = [iv for iv in intervals if iv[2] > threshold_local]
991
- if len(valid) <= 1:
992
- results_local.append([])
993
- continue
994
- valid = sorted(valid, key=lambda x: x[0])
995
- dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
996
- results_local.append(dists)
997
- return results_local
998
-
999
- def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
1000
- # Accept numpy arrays or torch tensors
1001
- if isinstance(predicted_states_batch, _torch.Tensor):
1002
- pred_np = predicted_states_batch.detach().cpu().numpy()
1570
+ arr = arr.reshape(-1, 1)
1571
+ em = np.repeat(arr[: self.n_states, :], self.n_channels, axis=1)
1003
1572
  else:
1004
- pred_np = _np.asarray(predicted_states_batch)
1005
- if isinstance(probabilities_batch, _torch.Tensor):
1006
- probs_np = probabilities_batch.detach().cpu().numpy()
1007
- else:
1008
- probs_np = _np.asarray(probabilities_batch)
1009
-
1010
- batch_size, L = pred_np.shape
1011
- all_classifications_local = []
1012
- # allow caller to pass arbitrary state labels mapping; default two-state mapping:
1013
- state_labels = ["Non-Modified", "Modified"]
1014
- try:
1015
- target_idx = state_labels.index(target_state)
1016
- except ValueError:
1017
- target_idx = 1 # fallback
1018
-
1019
- for b in range(batch_size):
1020
- predicted_states = pred_np[b]
1021
- probabilities = probs_np[b]
1022
- regions = []
1023
- current_start, current_length, current_probs = None, 0, []
1024
- for i, state_index in enumerate(predicted_states):
1025
- state_prob = float(probabilities[i][state_index])
1026
- if state_index == target_idx:
1027
- if current_start is None:
1028
- current_start = i
1029
- current_length += 1
1030
- current_probs.append(state_prob)
1031
- elif current_start is not None:
1032
- regions.append((current_start, current_length, float(_np.mean(current_probs))))
1033
- current_start, current_length, current_probs = None, 0, []
1034
- if current_start is not None:
1035
- regions.append((current_start, current_length, float(_np.mean(current_probs))))
1036
-
1037
- final = []
1038
- for start, length, prob in regions:
1039
- # compute genomic length try/catch
1040
- try:
1041
- feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
1042
- except Exception:
1043
- feature_length = int(length)
1044
-
1045
- # classification_mapping values are (lo, hi) tuples or lists
1046
- label = None
1047
- for ftype, rng in classification_mapping.items():
1048
- lo, hi = rng[0], rng[1]
1049
- try:
1050
- if lo <= feature_length < hi:
1051
- label = ftype
1052
- break
1053
- except Exception:
1054
- continue
1055
- if label is None:
1056
- # fallback to first mapping key or 'unknown'
1057
- label = next(iter(classification_mapping.keys()), "feature")
1058
-
1059
- # Store reported start coordinate in same coordinate system as `coordinates`.
1060
- try:
1061
- genomic_start = int(coordinates[start])
1062
- except Exception:
1063
- genomic_start = int(start)
1064
- final.append((genomic_start, feature_length, label, prob))
1065
- all_classifications_local.append(final)
1066
- return all_classifications_local
1067
-
1068
- # -----------------------------------------------------------------------
1069
-
1070
- # Ensure obs_column is categorical-like for iteration
1071
- sseries = adata.obs[obs_column]
1072
- if not pd.api.types.is_categorical_dtype(sseries):
1073
- sseries = sseries.astype("category")
1074
- references = list(sseries.cat.categories)
1075
-
1076
- ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
1077
- for ref in ref_iter:
1078
- # subset reads with this obs_column value
1079
- ref_mask = adata.obs[obs_column] == ref
1080
- ref_subset = adata[ref_mask].copy()
1081
- combined_mask = None
1082
-
1083
- # per-methbase processing
1084
- for methbase in methbases:
1085
- key_lower = methbase.strip().lower()
1086
-
1087
- # map several common synonyms -> canonical lookup
1088
- if key_lower in ("a",):
1089
- pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
1090
- elif key_lower in ("c", "any_c", "anyc", "any-c"):
1091
- # unify 'C' or 'any_C' names to the any_C var column
1092
- pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
1093
- elif key_lower in ("gpc", "gpc_site", "gpc-site"):
1094
- pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
1095
- elif key_lower in ("cpg", "cpg_site", "cpg-site"):
1096
- pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
1097
- else:
1098
- # try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
1099
- alt_col = f"{ref}_{methbase}_site"
1100
- pos_mask = ref_subset.var.get(alt_col, None)
1573
+ em = arr[: self.n_states, : self.n_channels]
1574
+ if em.shape != (self.n_states, self.n_channels):
1575
+ em = np.full((self.n_states, self.n_channels), 0.5, dtype=float)
1101
1576
 
1102
- if pos_mask is None:
1103
- continue
1104
- combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
1577
+ self.emission = nn.Parameter(torch.tensor(em, dtype=self.dtype), requires_grad=False)
1578
+ self._normalize_emission()
1105
1579
 
1106
- if pos_mask.sum() == 0:
1107
- continue
1580
+ @classmethod
1581
+ def from_config(cls, cfg, *, override=None, device=None):
1582
+ """Create a multi-channel Bernoulli HMM from config.
1108
1583
 
1109
- sub = ref_subset[:, pos_mask]
1110
- # choose matrix
1111
- matrix = sub.layers[layer] if (layer and layer in sub.layers) else sub.X
1112
- matrix = _ensure_2d_array_like(matrix)
1113
- n_reads = matrix.shape[0]
1584
+ Args:
1585
+ cfg: Configuration mapping or object.
1586
+ override: Override values to apply.
1587
+ device: Optional device specifier.
1114
1588
 
1115
- # coordinates for this sub (try to convert to ints, else fallback to indices)
1116
- try:
1117
- coords = _np.asarray(sub.var_names, dtype=int)
1118
- except Exception:
1119
- coords = _np.arange(sub.shape[1], dtype=int)
1120
-
1121
- # chunked processing
1122
- chunk_iter = range(0, n_reads, batch_size)
1123
- if verbose:
1124
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:{methbase} chunks")
1125
- for start_idx in chunk_iter:
1126
- stop_idx = min(n_reads, start_idx + batch_size)
1127
- chunk = matrix[start_idx:stop_idx]
1128
- seqs = chunk.tolist()
1129
- # posterior marginals
1130
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1131
- if len(gammas) == 0:
1132
- continue
1133
- probs_batch = _np.stack(gammas, axis=0) # (B, L, K)
1134
- if use_viterbi:
1135
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1136
- pred_states = _np.asarray(paths)
1137
- else:
1138
- pred_states = _np.argmax(probs_batch, axis=2)
1589
+ Returns:
1590
+ Initialized MultiBernoulliHMM instance.
1591
+ """
1592
+ merged = cls._cfg_to_dict(cfg)
1593
+ if override:
1594
+ merged.update(override)
1595
+ n_states = int(merged.get("hmm_n_states", 2))
1596
+ eps = float(merged.get("hmm_eps", 1e-8))
1597
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1598
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1599
+ n_channels = int(merged.get("hmm_n_channels", merged.get("n_channels", 2)))
1600
+ init_em = merged.get("hmm_init_emission_probs", None)
1601
+ model = cls(
1602
+ n_states=n_states, n_channels=n_channels, init_emission=init_em, eps=eps, dtype=dtype
1603
+ )
1604
+ if device is not None:
1605
+ model.to(torch.device(device) if isinstance(device, str) else device)
1606
+ model._persisted_cfg = merged
1607
+ return model
1139
1608
 
1140
- # For each feature group, classify separately and write back
1141
- for key, fs in feature_sets.items():
1142
- if key == "cpg":
1143
- continue
1144
- state_target = fs.get("state", "Modified")
1145
- feature_map = fs.get("features", {})
1146
- classifications = classify_batch_local(pred_states, probs_batch, coords, feature_map, target_state=state_target)
1147
-
1148
- # write results to adata.obs rows (use original index names)
1149
- row_indices = list(sub.obs.index[start_idx:stop_idx])
1150
- for i_local, idx in enumerate(row_indices):
1151
- for start, length, label, prob in classifications[i_local]:
1152
- col_name = f"{methbase}_{label}"
1153
- all_col = f"{methbase}_all_{key}_features"
1154
- adata.obs.at[idx, col_name].append([start, length, prob])
1155
- adata.obs.at[idx, all_col].append([start, length, prob])
1156
-
1157
- # Combined subset (if multiple methbases)
1158
- if len(methbases) > 1 and (combined_mask is not None) and (combined_mask.sum() > 0):
1159
- comb = ref_subset[:, combined_mask]
1160
- if comb.shape[1] > 0:
1161
- matrix = comb.layers[layer] if (layer and layer in comb.layers) else comb.X
1162
- matrix = _ensure_2d_array_like(matrix)
1163
- n_reads_comb = matrix.shape[0]
1164
- try:
1165
- coords_comb = _np.asarray(comb.var_names, dtype=int)
1166
- except Exception:
1167
- coords_comb = _np.arange(comb.shape[1], dtype=int)
1168
-
1169
- chunk_iter = range(0, n_reads_comb, batch_size)
1170
- if verbose:
1171
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:Combined chunks")
1172
- for start_idx in chunk_iter:
1173
- stop_idx = min(n_reads_comb, start_idx + batch_size)
1174
- chunk = matrix[start_idx:stop_idx]
1175
- seqs = chunk.tolist()
1176
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1177
- if len(gammas) == 0:
1178
- continue
1179
- probs_batch = _np.stack(gammas, axis=0)
1180
- if use_viterbi:
1181
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1182
- pred_states = _np.asarray(paths)
1183
- else:
1184
- pred_states = _np.argmax(probs_batch, axis=2)
1185
-
1186
- for key, fs in feature_sets.items():
1187
- if key == "cpg":
1188
- continue
1189
- state_target = fs.get("state", "Modified")
1190
- feature_map = fs.get("features", {})
1191
- classifications = classify_batch_local(pred_states, probs_batch, coords_comb, feature_map, target_state=state_target)
1192
- row_indices = list(comb.obs.index[start_idx:stop_idx])
1193
- for i_local, idx in enumerate(row_indices):
1194
- for start, length, label, prob in classifications[i_local]:
1195
- adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
1196
- adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
1197
-
1198
- # CpG special handling
1199
- if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
1200
- cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
1201
- for ref in cpg_iter:
1202
- ref_mask = adata.obs[obs_column] == ref
1203
- ref_subset = adata[ref_mask].copy()
1204
- pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
1205
- if pos_mask.sum() == 0:
1206
- continue
1207
- cpg_sub = ref_subset[:, pos_mask]
1208
- matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
1209
- matrix = _ensure_2d_array_like(matrix)
1210
- n_reads = matrix.shape[0]
1211
- try:
1212
- coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
1213
- except Exception:
1214
- coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
1215
-
1216
- chunk_iter = range(0, n_reads, batch_size)
1217
- if verbose:
1218
- chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
1219
- for start_idx in chunk_iter:
1220
- stop_idx = min(n_reads, start_idx + batch_size)
1221
- chunk = matrix[start_idx:stop_idx]
1222
- seqs = chunk.tolist()
1223
- gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1224
- if len(gammas) == 0:
1225
- continue
1226
- probs_batch = _np.stack(gammas, axis=0)
1227
- if use_viterbi:
1228
- paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1229
- pred_states = _np.asarray(paths)
1230
- else:
1231
- pred_states = _np.argmax(probs_batch, axis=2)
1232
-
1233
- fs = feature_sets["cpg"]
1234
- state_target = fs.get("state", "Modified")
1235
- feature_map = fs.get("features", {})
1236
- classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
1237
- row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
1238
- for i_local, idx in enumerate(row_indices):
1239
- for start, length, label, prob in classifications[i_local]:
1240
- adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
1241
- adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
1242
-
1243
- # finalize: convert intervals into binary layers and distances
1244
- try:
1245
- coordinates = _np.asarray(adata.var_names, dtype=int)
1246
- coords_are_ints = True
1247
- except Exception:
1248
- coordinates = _np.arange(adata.shape[1], dtype=int)
1249
- coords_are_ints = False
1250
-
1251
- features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
1252
- for feature in features_iter:
1253
- bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1254
- counts = _np.zeros(adata.shape[0], dtype=int)
1255
-
1256
- # new: integer-length layer (0 where not inside a feature)
1257
- len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1258
-
1259
- for row_idx, intervals in enumerate(adata.obs[feature]):
1260
- if not isinstance(intervals, list):
1261
- intervals = []
1262
- for start, length, prob in intervals:
1263
- if prob > threshold:
1264
- if coords_are_ints:
1265
- # map genomic start/length into index interval [start_idx, end_idx)
1266
- start_idx = _np.searchsorted(coordinates, int(start), side="left")
1267
- end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
1268
- else:
1269
- start_idx = int(start)
1270
- end_idx = start_idx + int(length)
1271
-
1272
- start_idx = max(0, min(start_idx, adata.shape[1]))
1273
- end_idx = max(0, min(end_idx, adata.shape[1]))
1274
-
1275
- if start_idx < end_idx:
1276
- span = end_idx - start_idx # number of positions covered
1277
- # set binary mask
1278
- bin_matrix[row_idx, start_idx:end_idx] = 1
1279
- # set length mask: use maximum in case of overlaps
1280
- existing = len_matrix[row_idx, start_idx:end_idx]
1281
- len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
1282
- counts[row_idx] += 1
1283
-
1284
- # write binary layer and length layer, track appended names
1285
- adata.layers[feature] = bin_matrix
1286
- appended_layers.append(feature)
1287
-
1288
- # name the integer-length layer (choose suffix you like)
1289
- length_layer_name = f"{feature}_lengths"
1290
- adata.layers[length_layer_name] = len_matrix
1291
- appended_layers.append(length_layer_name)
1292
-
1293
- adata.obs[f"n_{feature}"] = counts
1294
- adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
1295
-
1296
- # Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
1297
- existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
1298
- new_list = existing + [l for l in appended_layers if l not in existing]
1299
- adata.uns[uns_key] = new_list
1300
-
1301
- return None if in_place else adata
1302
-
1303
- def merge_intervals_in_layer(
1609
+ def _normalize_emission(self):
1610
+ """Normalize and clamp emission probabilities in-place."""
1611
+ with torch.no_grad():
1612
+ self.emission.data = self.emission.data.reshape(self.n_states, self.n_channels)
1613
+ self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
1614
+
1615
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1616
+ """Move emission parameters to the requested device/dtype."""
1617
+ device = super()._ensure_device_dtype(device)
1618
+ self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
1619
+ return device
1620
+
1621
+ def _state_modified_score(self) -> torch.Tensor:
1622
+ """Return per-state modified scores for ranking."""
1623
+ # more “modified” = higher mean P(1) across channels
1624
+ return self.emission.detach().mean(dim=1)
1625
+
1626
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1627
+ """
1628
+ obs: (N,L,C), mask: (N,L,C) -> logB: (N,L,K)
1629
+ """
1630
+ N, L, C = obs.shape
1631
+ K = self.n_states
1632
+
1633
+ p = self.emission # (K,C)
1634
+ logp = torch.log(p + self.eps).view(1, 1, K, C)
1635
+ log1mp = torch.log1p(-p + self.eps).view(1, 1, K, C)
1636
+
1637
+ o = obs.unsqueeze(2) # (N,L,1,C)
1638
+ m = mask.unsqueeze(2) # (N,L,1,C)
1639
+
1640
+ logBC = o * logp + (1.0 - o) * log1mp
1641
+ logBC = torch.where(m, logBC, torch.zeros_like(logBC))
1642
+ return logBC.sum(dim=3) # sum channels -> (N,L,K)
1643
+
1644
+ def _extra_save_payload(self) -> dict:
1645
+ """Return extra payload data for serialization."""
1646
+ return {"n_channels": int(self.n_channels), "emission": self.emission.detach().cpu()}
1647
+
1648
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1649
+ """Load serialized emission parameters.
1650
+
1651
+ Args:
1652
+ payload: Serialized payload dictionary.
1653
+ device: Target torch device.
1654
+ """
1655
+ self.n_channels = int(payload.get("n_channels", self.n_channels))
1656
+ with torch.no_grad():
1657
+ self.emission.data = payload["emission"].to(device=device, dtype=self.dtype)
1658
+ self._normalize_emission()
1659
+
1660
+ def fit_em(
1304
1661
  self,
1305
- adata,
1306
- layer: str,
1307
- distance_threshold: int = 0,
1308
- merged_suffix: str = "_merged",
1309
- length_layer_suffix: str = "_lengths",
1310
- update_obs: bool = True,
1311
- prob_strategy: str = "mean", # 'mean'|'max'|'orig_first'
1312
- inplace: bool = True,
1313
- overwrite: bool = False,
1662
+ X: np.ndarray,
1663
+ coords: np.ndarray,
1664
+ *,
1665
+ device: torch.device,
1666
+ max_iter: int,
1667
+ tol: float,
1668
+ update_start: bool,
1669
+ update_trans: bool,
1670
+ update_emission: bool,
1671
+ verbose: bool,
1672
+ **kwargs,
1673
+ ) -> List[float]:
1674
+ """Run EM updates for a multi-channel Bernoulli HMM.
1675
+
1676
+ Args:
1677
+ X: Observations array (N, L, C).
1678
+ coords: Coordinate array aligned to X.
1679
+ device: Torch device.
1680
+ max_iter: Maximum iterations.
1681
+ tol: Convergence tolerance.
1682
+ update_start: Whether to update start probabilities.
1683
+ update_trans: Whether to update transitions.
1684
+ update_emission: Whether to update emission parameters.
1685
+ verbose: Whether to log progress.
1686
+ **kwargs: Additional implementation-specific kwargs.
1687
+
1688
+ Returns:
1689
+ List of log-likelihood proxy values.
1690
+ """
1691
+ X = np.asarray(X, dtype=float)
1692
+ if X.ndim != 3:
1693
+ raise ValueError("MultiBernoulliHMM expects X shape (N,L,C).")
1694
+
1695
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
1696
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
1697
+
1698
+ eps = float(self.eps)
1699
+ K = self.n_states
1700
+ N, L, C = obs.shape
1701
+
1702
+ self._ensure_n_channels(C, device)
1703
+
1704
+ hist: List[float] = []
1705
+ for it in range(1, int(max_iter) + 1):
1706
+ gamma = self._forward_backward(obs, mask) # (N,L,K)
1707
+
1708
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
1709
+ hist.append(ll_proxy)
1710
+
1711
+ # expected start
1712
+ start_acc = gamma[:, 0, :].sum(dim=0) # (K,)
1713
+
1714
+ # transitions xi
1715
+ logB = self._log_emission(obs, mask)
1716
+ logA = torch.log(self.trans + eps)
1717
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
1718
+ alpha[:, 0, :] = torch.log(self.start + eps).unsqueeze(0) + logB[:, 0, :]
1719
+ for t in range(1, L):
1720
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
1721
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
1722
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
1723
+ beta[:, L - 1, :] = 0.0
1724
+ for t in range(L - 2, -1, -1):
1725
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1726
+ beta[:, t, :] = _logsumexp(temp, dim=2)
1727
+
1728
+ trans_acc = torch.zeros((K, K), dtype=self.dtype, device=device)
1729
+ # valid timestep if at least one channel observed at both positions
1730
+ valid_pos = mask.any(dim=2) # (N,L)
1731
+ for t in range(L - 1):
1732
+ valid_t = (valid_pos[:, t] & valid_pos[:, t + 1]).float().view(N, 1, 1)
1733
+ log_xi = (
1734
+ alpha[:, t, :].unsqueeze(2)
1735
+ + logA.unsqueeze(0)
1736
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
1737
+ )
1738
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
1739
+ xi = (log_xi - log_norm).exp() * valid_t
1740
+ trans_acc += xi.sum(dim=0)
1741
+
1742
+ # emission update per channel
1743
+ gamma_k = gamma.unsqueeze(-1) # (N,L,K,1)
1744
+ obs_c = obs.unsqueeze(2) # (N,L,1,C)
1745
+ mask_c = mask.unsqueeze(2).float() # (N,L,1,C)
1746
+
1747
+ emit_num = (gamma_k * obs_c * mask_c).sum(dim=(0, 1)) # (K,C)
1748
+ emit_den = (gamma_k * mask_c).sum(dim=(0, 1)) # (K,C)
1749
+
1750
+ with torch.no_grad():
1751
+ if update_start:
1752
+ new_start = start_acc + eps
1753
+ self.start.data = new_start / new_start.sum()
1754
+
1755
+ if update_trans:
1756
+ new_trans = trans_acc + eps
1757
+ rs = new_trans.sum(dim=1, keepdim=True)
1758
+ rs[rs == 0.0] = 1.0
1759
+ self.trans.data = new_trans / rs
1760
+
1761
+ if update_emission:
1762
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
1763
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
1764
+
1765
+ self._normalize_params()
1766
+ self._normalize_emission()
1767
+
1768
+ if verbose:
1769
+ logger.info(
1770
+ "[MultiBernoulliHMM.fit] iter=%s ll_proxy=%.6f",
1771
+ it,
1772
+ hist[-1],
1773
+ )
1774
+
1775
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
1776
+ break
1777
+
1778
+ return hist
1779
+
1780
+ def adapt_emissions(
1781
+ self,
1782
+ X: np.ndarray,
1783
+ coords: Optional[np.ndarray] = None,
1784
+ *,
1785
+ device: Optional[Union[str, torch.device]] = None,
1786
+ iters: Optional[int] = None,
1787
+ max_iter: Optional[int] = None, # alias for your trainer
1314
1788
  verbose: bool = False,
1789
+ **kwargs,
1315
1790
  ):
1791
+ """Adapt emissions with legacy parameter names.
1792
+
1793
+ Args:
1794
+ X: Observations array.
1795
+ coords: Optional coordinate array.
1796
+ device: Device specifier.
1797
+ iters: Number of iterations.
1798
+ max_iter: Alias for iters.
1799
+ verbose: Whether to log progress.
1800
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
1801
+
1802
+ Returns:
1803
+ List of log-likelihood values.
1316
1804
  """
1317
- Merge intervals in `adata.layers[layer]` that are within `distance_threshold`.
1318
- Writes new merged binary layer named f"{layer}{merged_suffix}" and length layer
1319
- f"{layer}{merged_suffix}{length_layer_suffix}". Optionally updates adata.obs for merged intervals.
1320
-
1321
- Parameters
1322
- ----------
1323
- layer : str
1324
- Name of original binary layer (0/1 mask).
1325
- distance_threshold : int
1326
- Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
1327
- merged_suffix : str
1328
- Suffix appended to original layer for the merged binary layer (default "_merged").
1329
- length_layer_suffix : str
1330
- Suffix appended after merged suffix for the lengths layer (default "_lengths").
1331
- update_obs : bool
1332
- If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
1333
- prob_strategy : str
1334
- How to combine probs when merging ('mean', 'max', 'orig_first').
1335
- inplace : bool
1336
- If False, returns a new AnnData with changes (original untouched).
1337
- overwrite : bool
1338
- If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
1339
- """
1340
- import numpy as _np
1341
- from scipy.sparse import issparse
1805
+ if iters is None:
1806
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
1807
+ return super().adapt_emissions(
1808
+ np.asarray(X, dtype=float),
1809
+ coords if coords is not None else None,
1810
+ iters=int(iters),
1811
+ device=device,
1812
+ verbose=verbose,
1813
+ )
1342
1814
 
1343
- if not inplace:
1344
- adata = adata.copy()
1815
+ def _ensure_n_channels(self, C: int, device: torch.device):
1816
+ """Expand emission parameters when channel count changes.
1345
1817
 
1346
- merged_bin_name = f"{layer}{merged_suffix}"
1347
- merged_len_name = f"{layer}{merged_suffix}{length_layer_suffix}"
1818
+ Args:
1819
+ C: Target channel count.
1820
+ device: Torch device for the new parameters.
1821
+ """
1822
+ C = int(C)
1823
+ if C == self.n_channels:
1824
+ return
1825
+ with torch.no_grad():
1826
+ old = self.emission.detach().cpu().numpy() # (K, Cold)
1827
+ K = old.shape[0]
1828
+ new = np.full((K, C), 0.5, dtype=float)
1829
+ m = min(old.shape[1], C)
1830
+ new[:, :m] = old[:, :m]
1831
+ if C > old.shape[1]:
1832
+ fill = old.mean(axis=1, keepdims=True)
1833
+ new[:, m:] = fill
1834
+ self.n_channels = C
1835
+ self.emission = nn.Parameter(
1836
+ torch.tensor(new, dtype=self.dtype, device=device), requires_grad=False
1837
+ )
1838
+ self._normalize_emission()
1839
+
1840
+
1841
+ # =============================================================================
1842
+ # Distance-binned transitions (single-channel only)
1843
+ # =============================================================================
1844
+
1845
+
1846
+ @register_hmm("single_distance_binned")
1847
+ class DistanceBinnedSingleBernoulliHMM(SingleBernoulliHMM):
1848
+ """
1849
+ Transition matrix depends on binned distances between consecutive coords.
1348
1850
 
1349
- if (merged_bin_name in adata.layers or merged_len_name in adata.layers or
1350
- (update_obs and merged_bin_name in adata.obs.columns)) and not overwrite:
1351
- raise KeyError(f"Merged outputs exist (use overwrite=True to replace): {merged_bin_name} / {merged_len_name}")
1851
+ Config keys:
1852
+ hmm_distance_bins: list[int] edges (bp)
1853
+ hmm_init_transitions_by_bin: optional (n_bins,K,K)
1854
+ """
1352
1855
 
1353
- if layer not in adata.layers:
1354
- raise KeyError(f"Layer '{layer}' not found in adata.layers")
1856
+ def __init__(
1857
+ self,
1858
+ n_states: int = 2,
1859
+ init_emission: Optional[Sequence[float]] = None,
1860
+ distance_bins: Optional[Sequence[int]] = None,
1861
+ init_trans_by_bin: Optional[Any] = None,
1862
+ eps: float = 1e-8,
1863
+ dtype: torch.dtype = torch.float64,
1864
+ ):
1865
+ """Initialize a distance-binned transition HMM.
1866
+
1867
+ Args:
1868
+ n_states: Number of hidden states.
1869
+ init_emission: Initial emission probabilities per state.
1870
+ distance_bins: Distance bin edges in base pairs.
1871
+ init_trans_by_bin: Initial transition matrices per bin.
1872
+ eps: Smoothing epsilon for probabilities.
1873
+ dtype: Torch dtype for parameters.
1874
+ """
1875
+ super().__init__(n_states=n_states, init_emission=init_emission, eps=eps, dtype=dtype)
1355
1876
 
1356
- bin_layer = adata.layers[layer]
1357
- if issparse(bin_layer):
1358
- bin_arr = bin_layer.toarray().astype(int)
1877
+ self.distance_bins = np.asarray(
1878
+ distance_bins if distance_bins is not None else [1, 5, 10, 25, 50, 100], dtype=int
1879
+ )
1880
+ self.n_bins = int(len(self.distance_bins) + 1)
1881
+
1882
+ if init_trans_by_bin is None:
1883
+ base = self.trans.detach().cpu().numpy()
1884
+ tb = np.stack([base for _ in range(self.n_bins)], axis=0)
1359
1885
  else:
1360
- bin_arr = _np.asarray(bin_layer, dtype=int)
1886
+ tb = np.asarray(init_trans_by_bin, dtype=float)
1887
+ if tb.shape != (self.n_bins, self.n_states, self.n_states):
1888
+ base = self.trans.detach().cpu().numpy()
1889
+ tb = np.stack([base for _ in range(self.n_bins)], axis=0)
1361
1890
 
1362
- n_rows, n_cols = bin_arr.shape
1891
+ self.trans_by_bin = nn.Parameter(torch.tensor(tb, dtype=self.dtype), requires_grad=False)
1892
+ self._normalize_trans_by_bin()
1363
1893
 
1364
- # coordinates in genomic units if possible
1365
- try:
1366
- coords = _np.asarray(adata.var_names, dtype=int)
1367
- coords_are_ints = True
1368
- except Exception:
1369
- coords = _np.arange(n_cols, dtype=int)
1370
- coords_are_ints = False
1371
-
1372
- # helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
1373
- def _runs_from_mask(mask_1d):
1374
- idx = _np.nonzero(mask_1d)[0]
1375
- if idx.size == 0:
1376
- return []
1377
- runs = []
1378
- start = idx[0]
1379
- prev = idx[0]
1380
- for i in idx[1:]:
1381
- if i == prev + 1:
1382
- prev = i
1383
- continue
1384
- runs.append((start, prev + 1))
1385
- start = i
1386
- prev = i
1387
- runs.append((start, prev + 1))
1388
- return runs
1389
-
1390
- # read original obs intervals/probs if available (for combining probs)
1391
- orig_obs = None
1392
- if update_obs and (layer in adata.obs.columns):
1393
- orig_obs = list(adata.obs[layer]) # might be non-list entries
1394
-
1395
- # prepare outputs
1396
- merged_bin = _np.zeros_like(bin_arr, dtype=int)
1397
- merged_len = _np.zeros_like(bin_arr, dtype=int)
1398
- merged_obs_col = [[] for _ in range(n_rows)]
1399
- merged_counts = _np.zeros(n_rows, dtype=int)
1400
-
1401
- for r in range(n_rows):
1402
- mask = bin_arr[r, :] != 0
1403
- runs = _runs_from_mask(mask)
1404
- if not runs:
1405
- merged_obs_col[r] = []
1406
- continue
1894
+ @classmethod
1895
+ def from_config(cls, cfg, *, override=None, device=None):
1896
+ """Create a distance-binned HMM from config.
1407
1897
 
1408
- # merge runs where gap <= distance_threshold (gap in genomic coords when possible)
1409
- merged_runs = []
1410
- cur_s, cur_e = runs[0]
1411
- for (s, e) in runs[1:]:
1412
- if coords_are_ints:
1413
- end_coord = int(coords[cur_e - 1])
1414
- next_start_coord = int(coords[s])
1415
- gap = next_start_coord - end_coord - 1
1416
- else:
1417
- gap = s - cur_e
1418
- if gap <= distance_threshold:
1419
- # extend
1420
- cur_e = e
1421
- else:
1422
- merged_runs.append((cur_s, cur_e))
1423
- cur_s, cur_e = s, e
1424
- merged_runs.append((cur_s, cur_e))
1425
-
1426
- # assemble merged mask/lengths and obs entries
1427
- row_entries = []
1428
- for (s_idx, e_idx) in merged_runs:
1429
- if e_idx <= s_idx:
1430
- continue
1431
- span_positions = e_idx - s_idx
1432
- if coords_are_ints:
1433
- try:
1434
- length_val = int(coords[e_idx - 1]) - int(coords[s_idx]) + 1
1435
- except Exception:
1436
- length_val = span_positions
1437
- else:
1438
- length_val = span_positions
1439
-
1440
- # set binary and length masks
1441
- merged_bin[r, s_idx:e_idx] = 1
1442
- existing_segment = merged_len[r, s_idx:e_idx]
1443
- # set to max(existing, length_val)
1444
- if existing_segment.size > 0:
1445
- merged_len[r, s_idx:e_idx] = _np.maximum(existing_segment, length_val)
1446
- else:
1447
- merged_len[r, s_idx:e_idx] = length_val
1448
-
1449
- # determine prob from overlapping original obs (if present)
1450
- prob_val = 1.0
1451
- if update_obs and orig_obs is not None:
1452
- overlaps = []
1453
- for orig in (orig_obs[r] or []):
1454
- try:
1455
- ostart, olen, opro = orig[0], int(orig[1]), float(orig[2])
1456
- except Exception:
1457
- continue
1458
- if coords_are_ints:
1459
- ostart_idx = _np.searchsorted(coords, int(ostart), side="left")
1460
- oend_idx = ostart_idx + olen
1461
- else:
1462
- ostart_idx = int(ostart)
1463
- oend_idx = ostart_idx + olen
1464
- # overlap test in index space
1465
- if not (oend_idx <= s_idx or ostart_idx >= e_idx):
1466
- overlaps.append(opro)
1467
- if overlaps:
1468
- if prob_strategy == "mean":
1469
- prob_val = float(_np.mean(overlaps))
1470
- elif prob_strategy == "max":
1471
- prob_val = float(_np.max(overlaps))
1472
- else:
1473
- prob_val = float(overlaps[0])
1474
-
1475
- start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
1476
- row_entries.append((start_coord, int(length_val), float(prob_val)))
1477
-
1478
- merged_obs_col[r] = row_entries
1479
- merged_counts[r] = len(row_entries)
1480
-
1481
- # write merged layers (do not overwrite originals unless overwrite=True was set above)
1482
- adata.layers[merged_bin_name] = merged_bin
1483
- adata.layers[merged_len_name] = merged_len
1484
-
1485
- if update_obs:
1486
- adata.obs[merged_bin_name] = merged_obs_col
1487
- adata.obs[f"n_{merged_bin_name}"] = merged_counts
1488
-
1489
- # recompute distances list per-row (gaps between adjacent merged intervals)
1490
- def _calc_distances(obs_list):
1491
- out = []
1492
- for intervals in obs_list:
1493
- if not intervals:
1494
- out.append([])
1495
- continue
1496
- iv = sorted(intervals, key=lambda x: int(x[0]))
1497
- if len(iv) <= 1:
1498
- out.append([])
1499
- continue
1500
- dlist = []
1501
- for i in range(len(iv) - 1):
1502
- endi = int(iv[i][0]) + int(iv[i][1]) - 1
1503
- startn = int(iv[i + 1][0])
1504
- dlist.append(startn - endi - 1)
1505
- out.append(dlist)
1506
- return out
1507
-
1508
- adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
1509
-
1510
- # update uns appended list
1511
- uns_key = "hmm_appended_layers"
1512
- existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
1513
- for nm in (merged_bin_name, merged_len_name):
1514
- if nm not in existing:
1515
- existing.append(nm)
1516
- adata.uns[uns_key] = existing
1517
-
1518
- if verbose:
1519
- print(f"Created merged binary layer: {merged_bin_name}")
1520
- print(f"Created merged length layer: {merged_len_name}")
1521
- if update_obs:
1522
- print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
1523
-
1524
- return None if inplace else adata
1525
-
1526
- def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
1898
+ Args:
1899
+ cfg: Configuration mapping or object.
1900
+ override: Override values to apply.
1901
+ device: Optional device specifier.
1902
+
1903
+ Returns:
1904
+ Initialized DistanceBinnedSingleBernoulliHMM instance.
1905
+ """
1906
+ merged = cls._cfg_to_dict(cfg)
1907
+ if override:
1908
+ merged.update(override)
1909
+
1910
+ n_states = int(merged.get("hmm_n_states", 2))
1911
+ eps = float(merged.get("hmm_eps", 1e-8))
1912
+ dtype = _resolve_dtype(merged.get("hmm_dtype", None))
1913
+ dtype = _coerce_dtype_for_device(dtype, device) # <<< NEW
1914
+ init_em = merged.get("hmm_init_emission_probs", None)
1915
+
1916
+ bins = merged.get("hmm_distance_bins", [1, 5, 10, 25, 50, 100])
1917
+ init_tb = merged.get("hmm_init_transitions_by_bin", None)
1918
+
1919
+ model = cls(
1920
+ n_states=n_states,
1921
+ init_emission=init_em,
1922
+ distance_bins=bins,
1923
+ init_trans_by_bin=init_tb,
1924
+ eps=eps,
1925
+ dtype=dtype,
1926
+ )
1927
+ if device is not None:
1928
+ model.to(torch.device(device) if isinstance(device, str) else device)
1929
+ model._persisted_cfg = merged
1930
+ return model
1931
+
1932
+ def _ensure_device_dtype(self, device=None) -> torch.device:
1933
+ """Move transition-by-bin parameters to the requested device/dtype."""
1934
+ device = super()._ensure_device_dtype(device)
1935
+ self.trans_by_bin.data = self.trans_by_bin.data.to(device=device, dtype=self.dtype)
1936
+ return device
1937
+
1938
+ def _normalize_trans_by_bin(self):
1939
+ """Normalize transition matrices per distance bin in-place."""
1940
+ with torch.no_grad():
1941
+ tb = self.trans_by_bin.data.reshape(self.n_bins, self.n_states, self.n_states)
1942
+ tb = tb + self.eps
1943
+ rs = tb.sum(dim=2, keepdim=True)
1944
+ rs[rs == 0.0] = 1.0
1945
+ self.trans_by_bin.data = tb / rs
1946
+
1947
+ def _extra_save_payload(self) -> dict:
1948
+ """Return extra payload data for serialization."""
1949
+ p = super()._extra_save_payload()
1950
+ p.update(
1951
+ {
1952
+ "distance_bins": torch.tensor(self.distance_bins, dtype=torch.long),
1953
+ "trans_by_bin": self.trans_by_bin.detach().cpu(),
1954
+ }
1955
+ )
1956
+ return p
1957
+
1958
+ def _load_extra_payload(self, payload: dict, *, device: torch.device):
1959
+ """Load serialized distance-bin parameters.
1960
+
1961
+ Args:
1962
+ payload: Serialized payload dictionary.
1963
+ device: Target torch device.
1527
1964
  """
1528
- Ensure final_adata.layers[layer_name] exists and assign rows corresponding to subset_idx_mask
1529
- sub_data has shape (n_subset_rows, n_vars).
1530
- subset_idx_mask: boolean array of length final_adata.n_obs with True where rows belong to subset.
1965
+ super()._load_extra_payload(payload, device=device)
1966
+ self.distance_bins = (
1967
+ payload.get("distance_bins", torch.tensor([1, 5, 10, 25, 50, 100]))
1968
+ .cpu()
1969
+ .numpy()
1970
+ .astype(int)
1971
+ )
1972
+ self.n_bins = int(len(self.distance_bins) + 1)
1973
+ with torch.no_grad():
1974
+ self.trans_by_bin.data = payload["trans_by_bin"].to(device=device, dtype=self.dtype)
1975
+ self._normalize_trans_by_bin()
1976
+
1977
+ def _bin_index(self, coords: np.ndarray) -> np.ndarray:
1978
+ """Return per-step distance bin indices for coordinates.
1979
+
1980
+ Args:
1981
+ coords: Coordinate array.
1982
+
1983
+ Returns:
1984
+ Array of bin indices (length L-1).
1531
1985
  """
1532
- from scipy.sparse import issparse, csr_matrix
1533
- import warnings
1986
+ d = np.diff(np.asarray(coords, dtype=int))
1987
+ return np.digitize(d, self.distance_bins, right=True) # length L-1
1534
1988
 
1535
- n_final_obs, n_vars = final_adata.shape
1536
- n_sub_rows = int(subset_idx_mask.sum())
1537
-
1538
- # prepare row indices in final_adata
1539
- final_row_indices = np.nonzero(subset_idx_mask)[0]
1540
-
1541
- # if sub_data is sparse, work with sparse
1542
- if issparse(sub_data):
1543
- sub_csr = sub_data.tocsr()
1544
- # if final layer not present, create sparse CSR with zero rows and same n_vars
1545
- if layer_name not in final_adata.layers:
1546
- # create an empty CSR of shape (n_final_obs, n_vars)
1547
- final_adata.layers[layer_name] = csr_matrix((n_final_obs, n_vars), dtype=sub_csr.dtype)
1548
- final_csr = final_adata.layers[layer_name]
1549
- if not issparse(final_csr):
1550
- # convert dense final to sparse first
1551
- final_csr = csr_matrix(final_csr)
1552
- # replace the block of rows: easiest is to build a new csr by stacking pieces
1553
- # (efficient for moderate sizes; for huge data you might want an in-place approach)
1554
- # Build list of blocks: rows before, the subset rows (from final where mask False -> zeros), rows after
1555
- # We'll convert final to LIL for row assignment (mutable), then back to CSR.
1556
- final_lil = final_csr.tolil()
1557
- for i_local, r in enumerate(final_row_indices):
1558
- final_lil.rows[r] = sub_csr.getrow(i_local).indices.tolist()
1559
- final_lil.data[r] = sub_csr.getrow(i_local).data.tolist()
1560
- final_csr = final_lil.tocsr()
1561
- final_adata.layers[layer_name] = final_csr
1562
- else:
1563
- # dense numpy array
1564
- sub_arr = np.asarray(sub_data)
1565
- if sub_arr.shape[0] != n_sub_rows:
1566
- raise ValueError(f"Sub data rows ({sub_arr.shape[0]}) != mask selected rows ({n_sub_rows})")
1567
- if layer_name not in final_adata.layers:
1568
- # create zero array with small dtype
1569
- final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
1570
- final_arr = final_adata.layers[layer_name]
1571
- if issparse(final_arr):
1572
- # convert sparse final to dense (or convert sub to sparse); we'll convert final to dense here
1573
- final_arr = final_arr.toarray()
1574
- # assign
1575
- final_arr[final_row_indices, :] = sub_arr
1576
- final_adata.layers[layer_name] = final_arr
1989
+ def _forward_backward(
1990
+ self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
1991
+ ) -> torch.Tensor:
1992
+ """Run forward-backward using distance-binned transitions.
1993
+
1994
+ Args:
1995
+ obs: Observation tensor.
1996
+ mask: Observation mask.
1997
+ coords: Coordinate array.
1998
+
1999
+ Returns:
2000
+ Posterior probabilities (gamma).
2001
+ """
2002
+ if coords is None:
2003
+ raise ValueError("Distance-binned HMM requires coords.")
2004
+ device = obs.device
2005
+ eps = float(self.eps)
2006
+ K = self.n_states
2007
+
2008
+ coords = np.asarray(coords, dtype=int)
2009
+ bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
2010
+
2011
+ logB = self._log_emission(obs, mask) # (N,L,K)
2012
+ logstart = torch.log(self.start + eps)
2013
+ logA_by_bin = torch.log(self.trans_by_bin + eps) # (nb,K,K)
2014
+
2015
+ N, L, _ = logB.shape
2016
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
2017
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2018
+
2019
+ for t in range(1, L):
2020
+ b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
2021
+ logA = logA_by_bin[b]
2022
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2023
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
2024
+
2025
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2026
+ beta[:, L - 1, :] = 0.0
2027
+ for t in range(L - 2, -1, -1):
2028
+ b = int(bins[t].item()) if t < bins.numel() else 0
2029
+ logA = logA_by_bin[b]
2030
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2031
+ beta[:, t, :] = _logsumexp(temp, dim=2)
2032
+
2033
+ log_gamma = alpha + beta
2034
+ logZ = _logsumexp(log_gamma, dim=2).unsqueeze(2)
2035
+ return (log_gamma - logZ).exp()
2036
+
2037
+ def _viterbi(
2038
+ self, obs: torch.Tensor, mask: torch.Tensor, *, coords: Optional[np.ndarray] = None
2039
+ ) -> torch.Tensor:
2040
+ """Run Viterbi decoding using distance-binned transitions.
2041
+
2042
+ Args:
2043
+ obs: Observation tensor.
2044
+ mask: Observation mask.
2045
+ coords: Coordinate array.
2046
+
2047
+ Returns:
2048
+ Decoded state sequence tensor.
2049
+ """
2050
+ if coords is None:
2051
+ raise ValueError("Distance-binned HMM requires coords.")
2052
+ device = obs.device
2053
+ eps = float(self.eps)
2054
+ K = self.n_states
2055
+
2056
+ coords = np.asarray(coords, dtype=int)
2057
+ bins = torch.tensor(self._bin_index(coords), dtype=torch.long, device=device) # (L-1,)
2058
+
2059
+ logB = self._log_emission(obs, mask)
2060
+ logstart = torch.log(self.start + eps)
2061
+ logA_by_bin = torch.log(self.trans_by_bin + eps)
2062
+
2063
+ N, L, _ = logB.shape
2064
+ delta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2065
+ psi = torch.empty((N, L, K), dtype=torch.long, device=device)
2066
+
2067
+ delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2068
+ psi[:, 0, :] = -1
2069
+
2070
+ for t in range(1, L):
2071
+ b = int(bins[t - 1].item()) if (t - 1) < bins.numel() else 0
2072
+ logA = logA_by_bin[b]
2073
+ cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2074
+ best_val, best_idx = cand.max(dim=1)
2075
+ delta[:, t, :] = best_val + logB[:, t, :]
2076
+ psi[:, t, :] = best_idx
2077
+
2078
+ last_state = torch.argmax(delta[:, L - 1, :], dim=1)
2079
+ states = torch.empty((N, L), dtype=torch.long, device=device)
2080
+ states[:, L - 1] = last_state
2081
+ for t in range(L - 2, -1, -1):
2082
+ states[:, t] = psi[torch.arange(N, device=device), t + 1, states[:, t + 1]]
2083
+ return states
2084
+
2085
+ def fit_em(
2086
+ self,
2087
+ X: np.ndarray,
2088
+ coords: np.ndarray,
2089
+ *,
2090
+ device: torch.device,
2091
+ max_iter: int,
2092
+ tol: float,
2093
+ update_start: bool,
2094
+ update_trans: bool,
2095
+ update_emission: bool,
2096
+ verbose: bool,
2097
+ **kwargs,
2098
+ ) -> List[float]:
2099
+ """Run EM updates for distance-binned transitions.
2100
+
2101
+ Args:
2102
+ X: Observations array (N, L).
2103
+ coords: Coordinate array aligned to X.
2104
+ device: Torch device.
2105
+ max_iter: Maximum iterations.
2106
+ tol: Convergence tolerance.
2107
+ update_start: Whether to update start probabilities.
2108
+ update_trans: Whether to update transitions.
2109
+ update_emission: Whether to update emission parameters.
2110
+ verbose: Whether to log progress.
2111
+ **kwargs: Additional implementation-specific kwargs.
2112
+
2113
+ Returns:
2114
+ List of log-likelihood proxy values.
2115
+ """
2116
+ # Keep this simple: use gamma for emissions; transitions-by-bin updated via xi (same pattern).
2117
+ X = np.asarray(X, dtype=float)
2118
+ if X.ndim != 2:
2119
+ raise ValueError("DistanceBinnedSingleBernoulliHMM expects X shape (N,L).")
2120
+
2121
+ coords = np.asarray(coords, dtype=int)
2122
+ bins_np = self._bin_index(coords) # (L-1,)
2123
+
2124
+ obs = torch.tensor(np.nan_to_num(X, nan=0.0), dtype=self.dtype, device=device)
2125
+ mask = torch.tensor(~np.isnan(X), dtype=torch.bool, device=device)
2126
+
2127
+ eps = float(self.eps)
2128
+ K = self.n_states
2129
+ N, L = obs.shape
2130
+
2131
+ hist: List[float] = []
2132
+ for it in range(1, int(max_iter) + 1):
2133
+ gamma = self._forward_backward(obs, mask, coords=coords) # (N,L,K)
2134
+ ll_proxy = float(torch.sum(torch.log(torch.clamp(gamma.sum(dim=2), min=eps))).item())
2135
+ hist.append(ll_proxy)
2136
+
2137
+ # expected start
2138
+ start_acc = gamma[:, 0, :].sum(dim=0)
2139
+
2140
+ # compute alpha/beta for xi
2141
+ logB = self._log_emission(obs, mask)
2142
+ logstart = torch.log(self.start + eps)
2143
+ logA_by_bin = torch.log(self.trans_by_bin + eps)
2144
+
2145
+ alpha = torch.empty((N, L, K), dtype=self.dtype, device=device)
2146
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
2147
+
2148
+ for t in range(1, L):
2149
+ b = int(bins_np[t - 1])
2150
+ logA = logA_by_bin[b]
2151
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
2152
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
2153
+
2154
+ beta = torch.empty((N, L, K), dtype=self.dtype, device=device)
2155
+ beta[:, L - 1, :] = 0.0
2156
+ for t in range(L - 2, -1, -1):
2157
+ b = int(bins_np[t])
2158
+ logA = logA_by_bin[b]
2159
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2160
+ beta[:, t, :] = _logsumexp(temp, dim=2)
2161
+
2162
+ trans_acc_by_bin = torch.zeros((self.n_bins, K, K), dtype=self.dtype, device=device)
2163
+ for t in range(L - 1):
2164
+ b = int(bins_np[t])
2165
+ logA = logA_by_bin[b]
2166
+ valid_t = (mask[:, t] & mask[:, t + 1]).float().view(N, 1, 1)
2167
+ log_xi = (
2168
+ alpha[:, t, :].unsqueeze(2)
2169
+ + logA.unsqueeze(0)
2170
+ + (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1)
2171
+ )
2172
+ log_norm = _logsumexp(log_xi.view(N, -1), dim=1).view(N, 1, 1)
2173
+ xi = (log_xi - log_norm).exp() * valid_t
2174
+ trans_acc_by_bin[b] += xi.sum(dim=0)
2175
+
2176
+ mask_f = mask.float().unsqueeze(-1)
2177
+ emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1))
2178
+ emit_den = (gamma * mask_f).sum(dim=(0, 1))
2179
+
2180
+ with torch.no_grad():
2181
+ if update_start:
2182
+ new_start = start_acc + eps
2183
+ self.start.data = new_start / new_start.sum()
2184
+
2185
+ if update_trans:
2186
+ tb = trans_acc_by_bin + eps
2187
+ rs = tb.sum(dim=2, keepdim=True)
2188
+ rs[rs == 0.0] = 1.0
2189
+ self.trans_by_bin.data = tb / rs
2190
+
2191
+ if update_emission:
2192
+ new_em = (emit_num + eps) / (emit_den + 2.0 * eps)
2193
+ self.emission.data = new_em.clamp(min=eps, max=1.0 - eps)
2194
+
2195
+ self._normalize_params()
2196
+ self._normalize_emission()
2197
+ self._normalize_trans_by_bin()
2198
+
2199
+ if verbose:
2200
+ logger.info(
2201
+ "[DistanceBinnedSingle.fit] iter=%s ll_proxy=%.6f",
2202
+ it,
2203
+ hist[-1],
2204
+ )
2205
+
2206
+ if len(hist) > 1 and abs(hist[-1] - hist[-2]) < float(tol):
2207
+ break
2208
+
2209
+ return hist
2210
+
2211
+ def adapt_emissions(
2212
+ self,
2213
+ X: np.ndarray,
2214
+ coords: Optional[np.ndarray] = None,
2215
+ *,
2216
+ device: Optional[Union[str, torch.device]] = None,
2217
+ iters: Optional[int] = None,
2218
+ max_iter: Optional[int] = None,
2219
+ verbose: bool = False,
2220
+ **kwargs,
2221
+ ):
2222
+ """Adapt emissions with legacy parameter names.
2223
+
2224
+ Args:
2225
+ X: Observations array.
2226
+ coords: Optional coordinate array.
2227
+ device: Device specifier.
2228
+ iters: Number of iterations.
2229
+ max_iter: Alias for iters.
2230
+ verbose: Whether to log progress.
2231
+ **kwargs: Additional kwargs forwarded to BaseHMM.adapt_emissions.
2232
+
2233
+ Returns:
2234
+ List of log-likelihood values.
2235
+ """
2236
+ if iters is None:
2237
+ iters = int(max_iter) if max_iter is not None else int(kwargs.pop("iters", 5))
2238
+ return super().adapt_emissions(
2239
+ np.asarray(X, dtype=float),
2240
+ coords if coords is not None else None,
2241
+ iters=int(iters),
2242
+ device=device,
2243
+ verbose=verbose,
2244
+ )
2245
+
2246
+
2247
+ # =============================================================================
2248
+ # Facade class to match workflow import style
2249
+ # =============================================================================
2250
+
2251
+
2252
+ class HMM:
2253
+ """
2254
+ Facade so workflow can do:
2255
+ from ..hmm.HMM import HMM
2256
+ hmm = HMM.from_config(cfg, arch="single")
2257
+ hmm.save(...)
2258
+ hmm = HMM.load(...)
2259
+ """
2260
+
2261
+ @staticmethod
2262
+ def from_config(cfg, arch: Optional[str] = None, **kwargs) -> BaseHMM:
2263
+ """Create an HMM instance from configuration.
2264
+
2265
+ Args:
2266
+ cfg: Configuration mapping or object.
2267
+ arch: Optional HMM architecture name.
2268
+ **kwargs: Additional parameters passed to the factory.
2269
+
2270
+ Returns:
2271
+ Initialized HMM instance.
2272
+ """
2273
+ return create_hmm(cfg, arch=arch, **kwargs)
2274
+
2275
+ @staticmethod
2276
+ def load(path: Union[str, Path], device: Optional[Union[str, torch.device]] = None) -> BaseHMM:
2277
+ """Load an HMM instance from disk.
2278
+
2279
+ Args:
2280
+ path: Path to the serialized model.
2281
+ device: Optional device specifier.
2282
+
2283
+ Returns:
2284
+ Loaded HMM instance.
2285
+ """
2286
+ return BaseHMM.load(path, device=device)