smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py ADDED
@@ -0,0 +1,1576 @@
1
+ import math
2
+ from typing import List, Optional, Tuple, Union, Any, Dict
3
+ import ast
4
+ import json
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ def _logsumexp(vec: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
12
+ return torch.logsumexp(vec, dim=dim, keepdim=keepdim)
13
+
14
+ class HMM(nn.Module):
15
+ """
16
+ Vectorized HMM (Bernoulli emissions) implemented in PyTorch.
17
+
18
+ Methods:
19
+ - fit(data, ...) -> trains params in-place
20
+ - predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
21
+ - viterbi(seq, ...) -> (path_list, score)
22
+ - batch_viterbi(data, ...) -> list of (path_list, score)
23
+ - score(seq_or_list, ...) -> float or list of floats
24
+
25
+ Notes:
26
+ - data: list of sequences (each sequence is iterable of {0,1,np.nan}).
27
+ - impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ n_states: int = 2,
33
+ init_start: Optional[List[float]] = None,
34
+ init_trans: Optional[List[List[float]]] = None,
35
+ init_emission: Optional[List[float]] = None,
36
+ dtype: torch.dtype = torch.float64,
37
+ eps: float = 1e-8,
38
+ smf_modality: Optional[str] = None,
39
+ ):
40
+ super().__init__()
41
+ if n_states < 2:
42
+ raise ValueError("n_states must be >= 2")
43
+ self.n_states = n_states
44
+ self.eps = float(eps)
45
+ self.dtype = dtype
46
+ self.smf_modality = smf_modality
47
+
48
+ # initialize params (probabilities)
49
+ if init_start is None:
50
+ start = np.full((n_states,), 1.0 / n_states, dtype=float)
51
+ else:
52
+ start = np.asarray(init_start, dtype=float)
53
+ if init_trans is None:
54
+ trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
55
+ else:
56
+ trans = np.asarray(init_trans, dtype=float)
57
+ # --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
58
+ if init_emission is None:
59
+ emission = np.full((n_states,), 0.5, dtype=float)
60
+ else:
61
+ em_arr = np.asarray(init_emission, dtype=float)
62
+ # case: (K,2) -> pick P(obs==1) from second column
63
+ if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
64
+ emission = em_arr[:, 1].astype(float)
65
+ # case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
66
+ elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
67
+ emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
68
+ else:
69
+ emission = em_arr.reshape(-1)[:n_states].astype(float)
70
+
71
+ # store as parameters (not trainable via grad; EM updates .data in-place)
72
+ self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
73
+ self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
74
+ self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
75
+
76
+ self._normalize_params()
77
+
78
+ def _normalize_params(self):
79
+ with torch.no_grad():
80
+ # coerce shapes
81
+ K = self.n_states
82
+ self.start.data = self.start.data.squeeze()
83
+ if self.start.data.numel() != K:
84
+ self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
85
+
86
+ self.trans.data = self.trans.data.squeeze()
87
+ if not (self.trans.data.ndim == 2 and self.trans.data.shape == (K, K)):
88
+ if K == 2:
89
+ self.trans.data = torch.tensor([[0.9,0.1],[0.1,0.9]], dtype=self.dtype)
90
+ else:
91
+ self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
92
+
93
+ self.emission.data = self.emission.data.squeeze()
94
+ if self.emission.data.numel() != K:
95
+ self.emission.data = torch.full((K,), 0.5, dtype=self.dtype)
96
+
97
+ # now perform smoothing/normalization
98
+ self.start.data = (self.start.data + self.eps)
99
+ self.start.data = self.start.data / self.start.data.sum()
100
+
101
+ self.trans.data = (self.trans.data + self.eps)
102
+ row_sums = self.trans.data.sum(dim=1, keepdim=True)
103
+ row_sums[row_sums == 0.0] = 1.0
104
+ self.trans.data = self.trans.data / row_sums
105
+
106
+ self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
107
+
108
+ @staticmethod
109
+ def _resolve_dtype(dtype_entry):
110
+ """Accept torch.dtype, string ('float32'/'float64') or None -> torch.dtype."""
111
+ if dtype_entry is None:
112
+ return torch.float64
113
+ if isinstance(dtype_entry, torch.dtype):
114
+ return dtype_entry
115
+ s = str(dtype_entry).lower()
116
+ if "32" in s:
117
+ return torch.float32
118
+ if "16" in s:
119
+ return torch.float16
120
+ return torch.float64
121
+
122
+ @classmethod
123
+ def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
124
+ override: Optional[dict] = None,
125
+ device: Optional[Union[str, torch.device]] = None) -> "HMM":
126
+ """
127
+ Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
128
+
129
+ cfg may be:
130
+ - an ExperimentConfig (your dataclass instance)
131
+ - a dict (e.g. loader.var_dict or merged defaults)
132
+ - None (uses internal defaults)
133
+
134
+ override: optional dict to override resolved keys (handy for tests).
135
+ device: optional device string or torch.device to move model to.
136
+ """
137
+ # Accept ExperimentConfig dataclass
138
+ if cfg is None:
139
+ merged = {}
140
+ elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
141
+ merged = dict(cfg.to_dict())
142
+ elif isinstance(cfg, dict):
143
+ merged = dict(cfg)
144
+ else:
145
+ # try attr access as fallback
146
+ try:
147
+ merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
148
+ except Exception:
149
+ merged = {}
150
+
151
+ if override:
152
+ merged.update(override)
153
+
154
+ # basic resolution with fallback
155
+ n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
156
+ init_start = merged.get("hmm_init_start_probs", merged.get("hmm_init_start", None))
157
+ init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
158
+ init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
159
+ eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
160
+ dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
161
+
162
+ # coerce lists (if present) -> numpy arrays (the HMM constructor already sanitizes)
163
+ def _coerce_np(x):
164
+ if x is None:
165
+ return None
166
+ return np.asarray(x, dtype=float)
167
+
168
+ init_start = _coerce_np(init_start)
169
+ init_trans = _coerce_np(init_trans)
170
+ init_emission = _coerce_np(init_emission)
171
+
172
+ model = cls(
173
+ n_states=n_states,
174
+ init_start=init_start,
175
+ init_trans=init_trans,
176
+ init_emission=init_emission,
177
+ dtype=dtype,
178
+ eps=eps,
179
+ smf_modality=merged.get("smf_modality", None),
180
+ )
181
+
182
+ # move to device if requested
183
+ if device is not None:
184
+ if isinstance(device, str):
185
+ device = torch.device(device)
186
+ model.to(device)
187
+
188
+ # persist the config to the hmm class
189
+ cls.config = cfg
190
+
191
+ return model
192
+
193
+ def update_from_config(self, cfg: Union[dict, "ExperimentConfig", None], *,
194
+ override: Optional[dict] = None):
195
+ """
196
+ Update existing model parameters from a config or dict (in-place).
197
+ This will normalize / reinitialize start/trans/emission using same logic as constructor.
198
+ """
199
+ if cfg is None:
200
+ merged = {}
201
+ elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
202
+ merged = dict(cfg.to_dict())
203
+ elif isinstance(cfg, dict):
204
+ merged = dict(cfg)
205
+ else:
206
+ try:
207
+ merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
208
+ except Exception:
209
+ merged = {}
210
+
211
+ if override:
212
+ merged.update(override)
213
+
214
+ # extract same keys as from_config
215
+ n_states = int(merged.get("hmm_n_states", self.n_states))
216
+ init_start = merged.get("hmm_init_start_probs", None)
217
+ init_trans = merged.get("hmm_init_transition_probs", None)
218
+ init_emission = merged.get("hmm_init_emission_probs", None)
219
+ eps = merged.get("hmm_eps", None)
220
+ dtype = merged.get("hmm_dtype", None)
221
+
222
+ # apply dtype/eps if present
223
+ if eps is not None:
224
+ self.eps = float(eps)
225
+ if dtype is not None:
226
+ self.dtype = self._resolve_dtype(dtype)
227
+
228
+ # if n_states changes we need a fresh re-init (easy approach: reconstruct)
229
+ if int(n_states) != int(self.n_states):
230
+ # rebuild self in-place: create a new model and copy tensors
231
+ new_model = HMM.from_config(merged)
232
+ # copy content
233
+ with torch.no_grad():
234
+ self.n_states = new_model.n_states
235
+ self.eps = new_model.eps
236
+ self.dtype = new_model.dtype
237
+ self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
238
+ self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
239
+ self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
240
+ return
241
+
242
+ # else only update provided tensors
243
+ def _to_tensor(obj, shape_expected=None):
244
+ if obj is None:
245
+ return None
246
+ arr = np.asarray(obj, dtype=float)
247
+ if shape_expected is not None:
248
+ try:
249
+ arr = arr.reshape(shape_expected)
250
+ except Exception:
251
+ # try to free-form slice/reshape (keep best-effort)
252
+ arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
253
+ return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
254
+
255
+ with torch.no_grad():
256
+ if init_start is not None:
257
+ t = _to_tensor(init_start, (self.n_states,))
258
+ if t.numel() == self.n_states:
259
+ self.start.data = t.clone()
260
+ if init_trans is not None:
261
+ t = _to_tensor(init_trans, (self.n_states, self.n_states))
262
+ if t.shape == (self.n_states, self.n_states):
263
+ self.trans.data = t.clone()
264
+ if init_emission is not None:
265
+ # attempt to extract P(obs==1) if shaped (K,2)
266
+ arr = np.asarray(init_emission, dtype=float)
267
+ if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
268
+ em = arr[: self.n_states, 1]
269
+ else:
270
+ em = arr.reshape(-1)[: self.n_states]
271
+ t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
272
+ if t.numel() == self.n_states:
273
+ self.emission.data = t.clone()
274
+
275
+ # finally normalize
276
+ self._normalize_params()
277
+
278
+ def _ensure_device_dtype(self, device: Optional[torch.device]):
279
+ if device is None:
280
+ device = next(self.parameters()).device
281
+ self.start.data = self.start.data.to(device=device, dtype=self.dtype)
282
+ self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
283
+ self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
284
+ return device
285
+
286
+ @staticmethod
287
+ def _pad_and_mask(
288
+ data: List[List],
289
+ device: torch.device,
290
+ dtype: torch.dtype,
291
+ impute_strategy: str = "ignore",
292
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
293
+ """
294
+ Pads sequences to shape (B, L). Returns (obs, mask, lengths)
295
+ - Accepts: list-of-seqs, or 2D ndarray (B, L).
296
+ - If a sequence element is itself an array (per-timestep feature vector),
297
+ collapse the last axis by mean (warns once).
298
+ """
299
+ import warnings
300
+
301
+ # If somebody passed a 2-D ndarray directly, convert to list-of-rows
302
+ if isinstance(data, np.ndarray) and data.ndim == 2:
303
+ # convert rows -> python lists (scalars per timestep)
304
+ data = data.tolist()
305
+
306
+ B = len(data)
307
+ lengths = torch.tensor([len(s) for s in data], dtype=torch.long, device=device)
308
+ L = int(lengths.max().item()) if B > 0 else 0
309
+ obs = torch.zeros((B, L), dtype=dtype, device=device)
310
+ mask = torch.zeros((B, L), dtype=torch.bool, device=device)
311
+
312
+ warned_collapse = False
313
+
314
+ for i, seq in enumerate(data):
315
+ # seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
316
+ arr = np.asarray(seq, dtype=float)
317
+
318
+ # If arr is shape (L,1,1,...) squeeze trailing singletons
319
+ while arr.ndim > 1 and arr.shape[-1] == 1:
320
+ arr = np.squeeze(arr, axis=-1)
321
+
322
+ # If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
323
+ if arr.ndim > 1:
324
+ if not warned_collapse:
325
+ warnings.warn(
326
+ "HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
327
+ "(arr had shape {}). If you prefer a different reduction, "
328
+ "preprocess your data.".format(arr.shape),
329
+ stacklevel=2,
330
+ )
331
+ warned_collapse = True
332
+ # collapse features -> scalar per timestep
333
+ arr = np.asarray(arr, dtype=float).mean(axis=-1)
334
+
335
+ # now arr should be 1D (T,)
336
+ if arr.ndim == 0:
337
+ # single scalar: treat as length-1 sequence
338
+ arr = np.atleast_1d(arr)
339
+
340
+ nan_mask = np.isnan(arr)
341
+ if impute_strategy == "random" and nan_mask.any():
342
+ arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
343
+ local_mask = np.ones_like(arr, dtype=bool)
344
+ else:
345
+ local_mask = ~nan_mask
346
+ arr = np.where(local_mask, arr, 0.0)
347
+
348
+ L_i = arr.shape[0]
349
+ obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
350
+ mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
351
+
352
+ return obs, mask, lengths
353
+
354
+ def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
355
+ """
356
+ obs: (B, L)
357
+ mask: (B, L) bool
358
+ returns logB (B, L, K)
359
+ """
360
+ B, L = obs.shape
361
+ p = self.emission # (K,)
362
+ logp = torch.log(p + self.eps)
363
+ log1mp = torch.log1p(-p + self.eps)
364
+ obs_expand = obs.unsqueeze(-1) # (B, L, 1)
365
+ logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
366
+ logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
367
+ return logB
368
+
369
+ def fit(
370
+ self,
371
+ data: List[List],
372
+ max_iter: int = 100,
373
+ tol: float = 1e-4,
374
+ impute_strategy: str = "ignore",
375
+ verbose: bool = True,
376
+ return_history: bool = False,
377
+ device: Optional[Union[torch.device, str]] = None,
378
+ ):
379
+ """
380
+ Vectorized Baum-Welch EM across a batch of sequences (padded).
381
+ """
382
+ if device is None:
383
+ device = next(self.parameters()).device
384
+ elif isinstance(device, str):
385
+ device = torch.device(device)
386
+ device = self._ensure_device_dtype(device)
387
+
388
+ if isinstance(data, np.ndarray):
389
+ if data.ndim == 2:
390
+ # rows are sequences: convert to list of 1D arrays
391
+ data = data.tolist()
392
+ elif data.ndim == 1:
393
+ # single sequence
394
+ data = [data.tolist()]
395
+ else:
396
+ raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
397
+
398
+ obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
399
+ B, L = obs.shape
400
+ K = self.n_states
401
+ eps = float(self.eps)
402
+
403
+ if verbose:
404
+ print(f"[HMM.fit] device={device}, batch={B}, max_len={L}, states={K}")
405
+
406
+ loglik_history = []
407
+
408
+ for it in range(1, max_iter + 1):
409
+ if verbose:
410
+ print(f"[HMM.fit] EM iter {it}")
411
+
412
+ # compute batched emission logs
413
+ logB = self._log_emission(obs, mask) # (B, L, K)
414
+
415
+ # logs for start and transition
416
+ logA = torch.log(self.trans + eps) # (K, K)
417
+ logstart = torch.log(self.start + eps) # (K,)
418
+
419
+ # Forward (batched)
420
+ alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
421
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :] # (B,K)
422
+ for t in range(1, L):
423
+ # prev: (B, i, 1) + (1, i, j) broadcast => (B, i, j)
424
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
425
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
426
+
427
+ # Backward (batched)
428
+ beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
429
+ beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
430
+ for t in range(L - 2, -1, -1):
431
+ # temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
432
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
433
+ beta[:, t, :] = _logsumexp(temp, dim=2)
434
+
435
+ # sequence log-likelihoods (use last real index)
436
+ last_idx = (lengths - 1).clamp(min=0)
437
+ idx_range = torch.arange(B, device=device)
438
+ final_alpha = alpha[idx_range, last_idx, :] # (B, K)
439
+ seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
440
+ total_loglike = float(seq_loglikes.sum().item())
441
+
442
+ # posterior gamma (B, L, K)
443
+ log_gamma = alpha + beta # (B, L, K)
444
+ logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True) # (B, L, 1)
445
+ gamma = (log_gamma - logZ_time).exp() # (B, L, K)
446
+
447
+ # accumulators: starts, transitions, emissions
448
+ gamma_start_accum = gamma[:, 0, :].sum(dim=0) # (K,)
449
+
450
+ # emission accumulators: sum over observed positions only
451
+ mask_f = mask.unsqueeze(-1) # (B, L, 1)
452
+ emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
453
+ emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
454
+
455
+ # transitions: accumulate xi across t for valid positions
456
+ trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
457
+ if L >= 2:
458
+ time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
459
+ valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
460
+ for t in range(L - 1):
461
+ a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
462
+ b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
463
+ log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
464
+ log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
465
+ log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
466
+ xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
467
+ valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
468
+ xi_masked = xi * valid_t
469
+ trans_accum += xi_masked.sum(dim=0) # (i,j)
470
+
471
+ # M-step: update parameters with smoothing
472
+ with torch.no_grad():
473
+ new_start = gamma_start_accum + eps
474
+ new_start = new_start / new_start.sum()
475
+
476
+ new_trans = trans_accum + eps
477
+ row_sums = new_trans.sum(dim=1, keepdim=True)
478
+ row_sums[row_sums == 0.0] = 1.0
479
+ new_trans = new_trans / row_sums
480
+
481
+ new_emission = (emit_num + eps) / (emit_den + 2.0 * eps)
482
+ new_emission = new_emission.clamp(min=eps, max=1.0 - eps)
483
+
484
+ self.start.data = new_start
485
+ self.trans.data = new_trans
486
+ self.emission.data = new_emission
487
+
488
+ loglik_history.append(total_loglike)
489
+ if verbose:
490
+ print(f" total loglik = {total_loglike:.6f}")
491
+
492
+ if len(loglik_history) > 1 and abs(loglik_history[-1] - loglik_history[-2]) < tol:
493
+ if verbose:
494
+ print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
495
+ break
496
+
497
+ return loglik_history if return_history else None
498
+
499
+ def get_params(self) -> dict:
500
+ """
501
+ Return model parameters as numpy arrays on CPU.
502
+ """
503
+ with torch.no_grad():
504
+ return {
505
+ "n_states": int(self.n_states),
506
+ "start": self.start.detach().cpu().numpy().astype(float).reshape(-1),
507
+ "trans": self.trans.detach().cpu().numpy().astype(float),
508
+ "emission": self.emission.detach().cpu().numpy().astype(float).reshape(-1),
509
+ }
510
+
511
+ def print_params(self, decimals: int = 4):
512
+ """
513
+ Nicely print start, transition, and emission probabilities.
514
+ """
515
+ params = self.get_params()
516
+ K = params["n_states"]
517
+ fmt = f"{{:.{decimals}f}}"
518
+ print(f"HMM params (K={K} states):")
519
+ print(" start probs:")
520
+ print(" [" + ", ".join(fmt.format(v) for v in params["start"]) + "]")
521
+ print(" transition matrix (rows = from-state, cols = to-state):")
522
+ for i, row in enumerate(params["trans"]):
523
+ print(" s{:d}: [".format(i) + ", ".join(fmt.format(v) for v in row) + "]")
524
+ print(" emission P(obs==1 | state):")
525
+ for i, v in enumerate(params["emission"]):
526
+ print(f" s{i}: {fmt.format(v)}")
527
+
528
+ def to_dataframes(self) -> dict:
529
+ """
530
+ Return pandas DataFrames for start (Series), trans (DataFrame), emission (Series).
531
+ """
532
+ p = self.get_params()
533
+ K = p["n_states"]
534
+ state_names = [f"state_{i}" for i in range(K)]
535
+ start_s = pd.Series(p["start"], index=state_names, name="start_prob")
536
+ trans_df = pd.DataFrame(p["trans"], index=state_names, columns=state_names)
537
+ emission_s = pd.Series(p["emission"], index=state_names, name="p_obs1")
538
+ return {"start": start_s, "trans": trans_df, "emission": emission_s}
539
+
540
+ def predict(self, data: List[List], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> List[np.ndarray]:
541
+ """
542
+ Return posterior marginals gamma_t(k) for each sequence as list of (L, K) numpy arrays.
543
+ """
544
+ if device is None:
545
+ device = next(self.parameters()).device
546
+ elif isinstance(device, str):
547
+ device = torch.device(device)
548
+ device = self._ensure_device_dtype(device)
549
+
550
+ obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
551
+ B, L = obs.shape
552
+ K = self.n_states
553
+ eps = float(self.eps)
554
+
555
+ logB = self._log_emission(obs, mask) # (B, L, K)
556
+ logA = torch.log(self.trans + eps)
557
+ logstart = torch.log(self.start + eps)
558
+
559
+ # Forward
560
+ alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
561
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
562
+ for t in range(1, L):
563
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
564
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
565
+
566
+ # Backward
567
+ beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
568
+ beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
569
+ for t in range(L - 2, -1, -1):
570
+ temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
571
+ beta[:, t, :] = _logsumexp(temp, dim=2)
572
+
573
+ # gamma
574
+ log_gamma = alpha + beta
575
+ logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True)
576
+ gamma = (log_gamma - logZ_time).exp() # (B, L, K)
577
+
578
+ results = []
579
+ for i in range(B):
580
+ L_i = int(lengths[i].item())
581
+ results.append(gamma[i, :L_i, :].detach().cpu().numpy())
582
+ return results
583
+
584
+ def score(self, seq_or_list: Union[List[float], List[List[float]]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Union[float, List[float]]:
585
+ """
586
+ Compute log-likelihood of a single sequence or list of sequences under current params.
587
+ Returns float (single) or list of floats (batch).
588
+ """
589
+ single = False
590
+ if not isinstance(seq_or_list[0], (list, tuple, np.ndarray)):
591
+ seqs = [seq_or_list]
592
+ single = True
593
+ else:
594
+ seqs = seq_or_list
595
+
596
+ if device is None:
597
+ device = next(self.parameters()).device
598
+ elif isinstance(device, str):
599
+ device = torch.device(device)
600
+ device = self._ensure_device_dtype(device)
601
+
602
+ obs, mask, lengths = self._pad_and_mask(seqs, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
603
+ B, L = obs.shape
604
+ K = self.n_states
605
+ eps = float(self.eps)
606
+
607
+ logB = self._log_emission(obs, mask)
608
+ logA = torch.log(self.trans + eps)
609
+ logstart = torch.log(self.start + eps)
610
+
611
+ alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
612
+ alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
613
+ for t in range(1, L):
614
+ prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
615
+ alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
616
+
617
+ last_idx = (lengths - 1).clamp(min=0)
618
+ idx_range = torch.arange(B, device=device)
619
+ final_alpha = alpha[idx_range, last_idx, :] # (B, K)
620
+ seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
621
+ seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
622
+ return seq_loglikes[0] if single else seq_loglikes
623
+
624
+ def viterbi(self, seq: List[float], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[int], float]:
625
+ """
626
+ Viterbi decode a single sequence. Returns (state_path, log_probability_of_path).
627
+ """
628
+ paths, scores = self.batch_viterbi([seq], impute_strategy=impute_strategy, device=device)
629
+ return paths[0], scores[0]
630
+
631
+ def batch_viterbi(self, data: List[List[float]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[List[int]], List[float]]:
632
+ """
633
+ Batched Viterbi decoding on padded sequences. Returns (list_of_paths, list_of_scores).
634
+ Each path is the length of the original sequence.
635
+ """
636
+ if device is None:
637
+ device = next(self.parameters()).device
638
+ elif isinstance(device, str):
639
+ device = torch.device(device)
640
+ device = self._ensure_device_dtype(device)
641
+
642
+ obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
643
+ B, L = obs.shape
644
+ K = self.n_states
645
+ eps = float(self.eps)
646
+
647
+ p = self.emission
648
+ logp = torch.log(p + eps)
649
+ log1mp = torch.log1p(-p + eps)
650
+ logB = obs.unsqueeze(-1) * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs.unsqueeze(-1)) * log1mp.unsqueeze(0).unsqueeze(0)
651
+ logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
652
+
653
+ logstart = torch.log(self.start + eps)
654
+ logA = torch.log(self.trans + eps)
655
+
656
+ # delta (score) and psi (argmax pointers)
657
+ delta = torch.empty((B, L, K), dtype=self.dtype, device=device)
658
+ psi = torch.zeros((B, L, K), dtype=torch.long, device=device)
659
+
660
+ delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
661
+ psi[:, 0, :] = -1 # sentinel
662
+
663
+ for t in range(1, L):
664
+ # cand shape (B, i, j)
665
+ cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
666
+ best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
667
+ delta[:, t, :] = best_val + logB[:, t, :]
668
+ psi[:, t, :] = best_idx # best previous state index for each (B, j)
669
+
670
+ # backtrack
671
+ last_idx = (lengths - 1).clamp(min=0)
672
+ idx_range = torch.arange(B, device=device)
673
+ final_delta = delta[idx_range, last_idx, :] # (B, K)
674
+ best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
675
+ paths = []
676
+ scores = []
677
+ for b in range(B):
678
+ Lb = int(lengths[b].item())
679
+ if Lb == 0:
680
+ paths.append([])
681
+ scores.append(float("-inf"))
682
+ continue
683
+ s = int(best_last_state[b].item())
684
+ path = [s]
685
+ for t in range(Lb - 1, 0, -1):
686
+ s = int(psi[b, t, s].item())
687
+ path.append(s)
688
+ path.reverse()
689
+ paths.append(path)
690
+ scores.append(float(best_last_val[b].item()))
691
+ return paths, scores
692
+
693
+ def save(self, path: str) -> None:
694
+ """
695
+ Save HMM to `path` using torch.save. Stores:
696
+ - n_states, eps, dtype (string)
697
+ - start, trans, emission (CPU tensors)
698
+ """
699
+ payload = {
700
+ "n_states": int(self.n_states),
701
+ "eps": float(self.eps),
702
+ # store dtype as a string like "torch.float64" (portable)
703
+ "dtype": str(self.dtype),
704
+ "start": self.start.detach().cpu(),
705
+ "trans": self.trans.detach().cpu(),
706
+ "emission": self.emission.detach().cpu(),
707
+ }
708
+ torch.save(payload, path)
709
+
710
+ @classmethod
711
+ def load(cls, path: str, device: Optional[Union[torch.device, str]] = None) -> "HMM":
712
+ """
713
+ Load model from `path`. If `device` is provided (str or torch.device),
714
+ parameters will be moved to that device; otherwise they remain on CPU.
715
+ Example: model = HMM.load('hmm.pt', device='cuda')
716
+ """
717
+ payload = torch.load(path, map_location="cpu")
718
+
719
+ n_states = int(payload.get("n_states"))
720
+ eps = float(payload.get("eps", 1e-8))
721
+ dtype_entry = payload.get("dtype", "torch.float64")
722
+
723
+ # Resolve dtype string robustly:
724
+ # Accept "torch.float64" or "float64" or actual torch.dtype (older payloads)
725
+ if isinstance(dtype_entry, torch.dtype):
726
+ torch_dtype = dtype_entry
727
+ else:
728
+ # dtype_entry expected to be a string
729
+ dtype_str = str(dtype_entry)
730
+ # take last part after dot if present: "torch.float64" -> "float64"
731
+ name = dtype_str.split(".")[-1]
732
+ # map to torch dtype if available, else fallback mapping
733
+ if hasattr(torch, name):
734
+ torch_dtype = getattr(torch, name)
735
+ else:
736
+ fallback = {"float64": torch.float64, "float32": torch.float32, "float16": torch.float16}
737
+ torch_dtype = fallback.get(name, torch.float64)
738
+
739
+ # Build instance (use resolved dtype)
740
+ model = cls(n_states=n_states, dtype=torch_dtype, eps=eps)
741
+
742
+ # Determine target device
743
+ if device is None:
744
+ device = torch.device("cpu")
745
+ elif isinstance(device, str):
746
+ device = torch.device(device)
747
+
748
+ # Load params (they were saved on CPU) and cast to model dtype/device
749
+ with torch.no_grad():
750
+ model.start.data = payload["start"].to(device=device, dtype=model.dtype)
751
+ model.trans.data = payload["trans"].to(device=device, dtype=model.dtype)
752
+ model.emission.data = payload["emission"].to(device=device, dtype=model.dtype)
753
+
754
+ # Normalize / coerce shapes just in case
755
+ model._normalize_params()
756
+ return model
757
+
758
+ def annotate_adata(
759
+ self,
760
+ adata,
761
+ obs_column: str,
762
+ layer: Optional[str] = None,
763
+ footprints: Optional[bool] = None,
764
+ accessible_patches: Optional[bool] = None,
765
+ cpg: Optional[bool] = None,
766
+ methbases: Optional[List[str]] = None,
767
+ threshold: Optional[float] = None,
768
+ device: Optional[Union[str, torch.device]] = None,
769
+ batch_size: Optional[int] = None,
770
+ use_viterbi: Optional[bool] = None,
771
+ in_place: bool = True,
772
+ verbose: bool = True,
773
+ uns_key: str = "hmm_appended_layers",
774
+ config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
775
+ ):
776
+ """
777
+ Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
778
+
779
+ Parameters
780
+ ----------
781
+ config : optional ExperimentConfig instance or plain dict
782
+ When provided, the following keys (if present) are used to override defaults:
783
+ - hmm_feature_sets : dict (canonical feature set structure) OR a JSON/string repr
784
+ - hmm_annotation_threshold : float
785
+ - hmm_batch_size : int
786
+ - hmm_use_viterbi : bool
787
+ - hmm_methbases : list
788
+ - footprints / accessible_patches / cpg (booleans)
789
+ Other keyword args override config values if explicitly provided.
790
+ """
791
+ import json, ast, warnings
792
+ import numpy as _np
793
+ import torch as _torch
794
+ from tqdm import trange, tqdm as _tqdm
795
+
796
+ # small helpers
797
+ def _try_json_or_literal(s):
798
+ if s is None:
799
+ return None
800
+ if not isinstance(s, str):
801
+ return s
802
+ s0 = s.strip()
803
+ if s0 == "":
804
+ return None
805
+ try:
806
+ return json.loads(s0)
807
+ except Exception:
808
+ pass
809
+ try:
810
+ return ast.literal_eval(s0)
811
+ except Exception:
812
+ return s
813
+
814
+ def _coerce_bool(x):
815
+ if x is None:
816
+ return False
817
+ if isinstance(x, bool):
818
+ return x
819
+ if isinstance(x, (int, float)):
820
+ return bool(x)
821
+ s = str(x).strip().lower()
822
+ return s in ("1", "true", "t", "yes", "y", "on")
823
+
824
+ def normalize_hmm_feature_sets(raw):
825
+ if raw is None:
826
+ return {}
827
+ parsed = raw
828
+ if isinstance(raw, str):
829
+ parsed = _try_json_or_literal(raw)
830
+ if not isinstance(parsed, dict):
831
+ return {}
832
+
833
+ def _coerce_bound(x):
834
+ if x is None:
835
+ return None
836
+ if isinstance(x, (int, float)):
837
+ return float(x)
838
+ s = str(x).strip().lower()
839
+ if s in ("inf", "infty", "infinite", "np.inf"):
840
+ return _np.inf
841
+ if s in ("none", ""):
842
+ return None
843
+ try:
844
+ return float(x)
845
+ except Exception:
846
+ return None
847
+
848
+ def _coerce_feature_map(feats):
849
+ out = {}
850
+ if not isinstance(feats, dict):
851
+ return out
852
+ for fname, rng in feats.items():
853
+ if rng is None:
854
+ out[fname] = (0.0, _np.inf)
855
+ continue
856
+ if isinstance(rng, (list, tuple)) and len(rng) >= 2:
857
+ lo = _coerce_bound(rng[0]) or 0.0
858
+ hi = _coerce_bound(rng[1])
859
+ if hi is None:
860
+ hi = _np.inf
861
+ out[fname] = (float(lo), float(hi) if not _np.isinf(hi) else _np.inf)
862
+ else:
863
+ val = _coerce_bound(rng)
864
+ out[fname] = (0.0, float(val) if val is not None else _np.inf)
865
+ return out
866
+
867
+ canonical = {}
868
+ for grp, info in parsed.items():
869
+ if not isinstance(info, dict):
870
+ feats = _coerce_feature_map(info)
871
+ canonical[grp] = {"features": feats, "state": "Modified"}
872
+ continue
873
+ feats = _coerce_feature_map(info.get("features", info.get("ranges", {})))
874
+ state = info.get("state", info.get("label", "Modified"))
875
+ canonical[grp] = {"features": feats, "state": state}
876
+ return canonical
877
+
878
+ # ---------- resolve config dict ----------
879
+ merged_cfg = {}
880
+ if config is not None:
881
+ if hasattr(config, "to_dict") and callable(getattr(config, "to_dict")):
882
+ merged_cfg = dict(config.to_dict())
883
+ elif isinstance(config, dict):
884
+ merged_cfg = dict(config)
885
+ else:
886
+ try:
887
+ merged_cfg = {k: getattr(config, k) for k in dir(config) if k.startswith("hmm_")}
888
+ except Exception:
889
+ merged_cfg = {}
890
+
891
+ def _pick(key, local_val, fallback=None):
892
+ if local_val is not None:
893
+ return local_val
894
+ if key in merged_cfg and merged_cfg[key] is not None:
895
+ return merged_cfg[key]
896
+ alt = f"hmm_{key}"
897
+ if alt in merged_cfg and merged_cfg[alt] is not None:
898
+ return merged_cfg[alt]
899
+ return fallback
900
+
901
+ # coerce booleans robustly
902
+ footprints = _coerce_bool(_pick("footprints", footprints, merged_cfg.get("footprints", False)))
903
+ accessible_patches = _coerce_bool(_pick("accessible_patches", accessible_patches, merged_cfg.get("accessible_patches", False)))
904
+ cpg = _coerce_bool(_pick("cpg", cpg, merged_cfg.get("cpg", False)))
905
+
906
+ threshold = float(_pick("threshold", threshold, merged_cfg.get("hmm_annotation_threshold", 0.5)))
907
+ batch_size = int(_pick("batch_size", batch_size, merged_cfg.get("hmm_batch_size", 1024)))
908
+ use_viterbi = _coerce_bool(_pick("use_viterbi", use_viterbi, merged_cfg.get("hmm_use_viterbi", False)))
909
+
910
+ methbases = merged_cfg.get("hmm_methbases", None)
911
+
912
+ # normalize whitespace/case for human-friendly inputs (but keep original tokens as given)
913
+ methbases = [str(m).strip() for m in methbases if m is not None]
914
+ if verbose:
915
+ print("DEBUG: final methbases list =", methbases)
916
+
917
+ # resolve feature sets: prefer canonical if it yields non-empty mapping, otherwise fall back to boolean defaults
918
+ feature_sets = {}
919
+ if "hmm_feature_sets" in merged_cfg and merged_cfg.get("hmm_feature_sets") is not None:
920
+ cand = normalize_hmm_feature_sets(merged_cfg.get("hmm_feature_sets"))
921
+ if isinstance(cand, dict) and len(cand) > 0:
922
+ feature_sets = cand
923
+
924
+ if not feature_sets:
925
+ if verbose:
926
+ print("[HMM.annotate_adata] no feature sets configured; nothing to append.")
927
+ return None if in_place else adata
928
+
929
+ if verbose:
930
+ print("[HMM.annotate_adata] resolved feature sets:", list(feature_sets.keys()))
931
+
932
+ # copy vs in-place
933
+ if not in_place:
934
+ adata = adata.copy()
935
+
936
+ # prepare column names
937
+ all_features = []
938
+ combined_prefix = "Combined"
939
+ for key, fs in feature_sets.items():
940
+ feats = fs.get("features", {})
941
+ if key == "cpg":
942
+ all_features += [f"CpG_{f}" for f in feats]
943
+ all_features.append(f"CpG_all_{key}_features")
944
+ else:
945
+ for methbase in methbases:
946
+ all_features += [f"{methbase}_{f}" for f in feats]
947
+ all_features.append(f"{methbase}_all_{key}_features")
948
+ if len(methbases) > 1:
949
+ all_features += [f"{combined_prefix}_{f}" for f in feats]
950
+ all_features.append(f"{combined_prefix}_all_{key}_features")
951
+
952
+ # initialize obs columns (unique lists per row)
953
+ n_rows = adata.shape[0]
954
+ for feature in all_features:
955
+ if feature not in adata.obs.columns:
956
+ adata.obs[feature] = [[] for _ in range(n_rows)]
957
+ if f"{feature}_distances" not in adata.obs.columns:
958
+ adata.obs[f"{feature}_distances"] = [None] * n_rows
959
+ if f"n_{feature}" not in adata.obs.columns:
960
+ adata.obs[f"n_{feature}"] = -1
961
+
962
+ appended_layers: List[str] = []
963
+
964
+ # device management
965
+ if device is None:
966
+ device = next(self.parameters()).device
967
+ elif isinstance(device, str):
968
+ device = _torch.device(device)
969
+ self.to(device)
970
+
971
+ # helpers ---------------------------------------------------------------
972
+ def _ensure_2d_array_like(matrix):
973
+ arr = _np.asarray(matrix)
974
+ if arr.ndim == 1:
975
+ arr = arr[_np.newaxis, :]
976
+ elif arr.ndim > 2:
977
+ # squeeze trailing singletons
978
+ while arr.ndim > 2 and arr.shape[-1] == 1:
979
+ arr = _np.squeeze(arr, axis=-1)
980
+ if arr.ndim != 2:
981
+ raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
982
+ return arr
983
+
984
+ def calculate_batch_distances(intervals_list, threshold_local=0.9):
985
+ results_local = []
986
+ for intervals in intervals_list:
987
+ if not isinstance(intervals, list) or len(intervals) == 0:
988
+ results_local.append([])
989
+ continue
990
+ valid = [iv for iv in intervals if iv[2] > threshold_local]
991
+ if len(valid) <= 1:
992
+ results_local.append([])
993
+ continue
994
+ valid = sorted(valid, key=lambda x: x[0])
995
+ dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
996
+ results_local.append(dists)
997
+ return results_local
998
+
999
+ def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
1000
+ # Accept numpy arrays or torch tensors
1001
+ if isinstance(predicted_states_batch, _torch.Tensor):
1002
+ pred_np = predicted_states_batch.detach().cpu().numpy()
1003
+ else:
1004
+ pred_np = _np.asarray(predicted_states_batch)
1005
+ if isinstance(probabilities_batch, _torch.Tensor):
1006
+ probs_np = probabilities_batch.detach().cpu().numpy()
1007
+ else:
1008
+ probs_np = _np.asarray(probabilities_batch)
1009
+
1010
+ batch_size, L = pred_np.shape
1011
+ all_classifications_local = []
1012
+ # allow caller to pass arbitrary state labels mapping; default two-state mapping:
1013
+ state_labels = ["Non-Modified", "Modified"]
1014
+ try:
1015
+ target_idx = state_labels.index(target_state)
1016
+ except ValueError:
1017
+ target_idx = 1 # fallback
1018
+
1019
+ for b in range(batch_size):
1020
+ predicted_states = pred_np[b]
1021
+ probabilities = probs_np[b]
1022
+ regions = []
1023
+ current_start, current_length, current_probs = None, 0, []
1024
+ for i, state_index in enumerate(predicted_states):
1025
+ state_prob = float(probabilities[i][state_index])
1026
+ if state_index == target_idx:
1027
+ if current_start is None:
1028
+ current_start = i
1029
+ current_length += 1
1030
+ current_probs.append(state_prob)
1031
+ elif current_start is not None:
1032
+ regions.append((current_start, current_length, float(_np.mean(current_probs))))
1033
+ current_start, current_length, current_probs = None, 0, []
1034
+ if current_start is not None:
1035
+ regions.append((current_start, current_length, float(_np.mean(current_probs))))
1036
+
1037
+ final = []
1038
+ for start, length, prob in regions:
1039
+ # compute genomic length try/catch
1040
+ try:
1041
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
1042
+ except Exception:
1043
+ feature_length = int(length)
1044
+
1045
+ # classification_mapping values are (lo, hi) tuples or lists
1046
+ label = None
1047
+ for ftype, rng in classification_mapping.items():
1048
+ lo, hi = rng[0], rng[1]
1049
+ try:
1050
+ if lo <= feature_length < hi:
1051
+ label = ftype
1052
+ break
1053
+ except Exception:
1054
+ continue
1055
+ if label is None:
1056
+ # fallback to first mapping key or 'unknown'
1057
+ label = next(iter(classification_mapping.keys()), "feature")
1058
+
1059
+ # Store reported start coordinate in same coordinate system as `coordinates`.
1060
+ try:
1061
+ genomic_start = int(coordinates[start])
1062
+ except Exception:
1063
+ genomic_start = int(start)
1064
+ final.append((genomic_start, feature_length, label, prob))
1065
+ all_classifications_local.append(final)
1066
+ return all_classifications_local
1067
+
1068
+ # -----------------------------------------------------------------------
1069
+
1070
+ # Ensure obs_column is categorical-like for iteration
1071
+ sseries = adata.obs[obs_column]
1072
+ if not pd.api.types.is_categorical_dtype(sseries):
1073
+ sseries = sseries.astype("category")
1074
+ references = list(sseries.cat.categories)
1075
+
1076
+ ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
1077
+ for ref in ref_iter:
1078
+ # subset reads with this obs_column value
1079
+ ref_mask = adata.obs[obs_column] == ref
1080
+ ref_subset = adata[ref_mask].copy()
1081
+ combined_mask = None
1082
+
1083
+ # per-methbase processing
1084
+ for methbase in methbases:
1085
+ key_lower = methbase.strip().lower()
1086
+
1087
+ # map several common synonyms -> canonical lookup
1088
+ if key_lower in ("a",):
1089
+ pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
1090
+ elif key_lower in ("c", "any_c", "anyc", "any-c"):
1091
+ # unify 'C' or 'any_C' names to the any_C var column
1092
+ pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
1093
+ elif key_lower in ("gpc", "gpc_site", "gpc-site"):
1094
+ pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
1095
+ elif key_lower in ("cpg", "cpg_site", "cpg-site"):
1096
+ pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
1097
+ else:
1098
+ # try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
1099
+ alt_col = f"{ref}_{methbase}_site"
1100
+ pos_mask = ref_subset.var.get(alt_col, None)
1101
+
1102
+ if pos_mask is None:
1103
+ continue
1104
+ combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
1105
+
1106
+ if pos_mask.sum() == 0:
1107
+ continue
1108
+
1109
+ sub = ref_subset[:, pos_mask]
1110
+ # choose matrix
1111
+ matrix = sub.layers[layer] if (layer and layer in sub.layers) else sub.X
1112
+ matrix = _ensure_2d_array_like(matrix)
1113
+ n_reads = matrix.shape[0]
1114
+
1115
+ # coordinates for this sub (try to convert to ints, else fallback to indices)
1116
+ try:
1117
+ coords = _np.asarray(sub.var_names, dtype=int)
1118
+ except Exception:
1119
+ coords = _np.arange(sub.shape[1], dtype=int)
1120
+
1121
+ # chunked processing
1122
+ chunk_iter = range(0, n_reads, batch_size)
1123
+ if verbose:
1124
+ chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:{methbase} chunks")
1125
+ for start_idx in chunk_iter:
1126
+ stop_idx = min(n_reads, start_idx + batch_size)
1127
+ chunk = matrix[start_idx:stop_idx]
1128
+ seqs = chunk.tolist()
1129
+ # posterior marginals
1130
+ gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1131
+ if len(gammas) == 0:
1132
+ continue
1133
+ probs_batch = _np.stack(gammas, axis=0) # (B, L, K)
1134
+ if use_viterbi:
1135
+ paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1136
+ pred_states = _np.asarray(paths)
1137
+ else:
1138
+ pred_states = _np.argmax(probs_batch, axis=2)
1139
+
1140
+ # For each feature group, classify separately and write back
1141
+ for key, fs in feature_sets.items():
1142
+ if key == "cpg":
1143
+ continue
1144
+ state_target = fs.get("state", "Modified")
1145
+ feature_map = fs.get("features", {})
1146
+ classifications = classify_batch_local(pred_states, probs_batch, coords, feature_map, target_state=state_target)
1147
+
1148
+ # write results to adata.obs rows (use original index names)
1149
+ row_indices = list(sub.obs.index[start_idx:stop_idx])
1150
+ for i_local, idx in enumerate(row_indices):
1151
+ for start, length, label, prob in classifications[i_local]:
1152
+ col_name = f"{methbase}_{label}"
1153
+ all_col = f"{methbase}_all_{key}_features"
1154
+ adata.obs.at[idx, col_name].append([start, length, prob])
1155
+ adata.obs.at[idx, all_col].append([start, length, prob])
1156
+
1157
+ # Combined subset (if multiple methbases)
1158
+ if len(methbases) > 1 and (combined_mask is not None) and (combined_mask.sum() > 0):
1159
+ comb = ref_subset[:, combined_mask]
1160
+ if comb.shape[1] > 0:
1161
+ matrix = comb.layers[layer] if (layer and layer in comb.layers) else comb.X
1162
+ matrix = _ensure_2d_array_like(matrix)
1163
+ n_reads_comb = matrix.shape[0]
1164
+ try:
1165
+ coords_comb = _np.asarray(comb.var_names, dtype=int)
1166
+ except Exception:
1167
+ coords_comb = _np.arange(comb.shape[1], dtype=int)
1168
+
1169
+ chunk_iter = range(0, n_reads_comb, batch_size)
1170
+ if verbose:
1171
+ chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:Combined chunks")
1172
+ for start_idx in chunk_iter:
1173
+ stop_idx = min(n_reads_comb, start_idx + batch_size)
1174
+ chunk = matrix[start_idx:stop_idx]
1175
+ seqs = chunk.tolist()
1176
+ gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1177
+ if len(gammas) == 0:
1178
+ continue
1179
+ probs_batch = _np.stack(gammas, axis=0)
1180
+ if use_viterbi:
1181
+ paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1182
+ pred_states = _np.asarray(paths)
1183
+ else:
1184
+ pred_states = _np.argmax(probs_batch, axis=2)
1185
+
1186
+ for key, fs in feature_sets.items():
1187
+ if key == "cpg":
1188
+ continue
1189
+ state_target = fs.get("state", "Modified")
1190
+ feature_map = fs.get("features", {})
1191
+ classifications = classify_batch_local(pred_states, probs_batch, coords_comb, feature_map, target_state=state_target)
1192
+ row_indices = list(comb.obs.index[start_idx:stop_idx])
1193
+ for i_local, idx in enumerate(row_indices):
1194
+ for start, length, label, prob in classifications[i_local]:
1195
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
1196
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
1197
+
1198
+ # CpG special handling
1199
+ if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
1200
+ cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
1201
+ for ref in cpg_iter:
1202
+ ref_mask = adata.obs[obs_column] == ref
1203
+ ref_subset = adata[ref_mask].copy()
1204
+ pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
1205
+ if pos_mask.sum() == 0:
1206
+ continue
1207
+ cpg_sub = ref_subset[:, pos_mask]
1208
+ matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
1209
+ matrix = _ensure_2d_array_like(matrix)
1210
+ n_reads = matrix.shape[0]
1211
+ try:
1212
+ coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
1213
+ except Exception:
1214
+ coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
1215
+
1216
+ chunk_iter = range(0, n_reads, batch_size)
1217
+ if verbose:
1218
+ chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
1219
+ for start_idx in chunk_iter:
1220
+ stop_idx = min(n_reads, start_idx + batch_size)
1221
+ chunk = matrix[start_idx:stop_idx]
1222
+ seqs = chunk.tolist()
1223
+ gammas = self.predict(seqs, impute_strategy="ignore", device=device)
1224
+ if len(gammas) == 0:
1225
+ continue
1226
+ probs_batch = _np.stack(gammas, axis=0)
1227
+ if use_viterbi:
1228
+ paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
1229
+ pred_states = _np.asarray(paths)
1230
+ else:
1231
+ pred_states = _np.argmax(probs_batch, axis=2)
1232
+
1233
+ fs = feature_sets["cpg"]
1234
+ state_target = fs.get("state", "Modified")
1235
+ feature_map = fs.get("features", {})
1236
+ classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
1237
+ row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
1238
+ for i_local, idx in enumerate(row_indices):
1239
+ for start, length, label, prob in classifications[i_local]:
1240
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
1241
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
1242
+
1243
+ # finalize: convert intervals into binary layers and distances
1244
+ try:
1245
+ coordinates = _np.asarray(adata.var_names, dtype=int)
1246
+ coords_are_ints = True
1247
+ except Exception:
1248
+ coordinates = _np.arange(adata.shape[1], dtype=int)
1249
+ coords_are_ints = False
1250
+
1251
+ features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
1252
+ for feature in features_iter:
1253
+ bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1254
+ counts = _np.zeros(adata.shape[0], dtype=int)
1255
+
1256
+ # new: integer-length layer (0 where not inside a feature)
1257
+ len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
1258
+
1259
+ for row_idx, intervals in enumerate(adata.obs[feature]):
1260
+ if not isinstance(intervals, list):
1261
+ intervals = []
1262
+ for start, length, prob in intervals:
1263
+ if prob > threshold:
1264
+ if coords_are_ints:
1265
+ # map genomic start/length into index interval [start_idx, end_idx)
1266
+ start_idx = _np.searchsorted(coordinates, int(start), side="left")
1267
+ end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
1268
+ else:
1269
+ start_idx = int(start)
1270
+ end_idx = start_idx + int(length)
1271
+
1272
+ start_idx = max(0, min(start_idx, adata.shape[1]))
1273
+ end_idx = max(0, min(end_idx, adata.shape[1]))
1274
+
1275
+ if start_idx < end_idx:
1276
+ span = end_idx - start_idx # number of positions covered
1277
+ # set binary mask
1278
+ bin_matrix[row_idx, start_idx:end_idx] = 1
1279
+ # set length mask: use maximum in case of overlaps
1280
+ existing = len_matrix[row_idx, start_idx:end_idx]
1281
+ len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
1282
+ counts[row_idx] += 1
1283
+
1284
+ # write binary layer and length layer, track appended names
1285
+ adata.layers[feature] = bin_matrix
1286
+ appended_layers.append(feature)
1287
+
1288
+ # name the integer-length layer (choose suffix you like)
1289
+ length_layer_name = f"{feature}_lengths"
1290
+ adata.layers[length_layer_name] = len_matrix
1291
+ appended_layers.append(length_layer_name)
1292
+
1293
+ adata.obs[f"n_{feature}"] = counts
1294
+ adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
1295
+
1296
+ # Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
1297
+ existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
1298
+ new_list = existing + [l for l in appended_layers if l not in existing]
1299
+ adata.uns[uns_key] = new_list
1300
+
1301
+ return None if in_place else adata
1302
+
1303
+ def merge_intervals_in_layer(
1304
+ self,
1305
+ adata,
1306
+ layer: str,
1307
+ distance_threshold: int = 0,
1308
+ merged_suffix: str = "_merged",
1309
+ length_layer_suffix: str = "_lengths",
1310
+ update_obs: bool = True,
1311
+ prob_strategy: str = "mean", # 'mean'|'max'|'orig_first'
1312
+ inplace: bool = True,
1313
+ overwrite: bool = False,
1314
+ verbose: bool = False,
1315
+ ):
1316
+ """
1317
+ Merge intervals in `adata.layers[layer]` that are within `distance_threshold`.
1318
+ Writes new merged binary layer named f"{layer}{merged_suffix}" and length layer
1319
+ f"{layer}{merged_suffix}{length_layer_suffix}". Optionally updates adata.obs for merged intervals.
1320
+
1321
+ Parameters
1322
+ ----------
1323
+ layer : str
1324
+ Name of original binary layer (0/1 mask).
1325
+ distance_threshold : int
1326
+ Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
1327
+ merged_suffix : str
1328
+ Suffix appended to original layer for the merged binary layer (default "_merged").
1329
+ length_layer_suffix : str
1330
+ Suffix appended after merged suffix for the lengths layer (default "_lengths").
1331
+ update_obs : bool
1332
+ If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
1333
+ prob_strategy : str
1334
+ How to combine probs when merging ('mean', 'max', 'orig_first').
1335
+ inplace : bool
1336
+ If False, returns a new AnnData with changes (original untouched).
1337
+ overwrite : bool
1338
+ If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
1339
+ """
1340
+ import numpy as _np
1341
+ from scipy.sparse import issparse
1342
+
1343
+ if not inplace:
1344
+ adata = adata.copy()
1345
+
1346
+ merged_bin_name = f"{layer}{merged_suffix}"
1347
+ merged_len_name = f"{layer}{merged_suffix}{length_layer_suffix}"
1348
+
1349
+ if (merged_bin_name in adata.layers or merged_len_name in adata.layers or
1350
+ (update_obs and merged_bin_name in adata.obs.columns)) and not overwrite:
1351
+ raise KeyError(f"Merged outputs exist (use overwrite=True to replace): {merged_bin_name} / {merged_len_name}")
1352
+
1353
+ if layer not in adata.layers:
1354
+ raise KeyError(f"Layer '{layer}' not found in adata.layers")
1355
+
1356
+ bin_layer = adata.layers[layer]
1357
+ if issparse(bin_layer):
1358
+ bin_arr = bin_layer.toarray().astype(int)
1359
+ else:
1360
+ bin_arr = _np.asarray(bin_layer, dtype=int)
1361
+
1362
+ n_rows, n_cols = bin_arr.shape
1363
+
1364
+ # coordinates in genomic units if possible
1365
+ try:
1366
+ coords = _np.asarray(adata.var_names, dtype=int)
1367
+ coords_are_ints = True
1368
+ except Exception:
1369
+ coords = _np.arange(n_cols, dtype=int)
1370
+ coords_are_ints = False
1371
+
1372
+ # helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
1373
+ def _runs_from_mask(mask_1d):
1374
+ idx = _np.nonzero(mask_1d)[0]
1375
+ if idx.size == 0:
1376
+ return []
1377
+ runs = []
1378
+ start = idx[0]
1379
+ prev = idx[0]
1380
+ for i in idx[1:]:
1381
+ if i == prev + 1:
1382
+ prev = i
1383
+ continue
1384
+ runs.append((start, prev + 1))
1385
+ start = i
1386
+ prev = i
1387
+ runs.append((start, prev + 1))
1388
+ return runs
1389
+
1390
+ # read original obs intervals/probs if available (for combining probs)
1391
+ orig_obs = None
1392
+ if update_obs and (layer in adata.obs.columns):
1393
+ orig_obs = list(adata.obs[layer]) # might be non-list entries
1394
+
1395
+ # prepare outputs
1396
+ merged_bin = _np.zeros_like(bin_arr, dtype=int)
1397
+ merged_len = _np.zeros_like(bin_arr, dtype=int)
1398
+ merged_obs_col = [[] for _ in range(n_rows)]
1399
+ merged_counts = _np.zeros(n_rows, dtype=int)
1400
+
1401
+ for r in range(n_rows):
1402
+ mask = bin_arr[r, :] != 0
1403
+ runs = _runs_from_mask(mask)
1404
+ if not runs:
1405
+ merged_obs_col[r] = []
1406
+ continue
1407
+
1408
+ # merge runs where gap <= distance_threshold (gap in genomic coords when possible)
1409
+ merged_runs = []
1410
+ cur_s, cur_e = runs[0]
1411
+ for (s, e) in runs[1:]:
1412
+ if coords_are_ints:
1413
+ end_coord = int(coords[cur_e - 1])
1414
+ next_start_coord = int(coords[s])
1415
+ gap = next_start_coord - end_coord - 1
1416
+ else:
1417
+ gap = s - cur_e
1418
+ if gap <= distance_threshold:
1419
+ # extend
1420
+ cur_e = e
1421
+ else:
1422
+ merged_runs.append((cur_s, cur_e))
1423
+ cur_s, cur_e = s, e
1424
+ merged_runs.append((cur_s, cur_e))
1425
+
1426
+ # assemble merged mask/lengths and obs entries
1427
+ row_entries = []
1428
+ for (s_idx, e_idx) in merged_runs:
1429
+ if e_idx <= s_idx:
1430
+ continue
1431
+ span_positions = e_idx - s_idx
1432
+ if coords_are_ints:
1433
+ try:
1434
+ length_val = int(coords[e_idx - 1]) - int(coords[s_idx]) + 1
1435
+ except Exception:
1436
+ length_val = span_positions
1437
+ else:
1438
+ length_val = span_positions
1439
+
1440
+ # set binary and length masks
1441
+ merged_bin[r, s_idx:e_idx] = 1
1442
+ existing_segment = merged_len[r, s_idx:e_idx]
1443
+ # set to max(existing, length_val)
1444
+ if existing_segment.size > 0:
1445
+ merged_len[r, s_idx:e_idx] = _np.maximum(existing_segment, length_val)
1446
+ else:
1447
+ merged_len[r, s_idx:e_idx] = length_val
1448
+
1449
+ # determine prob from overlapping original obs (if present)
1450
+ prob_val = 1.0
1451
+ if update_obs and orig_obs is not None:
1452
+ overlaps = []
1453
+ for orig in (orig_obs[r] or []):
1454
+ try:
1455
+ ostart, olen, opro = orig[0], int(orig[1]), float(orig[2])
1456
+ except Exception:
1457
+ continue
1458
+ if coords_are_ints:
1459
+ ostart_idx = _np.searchsorted(coords, int(ostart), side="left")
1460
+ oend_idx = ostart_idx + olen
1461
+ else:
1462
+ ostart_idx = int(ostart)
1463
+ oend_idx = ostart_idx + olen
1464
+ # overlap test in index space
1465
+ if not (oend_idx <= s_idx or ostart_idx >= e_idx):
1466
+ overlaps.append(opro)
1467
+ if overlaps:
1468
+ if prob_strategy == "mean":
1469
+ prob_val = float(_np.mean(overlaps))
1470
+ elif prob_strategy == "max":
1471
+ prob_val = float(_np.max(overlaps))
1472
+ else:
1473
+ prob_val = float(overlaps[0])
1474
+
1475
+ start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
1476
+ row_entries.append((start_coord, int(length_val), float(prob_val)))
1477
+
1478
+ merged_obs_col[r] = row_entries
1479
+ merged_counts[r] = len(row_entries)
1480
+
1481
+ # write merged layers (do not overwrite originals unless overwrite=True was set above)
1482
+ adata.layers[merged_bin_name] = merged_bin
1483
+ adata.layers[merged_len_name] = merged_len
1484
+
1485
+ if update_obs:
1486
+ adata.obs[merged_bin_name] = merged_obs_col
1487
+ adata.obs[f"n_{merged_bin_name}"] = merged_counts
1488
+
1489
+ # recompute distances list per-row (gaps between adjacent merged intervals)
1490
+ def _calc_distances(obs_list):
1491
+ out = []
1492
+ for intervals in obs_list:
1493
+ if not intervals:
1494
+ out.append([])
1495
+ continue
1496
+ iv = sorted(intervals, key=lambda x: int(x[0]))
1497
+ if len(iv) <= 1:
1498
+ out.append([])
1499
+ continue
1500
+ dlist = []
1501
+ for i in range(len(iv) - 1):
1502
+ endi = int(iv[i][0]) + int(iv[i][1]) - 1
1503
+ startn = int(iv[i + 1][0])
1504
+ dlist.append(startn - endi - 1)
1505
+ out.append(dlist)
1506
+ return out
1507
+
1508
+ adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
1509
+
1510
+ # update uns appended list
1511
+ uns_key = "hmm_appended_layers"
1512
+ existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
1513
+ for nm in (merged_bin_name, merged_len_name):
1514
+ if nm not in existing:
1515
+ existing.append(nm)
1516
+ adata.uns[uns_key] = existing
1517
+
1518
+ if verbose:
1519
+ print(f"Created merged binary layer: {merged_bin_name}")
1520
+ print(f"Created merged length layer: {merged_len_name}")
1521
+ if update_obs:
1522
+ print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
1523
+
1524
+ return None if inplace else adata
1525
+
1526
+ def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
1527
+ """
1528
+ Ensure final_adata.layers[layer_name] exists and assign rows corresponding to subset_idx_mask
1529
+ sub_data has shape (n_subset_rows, n_vars).
1530
+ subset_idx_mask: boolean array of length final_adata.n_obs with True where rows belong to subset.
1531
+ """
1532
+ from scipy.sparse import issparse, csr_matrix
1533
+ import warnings
1534
+
1535
+ n_final_obs, n_vars = final_adata.shape
1536
+ n_sub_rows = int(subset_idx_mask.sum())
1537
+
1538
+ # prepare row indices in final_adata
1539
+ final_row_indices = np.nonzero(subset_idx_mask)[0]
1540
+
1541
+ # if sub_data is sparse, work with sparse
1542
+ if issparse(sub_data):
1543
+ sub_csr = sub_data.tocsr()
1544
+ # if final layer not present, create sparse CSR with zero rows and same n_vars
1545
+ if layer_name not in final_adata.layers:
1546
+ # create an empty CSR of shape (n_final_obs, n_vars)
1547
+ final_adata.layers[layer_name] = csr_matrix((n_final_obs, n_vars), dtype=sub_csr.dtype)
1548
+ final_csr = final_adata.layers[layer_name]
1549
+ if not issparse(final_csr):
1550
+ # convert dense final to sparse first
1551
+ final_csr = csr_matrix(final_csr)
1552
+ # replace the block of rows: easiest is to build a new csr by stacking pieces
1553
+ # (efficient for moderate sizes; for huge data you might want an in-place approach)
1554
+ # Build list of blocks: rows before, the subset rows (from final where mask False -> zeros), rows after
1555
+ # We'll convert final to LIL for row assignment (mutable), then back to CSR.
1556
+ final_lil = final_csr.tolil()
1557
+ for i_local, r in enumerate(final_row_indices):
1558
+ final_lil.rows[r] = sub_csr.getrow(i_local).indices.tolist()
1559
+ final_lil.data[r] = sub_csr.getrow(i_local).data.tolist()
1560
+ final_csr = final_lil.tocsr()
1561
+ final_adata.layers[layer_name] = final_csr
1562
+ else:
1563
+ # dense numpy array
1564
+ sub_arr = np.asarray(sub_data)
1565
+ if sub_arr.shape[0] != n_sub_rows:
1566
+ raise ValueError(f"Sub data rows ({sub_arr.shape[0]}) != mask selected rows ({n_sub_rows})")
1567
+ if layer_name not in final_adata.layers:
1568
+ # create zero array with small dtype
1569
+ final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
1570
+ final_arr = final_adata.layers[layer_name]
1571
+ if issparse(final_arr):
1572
+ # convert sparse final to dense (or convert sub to sparse); we'll convert final to dense here
1573
+ final_arr = final_arr.toarray()
1574
+ # assign
1575
+ final_arr[final_row_indices, :] = sub_arr
1576
+ final_adata.layers[layer_name] = final_arr