smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/HMM.py
ADDED
|
@@ -0,0 +1,1576 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple, Union, Any, Dict
|
|
3
|
+
import ast
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
def _logsumexp(vec: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
|
|
12
|
+
return torch.logsumexp(vec, dim=dim, keepdim=keepdim)
|
|
13
|
+
|
|
14
|
+
class HMM(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Vectorized HMM (Bernoulli emissions) implemented in PyTorch.
|
|
17
|
+
|
|
18
|
+
Methods:
|
|
19
|
+
- fit(data, ...) -> trains params in-place
|
|
20
|
+
- predict(data, ...) -> list of (L, K) posterior marginals (gamma) numpy arrays
|
|
21
|
+
- viterbi(seq, ...) -> (path_list, score)
|
|
22
|
+
- batch_viterbi(data, ...) -> list of (path_list, score)
|
|
23
|
+
- score(seq_or_list, ...) -> float or list of floats
|
|
24
|
+
|
|
25
|
+
Notes:
|
|
26
|
+
- data: list of sequences (each sequence is iterable of {0,1,np.nan}).
|
|
27
|
+
- impute_strategy: "ignore" (NaN treated as missing), "random" (fill NaNs randomly with 0/1).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
n_states: int = 2,
|
|
33
|
+
init_start: Optional[List[float]] = None,
|
|
34
|
+
init_trans: Optional[List[List[float]]] = None,
|
|
35
|
+
init_emission: Optional[List[float]] = None,
|
|
36
|
+
dtype: torch.dtype = torch.float64,
|
|
37
|
+
eps: float = 1e-8,
|
|
38
|
+
smf_modality: Optional[str] = None,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
if n_states < 2:
|
|
42
|
+
raise ValueError("n_states must be >= 2")
|
|
43
|
+
self.n_states = n_states
|
|
44
|
+
self.eps = float(eps)
|
|
45
|
+
self.dtype = dtype
|
|
46
|
+
self.smf_modality = smf_modality
|
|
47
|
+
|
|
48
|
+
# initialize params (probabilities)
|
|
49
|
+
if init_start is None:
|
|
50
|
+
start = np.full((n_states,), 1.0 / n_states, dtype=float)
|
|
51
|
+
else:
|
|
52
|
+
start = np.asarray(init_start, dtype=float)
|
|
53
|
+
if init_trans is None:
|
|
54
|
+
trans = np.full((n_states, n_states), 1.0 / n_states, dtype=float)
|
|
55
|
+
else:
|
|
56
|
+
trans = np.asarray(init_trans, dtype=float)
|
|
57
|
+
# --- sanitize init_emission so it's a 1-D list of P(obs==1 | state) ---
|
|
58
|
+
if init_emission is None:
|
|
59
|
+
emission = np.full((n_states,), 0.5, dtype=float)
|
|
60
|
+
else:
|
|
61
|
+
em_arr = np.asarray(init_emission, dtype=float)
|
|
62
|
+
# case: (K,2) -> pick P(obs==1) from second column
|
|
63
|
+
if em_arr.ndim == 2 and em_arr.shape[1] == 2 and em_arr.shape[0] == n_states:
|
|
64
|
+
emission = em_arr[:, 1].astype(float)
|
|
65
|
+
# case: maybe shape (1,K,2) etc. -> try to collapse trailing axis of length 2
|
|
66
|
+
elif em_arr.ndim >= 2 and em_arr.shape[-1] == 2:
|
|
67
|
+
emission = em_arr.reshape(-1, 2)[:n_states, 1].astype(float)
|
|
68
|
+
else:
|
|
69
|
+
emission = em_arr.reshape(-1)[:n_states].astype(float)
|
|
70
|
+
|
|
71
|
+
# store as parameters (not trainable via grad; EM updates .data in-place)
|
|
72
|
+
self.start = nn.Parameter(torch.tensor(start, dtype=self.dtype), requires_grad=False)
|
|
73
|
+
self.trans = nn.Parameter(torch.tensor(trans, dtype=self.dtype), requires_grad=False)
|
|
74
|
+
self.emission = nn.Parameter(torch.tensor(emission, dtype=self.dtype), requires_grad=False)
|
|
75
|
+
|
|
76
|
+
self._normalize_params()
|
|
77
|
+
|
|
78
|
+
def _normalize_params(self):
|
|
79
|
+
with torch.no_grad():
|
|
80
|
+
# coerce shapes
|
|
81
|
+
K = self.n_states
|
|
82
|
+
self.start.data = self.start.data.squeeze()
|
|
83
|
+
if self.start.data.numel() != K:
|
|
84
|
+
self.start.data = torch.full((K,), 1.0 / K, dtype=self.dtype)
|
|
85
|
+
|
|
86
|
+
self.trans.data = self.trans.data.squeeze()
|
|
87
|
+
if not (self.trans.data.ndim == 2 and self.trans.data.shape == (K, K)):
|
|
88
|
+
if K == 2:
|
|
89
|
+
self.trans.data = torch.tensor([[0.9,0.1],[0.1,0.9]], dtype=self.dtype)
|
|
90
|
+
else:
|
|
91
|
+
self.trans.data = torch.full((K, K), 1.0 / K, dtype=self.dtype)
|
|
92
|
+
|
|
93
|
+
self.emission.data = self.emission.data.squeeze()
|
|
94
|
+
if self.emission.data.numel() != K:
|
|
95
|
+
self.emission.data = torch.full((K,), 0.5, dtype=self.dtype)
|
|
96
|
+
|
|
97
|
+
# now perform smoothing/normalization
|
|
98
|
+
self.start.data = (self.start.data + self.eps)
|
|
99
|
+
self.start.data = self.start.data / self.start.data.sum()
|
|
100
|
+
|
|
101
|
+
self.trans.data = (self.trans.data + self.eps)
|
|
102
|
+
row_sums = self.trans.data.sum(dim=1, keepdim=True)
|
|
103
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
104
|
+
self.trans.data = self.trans.data / row_sums
|
|
105
|
+
|
|
106
|
+
self.emission.data = self.emission.data.clamp(min=self.eps, max=1.0 - self.eps)
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def _resolve_dtype(dtype_entry):
|
|
110
|
+
"""Accept torch.dtype, string ('float32'/'float64') or None -> torch.dtype."""
|
|
111
|
+
if dtype_entry is None:
|
|
112
|
+
return torch.float64
|
|
113
|
+
if isinstance(dtype_entry, torch.dtype):
|
|
114
|
+
return dtype_entry
|
|
115
|
+
s = str(dtype_entry).lower()
|
|
116
|
+
if "32" in s:
|
|
117
|
+
return torch.float32
|
|
118
|
+
if "16" in s:
|
|
119
|
+
return torch.float16
|
|
120
|
+
return torch.float64
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_config(cls, cfg: Union[dict, "ExperimentConfig", None], *,
|
|
124
|
+
override: Optional[dict] = None,
|
|
125
|
+
device: Optional[Union[str, torch.device]] = None) -> "HMM":
|
|
126
|
+
"""
|
|
127
|
+
Construct an HMM using keys from an ExperimentConfig instance or a plain dict.
|
|
128
|
+
|
|
129
|
+
cfg may be:
|
|
130
|
+
- an ExperimentConfig (your dataclass instance)
|
|
131
|
+
- a dict (e.g. loader.var_dict or merged defaults)
|
|
132
|
+
- None (uses internal defaults)
|
|
133
|
+
|
|
134
|
+
override: optional dict to override resolved keys (handy for tests).
|
|
135
|
+
device: optional device string or torch.device to move model to.
|
|
136
|
+
"""
|
|
137
|
+
# Accept ExperimentConfig dataclass
|
|
138
|
+
if cfg is None:
|
|
139
|
+
merged = {}
|
|
140
|
+
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
141
|
+
merged = dict(cfg.to_dict())
|
|
142
|
+
elif isinstance(cfg, dict):
|
|
143
|
+
merged = dict(cfg)
|
|
144
|
+
else:
|
|
145
|
+
# try attr access as fallback
|
|
146
|
+
try:
|
|
147
|
+
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
148
|
+
except Exception:
|
|
149
|
+
merged = {}
|
|
150
|
+
|
|
151
|
+
if override:
|
|
152
|
+
merged.update(override)
|
|
153
|
+
|
|
154
|
+
# basic resolution with fallback
|
|
155
|
+
n_states = int(merged.get("hmm_n_states", merged.get("n_states", 2)))
|
|
156
|
+
init_start = merged.get("hmm_init_start_probs", merged.get("hmm_init_start", None))
|
|
157
|
+
init_trans = merged.get("hmm_init_transition_probs", merged.get("hmm_init_trans", None))
|
|
158
|
+
init_emission = merged.get("hmm_init_emission_probs", merged.get("hmm_init_emission", None))
|
|
159
|
+
eps = float(merged.get("hmm_eps", merged.get("eps", 1e-8)))
|
|
160
|
+
dtype = cls._resolve_dtype(merged.get("hmm_dtype", merged.get("dtype", None)))
|
|
161
|
+
|
|
162
|
+
# coerce lists (if present) -> numpy arrays (the HMM constructor already sanitizes)
|
|
163
|
+
def _coerce_np(x):
|
|
164
|
+
if x is None:
|
|
165
|
+
return None
|
|
166
|
+
return np.asarray(x, dtype=float)
|
|
167
|
+
|
|
168
|
+
init_start = _coerce_np(init_start)
|
|
169
|
+
init_trans = _coerce_np(init_trans)
|
|
170
|
+
init_emission = _coerce_np(init_emission)
|
|
171
|
+
|
|
172
|
+
model = cls(
|
|
173
|
+
n_states=n_states,
|
|
174
|
+
init_start=init_start,
|
|
175
|
+
init_trans=init_trans,
|
|
176
|
+
init_emission=init_emission,
|
|
177
|
+
dtype=dtype,
|
|
178
|
+
eps=eps,
|
|
179
|
+
smf_modality=merged.get("smf_modality", None),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# move to device if requested
|
|
183
|
+
if device is not None:
|
|
184
|
+
if isinstance(device, str):
|
|
185
|
+
device = torch.device(device)
|
|
186
|
+
model.to(device)
|
|
187
|
+
|
|
188
|
+
# persist the config to the hmm class
|
|
189
|
+
cls.config = cfg
|
|
190
|
+
|
|
191
|
+
return model
|
|
192
|
+
|
|
193
|
+
def update_from_config(self, cfg: Union[dict, "ExperimentConfig", None], *,
|
|
194
|
+
override: Optional[dict] = None):
|
|
195
|
+
"""
|
|
196
|
+
Update existing model parameters from a config or dict (in-place).
|
|
197
|
+
This will normalize / reinitialize start/trans/emission using same logic as constructor.
|
|
198
|
+
"""
|
|
199
|
+
if cfg is None:
|
|
200
|
+
merged = {}
|
|
201
|
+
elif hasattr(cfg, "to_dict") and callable(getattr(cfg, "to_dict")):
|
|
202
|
+
merged = dict(cfg.to_dict())
|
|
203
|
+
elif isinstance(cfg, dict):
|
|
204
|
+
merged = dict(cfg)
|
|
205
|
+
else:
|
|
206
|
+
try:
|
|
207
|
+
merged = {k: getattr(cfg, k) for k in dir(cfg) if k.startswith("hmm_")}
|
|
208
|
+
except Exception:
|
|
209
|
+
merged = {}
|
|
210
|
+
|
|
211
|
+
if override:
|
|
212
|
+
merged.update(override)
|
|
213
|
+
|
|
214
|
+
# extract same keys as from_config
|
|
215
|
+
n_states = int(merged.get("hmm_n_states", self.n_states))
|
|
216
|
+
init_start = merged.get("hmm_init_start_probs", None)
|
|
217
|
+
init_trans = merged.get("hmm_init_transition_probs", None)
|
|
218
|
+
init_emission = merged.get("hmm_init_emission_probs", None)
|
|
219
|
+
eps = merged.get("hmm_eps", None)
|
|
220
|
+
dtype = merged.get("hmm_dtype", None)
|
|
221
|
+
|
|
222
|
+
# apply dtype/eps if present
|
|
223
|
+
if eps is not None:
|
|
224
|
+
self.eps = float(eps)
|
|
225
|
+
if dtype is not None:
|
|
226
|
+
self.dtype = self._resolve_dtype(dtype)
|
|
227
|
+
|
|
228
|
+
# if n_states changes we need a fresh re-init (easy approach: reconstruct)
|
|
229
|
+
if int(n_states) != int(self.n_states):
|
|
230
|
+
# rebuild self in-place: create a new model and copy tensors
|
|
231
|
+
new_model = HMM.from_config(merged)
|
|
232
|
+
# copy content
|
|
233
|
+
with torch.no_grad():
|
|
234
|
+
self.n_states = new_model.n_states
|
|
235
|
+
self.eps = new_model.eps
|
|
236
|
+
self.dtype = new_model.dtype
|
|
237
|
+
self.start.data = new_model.start.data.clone().to(self.start.device, dtype=self.dtype)
|
|
238
|
+
self.trans.data = new_model.trans.data.clone().to(self.trans.device, dtype=self.dtype)
|
|
239
|
+
self.emission.data = new_model.emission.data.clone().to(self.emission.device, dtype=self.dtype)
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
# else only update provided tensors
|
|
243
|
+
def _to_tensor(obj, shape_expected=None):
|
|
244
|
+
if obj is None:
|
|
245
|
+
return None
|
|
246
|
+
arr = np.asarray(obj, dtype=float)
|
|
247
|
+
if shape_expected is not None:
|
|
248
|
+
try:
|
|
249
|
+
arr = arr.reshape(shape_expected)
|
|
250
|
+
except Exception:
|
|
251
|
+
# try to free-form slice/reshape (keep best-effort)
|
|
252
|
+
arr = np.reshape(arr, shape_expected) if arr.size >= np.prod(shape_expected) else arr
|
|
253
|
+
return torch.tensor(arr, dtype=self.dtype, device=self.start.device)
|
|
254
|
+
|
|
255
|
+
with torch.no_grad():
|
|
256
|
+
if init_start is not None:
|
|
257
|
+
t = _to_tensor(init_start, (self.n_states,))
|
|
258
|
+
if t.numel() == self.n_states:
|
|
259
|
+
self.start.data = t.clone()
|
|
260
|
+
if init_trans is not None:
|
|
261
|
+
t = _to_tensor(init_trans, (self.n_states, self.n_states))
|
|
262
|
+
if t.shape == (self.n_states, self.n_states):
|
|
263
|
+
self.trans.data = t.clone()
|
|
264
|
+
if init_emission is not None:
|
|
265
|
+
# attempt to extract P(obs==1) if shaped (K,2)
|
|
266
|
+
arr = np.asarray(init_emission, dtype=float)
|
|
267
|
+
if arr.ndim == 2 and arr.shape[1] == 2 and arr.shape[0] >= self.n_states:
|
|
268
|
+
em = arr[: self.n_states, 1]
|
|
269
|
+
else:
|
|
270
|
+
em = arr.reshape(-1)[: self.n_states]
|
|
271
|
+
t = torch.tensor(em, dtype=self.dtype, device=self.start.device)
|
|
272
|
+
if t.numel() == self.n_states:
|
|
273
|
+
self.emission.data = t.clone()
|
|
274
|
+
|
|
275
|
+
# finally normalize
|
|
276
|
+
self._normalize_params()
|
|
277
|
+
|
|
278
|
+
def _ensure_device_dtype(self, device: Optional[torch.device]):
|
|
279
|
+
if device is None:
|
|
280
|
+
device = next(self.parameters()).device
|
|
281
|
+
self.start.data = self.start.data.to(device=device, dtype=self.dtype)
|
|
282
|
+
self.trans.data = self.trans.data.to(device=device, dtype=self.dtype)
|
|
283
|
+
self.emission.data = self.emission.data.to(device=device, dtype=self.dtype)
|
|
284
|
+
return device
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _pad_and_mask(
|
|
288
|
+
data: List[List],
|
|
289
|
+
device: torch.device,
|
|
290
|
+
dtype: torch.dtype,
|
|
291
|
+
impute_strategy: str = "ignore",
|
|
292
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
293
|
+
"""
|
|
294
|
+
Pads sequences to shape (B, L). Returns (obs, mask, lengths)
|
|
295
|
+
- Accepts: list-of-seqs, or 2D ndarray (B, L).
|
|
296
|
+
- If a sequence element is itself an array (per-timestep feature vector),
|
|
297
|
+
collapse the last axis by mean (warns once).
|
|
298
|
+
"""
|
|
299
|
+
import warnings
|
|
300
|
+
|
|
301
|
+
# If somebody passed a 2-D ndarray directly, convert to list-of-rows
|
|
302
|
+
if isinstance(data, np.ndarray) and data.ndim == 2:
|
|
303
|
+
# convert rows -> python lists (scalars per timestep)
|
|
304
|
+
data = data.tolist()
|
|
305
|
+
|
|
306
|
+
B = len(data)
|
|
307
|
+
lengths = torch.tensor([len(s) for s in data], dtype=torch.long, device=device)
|
|
308
|
+
L = int(lengths.max().item()) if B > 0 else 0
|
|
309
|
+
obs = torch.zeros((B, L), dtype=dtype, device=device)
|
|
310
|
+
mask = torch.zeros((B, L), dtype=torch.bool, device=device)
|
|
311
|
+
|
|
312
|
+
warned_collapse = False
|
|
313
|
+
|
|
314
|
+
for i, seq in enumerate(data):
|
|
315
|
+
# seq may be list/ndarray of scalars OR list/ndarray of per-timestep arrays
|
|
316
|
+
arr = np.asarray(seq, dtype=float)
|
|
317
|
+
|
|
318
|
+
# If arr is shape (L,1,1,...) squeeze trailing singletons
|
|
319
|
+
while arr.ndim > 1 and arr.shape[-1] == 1:
|
|
320
|
+
arr = np.squeeze(arr, axis=-1)
|
|
321
|
+
|
|
322
|
+
# If arr is still >1D (e.g., (L, F)), collapse the last axis by mean
|
|
323
|
+
if arr.ndim > 1:
|
|
324
|
+
if not warned_collapse:
|
|
325
|
+
warnings.warn(
|
|
326
|
+
"HMM._pad_and_mask: collapsing per-timestep feature axis by mean "
|
|
327
|
+
"(arr had shape {}). If you prefer a different reduction, "
|
|
328
|
+
"preprocess your data.".format(arr.shape),
|
|
329
|
+
stacklevel=2,
|
|
330
|
+
)
|
|
331
|
+
warned_collapse = True
|
|
332
|
+
# collapse features -> scalar per timestep
|
|
333
|
+
arr = np.asarray(arr, dtype=float).mean(axis=-1)
|
|
334
|
+
|
|
335
|
+
# now arr should be 1D (T,)
|
|
336
|
+
if arr.ndim == 0:
|
|
337
|
+
# single scalar: treat as length-1 sequence
|
|
338
|
+
arr = np.atleast_1d(arr)
|
|
339
|
+
|
|
340
|
+
nan_mask = np.isnan(arr)
|
|
341
|
+
if impute_strategy == "random" and nan_mask.any():
|
|
342
|
+
arr[nan_mask] = np.random.choice([0, 1], size=nan_mask.sum())
|
|
343
|
+
local_mask = np.ones_like(arr, dtype=bool)
|
|
344
|
+
else:
|
|
345
|
+
local_mask = ~nan_mask
|
|
346
|
+
arr = np.where(local_mask, arr, 0.0)
|
|
347
|
+
|
|
348
|
+
L_i = arr.shape[0]
|
|
349
|
+
obs[i, :L_i] = torch.tensor(arr, dtype=dtype, device=device)
|
|
350
|
+
mask[i, :L_i] = torch.tensor(local_mask, dtype=torch.bool, device=device)
|
|
351
|
+
|
|
352
|
+
return obs, mask, lengths
|
|
353
|
+
|
|
354
|
+
def _log_emission(self, obs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
"""
|
|
356
|
+
obs: (B, L)
|
|
357
|
+
mask: (B, L) bool
|
|
358
|
+
returns logB (B, L, K)
|
|
359
|
+
"""
|
|
360
|
+
B, L = obs.shape
|
|
361
|
+
p = self.emission # (K,)
|
|
362
|
+
logp = torch.log(p + self.eps)
|
|
363
|
+
log1mp = torch.log1p(-p + self.eps)
|
|
364
|
+
obs_expand = obs.unsqueeze(-1) # (B, L, 1)
|
|
365
|
+
logB = obs_expand * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs_expand) * log1mp.unsqueeze(0).unsqueeze(0)
|
|
366
|
+
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
367
|
+
return logB
|
|
368
|
+
|
|
369
|
+
def fit(
|
|
370
|
+
self,
|
|
371
|
+
data: List[List],
|
|
372
|
+
max_iter: int = 100,
|
|
373
|
+
tol: float = 1e-4,
|
|
374
|
+
impute_strategy: str = "ignore",
|
|
375
|
+
verbose: bool = True,
|
|
376
|
+
return_history: bool = False,
|
|
377
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
378
|
+
):
|
|
379
|
+
"""
|
|
380
|
+
Vectorized Baum-Welch EM across a batch of sequences (padded).
|
|
381
|
+
"""
|
|
382
|
+
if device is None:
|
|
383
|
+
device = next(self.parameters()).device
|
|
384
|
+
elif isinstance(device, str):
|
|
385
|
+
device = torch.device(device)
|
|
386
|
+
device = self._ensure_device_dtype(device)
|
|
387
|
+
|
|
388
|
+
if isinstance(data, np.ndarray):
|
|
389
|
+
if data.ndim == 2:
|
|
390
|
+
# rows are sequences: convert to list of 1D arrays
|
|
391
|
+
data = data.tolist()
|
|
392
|
+
elif data.ndim == 1:
|
|
393
|
+
# single sequence
|
|
394
|
+
data = [data.tolist()]
|
|
395
|
+
else:
|
|
396
|
+
raise ValueError(f"Expected data to be 1D or 2D ndarray; got array with ndim={data.ndim}")
|
|
397
|
+
|
|
398
|
+
obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
399
|
+
B, L = obs.shape
|
|
400
|
+
K = self.n_states
|
|
401
|
+
eps = float(self.eps)
|
|
402
|
+
|
|
403
|
+
if verbose:
|
|
404
|
+
print(f"[HMM.fit] device={device}, batch={B}, max_len={L}, states={K}")
|
|
405
|
+
|
|
406
|
+
loglik_history = []
|
|
407
|
+
|
|
408
|
+
for it in range(1, max_iter + 1):
|
|
409
|
+
if verbose:
|
|
410
|
+
print(f"[HMM.fit] EM iter {it}")
|
|
411
|
+
|
|
412
|
+
# compute batched emission logs
|
|
413
|
+
logB = self._log_emission(obs, mask) # (B, L, K)
|
|
414
|
+
|
|
415
|
+
# logs for start and transition
|
|
416
|
+
logA = torch.log(self.trans + eps) # (K, K)
|
|
417
|
+
logstart = torch.log(self.start + eps) # (K,)
|
|
418
|
+
|
|
419
|
+
# Forward (batched)
|
|
420
|
+
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
421
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :] # (B,K)
|
|
422
|
+
for t in range(1, L):
|
|
423
|
+
# prev: (B, i, 1) + (1, i, j) broadcast => (B, i, j)
|
|
424
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
|
|
425
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
426
|
+
|
|
427
|
+
# Backward (batched)
|
|
428
|
+
beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
429
|
+
beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
|
|
430
|
+
for t in range(L - 2, -1, -1):
|
|
431
|
+
# temp (B, i, j) = logA[i,j] + logB[:,t+1,j] + beta[:,t+1,j]
|
|
432
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
|
|
433
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
434
|
+
|
|
435
|
+
# sequence log-likelihoods (use last real index)
|
|
436
|
+
last_idx = (lengths - 1).clamp(min=0)
|
|
437
|
+
idx_range = torch.arange(B, device=device)
|
|
438
|
+
final_alpha = alpha[idx_range, last_idx, :] # (B, K)
|
|
439
|
+
seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
|
|
440
|
+
total_loglike = float(seq_loglikes.sum().item())
|
|
441
|
+
|
|
442
|
+
# posterior gamma (B, L, K)
|
|
443
|
+
log_gamma = alpha + beta # (B, L, K)
|
|
444
|
+
logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True) # (B, L, 1)
|
|
445
|
+
gamma = (log_gamma - logZ_time).exp() # (B, L, K)
|
|
446
|
+
|
|
447
|
+
# accumulators: starts, transitions, emissions
|
|
448
|
+
gamma_start_accum = gamma[:, 0, :].sum(dim=0) # (K,)
|
|
449
|
+
|
|
450
|
+
# emission accumulators: sum over observed positions only
|
|
451
|
+
mask_f = mask.unsqueeze(-1) # (B, L, 1)
|
|
452
|
+
emit_num = (gamma * obs.unsqueeze(-1) * mask_f).sum(dim=(0, 1)) # (K,)
|
|
453
|
+
emit_den = (gamma * mask_f).sum(dim=(0, 1)) # (K,)
|
|
454
|
+
|
|
455
|
+
# transitions: accumulate xi across t for valid positions
|
|
456
|
+
trans_accum = torch.zeros((K, K), dtype=self.dtype, device=device)
|
|
457
|
+
if L >= 2:
|
|
458
|
+
time_idx = torch.arange(L - 1, device=device).unsqueeze(0).expand(B, L - 1) # (B, L-1)
|
|
459
|
+
valid = time_idx < (lengths.unsqueeze(1) - 1) # (B, L-1) bool
|
|
460
|
+
for t in range(L - 1):
|
|
461
|
+
a_t = alpha[:, t, :].unsqueeze(2) # (B, i, 1)
|
|
462
|
+
b_next = (logB[:, t + 1, :] + beta[:, t + 1, :]).unsqueeze(1) # (B, 1, j)
|
|
463
|
+
log_xi_unnorm = a_t + logA.unsqueeze(0) + b_next # (B, i, j)
|
|
464
|
+
log_xi_flat = log_xi_unnorm.view(B, -1) # (B, i*j)
|
|
465
|
+
log_norm = _logsumexp(log_xi_flat, dim=1).unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
466
|
+
xi = (log_xi_unnorm - log_norm).exp() # (B,i,j)
|
|
467
|
+
valid_t = valid[:, t].float().unsqueeze(1).unsqueeze(2) # (B,1,1)
|
|
468
|
+
xi_masked = xi * valid_t
|
|
469
|
+
trans_accum += xi_masked.sum(dim=0) # (i,j)
|
|
470
|
+
|
|
471
|
+
# M-step: update parameters with smoothing
|
|
472
|
+
with torch.no_grad():
|
|
473
|
+
new_start = gamma_start_accum + eps
|
|
474
|
+
new_start = new_start / new_start.sum()
|
|
475
|
+
|
|
476
|
+
new_trans = trans_accum + eps
|
|
477
|
+
row_sums = new_trans.sum(dim=1, keepdim=True)
|
|
478
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
479
|
+
new_trans = new_trans / row_sums
|
|
480
|
+
|
|
481
|
+
new_emission = (emit_num + eps) / (emit_den + 2.0 * eps)
|
|
482
|
+
new_emission = new_emission.clamp(min=eps, max=1.0 - eps)
|
|
483
|
+
|
|
484
|
+
self.start.data = new_start
|
|
485
|
+
self.trans.data = new_trans
|
|
486
|
+
self.emission.data = new_emission
|
|
487
|
+
|
|
488
|
+
loglik_history.append(total_loglike)
|
|
489
|
+
if verbose:
|
|
490
|
+
print(f" total loglik = {total_loglike:.6f}")
|
|
491
|
+
|
|
492
|
+
if len(loglik_history) > 1 and abs(loglik_history[-1] - loglik_history[-2]) < tol:
|
|
493
|
+
if verbose:
|
|
494
|
+
print(f"[HMM.fit] converged (Δll < {tol}) at iter {it}")
|
|
495
|
+
break
|
|
496
|
+
|
|
497
|
+
return loglik_history if return_history else None
|
|
498
|
+
|
|
499
|
+
def get_params(self) -> dict:
|
|
500
|
+
"""
|
|
501
|
+
Return model parameters as numpy arrays on CPU.
|
|
502
|
+
"""
|
|
503
|
+
with torch.no_grad():
|
|
504
|
+
return {
|
|
505
|
+
"n_states": int(self.n_states),
|
|
506
|
+
"start": self.start.detach().cpu().numpy().astype(float).reshape(-1),
|
|
507
|
+
"trans": self.trans.detach().cpu().numpy().astype(float),
|
|
508
|
+
"emission": self.emission.detach().cpu().numpy().astype(float).reshape(-1),
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
def print_params(self, decimals: int = 4):
|
|
512
|
+
"""
|
|
513
|
+
Nicely print start, transition, and emission probabilities.
|
|
514
|
+
"""
|
|
515
|
+
params = self.get_params()
|
|
516
|
+
K = params["n_states"]
|
|
517
|
+
fmt = f"{{:.{decimals}f}}"
|
|
518
|
+
print(f"HMM params (K={K} states):")
|
|
519
|
+
print(" start probs:")
|
|
520
|
+
print(" [" + ", ".join(fmt.format(v) for v in params["start"]) + "]")
|
|
521
|
+
print(" transition matrix (rows = from-state, cols = to-state):")
|
|
522
|
+
for i, row in enumerate(params["trans"]):
|
|
523
|
+
print(" s{:d}: [".format(i) + ", ".join(fmt.format(v) for v in row) + "]")
|
|
524
|
+
print(" emission P(obs==1 | state):")
|
|
525
|
+
for i, v in enumerate(params["emission"]):
|
|
526
|
+
print(f" s{i}: {fmt.format(v)}")
|
|
527
|
+
|
|
528
|
+
def to_dataframes(self) -> dict:
|
|
529
|
+
"""
|
|
530
|
+
Return pandas DataFrames for start (Series), trans (DataFrame), emission (Series).
|
|
531
|
+
"""
|
|
532
|
+
p = self.get_params()
|
|
533
|
+
K = p["n_states"]
|
|
534
|
+
state_names = [f"state_{i}" for i in range(K)]
|
|
535
|
+
start_s = pd.Series(p["start"], index=state_names, name="start_prob")
|
|
536
|
+
trans_df = pd.DataFrame(p["trans"], index=state_names, columns=state_names)
|
|
537
|
+
emission_s = pd.Series(p["emission"], index=state_names, name="p_obs1")
|
|
538
|
+
return {"start": start_s, "trans": trans_df, "emission": emission_s}
|
|
539
|
+
|
|
540
|
+
def predict(self, data: List[List], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> List[np.ndarray]:
|
|
541
|
+
"""
|
|
542
|
+
Return posterior marginals gamma_t(k) for each sequence as list of (L, K) numpy arrays.
|
|
543
|
+
"""
|
|
544
|
+
if device is None:
|
|
545
|
+
device = next(self.parameters()).device
|
|
546
|
+
elif isinstance(device, str):
|
|
547
|
+
device = torch.device(device)
|
|
548
|
+
device = self._ensure_device_dtype(device)
|
|
549
|
+
|
|
550
|
+
obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
551
|
+
B, L = obs.shape
|
|
552
|
+
K = self.n_states
|
|
553
|
+
eps = float(self.eps)
|
|
554
|
+
|
|
555
|
+
logB = self._log_emission(obs, mask) # (B, L, K)
|
|
556
|
+
logA = torch.log(self.trans + eps)
|
|
557
|
+
logstart = torch.log(self.start + eps)
|
|
558
|
+
|
|
559
|
+
# Forward
|
|
560
|
+
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
561
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
562
|
+
for t in range(1, L):
|
|
563
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
564
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
565
|
+
|
|
566
|
+
# Backward
|
|
567
|
+
beta = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
568
|
+
beta[:, L - 1, :] = torch.zeros((K,), dtype=self.dtype, device=device).unsqueeze(0).expand(B, K)
|
|
569
|
+
for t in range(L - 2, -1, -1):
|
|
570
|
+
temp = logA.unsqueeze(0) + (logB[:, t + 1, :].unsqueeze(1) + beta[:, t + 1, :].unsqueeze(1))
|
|
571
|
+
beta[:, t, :] = _logsumexp(temp, dim=2)
|
|
572
|
+
|
|
573
|
+
# gamma
|
|
574
|
+
log_gamma = alpha + beta
|
|
575
|
+
logZ_time = _logsumexp(log_gamma, dim=2, keepdim=True)
|
|
576
|
+
gamma = (log_gamma - logZ_time).exp() # (B, L, K)
|
|
577
|
+
|
|
578
|
+
results = []
|
|
579
|
+
for i in range(B):
|
|
580
|
+
L_i = int(lengths[i].item())
|
|
581
|
+
results.append(gamma[i, :L_i, :].detach().cpu().numpy())
|
|
582
|
+
return results
|
|
583
|
+
|
|
584
|
+
def score(self, seq_or_list: Union[List[float], List[List[float]]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Union[float, List[float]]:
|
|
585
|
+
"""
|
|
586
|
+
Compute log-likelihood of a single sequence or list of sequences under current params.
|
|
587
|
+
Returns float (single) or list of floats (batch).
|
|
588
|
+
"""
|
|
589
|
+
single = False
|
|
590
|
+
if not isinstance(seq_or_list[0], (list, tuple, np.ndarray)):
|
|
591
|
+
seqs = [seq_or_list]
|
|
592
|
+
single = True
|
|
593
|
+
else:
|
|
594
|
+
seqs = seq_or_list
|
|
595
|
+
|
|
596
|
+
if device is None:
|
|
597
|
+
device = next(self.parameters()).device
|
|
598
|
+
elif isinstance(device, str):
|
|
599
|
+
device = torch.device(device)
|
|
600
|
+
device = self._ensure_device_dtype(device)
|
|
601
|
+
|
|
602
|
+
obs, mask, lengths = self._pad_and_mask(seqs, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
603
|
+
B, L = obs.shape
|
|
604
|
+
K = self.n_states
|
|
605
|
+
eps = float(self.eps)
|
|
606
|
+
|
|
607
|
+
logB = self._log_emission(obs, mask)
|
|
608
|
+
logA = torch.log(self.trans + eps)
|
|
609
|
+
logstart = torch.log(self.start + eps)
|
|
610
|
+
|
|
611
|
+
alpha = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
612
|
+
alpha[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
613
|
+
for t in range(1, L):
|
|
614
|
+
prev = alpha[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0)
|
|
615
|
+
alpha[:, t, :] = _logsumexp(prev, dim=1) + logB[:, t, :]
|
|
616
|
+
|
|
617
|
+
last_idx = (lengths - 1).clamp(min=0)
|
|
618
|
+
idx_range = torch.arange(B, device=device)
|
|
619
|
+
final_alpha = alpha[idx_range, last_idx, :] # (B, K)
|
|
620
|
+
seq_loglikes = _logsumexp(final_alpha, dim=1) # (B,)
|
|
621
|
+
seq_loglikes = seq_loglikes.detach().cpu().numpy().tolist()
|
|
622
|
+
return seq_loglikes[0] if single else seq_loglikes
|
|
623
|
+
|
|
624
|
+
def viterbi(self, seq: List[float], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[int], float]:
|
|
625
|
+
"""
|
|
626
|
+
Viterbi decode a single sequence. Returns (state_path, log_probability_of_path).
|
|
627
|
+
"""
|
|
628
|
+
paths, scores = self.batch_viterbi([seq], impute_strategy=impute_strategy, device=device)
|
|
629
|
+
return paths[0], scores[0]
|
|
630
|
+
|
|
631
|
+
def batch_viterbi(self, data: List[List[float]], impute_strategy: str = "ignore", device: Optional[Union[torch.device, str]] = None) -> Tuple[List[List[int]], List[float]]:
|
|
632
|
+
"""
|
|
633
|
+
Batched Viterbi decoding on padded sequences. Returns (list_of_paths, list_of_scores).
|
|
634
|
+
Each path is the length of the original sequence.
|
|
635
|
+
"""
|
|
636
|
+
if device is None:
|
|
637
|
+
device = next(self.parameters()).device
|
|
638
|
+
elif isinstance(device, str):
|
|
639
|
+
device = torch.device(device)
|
|
640
|
+
device = self._ensure_device_dtype(device)
|
|
641
|
+
|
|
642
|
+
obs, mask, lengths = self._pad_and_mask(data, device=device, dtype=self.dtype, impute_strategy=impute_strategy)
|
|
643
|
+
B, L = obs.shape
|
|
644
|
+
K = self.n_states
|
|
645
|
+
eps = float(self.eps)
|
|
646
|
+
|
|
647
|
+
p = self.emission
|
|
648
|
+
logp = torch.log(p + eps)
|
|
649
|
+
log1mp = torch.log1p(-p + eps)
|
|
650
|
+
logB = obs.unsqueeze(-1) * logp.unsqueeze(0).unsqueeze(0) + (1.0 - obs.unsqueeze(-1)) * log1mp.unsqueeze(0).unsqueeze(0)
|
|
651
|
+
logB = torch.where(mask.unsqueeze(-1), logB, torch.zeros_like(logB))
|
|
652
|
+
|
|
653
|
+
logstart = torch.log(self.start + eps)
|
|
654
|
+
logA = torch.log(self.trans + eps)
|
|
655
|
+
|
|
656
|
+
# delta (score) and psi (argmax pointers)
|
|
657
|
+
delta = torch.empty((B, L, K), dtype=self.dtype, device=device)
|
|
658
|
+
psi = torch.zeros((B, L, K), dtype=torch.long, device=device)
|
|
659
|
+
|
|
660
|
+
delta[:, 0, :] = logstart.unsqueeze(0) + logB[:, 0, :]
|
|
661
|
+
psi[:, 0, :] = -1 # sentinel
|
|
662
|
+
|
|
663
|
+
for t in range(1, L):
|
|
664
|
+
# cand shape (B, i, j)
|
|
665
|
+
cand = delta[:, t - 1, :].unsqueeze(2) + logA.unsqueeze(0) # (B, K, K)
|
|
666
|
+
best_val, best_idx = cand.max(dim=1) # best over previous i: results (B, j)
|
|
667
|
+
delta[:, t, :] = best_val + logB[:, t, :]
|
|
668
|
+
psi[:, t, :] = best_idx # best previous state index for each (B, j)
|
|
669
|
+
|
|
670
|
+
# backtrack
|
|
671
|
+
last_idx = (lengths - 1).clamp(min=0)
|
|
672
|
+
idx_range = torch.arange(B, device=device)
|
|
673
|
+
final_delta = delta[idx_range, last_idx, :] # (B, K)
|
|
674
|
+
best_last_val, best_last_state = final_delta.max(dim=1) # (B,), (B,)
|
|
675
|
+
paths = []
|
|
676
|
+
scores = []
|
|
677
|
+
for b in range(B):
|
|
678
|
+
Lb = int(lengths[b].item())
|
|
679
|
+
if Lb == 0:
|
|
680
|
+
paths.append([])
|
|
681
|
+
scores.append(float("-inf"))
|
|
682
|
+
continue
|
|
683
|
+
s = int(best_last_state[b].item())
|
|
684
|
+
path = [s]
|
|
685
|
+
for t in range(Lb - 1, 0, -1):
|
|
686
|
+
s = int(psi[b, t, s].item())
|
|
687
|
+
path.append(s)
|
|
688
|
+
path.reverse()
|
|
689
|
+
paths.append(path)
|
|
690
|
+
scores.append(float(best_last_val[b].item()))
|
|
691
|
+
return paths, scores
|
|
692
|
+
|
|
693
|
+
def save(self, path: str) -> None:
|
|
694
|
+
"""
|
|
695
|
+
Save HMM to `path` using torch.save. Stores:
|
|
696
|
+
- n_states, eps, dtype (string)
|
|
697
|
+
- start, trans, emission (CPU tensors)
|
|
698
|
+
"""
|
|
699
|
+
payload = {
|
|
700
|
+
"n_states": int(self.n_states),
|
|
701
|
+
"eps": float(self.eps),
|
|
702
|
+
# store dtype as a string like "torch.float64" (portable)
|
|
703
|
+
"dtype": str(self.dtype),
|
|
704
|
+
"start": self.start.detach().cpu(),
|
|
705
|
+
"trans": self.trans.detach().cpu(),
|
|
706
|
+
"emission": self.emission.detach().cpu(),
|
|
707
|
+
}
|
|
708
|
+
torch.save(payload, path)
|
|
709
|
+
|
|
710
|
+
@classmethod
|
|
711
|
+
def load(cls, path: str, device: Optional[Union[torch.device, str]] = None) -> "HMM":
|
|
712
|
+
"""
|
|
713
|
+
Load model from `path`. If `device` is provided (str or torch.device),
|
|
714
|
+
parameters will be moved to that device; otherwise they remain on CPU.
|
|
715
|
+
Example: model = HMM.load('hmm.pt', device='cuda')
|
|
716
|
+
"""
|
|
717
|
+
payload = torch.load(path, map_location="cpu")
|
|
718
|
+
|
|
719
|
+
n_states = int(payload.get("n_states"))
|
|
720
|
+
eps = float(payload.get("eps", 1e-8))
|
|
721
|
+
dtype_entry = payload.get("dtype", "torch.float64")
|
|
722
|
+
|
|
723
|
+
# Resolve dtype string robustly:
|
|
724
|
+
# Accept "torch.float64" or "float64" or actual torch.dtype (older payloads)
|
|
725
|
+
if isinstance(dtype_entry, torch.dtype):
|
|
726
|
+
torch_dtype = dtype_entry
|
|
727
|
+
else:
|
|
728
|
+
# dtype_entry expected to be a string
|
|
729
|
+
dtype_str = str(dtype_entry)
|
|
730
|
+
# take last part after dot if present: "torch.float64" -> "float64"
|
|
731
|
+
name = dtype_str.split(".")[-1]
|
|
732
|
+
# map to torch dtype if available, else fallback mapping
|
|
733
|
+
if hasattr(torch, name):
|
|
734
|
+
torch_dtype = getattr(torch, name)
|
|
735
|
+
else:
|
|
736
|
+
fallback = {"float64": torch.float64, "float32": torch.float32, "float16": torch.float16}
|
|
737
|
+
torch_dtype = fallback.get(name, torch.float64)
|
|
738
|
+
|
|
739
|
+
# Build instance (use resolved dtype)
|
|
740
|
+
model = cls(n_states=n_states, dtype=torch_dtype, eps=eps)
|
|
741
|
+
|
|
742
|
+
# Determine target device
|
|
743
|
+
if device is None:
|
|
744
|
+
device = torch.device("cpu")
|
|
745
|
+
elif isinstance(device, str):
|
|
746
|
+
device = torch.device(device)
|
|
747
|
+
|
|
748
|
+
# Load params (they were saved on CPU) and cast to model dtype/device
|
|
749
|
+
with torch.no_grad():
|
|
750
|
+
model.start.data = payload["start"].to(device=device, dtype=model.dtype)
|
|
751
|
+
model.trans.data = payload["trans"].to(device=device, dtype=model.dtype)
|
|
752
|
+
model.emission.data = payload["emission"].to(device=device, dtype=model.dtype)
|
|
753
|
+
|
|
754
|
+
# Normalize / coerce shapes just in case
|
|
755
|
+
model._normalize_params()
|
|
756
|
+
return model
|
|
757
|
+
|
|
758
|
+
def annotate_adata(
|
|
759
|
+
self,
|
|
760
|
+
adata,
|
|
761
|
+
obs_column: str,
|
|
762
|
+
layer: Optional[str] = None,
|
|
763
|
+
footprints: Optional[bool] = None,
|
|
764
|
+
accessible_patches: Optional[bool] = None,
|
|
765
|
+
cpg: Optional[bool] = None,
|
|
766
|
+
methbases: Optional[List[str]] = None,
|
|
767
|
+
threshold: Optional[float] = None,
|
|
768
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
769
|
+
batch_size: Optional[int] = None,
|
|
770
|
+
use_viterbi: Optional[bool] = None,
|
|
771
|
+
in_place: bool = True,
|
|
772
|
+
verbose: bool = True,
|
|
773
|
+
uns_key: str = "hmm_appended_layers",
|
|
774
|
+
config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
|
|
775
|
+
):
|
|
776
|
+
"""
|
|
777
|
+
Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
|
|
778
|
+
|
|
779
|
+
Parameters
|
|
780
|
+
----------
|
|
781
|
+
config : optional ExperimentConfig instance or plain dict
|
|
782
|
+
When provided, the following keys (if present) are used to override defaults:
|
|
783
|
+
- hmm_feature_sets : dict (canonical feature set structure) OR a JSON/string repr
|
|
784
|
+
- hmm_annotation_threshold : float
|
|
785
|
+
- hmm_batch_size : int
|
|
786
|
+
- hmm_use_viterbi : bool
|
|
787
|
+
- hmm_methbases : list
|
|
788
|
+
- footprints / accessible_patches / cpg (booleans)
|
|
789
|
+
Other keyword args override config values if explicitly provided.
|
|
790
|
+
"""
|
|
791
|
+
import json, ast, warnings
|
|
792
|
+
import numpy as _np
|
|
793
|
+
import torch as _torch
|
|
794
|
+
from tqdm import trange, tqdm as _tqdm
|
|
795
|
+
|
|
796
|
+
# small helpers
|
|
797
|
+
def _try_json_or_literal(s):
|
|
798
|
+
if s is None:
|
|
799
|
+
return None
|
|
800
|
+
if not isinstance(s, str):
|
|
801
|
+
return s
|
|
802
|
+
s0 = s.strip()
|
|
803
|
+
if s0 == "":
|
|
804
|
+
return None
|
|
805
|
+
try:
|
|
806
|
+
return json.loads(s0)
|
|
807
|
+
except Exception:
|
|
808
|
+
pass
|
|
809
|
+
try:
|
|
810
|
+
return ast.literal_eval(s0)
|
|
811
|
+
except Exception:
|
|
812
|
+
return s
|
|
813
|
+
|
|
814
|
+
def _coerce_bool(x):
|
|
815
|
+
if x is None:
|
|
816
|
+
return False
|
|
817
|
+
if isinstance(x, bool):
|
|
818
|
+
return x
|
|
819
|
+
if isinstance(x, (int, float)):
|
|
820
|
+
return bool(x)
|
|
821
|
+
s = str(x).strip().lower()
|
|
822
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
823
|
+
|
|
824
|
+
def normalize_hmm_feature_sets(raw):
|
|
825
|
+
if raw is None:
|
|
826
|
+
return {}
|
|
827
|
+
parsed = raw
|
|
828
|
+
if isinstance(raw, str):
|
|
829
|
+
parsed = _try_json_or_literal(raw)
|
|
830
|
+
if not isinstance(parsed, dict):
|
|
831
|
+
return {}
|
|
832
|
+
|
|
833
|
+
def _coerce_bound(x):
|
|
834
|
+
if x is None:
|
|
835
|
+
return None
|
|
836
|
+
if isinstance(x, (int, float)):
|
|
837
|
+
return float(x)
|
|
838
|
+
s = str(x).strip().lower()
|
|
839
|
+
if s in ("inf", "infty", "infinite", "np.inf"):
|
|
840
|
+
return _np.inf
|
|
841
|
+
if s in ("none", ""):
|
|
842
|
+
return None
|
|
843
|
+
try:
|
|
844
|
+
return float(x)
|
|
845
|
+
except Exception:
|
|
846
|
+
return None
|
|
847
|
+
|
|
848
|
+
def _coerce_feature_map(feats):
|
|
849
|
+
out = {}
|
|
850
|
+
if not isinstance(feats, dict):
|
|
851
|
+
return out
|
|
852
|
+
for fname, rng in feats.items():
|
|
853
|
+
if rng is None:
|
|
854
|
+
out[fname] = (0.0, _np.inf)
|
|
855
|
+
continue
|
|
856
|
+
if isinstance(rng, (list, tuple)) and len(rng) >= 2:
|
|
857
|
+
lo = _coerce_bound(rng[0]) or 0.0
|
|
858
|
+
hi = _coerce_bound(rng[1])
|
|
859
|
+
if hi is None:
|
|
860
|
+
hi = _np.inf
|
|
861
|
+
out[fname] = (float(lo), float(hi) if not _np.isinf(hi) else _np.inf)
|
|
862
|
+
else:
|
|
863
|
+
val = _coerce_bound(rng)
|
|
864
|
+
out[fname] = (0.0, float(val) if val is not None else _np.inf)
|
|
865
|
+
return out
|
|
866
|
+
|
|
867
|
+
canonical = {}
|
|
868
|
+
for grp, info in parsed.items():
|
|
869
|
+
if not isinstance(info, dict):
|
|
870
|
+
feats = _coerce_feature_map(info)
|
|
871
|
+
canonical[grp] = {"features": feats, "state": "Modified"}
|
|
872
|
+
continue
|
|
873
|
+
feats = _coerce_feature_map(info.get("features", info.get("ranges", {})))
|
|
874
|
+
state = info.get("state", info.get("label", "Modified"))
|
|
875
|
+
canonical[grp] = {"features": feats, "state": state}
|
|
876
|
+
return canonical
|
|
877
|
+
|
|
878
|
+
# ---------- resolve config dict ----------
|
|
879
|
+
merged_cfg = {}
|
|
880
|
+
if config is not None:
|
|
881
|
+
if hasattr(config, "to_dict") and callable(getattr(config, "to_dict")):
|
|
882
|
+
merged_cfg = dict(config.to_dict())
|
|
883
|
+
elif isinstance(config, dict):
|
|
884
|
+
merged_cfg = dict(config)
|
|
885
|
+
else:
|
|
886
|
+
try:
|
|
887
|
+
merged_cfg = {k: getattr(config, k) for k in dir(config) if k.startswith("hmm_")}
|
|
888
|
+
except Exception:
|
|
889
|
+
merged_cfg = {}
|
|
890
|
+
|
|
891
|
+
def _pick(key, local_val, fallback=None):
|
|
892
|
+
if local_val is not None:
|
|
893
|
+
return local_val
|
|
894
|
+
if key in merged_cfg and merged_cfg[key] is not None:
|
|
895
|
+
return merged_cfg[key]
|
|
896
|
+
alt = f"hmm_{key}"
|
|
897
|
+
if alt in merged_cfg and merged_cfg[alt] is not None:
|
|
898
|
+
return merged_cfg[alt]
|
|
899
|
+
return fallback
|
|
900
|
+
|
|
901
|
+
# coerce booleans robustly
|
|
902
|
+
footprints = _coerce_bool(_pick("footprints", footprints, merged_cfg.get("footprints", False)))
|
|
903
|
+
accessible_patches = _coerce_bool(_pick("accessible_patches", accessible_patches, merged_cfg.get("accessible_patches", False)))
|
|
904
|
+
cpg = _coerce_bool(_pick("cpg", cpg, merged_cfg.get("cpg", False)))
|
|
905
|
+
|
|
906
|
+
threshold = float(_pick("threshold", threshold, merged_cfg.get("hmm_annotation_threshold", 0.5)))
|
|
907
|
+
batch_size = int(_pick("batch_size", batch_size, merged_cfg.get("hmm_batch_size", 1024)))
|
|
908
|
+
use_viterbi = _coerce_bool(_pick("use_viterbi", use_viterbi, merged_cfg.get("hmm_use_viterbi", False)))
|
|
909
|
+
|
|
910
|
+
methbases = merged_cfg.get("hmm_methbases", None)
|
|
911
|
+
|
|
912
|
+
# normalize whitespace/case for human-friendly inputs (but keep original tokens as given)
|
|
913
|
+
methbases = [str(m).strip() for m in methbases if m is not None]
|
|
914
|
+
if verbose:
|
|
915
|
+
print("DEBUG: final methbases list =", methbases)
|
|
916
|
+
|
|
917
|
+
# resolve feature sets: prefer canonical if it yields non-empty mapping, otherwise fall back to boolean defaults
|
|
918
|
+
feature_sets = {}
|
|
919
|
+
if "hmm_feature_sets" in merged_cfg and merged_cfg.get("hmm_feature_sets") is not None:
|
|
920
|
+
cand = normalize_hmm_feature_sets(merged_cfg.get("hmm_feature_sets"))
|
|
921
|
+
if isinstance(cand, dict) and len(cand) > 0:
|
|
922
|
+
feature_sets = cand
|
|
923
|
+
|
|
924
|
+
if not feature_sets:
|
|
925
|
+
if verbose:
|
|
926
|
+
print("[HMM.annotate_adata] no feature sets configured; nothing to append.")
|
|
927
|
+
return None if in_place else adata
|
|
928
|
+
|
|
929
|
+
if verbose:
|
|
930
|
+
print("[HMM.annotate_adata] resolved feature sets:", list(feature_sets.keys()))
|
|
931
|
+
|
|
932
|
+
# copy vs in-place
|
|
933
|
+
if not in_place:
|
|
934
|
+
adata = adata.copy()
|
|
935
|
+
|
|
936
|
+
# prepare column names
|
|
937
|
+
all_features = []
|
|
938
|
+
combined_prefix = "Combined"
|
|
939
|
+
for key, fs in feature_sets.items():
|
|
940
|
+
feats = fs.get("features", {})
|
|
941
|
+
if key == "cpg":
|
|
942
|
+
all_features += [f"CpG_{f}" for f in feats]
|
|
943
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
944
|
+
else:
|
|
945
|
+
for methbase in methbases:
|
|
946
|
+
all_features += [f"{methbase}_{f}" for f in feats]
|
|
947
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
948
|
+
if len(methbases) > 1:
|
|
949
|
+
all_features += [f"{combined_prefix}_{f}" for f in feats]
|
|
950
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
951
|
+
|
|
952
|
+
# initialize obs columns (unique lists per row)
|
|
953
|
+
n_rows = adata.shape[0]
|
|
954
|
+
for feature in all_features:
|
|
955
|
+
if feature not in adata.obs.columns:
|
|
956
|
+
adata.obs[feature] = [[] for _ in range(n_rows)]
|
|
957
|
+
if f"{feature}_distances" not in adata.obs.columns:
|
|
958
|
+
adata.obs[f"{feature}_distances"] = [None] * n_rows
|
|
959
|
+
if f"n_{feature}" not in adata.obs.columns:
|
|
960
|
+
adata.obs[f"n_{feature}"] = -1
|
|
961
|
+
|
|
962
|
+
appended_layers: List[str] = []
|
|
963
|
+
|
|
964
|
+
# device management
|
|
965
|
+
if device is None:
|
|
966
|
+
device = next(self.parameters()).device
|
|
967
|
+
elif isinstance(device, str):
|
|
968
|
+
device = _torch.device(device)
|
|
969
|
+
self.to(device)
|
|
970
|
+
|
|
971
|
+
# helpers ---------------------------------------------------------------
|
|
972
|
+
def _ensure_2d_array_like(matrix):
|
|
973
|
+
arr = _np.asarray(matrix)
|
|
974
|
+
if arr.ndim == 1:
|
|
975
|
+
arr = arr[_np.newaxis, :]
|
|
976
|
+
elif arr.ndim > 2:
|
|
977
|
+
# squeeze trailing singletons
|
|
978
|
+
while arr.ndim > 2 and arr.shape[-1] == 1:
|
|
979
|
+
arr = _np.squeeze(arr, axis=-1)
|
|
980
|
+
if arr.ndim != 2:
|
|
981
|
+
raise ValueError(f"Expected 2D sequence matrix; got array with shape {arr.shape}")
|
|
982
|
+
return arr
|
|
983
|
+
|
|
984
|
+
def calculate_batch_distances(intervals_list, threshold_local=0.9):
|
|
985
|
+
results_local = []
|
|
986
|
+
for intervals in intervals_list:
|
|
987
|
+
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
988
|
+
results_local.append([])
|
|
989
|
+
continue
|
|
990
|
+
valid = [iv for iv in intervals if iv[2] > threshold_local]
|
|
991
|
+
if len(valid) <= 1:
|
|
992
|
+
results_local.append([])
|
|
993
|
+
continue
|
|
994
|
+
valid = sorted(valid, key=lambda x: x[0])
|
|
995
|
+
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
996
|
+
results_local.append(dists)
|
|
997
|
+
return results_local
|
|
998
|
+
|
|
999
|
+
def classify_batch_local(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Modified"):
|
|
1000
|
+
# Accept numpy arrays or torch tensors
|
|
1001
|
+
if isinstance(predicted_states_batch, _torch.Tensor):
|
|
1002
|
+
pred_np = predicted_states_batch.detach().cpu().numpy()
|
|
1003
|
+
else:
|
|
1004
|
+
pred_np = _np.asarray(predicted_states_batch)
|
|
1005
|
+
if isinstance(probabilities_batch, _torch.Tensor):
|
|
1006
|
+
probs_np = probabilities_batch.detach().cpu().numpy()
|
|
1007
|
+
else:
|
|
1008
|
+
probs_np = _np.asarray(probabilities_batch)
|
|
1009
|
+
|
|
1010
|
+
batch_size, L = pred_np.shape
|
|
1011
|
+
all_classifications_local = []
|
|
1012
|
+
# allow caller to pass arbitrary state labels mapping; default two-state mapping:
|
|
1013
|
+
state_labels = ["Non-Modified", "Modified"]
|
|
1014
|
+
try:
|
|
1015
|
+
target_idx = state_labels.index(target_state)
|
|
1016
|
+
except ValueError:
|
|
1017
|
+
target_idx = 1 # fallback
|
|
1018
|
+
|
|
1019
|
+
for b in range(batch_size):
|
|
1020
|
+
predicted_states = pred_np[b]
|
|
1021
|
+
probabilities = probs_np[b]
|
|
1022
|
+
regions = []
|
|
1023
|
+
current_start, current_length, current_probs = None, 0, []
|
|
1024
|
+
for i, state_index in enumerate(predicted_states):
|
|
1025
|
+
state_prob = float(probabilities[i][state_index])
|
|
1026
|
+
if state_index == target_idx:
|
|
1027
|
+
if current_start is None:
|
|
1028
|
+
current_start = i
|
|
1029
|
+
current_length += 1
|
|
1030
|
+
current_probs.append(state_prob)
|
|
1031
|
+
elif current_start is not None:
|
|
1032
|
+
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1033
|
+
current_start, current_length, current_probs = None, 0, []
|
|
1034
|
+
if current_start is not None:
|
|
1035
|
+
regions.append((current_start, current_length, float(_np.mean(current_probs))))
|
|
1036
|
+
|
|
1037
|
+
final = []
|
|
1038
|
+
for start, length, prob in regions:
|
|
1039
|
+
# compute genomic length try/catch
|
|
1040
|
+
try:
|
|
1041
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
1042
|
+
except Exception:
|
|
1043
|
+
feature_length = int(length)
|
|
1044
|
+
|
|
1045
|
+
# classification_mapping values are (lo, hi) tuples or lists
|
|
1046
|
+
label = None
|
|
1047
|
+
for ftype, rng in classification_mapping.items():
|
|
1048
|
+
lo, hi = rng[0], rng[1]
|
|
1049
|
+
try:
|
|
1050
|
+
if lo <= feature_length < hi:
|
|
1051
|
+
label = ftype
|
|
1052
|
+
break
|
|
1053
|
+
except Exception:
|
|
1054
|
+
continue
|
|
1055
|
+
if label is None:
|
|
1056
|
+
# fallback to first mapping key or 'unknown'
|
|
1057
|
+
label = next(iter(classification_mapping.keys()), "feature")
|
|
1058
|
+
|
|
1059
|
+
# Store reported start coordinate in same coordinate system as `coordinates`.
|
|
1060
|
+
try:
|
|
1061
|
+
genomic_start = int(coordinates[start])
|
|
1062
|
+
except Exception:
|
|
1063
|
+
genomic_start = int(start)
|
|
1064
|
+
final.append((genomic_start, feature_length, label, prob))
|
|
1065
|
+
all_classifications_local.append(final)
|
|
1066
|
+
return all_classifications_local
|
|
1067
|
+
|
|
1068
|
+
# -----------------------------------------------------------------------
|
|
1069
|
+
|
|
1070
|
+
# Ensure obs_column is categorical-like for iteration
|
|
1071
|
+
sseries = adata.obs[obs_column]
|
|
1072
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
1073
|
+
sseries = sseries.astype("category")
|
|
1074
|
+
references = list(sseries.cat.categories)
|
|
1075
|
+
|
|
1076
|
+
ref_iter = references if not verbose else _tqdm(references, desc="Processing References")
|
|
1077
|
+
for ref in ref_iter:
|
|
1078
|
+
# subset reads with this obs_column value
|
|
1079
|
+
ref_mask = adata.obs[obs_column] == ref
|
|
1080
|
+
ref_subset = adata[ref_mask].copy()
|
|
1081
|
+
combined_mask = None
|
|
1082
|
+
|
|
1083
|
+
# per-methbase processing
|
|
1084
|
+
for methbase in methbases:
|
|
1085
|
+
key_lower = methbase.strip().lower()
|
|
1086
|
+
|
|
1087
|
+
# map several common synonyms -> canonical lookup
|
|
1088
|
+
if key_lower in ("a",):
|
|
1089
|
+
pos_mask = ref_subset.var.get(f"{ref}_strand_FASTA_base") == "A"
|
|
1090
|
+
elif key_lower in ("c", "any_c", "anyc", "any-c"):
|
|
1091
|
+
# unify 'C' or 'any_C' names to the any_C var column
|
|
1092
|
+
pos_mask = ref_subset.var.get(f"{ref}_any_C_site") == True
|
|
1093
|
+
elif key_lower in ("gpc", "gpc_site", "gpc-site"):
|
|
1094
|
+
pos_mask = ref_subset.var.get(f"{ref}_GpC_site") == True
|
|
1095
|
+
elif key_lower in ("cpg", "cpg_site", "cpg-site"):
|
|
1096
|
+
pos_mask = ref_subset.var.get(f"{ref}_CpG_site") == True
|
|
1097
|
+
else:
|
|
1098
|
+
# try a best-effort: if a column named f"{ref}_{methbase}_site" exists, use it
|
|
1099
|
+
alt_col = f"{ref}_{methbase}_site"
|
|
1100
|
+
pos_mask = ref_subset.var.get(alt_col, None)
|
|
1101
|
+
|
|
1102
|
+
if pos_mask is None:
|
|
1103
|
+
continue
|
|
1104
|
+
combined_mask = pos_mask if combined_mask is None else (combined_mask | pos_mask)
|
|
1105
|
+
|
|
1106
|
+
if pos_mask.sum() == 0:
|
|
1107
|
+
continue
|
|
1108
|
+
|
|
1109
|
+
sub = ref_subset[:, pos_mask]
|
|
1110
|
+
# choose matrix
|
|
1111
|
+
matrix = sub.layers[layer] if (layer and layer in sub.layers) else sub.X
|
|
1112
|
+
matrix = _ensure_2d_array_like(matrix)
|
|
1113
|
+
n_reads = matrix.shape[0]
|
|
1114
|
+
|
|
1115
|
+
# coordinates for this sub (try to convert to ints, else fallback to indices)
|
|
1116
|
+
try:
|
|
1117
|
+
coords = _np.asarray(sub.var_names, dtype=int)
|
|
1118
|
+
except Exception:
|
|
1119
|
+
coords = _np.arange(sub.shape[1], dtype=int)
|
|
1120
|
+
|
|
1121
|
+
# chunked processing
|
|
1122
|
+
chunk_iter = range(0, n_reads, batch_size)
|
|
1123
|
+
if verbose:
|
|
1124
|
+
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:{methbase} chunks")
|
|
1125
|
+
for start_idx in chunk_iter:
|
|
1126
|
+
stop_idx = min(n_reads, start_idx + batch_size)
|
|
1127
|
+
chunk = matrix[start_idx:stop_idx]
|
|
1128
|
+
seqs = chunk.tolist()
|
|
1129
|
+
# posterior marginals
|
|
1130
|
+
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1131
|
+
if len(gammas) == 0:
|
|
1132
|
+
continue
|
|
1133
|
+
probs_batch = _np.stack(gammas, axis=0) # (B, L, K)
|
|
1134
|
+
if use_viterbi:
|
|
1135
|
+
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1136
|
+
pred_states = _np.asarray(paths)
|
|
1137
|
+
else:
|
|
1138
|
+
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1139
|
+
|
|
1140
|
+
# For each feature group, classify separately and write back
|
|
1141
|
+
for key, fs in feature_sets.items():
|
|
1142
|
+
if key == "cpg":
|
|
1143
|
+
continue
|
|
1144
|
+
state_target = fs.get("state", "Modified")
|
|
1145
|
+
feature_map = fs.get("features", {})
|
|
1146
|
+
classifications = classify_batch_local(pred_states, probs_batch, coords, feature_map, target_state=state_target)
|
|
1147
|
+
|
|
1148
|
+
# write results to adata.obs rows (use original index names)
|
|
1149
|
+
row_indices = list(sub.obs.index[start_idx:stop_idx])
|
|
1150
|
+
for i_local, idx in enumerate(row_indices):
|
|
1151
|
+
for start, length, label, prob in classifications[i_local]:
|
|
1152
|
+
col_name = f"{methbase}_{label}"
|
|
1153
|
+
all_col = f"{methbase}_all_{key}_features"
|
|
1154
|
+
adata.obs.at[idx, col_name].append([start, length, prob])
|
|
1155
|
+
adata.obs.at[idx, all_col].append([start, length, prob])
|
|
1156
|
+
|
|
1157
|
+
# Combined subset (if multiple methbases)
|
|
1158
|
+
if len(methbases) > 1 and (combined_mask is not None) and (combined_mask.sum() > 0):
|
|
1159
|
+
comb = ref_subset[:, combined_mask]
|
|
1160
|
+
if comb.shape[1] > 0:
|
|
1161
|
+
matrix = comb.layers[layer] if (layer and layer in comb.layers) else comb.X
|
|
1162
|
+
matrix = _ensure_2d_array_like(matrix)
|
|
1163
|
+
n_reads_comb = matrix.shape[0]
|
|
1164
|
+
try:
|
|
1165
|
+
coords_comb = _np.asarray(comb.var_names, dtype=int)
|
|
1166
|
+
except Exception:
|
|
1167
|
+
coords_comb = _np.arange(comb.shape[1], dtype=int)
|
|
1168
|
+
|
|
1169
|
+
chunk_iter = range(0, n_reads_comb, batch_size)
|
|
1170
|
+
if verbose:
|
|
1171
|
+
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:Combined chunks")
|
|
1172
|
+
for start_idx in chunk_iter:
|
|
1173
|
+
stop_idx = min(n_reads_comb, start_idx + batch_size)
|
|
1174
|
+
chunk = matrix[start_idx:stop_idx]
|
|
1175
|
+
seqs = chunk.tolist()
|
|
1176
|
+
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1177
|
+
if len(gammas) == 0:
|
|
1178
|
+
continue
|
|
1179
|
+
probs_batch = _np.stack(gammas, axis=0)
|
|
1180
|
+
if use_viterbi:
|
|
1181
|
+
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1182
|
+
pred_states = _np.asarray(paths)
|
|
1183
|
+
else:
|
|
1184
|
+
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1185
|
+
|
|
1186
|
+
for key, fs in feature_sets.items():
|
|
1187
|
+
if key == "cpg":
|
|
1188
|
+
continue
|
|
1189
|
+
state_target = fs.get("state", "Modified")
|
|
1190
|
+
feature_map = fs.get("features", {})
|
|
1191
|
+
classifications = classify_batch_local(pred_states, probs_batch, coords_comb, feature_map, target_state=state_target)
|
|
1192
|
+
row_indices = list(comb.obs.index[start_idx:stop_idx])
|
|
1193
|
+
for i_local, idx in enumerate(row_indices):
|
|
1194
|
+
for start, length, label, prob in classifications[i_local]:
|
|
1195
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
1196
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
1197
|
+
|
|
1198
|
+
# CpG special handling
|
|
1199
|
+
if "cpg" in feature_sets and feature_sets.get("cpg") is not None:
|
|
1200
|
+
cpg_iter = references if not verbose else _tqdm(references, desc="Processing CpG")
|
|
1201
|
+
for ref in cpg_iter:
|
|
1202
|
+
ref_mask = adata.obs[obs_column] == ref
|
|
1203
|
+
ref_subset = adata[ref_mask].copy()
|
|
1204
|
+
pos_mask = ref_subset.var[f"{ref}_CpG_site"] == True
|
|
1205
|
+
if pos_mask.sum() == 0:
|
|
1206
|
+
continue
|
|
1207
|
+
cpg_sub = ref_subset[:, pos_mask]
|
|
1208
|
+
matrix = cpg_sub.layers[layer] if (layer and layer in cpg_sub.layers) else cpg_sub.X
|
|
1209
|
+
matrix = _ensure_2d_array_like(matrix)
|
|
1210
|
+
n_reads = matrix.shape[0]
|
|
1211
|
+
try:
|
|
1212
|
+
coords_cpg = _np.asarray(cpg_sub.var_names, dtype=int)
|
|
1213
|
+
except Exception:
|
|
1214
|
+
coords_cpg = _np.arange(cpg_sub.shape[1], dtype=int)
|
|
1215
|
+
|
|
1216
|
+
chunk_iter = range(0, n_reads, batch_size)
|
|
1217
|
+
if verbose:
|
|
1218
|
+
chunk_iter = _tqdm(list(chunk_iter), desc=f"{ref}:CpG chunks")
|
|
1219
|
+
for start_idx in chunk_iter:
|
|
1220
|
+
stop_idx = min(n_reads, start_idx + batch_size)
|
|
1221
|
+
chunk = matrix[start_idx:stop_idx]
|
|
1222
|
+
seqs = chunk.tolist()
|
|
1223
|
+
gammas = self.predict(seqs, impute_strategy="ignore", device=device)
|
|
1224
|
+
if len(gammas) == 0:
|
|
1225
|
+
continue
|
|
1226
|
+
probs_batch = _np.stack(gammas, axis=0)
|
|
1227
|
+
if use_viterbi:
|
|
1228
|
+
paths, _scores = self.batch_viterbi(seqs, impute_strategy="ignore", device=device)
|
|
1229
|
+
pred_states = _np.asarray(paths)
|
|
1230
|
+
else:
|
|
1231
|
+
pred_states = _np.argmax(probs_batch, axis=2)
|
|
1232
|
+
|
|
1233
|
+
fs = feature_sets["cpg"]
|
|
1234
|
+
state_target = fs.get("state", "Modified")
|
|
1235
|
+
feature_map = fs.get("features", {})
|
|
1236
|
+
classifications = classify_batch_local(pred_states, probs_batch, coords_cpg, feature_map, target_state=state_target)
|
|
1237
|
+
row_indices = list(cpg_sub.obs.index[start_idx:stop_idx])
|
|
1238
|
+
for i_local, idx in enumerate(row_indices):
|
|
1239
|
+
for start, length, label, prob in classifications[i_local]:
|
|
1240
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
1241
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
1242
|
+
|
|
1243
|
+
# finalize: convert intervals into binary layers and distances
|
|
1244
|
+
try:
|
|
1245
|
+
coordinates = _np.asarray(adata.var_names, dtype=int)
|
|
1246
|
+
coords_are_ints = True
|
|
1247
|
+
except Exception:
|
|
1248
|
+
coordinates = _np.arange(adata.shape[1], dtype=int)
|
|
1249
|
+
coords_are_ints = False
|
|
1250
|
+
|
|
1251
|
+
features_iter = all_features if not verbose else _tqdm(all_features, desc="Finalizing Layers")
|
|
1252
|
+
for feature in features_iter:
|
|
1253
|
+
bin_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1254
|
+
counts = _np.zeros(adata.shape[0], dtype=int)
|
|
1255
|
+
|
|
1256
|
+
# new: integer-length layer (0 where not inside a feature)
|
|
1257
|
+
len_matrix = _np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
1258
|
+
|
|
1259
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
1260
|
+
if not isinstance(intervals, list):
|
|
1261
|
+
intervals = []
|
|
1262
|
+
for start, length, prob in intervals:
|
|
1263
|
+
if prob > threshold:
|
|
1264
|
+
if coords_are_ints:
|
|
1265
|
+
# map genomic start/length into index interval [start_idx, end_idx)
|
|
1266
|
+
start_idx = _np.searchsorted(coordinates, int(start), side="left")
|
|
1267
|
+
end_idx = _np.searchsorted(coordinates, int(start) + int(length) - 1, side="right")
|
|
1268
|
+
else:
|
|
1269
|
+
start_idx = int(start)
|
|
1270
|
+
end_idx = start_idx + int(length)
|
|
1271
|
+
|
|
1272
|
+
start_idx = max(0, min(start_idx, adata.shape[1]))
|
|
1273
|
+
end_idx = max(0, min(end_idx, adata.shape[1]))
|
|
1274
|
+
|
|
1275
|
+
if start_idx < end_idx:
|
|
1276
|
+
span = end_idx - start_idx # number of positions covered
|
|
1277
|
+
# set binary mask
|
|
1278
|
+
bin_matrix[row_idx, start_idx:end_idx] = 1
|
|
1279
|
+
# set length mask: use maximum in case of overlaps
|
|
1280
|
+
existing = len_matrix[row_idx, start_idx:end_idx]
|
|
1281
|
+
len_matrix[row_idx, start_idx:end_idx] = _np.maximum(existing, span)
|
|
1282
|
+
counts[row_idx] += 1
|
|
1283
|
+
|
|
1284
|
+
# write binary layer and length layer, track appended names
|
|
1285
|
+
adata.layers[feature] = bin_matrix
|
|
1286
|
+
appended_layers.append(feature)
|
|
1287
|
+
|
|
1288
|
+
# name the integer-length layer (choose suffix you like)
|
|
1289
|
+
length_layer_name = f"{feature}_lengths"
|
|
1290
|
+
adata.layers[length_layer_name] = len_matrix
|
|
1291
|
+
appended_layers.append(length_layer_name)
|
|
1292
|
+
|
|
1293
|
+
adata.obs[f"n_{feature}"] = counts
|
|
1294
|
+
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
1295
|
+
|
|
1296
|
+
# Merge appended_layers into adata.uns[uns_key] (preserve pre-existing and avoid duplicates)
|
|
1297
|
+
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key) is not None else []
|
|
1298
|
+
new_list = existing + [l for l in appended_layers if l not in existing]
|
|
1299
|
+
adata.uns[uns_key] = new_list
|
|
1300
|
+
|
|
1301
|
+
return None if in_place else adata
|
|
1302
|
+
|
|
1303
|
+
def merge_intervals_in_layer(
|
|
1304
|
+
self,
|
|
1305
|
+
adata,
|
|
1306
|
+
layer: str,
|
|
1307
|
+
distance_threshold: int = 0,
|
|
1308
|
+
merged_suffix: str = "_merged",
|
|
1309
|
+
length_layer_suffix: str = "_lengths",
|
|
1310
|
+
update_obs: bool = True,
|
|
1311
|
+
prob_strategy: str = "mean", # 'mean'|'max'|'orig_first'
|
|
1312
|
+
inplace: bool = True,
|
|
1313
|
+
overwrite: bool = False,
|
|
1314
|
+
verbose: bool = False,
|
|
1315
|
+
):
|
|
1316
|
+
"""
|
|
1317
|
+
Merge intervals in `adata.layers[layer]` that are within `distance_threshold`.
|
|
1318
|
+
Writes new merged binary layer named f"{layer}{merged_suffix}" and length layer
|
|
1319
|
+
f"{layer}{merged_suffix}{length_layer_suffix}". Optionally updates adata.obs for merged intervals.
|
|
1320
|
+
|
|
1321
|
+
Parameters
|
|
1322
|
+
----------
|
|
1323
|
+
layer : str
|
|
1324
|
+
Name of original binary layer (0/1 mask).
|
|
1325
|
+
distance_threshold : int
|
|
1326
|
+
Merge intervals whose gap <= this threshold (genomic coords if adata.var_names are ints).
|
|
1327
|
+
merged_suffix : str
|
|
1328
|
+
Suffix appended to original layer for the merged binary layer (default "_merged").
|
|
1329
|
+
length_layer_suffix : str
|
|
1330
|
+
Suffix appended after merged suffix for the lengths layer (default "_lengths").
|
|
1331
|
+
update_obs : bool
|
|
1332
|
+
If True, create/update adata.obs[f"{layer}{merged_suffix}"] with merged intervals.
|
|
1333
|
+
prob_strategy : str
|
|
1334
|
+
How to combine probs when merging ('mean', 'max', 'orig_first').
|
|
1335
|
+
inplace : bool
|
|
1336
|
+
If False, returns a new AnnData with changes (original untouched).
|
|
1337
|
+
overwrite : bool
|
|
1338
|
+
If True, will overwrite existing merged layers / obs entries; otherwise will error if they exist.
|
|
1339
|
+
"""
|
|
1340
|
+
import numpy as _np
|
|
1341
|
+
from scipy.sparse import issparse
|
|
1342
|
+
|
|
1343
|
+
if not inplace:
|
|
1344
|
+
adata = adata.copy()
|
|
1345
|
+
|
|
1346
|
+
merged_bin_name = f"{layer}{merged_suffix}"
|
|
1347
|
+
merged_len_name = f"{layer}{merged_suffix}{length_layer_suffix}"
|
|
1348
|
+
|
|
1349
|
+
if (merged_bin_name in adata.layers or merged_len_name in adata.layers or
|
|
1350
|
+
(update_obs and merged_bin_name in adata.obs.columns)) and not overwrite:
|
|
1351
|
+
raise KeyError(f"Merged outputs exist (use overwrite=True to replace): {merged_bin_name} / {merged_len_name}")
|
|
1352
|
+
|
|
1353
|
+
if layer not in adata.layers:
|
|
1354
|
+
raise KeyError(f"Layer '{layer}' not found in adata.layers")
|
|
1355
|
+
|
|
1356
|
+
bin_layer = adata.layers[layer]
|
|
1357
|
+
if issparse(bin_layer):
|
|
1358
|
+
bin_arr = bin_layer.toarray().astype(int)
|
|
1359
|
+
else:
|
|
1360
|
+
bin_arr = _np.asarray(bin_layer, dtype=int)
|
|
1361
|
+
|
|
1362
|
+
n_rows, n_cols = bin_arr.shape
|
|
1363
|
+
|
|
1364
|
+
# coordinates in genomic units if possible
|
|
1365
|
+
try:
|
|
1366
|
+
coords = _np.asarray(adata.var_names, dtype=int)
|
|
1367
|
+
coords_are_ints = True
|
|
1368
|
+
except Exception:
|
|
1369
|
+
coords = _np.arange(n_cols, dtype=int)
|
|
1370
|
+
coords_are_ints = False
|
|
1371
|
+
|
|
1372
|
+
# helper: contiguous runs of 1s -> list of (start_idx, end_idx) (end exclusive)
|
|
1373
|
+
def _runs_from_mask(mask_1d):
|
|
1374
|
+
idx = _np.nonzero(mask_1d)[0]
|
|
1375
|
+
if idx.size == 0:
|
|
1376
|
+
return []
|
|
1377
|
+
runs = []
|
|
1378
|
+
start = idx[0]
|
|
1379
|
+
prev = idx[0]
|
|
1380
|
+
for i in idx[1:]:
|
|
1381
|
+
if i == prev + 1:
|
|
1382
|
+
prev = i
|
|
1383
|
+
continue
|
|
1384
|
+
runs.append((start, prev + 1))
|
|
1385
|
+
start = i
|
|
1386
|
+
prev = i
|
|
1387
|
+
runs.append((start, prev + 1))
|
|
1388
|
+
return runs
|
|
1389
|
+
|
|
1390
|
+
# read original obs intervals/probs if available (for combining probs)
|
|
1391
|
+
orig_obs = None
|
|
1392
|
+
if update_obs and (layer in adata.obs.columns):
|
|
1393
|
+
orig_obs = list(adata.obs[layer]) # might be non-list entries
|
|
1394
|
+
|
|
1395
|
+
# prepare outputs
|
|
1396
|
+
merged_bin = _np.zeros_like(bin_arr, dtype=int)
|
|
1397
|
+
merged_len = _np.zeros_like(bin_arr, dtype=int)
|
|
1398
|
+
merged_obs_col = [[] for _ in range(n_rows)]
|
|
1399
|
+
merged_counts = _np.zeros(n_rows, dtype=int)
|
|
1400
|
+
|
|
1401
|
+
for r in range(n_rows):
|
|
1402
|
+
mask = bin_arr[r, :] != 0
|
|
1403
|
+
runs = _runs_from_mask(mask)
|
|
1404
|
+
if not runs:
|
|
1405
|
+
merged_obs_col[r] = []
|
|
1406
|
+
continue
|
|
1407
|
+
|
|
1408
|
+
# merge runs where gap <= distance_threshold (gap in genomic coords when possible)
|
|
1409
|
+
merged_runs = []
|
|
1410
|
+
cur_s, cur_e = runs[0]
|
|
1411
|
+
for (s, e) in runs[1:]:
|
|
1412
|
+
if coords_are_ints:
|
|
1413
|
+
end_coord = int(coords[cur_e - 1])
|
|
1414
|
+
next_start_coord = int(coords[s])
|
|
1415
|
+
gap = next_start_coord - end_coord - 1
|
|
1416
|
+
else:
|
|
1417
|
+
gap = s - cur_e
|
|
1418
|
+
if gap <= distance_threshold:
|
|
1419
|
+
# extend
|
|
1420
|
+
cur_e = e
|
|
1421
|
+
else:
|
|
1422
|
+
merged_runs.append((cur_s, cur_e))
|
|
1423
|
+
cur_s, cur_e = s, e
|
|
1424
|
+
merged_runs.append((cur_s, cur_e))
|
|
1425
|
+
|
|
1426
|
+
# assemble merged mask/lengths and obs entries
|
|
1427
|
+
row_entries = []
|
|
1428
|
+
for (s_idx, e_idx) in merged_runs:
|
|
1429
|
+
if e_idx <= s_idx:
|
|
1430
|
+
continue
|
|
1431
|
+
span_positions = e_idx - s_idx
|
|
1432
|
+
if coords_are_ints:
|
|
1433
|
+
try:
|
|
1434
|
+
length_val = int(coords[e_idx - 1]) - int(coords[s_idx]) + 1
|
|
1435
|
+
except Exception:
|
|
1436
|
+
length_val = span_positions
|
|
1437
|
+
else:
|
|
1438
|
+
length_val = span_positions
|
|
1439
|
+
|
|
1440
|
+
# set binary and length masks
|
|
1441
|
+
merged_bin[r, s_idx:e_idx] = 1
|
|
1442
|
+
existing_segment = merged_len[r, s_idx:e_idx]
|
|
1443
|
+
# set to max(existing, length_val)
|
|
1444
|
+
if existing_segment.size > 0:
|
|
1445
|
+
merged_len[r, s_idx:e_idx] = _np.maximum(existing_segment, length_val)
|
|
1446
|
+
else:
|
|
1447
|
+
merged_len[r, s_idx:e_idx] = length_val
|
|
1448
|
+
|
|
1449
|
+
# determine prob from overlapping original obs (if present)
|
|
1450
|
+
prob_val = 1.0
|
|
1451
|
+
if update_obs and orig_obs is not None:
|
|
1452
|
+
overlaps = []
|
|
1453
|
+
for orig in (orig_obs[r] or []):
|
|
1454
|
+
try:
|
|
1455
|
+
ostart, olen, opro = orig[0], int(orig[1]), float(orig[2])
|
|
1456
|
+
except Exception:
|
|
1457
|
+
continue
|
|
1458
|
+
if coords_are_ints:
|
|
1459
|
+
ostart_idx = _np.searchsorted(coords, int(ostart), side="left")
|
|
1460
|
+
oend_idx = ostart_idx + olen
|
|
1461
|
+
else:
|
|
1462
|
+
ostart_idx = int(ostart)
|
|
1463
|
+
oend_idx = ostart_idx + olen
|
|
1464
|
+
# overlap test in index space
|
|
1465
|
+
if not (oend_idx <= s_idx or ostart_idx >= e_idx):
|
|
1466
|
+
overlaps.append(opro)
|
|
1467
|
+
if overlaps:
|
|
1468
|
+
if prob_strategy == "mean":
|
|
1469
|
+
prob_val = float(_np.mean(overlaps))
|
|
1470
|
+
elif prob_strategy == "max":
|
|
1471
|
+
prob_val = float(_np.max(overlaps))
|
|
1472
|
+
else:
|
|
1473
|
+
prob_val = float(overlaps[0])
|
|
1474
|
+
|
|
1475
|
+
start_coord = int(coords[s_idx]) if coords_are_ints else int(s_idx)
|
|
1476
|
+
row_entries.append((start_coord, int(length_val), float(prob_val)))
|
|
1477
|
+
|
|
1478
|
+
merged_obs_col[r] = row_entries
|
|
1479
|
+
merged_counts[r] = len(row_entries)
|
|
1480
|
+
|
|
1481
|
+
# write merged layers (do not overwrite originals unless overwrite=True was set above)
|
|
1482
|
+
adata.layers[merged_bin_name] = merged_bin
|
|
1483
|
+
adata.layers[merged_len_name] = merged_len
|
|
1484
|
+
|
|
1485
|
+
if update_obs:
|
|
1486
|
+
adata.obs[merged_bin_name] = merged_obs_col
|
|
1487
|
+
adata.obs[f"n_{merged_bin_name}"] = merged_counts
|
|
1488
|
+
|
|
1489
|
+
# recompute distances list per-row (gaps between adjacent merged intervals)
|
|
1490
|
+
def _calc_distances(obs_list):
|
|
1491
|
+
out = []
|
|
1492
|
+
for intervals in obs_list:
|
|
1493
|
+
if not intervals:
|
|
1494
|
+
out.append([])
|
|
1495
|
+
continue
|
|
1496
|
+
iv = sorted(intervals, key=lambda x: int(x[0]))
|
|
1497
|
+
if len(iv) <= 1:
|
|
1498
|
+
out.append([])
|
|
1499
|
+
continue
|
|
1500
|
+
dlist = []
|
|
1501
|
+
for i in range(len(iv) - 1):
|
|
1502
|
+
endi = int(iv[i][0]) + int(iv[i][1]) - 1
|
|
1503
|
+
startn = int(iv[i + 1][0])
|
|
1504
|
+
dlist.append(startn - endi - 1)
|
|
1505
|
+
out.append(dlist)
|
|
1506
|
+
return out
|
|
1507
|
+
|
|
1508
|
+
adata.obs[f"{merged_bin_name}_distances"] = _calc_distances(merged_obs_col)
|
|
1509
|
+
|
|
1510
|
+
# update uns appended list
|
|
1511
|
+
uns_key = "hmm_appended_layers"
|
|
1512
|
+
existing = list(adata.uns.get(uns_key, [])) if adata.uns.get(uns_key, None) is not None else []
|
|
1513
|
+
for nm in (merged_bin_name, merged_len_name):
|
|
1514
|
+
if nm not in existing:
|
|
1515
|
+
existing.append(nm)
|
|
1516
|
+
adata.uns[uns_key] = existing
|
|
1517
|
+
|
|
1518
|
+
if verbose:
|
|
1519
|
+
print(f"Created merged binary layer: {merged_bin_name}")
|
|
1520
|
+
print(f"Created merged length layer: {merged_len_name}")
|
|
1521
|
+
if update_obs:
|
|
1522
|
+
print(f"Updated adata.obs columns: {merged_bin_name}, n_{merged_bin_name}, {merged_bin_name}_distances")
|
|
1523
|
+
|
|
1524
|
+
return None if inplace else adata
|
|
1525
|
+
|
|
1526
|
+
def _ensure_final_layer_and_assign(self, final_adata, layer_name: str, subset_idx_mask: np.ndarray, sub_data):
|
|
1527
|
+
"""
|
|
1528
|
+
Ensure final_adata.layers[layer_name] exists and assign rows corresponding to subset_idx_mask
|
|
1529
|
+
sub_data has shape (n_subset_rows, n_vars).
|
|
1530
|
+
subset_idx_mask: boolean array of length final_adata.n_obs with True where rows belong to subset.
|
|
1531
|
+
"""
|
|
1532
|
+
from scipy.sparse import issparse, csr_matrix
|
|
1533
|
+
import warnings
|
|
1534
|
+
|
|
1535
|
+
n_final_obs, n_vars = final_adata.shape
|
|
1536
|
+
n_sub_rows = int(subset_idx_mask.sum())
|
|
1537
|
+
|
|
1538
|
+
# prepare row indices in final_adata
|
|
1539
|
+
final_row_indices = np.nonzero(subset_idx_mask)[0]
|
|
1540
|
+
|
|
1541
|
+
# if sub_data is sparse, work with sparse
|
|
1542
|
+
if issparse(sub_data):
|
|
1543
|
+
sub_csr = sub_data.tocsr()
|
|
1544
|
+
# if final layer not present, create sparse CSR with zero rows and same n_vars
|
|
1545
|
+
if layer_name not in final_adata.layers:
|
|
1546
|
+
# create an empty CSR of shape (n_final_obs, n_vars)
|
|
1547
|
+
final_adata.layers[layer_name] = csr_matrix((n_final_obs, n_vars), dtype=sub_csr.dtype)
|
|
1548
|
+
final_csr = final_adata.layers[layer_name]
|
|
1549
|
+
if not issparse(final_csr):
|
|
1550
|
+
# convert dense final to sparse first
|
|
1551
|
+
final_csr = csr_matrix(final_csr)
|
|
1552
|
+
# replace the block of rows: easiest is to build a new csr by stacking pieces
|
|
1553
|
+
# (efficient for moderate sizes; for huge data you might want an in-place approach)
|
|
1554
|
+
# Build list of blocks: rows before, the subset rows (from final where mask False -> zeros), rows after
|
|
1555
|
+
# We'll convert final to LIL for row assignment (mutable), then back to CSR.
|
|
1556
|
+
final_lil = final_csr.tolil()
|
|
1557
|
+
for i_local, r in enumerate(final_row_indices):
|
|
1558
|
+
final_lil.rows[r] = sub_csr.getrow(i_local).indices.tolist()
|
|
1559
|
+
final_lil.data[r] = sub_csr.getrow(i_local).data.tolist()
|
|
1560
|
+
final_csr = final_lil.tocsr()
|
|
1561
|
+
final_adata.layers[layer_name] = final_csr
|
|
1562
|
+
else:
|
|
1563
|
+
# dense numpy array
|
|
1564
|
+
sub_arr = np.asarray(sub_data)
|
|
1565
|
+
if sub_arr.shape[0] != n_sub_rows:
|
|
1566
|
+
raise ValueError(f"Sub data rows ({sub_arr.shape[0]}) != mask selected rows ({n_sub_rows})")
|
|
1567
|
+
if layer_name not in final_adata.layers:
|
|
1568
|
+
# create zero array with small dtype
|
|
1569
|
+
final_adata.layers[layer_name] = np.zeros((n_final_obs, n_vars), dtype=sub_arr.dtype)
|
|
1570
|
+
final_arr = final_adata.layers[layer_name]
|
|
1571
|
+
if issparse(final_arr):
|
|
1572
|
+
# convert sparse final to dense (or convert sub to sparse); we'll convert final to dense here
|
|
1573
|
+
final_arr = final_arr.toarray()
|
|
1574
|
+
# assign
|
|
1575
|
+
final_arr[final_row_indices, :] = sub_arr
|
|
1576
|
+
final_adata.layers[layer_name] = final_arr
|