yamcot 1.0.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yamcot/__init__.py +46 -0
- yamcot/_core/__init__.py +17 -0
- yamcot/_core/_core.cpython-310-darwin.so +0 -0
- yamcot/_core/bindings.cpp +28 -0
- yamcot/_core/core_functions.h +29 -0
- yamcot/_core/fasta_to_plain.h +182 -0
- yamcot/_core/mco_prc.cpp +1476 -0
- yamcot/_core/pfm_to_pwm.h +130 -0
- yamcot/cli.py +621 -0
- yamcot/comparison.py +1066 -0
- yamcot/execute.py +97 -0
- yamcot/functions.py +787 -0
- yamcot/io.py +522 -0
- yamcot/models.py +1161 -0
- yamcot/pipeline.py +402 -0
- yamcot/ragged.py +126 -0
- yamcot-1.0.0.dist-info/METADATA +433 -0
- yamcot-1.0.0.dist-info/RECORD +21 -0
- yamcot-1.0.0.dist-info/WHEEL +6 -0
- yamcot-1.0.0.dist-info/entry_points.txt +3 -0
- yamcot-1.0.0.dist-info/licenses/LICENSE +21 -0
yamcot/models.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
models
|
|
3
|
+
======
|
|
4
|
+
|
|
5
|
+
This module defines abstractions for motif models. A motif is
|
|
6
|
+
represented internally as a numeric array and exposes methods for
|
|
7
|
+
scoring sequences and computing the best match score across all
|
|
8
|
+
positions and strand orientations. Concrete subclasses implement
|
|
9
|
+
different motif scoring models such as PWMs or Bayesian Markov
|
|
10
|
+
models (BaMMs).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import ClassVar, Dict, Literal, Optional
|
|
21
|
+
|
|
22
|
+
import joblib
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
|
|
26
|
+
from yamcot.functions import batch_all_scores, pfm_to_pwm, scores_to_frequencies
|
|
27
|
+
from yamcot.io import parse_file_content, read_bamm, read_meme, read_pfm, read_sitega, write_sitega
|
|
28
|
+
from yamcot.ragged import RaggedData
|
|
29
|
+
|
|
30
|
+
StrandMode = Literal["best", "+", "-", "both"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class RaggedScores:
|
|
35
|
+
"""Container for ragged arrays of scores with variable-length sequences.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
values : np.ndarray
|
|
40
|
+
Float32 array of shape (n_seq, Lmax) containing the score values.
|
|
41
|
+
lengths : np.ndarray
|
|
42
|
+
Int32 array of shape (n_seq,) containing the length of each sequence.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
values: np.ndarray # float32, shape (n_seq, Lmax)
|
|
46
|
+
lengths: np.ndarray # int32, shape (n_seq,)
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_numba(cls, rs_numba: RaggedData) -> RaggedScores:
|
|
50
|
+
"""Convert a RaggedData object from Numba to RaggedScores.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
rs_numba : RaggedData
|
|
55
|
+
Input RaggedData object from Numba computations.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
RaggedScores
|
|
60
|
+
Converted RaggedScores object.
|
|
61
|
+
"""
|
|
62
|
+
n = rs_numba.num_sequences
|
|
63
|
+
lengths = np.zeros(n, dtype=np.int32)
|
|
64
|
+
for i in range(n):
|
|
65
|
+
lengths[i] = rs_numba.get_length(i)
|
|
66
|
+
|
|
67
|
+
l_max = int(lengths.max()) if n > 0 else 0
|
|
68
|
+
values = np.zeros((n, l_max), dtype=np.float32)
|
|
69
|
+
for i in range(n):
|
|
70
|
+
row = rs_numba.get_slice(i)
|
|
71
|
+
values[i, : len(row)] = row
|
|
72
|
+
return cls(values, lengths)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class MotifModel:
|
|
77
|
+
"""
|
|
78
|
+
Abstract base class for motif models.
|
|
79
|
+
|
|
80
|
+
Subclasses must implement the :meth:`scan` method which returns
|
|
81
|
+
per-position scores for the forward and reverse complement of a
|
|
82
|
+
sequence. The :meth:`best_score` helper uses :meth:`scan` to
|
|
83
|
+
compute the maximal score across all positions and strands.
|
|
84
|
+
|
|
85
|
+
Attributes
|
|
86
|
+
----------
|
|
87
|
+
matrix : np.ndarray
|
|
88
|
+
Numeric representation of the motif.
|
|
89
|
+
name : str
|
|
90
|
+
Human readable identifier of the motif.
|
|
91
|
+
length : int
|
|
92
|
+
Length of the motif (number of positions).
|
|
93
|
+
strand_mode : {"best", "+", "-"}
|
|
94
|
+
Default strand mode for scores/frequencies properties.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
_registry: ClassVar[Dict[str, type[MotifModel]]] = {}
|
|
98
|
+
|
|
99
|
+
matrix: np.ndarray
|
|
100
|
+
name: str
|
|
101
|
+
length: int
|
|
102
|
+
|
|
103
|
+
strand_mode: StrandMode = field(default="best", init=True)
|
|
104
|
+
|
|
105
|
+
statistics: Optional[Dict[str, float]] = field(default=None, init=False)
|
|
106
|
+
_threshold_table: Optional[np.ndarray] = field(default=None, init=False)
|
|
107
|
+
_pfm: Optional[np.ndarray] = field(default=None, init=False)
|
|
108
|
+
|
|
109
|
+
_freq_cache: Dict[StrandMode, RaggedData] = field(
|
|
110
|
+
default_factory=dict,
|
|
111
|
+
init=False,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def register_subclass(cls, model_type: str, subclass: type[MotifModel]):
|
|
116
|
+
"""Register a subclass for a specific model type."""
|
|
117
|
+
cls._registry[model_type.lower()] = subclass
|
|
118
|
+
|
|
119
|
+
# ------------------------------------------------------------------
|
|
120
|
+
# Thresholds
|
|
121
|
+
# ------------------------------------------------------------------
|
|
122
|
+
@property
|
|
123
|
+
def threshold_table(self) -> Optional[np.ndarray]:
|
|
124
|
+
"""
|
|
125
|
+
Lazy property for threshold table.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
Optional[np.ndarray]
|
|
130
|
+
Precomputed threshold values for scoring.
|
|
131
|
+
"""
|
|
132
|
+
if self._threshold_table is None:
|
|
133
|
+
logger = logging.getLogger(__name__)
|
|
134
|
+
logger.warning(f"There is no threshold table for model: {self.name}")
|
|
135
|
+
return self._threshold_table
|
|
136
|
+
|
|
137
|
+
def get_threshold_table(self, promoters) -> np.ndarray:
|
|
138
|
+
"""Compute the threshold table based on the matrix.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
np.ndarray
|
|
143
|
+
The computed threshold table.
|
|
144
|
+
"""
|
|
145
|
+
logger = logging.getLogger(__name__)
|
|
146
|
+
logger.info(f"Computing threshold table for model: {self.name}")
|
|
147
|
+
|
|
148
|
+
if self._threshold_table is None:
|
|
149
|
+
self._threshold_table = self._calculate_threshold_table(promoters)
|
|
150
|
+
|
|
151
|
+
return self._threshold_table
|
|
152
|
+
|
|
153
|
+
def _calculate_threshold_table(self, promoters: RaggedData) -> np.ndarray:
|
|
154
|
+
"""Calculate the threshold table using all sites in all sequences.
|
|
155
|
+
|
|
156
|
+
Computes a lookup table mapping motif scores to false positive rates (FPR)
|
|
157
|
+
by evaluating the motif against all possible positions in the promoter sequences.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
promoters : RaggedData
|
|
162
|
+
Collection of promoter sequences to calculate thresholds against.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
np.ndarray, shape (N, 2)
|
|
167
|
+
Array where each row contains [score, -log10(fpr)], sorted by descending score.
|
|
168
|
+
"""
|
|
169
|
+
# Get scores for all positions
|
|
170
|
+
ragged_scores = self.get_scores(promoters, strand=self.strand_mode)
|
|
171
|
+
|
|
172
|
+
# Flatten - collect all scores in one array directly from data (since there's no padding)
|
|
173
|
+
flat_scores = ragged_scores.data
|
|
174
|
+
|
|
175
|
+
if flat_scores.size == 0:
|
|
176
|
+
# Fallback: table with one "bad" value
|
|
177
|
+
return np.array([[0.0, 0.0]], dtype=np.float64)
|
|
178
|
+
|
|
179
|
+
# Sort in descending order
|
|
180
|
+
scores_sorted = np.sort(flat_scores)[::-1]
|
|
181
|
+
n_total = flat_scores.size
|
|
182
|
+
|
|
183
|
+
# Calculate FPR (-log10)
|
|
184
|
+
# FPR = rank / N_total
|
|
185
|
+
# Use unique to ensure identical scores have the same FPR (worst case)
|
|
186
|
+
unique_scores, inverse, counts = np.unique(scores_sorted, return_inverse=True, return_counts=True)
|
|
187
|
+
|
|
188
|
+
# np.unique sorts in ascending order. We need descending.
|
|
189
|
+
unique_scores = unique_scores[::-1] # Score desc
|
|
190
|
+
counts = counts[::-1]
|
|
191
|
+
|
|
192
|
+
# Cumulative sum (how many sites have score >= current)
|
|
193
|
+
cum_counts = np.cumsum(counts)
|
|
194
|
+
|
|
195
|
+
# FPR = cum_counts / n_total
|
|
196
|
+
fpr_values = cum_counts / n_total
|
|
197
|
+
|
|
198
|
+
# Log transform: -log10(FPR)
|
|
199
|
+
# Avoid log(0) (though here fpr > 0 always, since n >= 1)
|
|
200
|
+
log_fpr_values = -np.log10(fpr_values)
|
|
201
|
+
|
|
202
|
+
# Table: [Score, -log10(FPR)]
|
|
203
|
+
table = np.column_stack([unique_scores, log_fpr_values])
|
|
204
|
+
|
|
205
|
+
return table.astype(np.float64)
|
|
206
|
+
|
|
207
|
+
def _score_to_frequency(self, score: float) -> float:
|
|
208
|
+
"""Return -log10(FPR) for a given score using the threshold table.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
score : float
|
|
213
|
+
Motif score to convert to frequency.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
float
|
|
218
|
+
The -log10(false positive rate) corresponding to the score.
|
|
219
|
+
"""
|
|
220
|
+
if self._threshold_table is None:
|
|
221
|
+
return np.nan
|
|
222
|
+
|
|
223
|
+
scores_col = self._threshold_table[:, 0]
|
|
224
|
+
logfpr_col = self._threshold_table[:, 1] # <-- this is -log10(FPR)
|
|
225
|
+
|
|
226
|
+
if score >= scores_col[0]:
|
|
227
|
+
return float(logfpr_col[0])
|
|
228
|
+
if score <= scores_col[-1]:
|
|
229
|
+
return float(logfpr_col[-1])
|
|
230
|
+
|
|
231
|
+
idx = np.searchsorted(-scores_col, -score, side="left")
|
|
232
|
+
if idx >= len(logfpr_col):
|
|
233
|
+
return float(logfpr_col[-1])
|
|
234
|
+
return float(logfpr_col[idx])
|
|
235
|
+
|
|
236
|
+
def _frequency_to_score(self, frequency: float, background_data: Optional[RaggedData] = None) -> float:
|
|
237
|
+
"""Convert frequency (FPR) to threshold score.
|
|
238
|
+
|
|
239
|
+
If background_data is provided, calculation is done across ALL positions (sites)
|
|
240
|
+
in these sequences using scores_to_frequencies.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
frequency : float
|
|
245
|
+
False positive rate to convert to a score threshold.
|
|
246
|
+
background_data : RaggedData, optional
|
|
247
|
+
Background sequences to calculate threshold from, if not using precomputed table.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
float
|
|
252
|
+
Score threshold corresponding to the given frequency.
|
|
253
|
+
"""
|
|
254
|
+
if background_data is not None:
|
|
255
|
+
# 1. Get scores for ALL positions in background data
|
|
256
|
+
all_sites_scores = self.scan(background_data, strand="best")
|
|
257
|
+
|
|
258
|
+
if all_sites_scores.data.size == 0:
|
|
259
|
+
return 0.0
|
|
260
|
+
|
|
261
|
+
# 2. Use scores_to_frequencies to get -log10(FPR) for each position
|
|
262
|
+
# Function internally performs np.unique and np.cumsum on all .data values
|
|
263
|
+
freq_ragged = scores_to_frequencies(all_sites_scores)
|
|
264
|
+
log_p_values = freq_ragged.data
|
|
265
|
+
|
|
266
|
+
# 3. Determine target -log10(FPR) value
|
|
267
|
+
target_log_p = -np.log10(frequency) if frequency > 0 else 100.0
|
|
268
|
+
|
|
269
|
+
# 4. Find minimum score among those whose frequency <= target (i.e. -log_p >= target)
|
|
270
|
+
# This precisely corresponds to the _threshold_table construction logic
|
|
271
|
+
mask = log_p_values >= target_log_p
|
|
272
|
+
if not np.any(mask):
|
|
273
|
+
return float(all_sites_scores.data.min())
|
|
274
|
+
|
|
275
|
+
return float(all_sites_scores.data[mask].min())
|
|
276
|
+
|
|
277
|
+
# Fallback via precomputed table
|
|
278
|
+
if self._threshold_table is None:
|
|
279
|
+
raise ValueError("Threshold table not computed. Call get_threshold_table() or provide background_data.")
|
|
280
|
+
|
|
281
|
+
if frequency <= 0:
|
|
282
|
+
return float(self._threshold_table[0, 0]) # infinitely strict threshold -> maximum score
|
|
283
|
+
|
|
284
|
+
target_logfpr = -np.log10(frequency)
|
|
285
|
+
|
|
286
|
+
scores_col = self._threshold_table[:, 0]
|
|
287
|
+
logfpr_col = self._threshold_table[:, 1] # -log10(FPR)
|
|
288
|
+
|
|
289
|
+
mask = logfpr_col >= target_logfpr
|
|
290
|
+
if not np.any(mask):
|
|
291
|
+
# even the weakest score doesn't give such FPR
|
|
292
|
+
return float(scores_col[-1])
|
|
293
|
+
|
|
294
|
+
last_valid = np.where(mask)[0][-1]
|
|
295
|
+
return float(scores_col[last_valid])
|
|
296
|
+
|
|
297
|
+
# ------------------------------------------------------------------
|
|
298
|
+
# Scanning
|
|
299
|
+
# ------------------------------------------------------------------
|
|
300
|
+
def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
301
|
+
"""
|
|
302
|
+
Perform batch scanning of sequences for motif matches.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
sequences : RaggedData
|
|
307
|
+
Encoded sequences (int8).
|
|
308
|
+
strand : {"+", "-", "best"}, optional
|
|
309
|
+
Strand selection mode:
|
|
310
|
+
- None: use self.strand_mode.
|
|
311
|
+
- "+": forward strand only.
|
|
312
|
+
- "-": reverse complement only.
|
|
313
|
+
- "best": maximum score between strands at each position.
|
|
314
|
+
|
|
315
|
+
Returns
|
|
316
|
+
-------
|
|
317
|
+
RaggedData
|
|
318
|
+
Scan results.
|
|
319
|
+
"""
|
|
320
|
+
raise NotImplementedError("MotifModel subclasses must implement scan()")
|
|
321
|
+
|
|
322
|
+
@classmethod
|
|
323
|
+
def from_file(cls, path: str, **kwargs) -> "MotifModel":
|
|
324
|
+
"""
|
|
325
|
+
Abstract class method to create a motif model from a file.
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
path : str
|
|
330
|
+
Path to the motif file.
|
|
331
|
+
**kwargs : dict
|
|
332
|
+
Additional arguments for specific model types.
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
MotifModel
|
|
337
|
+
A motif model instance.
|
|
338
|
+
"""
|
|
339
|
+
raise NotImplementedError("Subclasses must implement from_file()")
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def create_from_file(cls, path: str, model_type: str, **kwargs) -> "MotifModel":
|
|
343
|
+
"""
|
|
344
|
+
Factory method to create a motif model from a file based on model_type.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
path : str
|
|
349
|
+
Path to the motif file.
|
|
350
|
+
model_type : str
|
|
351
|
+
Type of the model ('pwm', 'bamm', 'sitega').
|
|
352
|
+
**kwargs : dict
|
|
353
|
+
Additional arguments for specific model types.
|
|
354
|
+
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
MotifModel
|
|
358
|
+
A motif model instance.
|
|
359
|
+
"""
|
|
360
|
+
model_type = model_type.lower()
|
|
361
|
+
if model_type not in cls._registry:
|
|
362
|
+
raise ValueError(f"Unsupported model type: {model_type}. Registered types: {list(cls._registry.keys())}")
|
|
363
|
+
|
|
364
|
+
subclass = cls._registry[model_type]
|
|
365
|
+
return subclass.from_file(path, **kwargs)
|
|
366
|
+
|
|
367
|
+
def get_sites(
|
|
368
|
+
self,
|
|
369
|
+
sequences: RaggedData,
|
|
370
|
+
mode: str = "best",
|
|
371
|
+
fpr_threshold: Optional[float] = None,
|
|
372
|
+
) -> pd.DataFrame:
|
|
373
|
+
"""Find motif binding sites in sequences.
|
|
374
|
+
|
|
375
|
+
The method operates in two modes:
|
|
376
|
+
- "best": finds the single best site in each sequence
|
|
377
|
+
- "threshold": finds all sites with false positive rate ≤ fpr_threshold
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
sequences : RaggedData
|
|
382
|
+
Encoded sequences to search for motif sites.
|
|
383
|
+
mode : {"best", "threshold"}, optional
|
|
384
|
+
Site finding mode (default "best").
|
|
385
|
+
fpr_threshold : float, optional
|
|
386
|
+
False positive rate threshold for "threshold" mode (e.g., 0.001).
|
|
387
|
+
Required for mode="threshold".
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
pd.DataFrame
|
|
392
|
+
DataFrame with columns:
|
|
393
|
+
- seq_index: sequence index
|
|
394
|
+
- start: start position of site (0-based)
|
|
395
|
+
- end: end position of site (exclusive)
|
|
396
|
+
- strand: DNA strand ("+" or "-")
|
|
397
|
+
- score: recognition function value
|
|
398
|
+
- frequency: false positive rate (FPR) from threshold_table
|
|
399
|
+
- site: site as string representation (ACGT)
|
|
400
|
+
"""
|
|
401
|
+
# Validate parameters
|
|
402
|
+
if mode not in ["best", "threshold"]:
|
|
403
|
+
raise ValueError(f"mode must be 'best' or 'threshold', got {mode!r}")
|
|
404
|
+
if mode == "threshold" and fpr_threshold is None:
|
|
405
|
+
raise ValueError("fpr_threshold is required for mode='threshold'")
|
|
406
|
+
|
|
407
|
+
# Determine threshold score
|
|
408
|
+
score_threshold = (
|
|
409
|
+
self._frequency_to_score(fpr_threshold) if mode == "threshold" and fpr_threshold is not None else None
|
|
410
|
+
)
|
|
411
|
+
if score_threshold is not None:
|
|
412
|
+
logger = logging.getLogger(__name__)
|
|
413
|
+
logger.info(f"FPR threshold: {fpr_threshold} → Score threshold: {score_threshold:.4f}")
|
|
414
|
+
|
|
415
|
+
# Helper function to add a single site
|
|
416
|
+
def add_site(seq_idx: int, seq: np.ndarray, pos: int, strand_idx: int, score: float):
|
|
417
|
+
"""Add a site to results."""
|
|
418
|
+
if pos + self.length > len(seq):
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
site_seq = seq[pos : pos + self.length]
|
|
422
|
+
strand = "+" if strand_idx == 0 else "-"
|
|
423
|
+
|
|
424
|
+
# Reverse complement for minus strand
|
|
425
|
+
if strand_idx == 1:
|
|
426
|
+
site_seq = self._get_rc_sequence(site_seq)
|
|
427
|
+
|
|
428
|
+
results.append(
|
|
429
|
+
{
|
|
430
|
+
"seq_index": seq_idx,
|
|
431
|
+
"start": int(pos),
|
|
432
|
+
"end": int(pos + self.length),
|
|
433
|
+
"strand": strand,
|
|
434
|
+
"score": score,
|
|
435
|
+
"frequency": self._score_to_frequency(score),
|
|
436
|
+
"site": self._int_to_seq(site_seq),
|
|
437
|
+
}
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Collect results
|
|
441
|
+
results = []
|
|
442
|
+
|
|
443
|
+
# Batch scanning
|
|
444
|
+
s_fwd_ragged = self.scan(sequences, strand="+")
|
|
445
|
+
s_rev_ragged = self.scan(sequences, strand="-")
|
|
446
|
+
n_seq = sequences.num_sequences
|
|
447
|
+
|
|
448
|
+
for seq_idx in range(n_seq):
|
|
449
|
+
seq = sequences.get_slice(seq_idx)
|
|
450
|
+
s_fwd = s_fwd_ragged.get_slice(seq_idx)
|
|
451
|
+
s_rev = s_rev_ragged.get_slice(seq_idx)
|
|
452
|
+
|
|
453
|
+
if mode == "best":
|
|
454
|
+
f_max = s_fwd.max() if s_fwd.size > 0 else -1e9
|
|
455
|
+
r_max = s_rev.max() if s_rev.size > 0 else -1e9
|
|
456
|
+
|
|
457
|
+
if f_max >= r_max:
|
|
458
|
+
best_pos = int(np.argmax(s_fwd))
|
|
459
|
+
best_score = float(f_max)
|
|
460
|
+
add_site(seq_idx, seq, best_pos, 0, best_score)
|
|
461
|
+
else:
|
|
462
|
+
best_pos = int(np.argmax(s_rev))
|
|
463
|
+
best_score = float(r_max)
|
|
464
|
+
add_site(seq_idx, seq, best_pos, 1, best_score)
|
|
465
|
+
|
|
466
|
+
else: # mode == "threshold"
|
|
467
|
+
# Forward strand
|
|
468
|
+
f_pos = np.where(s_fwd >= score_threshold)[0]
|
|
469
|
+
for pos in f_pos:
|
|
470
|
+
add_site(seq_idx, seq, int(pos), 0, float(s_fwd[pos]))
|
|
471
|
+
|
|
472
|
+
# Reverse strand
|
|
473
|
+
r_pos = np.where(s_rev >= score_threshold)[0]
|
|
474
|
+
for pos in r_pos:
|
|
475
|
+
add_site(seq_idx, seq, int(pos), 1, float(s_rev[pos]))
|
|
476
|
+
|
|
477
|
+
# Create and sort DataFrame
|
|
478
|
+
df = pd.DataFrame(results)
|
|
479
|
+
if len(df) > 0:
|
|
480
|
+
df = df.sort_values(["seq_index", "score"], ascending=[True, False]).reset_index(drop=True)
|
|
481
|
+
|
|
482
|
+
logger = logging.getLogger(__name__)
|
|
483
|
+
logger.info(f"Found {len(df)} site(s) in {sequences.num_sequences} sequence(s)")
|
|
484
|
+
return df
|
|
485
|
+
|
|
486
|
+
@staticmethod
|
|
487
|
+
def _int_to_seq(seq_int: np.ndarray) -> str:
|
|
488
|
+
"""Convert integer-encoded sequence to ACGT string.
|
|
489
|
+
|
|
490
|
+
Parameters
|
|
491
|
+
----------
|
|
492
|
+
seq_int : np.ndarray
|
|
493
|
+
Integer-encoded sequence (0=A, 1=C, 2=G, 3=T, 4=N).
|
|
494
|
+
|
|
495
|
+
Returns
|
|
496
|
+
-------
|
|
497
|
+
str
|
|
498
|
+
Sequence as string.
|
|
499
|
+
"""
|
|
500
|
+
decoder = np.array(["A", "C", "G", "T", "N"], dtype="U1")
|
|
501
|
+
safe_seq = np.clip(seq_int, 0, 4)
|
|
502
|
+
return "".join(decoder[safe_seq])
|
|
503
|
+
|
|
504
|
+
@staticmethod
|
|
505
|
+
def _get_rc_sequence(seq_int: np.ndarray) -> np.ndarray:
|
|
506
|
+
"""Return reverse complement of sequence.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
seq_int : np.ndarray
|
|
511
|
+
Integer-encoded sequence.
|
|
512
|
+
|
|
513
|
+
Returns
|
|
514
|
+
-------
|
|
515
|
+
np.ndarray
|
|
516
|
+
Reverse complement of sequence.
|
|
517
|
+
"""
|
|
518
|
+
RC_TABLE = np.array([3, 2, 1, 0, 4], dtype=np.int8)
|
|
519
|
+
return RC_TABLE[seq_int[::-1]]
|
|
520
|
+
|
|
521
|
+
def write_pfm(self, path: str) -> None:
|
|
522
|
+
"""Write the motif to a PFM formatted file.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
path : str
|
|
527
|
+
Path of the output file.
|
|
528
|
+
"""
|
|
529
|
+
if self.pfm is not None:
|
|
530
|
+
with open(path, "w") as fname:
|
|
531
|
+
header = f">{self.name}"
|
|
532
|
+
np.savetxt(
|
|
533
|
+
fname,
|
|
534
|
+
self.pfm[:4, :].T,
|
|
535
|
+
fmt="%.8f",
|
|
536
|
+
delimiter="\t",
|
|
537
|
+
newline="\n",
|
|
538
|
+
header=header,
|
|
539
|
+
footer="",
|
|
540
|
+
comments="",
|
|
541
|
+
encoding=None,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def write_dist(self, path: str) -> None:
|
|
545
|
+
"""Write the threshold table of motif to a DIST formatted file.
|
|
546
|
+
|
|
547
|
+
Parameters
|
|
548
|
+
----------
|
|
549
|
+
path : str
|
|
550
|
+
Path of the output file.
|
|
551
|
+
"""
|
|
552
|
+
table = self.threshold_table
|
|
553
|
+
if table is None:
|
|
554
|
+
logger = logging.getLogger(__name__)
|
|
555
|
+
logger.error(f"Cannot write DIST file: threshold table not computed for {self.name}")
|
|
556
|
+
return
|
|
557
|
+
|
|
558
|
+
table = copy.deepcopy(table)
|
|
559
|
+
max_score = self.matrix.max(axis=0).sum()
|
|
560
|
+
min_score = self.matrix.min(axis=0).sum()
|
|
561
|
+
|
|
562
|
+
table[:, 0] = (table[:, 0] - min_score) / (max_score - min_score)
|
|
563
|
+
with open(path, "w") as fname:
|
|
564
|
+
np.savetxt(fname, table, fmt="%.18f", delimiter="\t", newline="\n", footer="", comments="", encoding=None)
|
|
565
|
+
|
|
566
|
+
@property
|
|
567
|
+
def pfm(self) -> Optional[np.ndarray]:
|
|
568
|
+
"""
|
|
569
|
+
Lazy property for position frequency matrix.
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
Optional[np.ndarray]
|
|
574
|
+
Cached PFM if available, otherwise None.
|
|
575
|
+
Use get_pfm() to compute and cache.
|
|
576
|
+
"""
|
|
577
|
+
if self._pfm is None:
|
|
578
|
+
logger = logging.getLogger(__name__)
|
|
579
|
+
logger.warning(f"PFM not computed for model: {self.name}. Use get_pfm(sequences) to compute.")
|
|
580
|
+
return self._pfm
|
|
581
|
+
|
|
582
|
+
def get_pfm(
|
|
583
|
+
self,
|
|
584
|
+
sequences: RaggedData,
|
|
585
|
+
mode: str = "best",
|
|
586
|
+
fpr_threshold: Optional[float] = None,
|
|
587
|
+
top_fraction: Optional[float] = None,
|
|
588
|
+
pseudocount: float = 0.25,
|
|
589
|
+
force_recompute: bool = False,
|
|
590
|
+
) -> np.ndarray:
|
|
591
|
+
"""Construct Position Frequency Matrix (PFM) from binding sites.
|
|
592
|
+
|
|
593
|
+
Result is cached in the _pfm attribute for reuse.
|
|
594
|
+
|
|
595
|
+
Parameters
|
|
596
|
+
----------
|
|
597
|
+
sequences : RaggedData
|
|
598
|
+
Encoded sequences to extract binding sites from.
|
|
599
|
+
mode : {"best", "threshold"}, optional
|
|
600
|
+
Site finding mode (default "best").
|
|
601
|
+
fpr_threshold : float, optional
|
|
602
|
+
Frequency threshold for "threshold" mode.
|
|
603
|
+
top_fraction : float, optional
|
|
604
|
+
Selects only top N% of sites by score.
|
|
605
|
+
pseudocount : float, optional
|
|
606
|
+
Pseudocount for smoothing (default 0.25).
|
|
607
|
+
force_recompute : bool, optional
|
|
608
|
+
If True, recomputes PFM even if it's already cached.
|
|
609
|
+
|
|
610
|
+
Returns
|
|
611
|
+
-------
|
|
612
|
+
np.ndarray
|
|
613
|
+
Normalized PFM of shape (4, motif_length).
|
|
614
|
+
"""
|
|
615
|
+
# Return cached version if available
|
|
616
|
+
if self._pfm is not None and not force_recompute:
|
|
617
|
+
logger = logging.getLogger(__name__)
|
|
618
|
+
logger.info(f"Returning cached PFM for model: {self.name}")
|
|
619
|
+
return self._pfm
|
|
620
|
+
|
|
621
|
+
logger = logging.getLogger(__name__)
|
|
622
|
+
logger.info(f"Computing PFM for model: {self.name}")
|
|
623
|
+
|
|
624
|
+
# Get sites
|
|
625
|
+
sites_df = self.get_sites(sequences, mode=mode, fpr_threshold=fpr_threshold)
|
|
626
|
+
if len(sites_df) == 0:
|
|
627
|
+
raise ValueError("No sites found")
|
|
628
|
+
|
|
629
|
+
sites_df = sites_df.sort_values(by=["score"], axis=0, ascending=False)
|
|
630
|
+
|
|
631
|
+
# Select top N% if specified
|
|
632
|
+
if top_fraction is not None:
|
|
633
|
+
n_keep = max(1, int(len(sites_df) * top_fraction))
|
|
634
|
+
sites_df = sites_df.nlargest(n_keep, "score")
|
|
635
|
+
logger = logging.getLogger(__name__)
|
|
636
|
+
logger.info(f"Selected top {top_fraction * 100:.1f}%: {n_keep} sites")
|
|
637
|
+
|
|
638
|
+
# Initialize PFM with pseudocounts
|
|
639
|
+
pfm = np.full((4, self.length), pseudocount, dtype=np.float32)
|
|
640
|
+
nuc_map = {"A": 0, "C": 1, "G": 2, "T": 3}
|
|
641
|
+
|
|
642
|
+
# Fill counters
|
|
643
|
+
for site_str in sites_df["site"]:
|
|
644
|
+
for pos, nuc in enumerate(site_str):
|
|
645
|
+
if nuc in nuc_map:
|
|
646
|
+
pfm[nuc_map[nuc], pos] += 1.0
|
|
647
|
+
|
|
648
|
+
# Normalize to probabilities
|
|
649
|
+
pfm = pfm / pfm.sum(axis=0, keepdims=True)
|
|
650
|
+
|
|
651
|
+
# Cache the result
|
|
652
|
+
self._pfm = pfm
|
|
653
|
+
|
|
654
|
+
return pfm
|
|
655
|
+
|
|
656
|
+
@staticmethod
|
|
657
|
+
def _reduce_strand(seq_scores: np.ndarray, strand: StrandMode) -> np.ndarray:
|
|
658
|
+
"""Transform (2, N) → (N,) depending on strand mode.
|
|
659
|
+
|
|
660
|
+
Parameters
|
|
661
|
+
----------
|
|
662
|
+
seq_scores : np.ndarray
|
|
663
|
+
Scores with shape (2, N_positions).
|
|
664
|
+
strand : {"best", "+", "-"}
|
|
665
|
+
|
|
666
|
+
Returns
|
|
667
|
+
-------
|
|
668
|
+
np.ndarray
|
|
669
|
+
One-dimensional scores per position (N_positions,).
|
|
670
|
+
"""
|
|
671
|
+
if strand == "best":
|
|
672
|
+
# maximum across both strands for each position
|
|
673
|
+
return np.max(seq_scores, axis=0)
|
|
674
|
+
if strand == "+":
|
|
675
|
+
return seq_scores[0]
|
|
676
|
+
if strand == "-":
|
|
677
|
+
return seq_scores[1]
|
|
678
|
+
raise ValueError(f"Unknown strand={strand!r}. Use '+', '-', or 'best'.")
|
|
679
|
+
|
|
680
|
+
def get_scores(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
681
|
+
"""Calculate motif scores for each position in the sequences using batch processing.
|
|
682
|
+
|
|
683
|
+
Parameters
|
|
684
|
+
----------
|
|
685
|
+
sequences : RaggedData
|
|
686
|
+
Encoded sequences (int8).
|
|
687
|
+
strand : {"best", "+", "-", "both"}, optional
|
|
688
|
+
Strand to score. Default is self.strand_mode.
|
|
689
|
+
|
|
690
|
+
Returns
|
|
691
|
+
-------
|
|
692
|
+
RaggedData
|
|
693
|
+
Ragged array of scores (float32).
|
|
694
|
+
"""
|
|
695
|
+
return self.scan(sequences, strand=strand)
|
|
696
|
+
|
|
697
|
+
def get_frequencies(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
698
|
+
"""Calculate per-position hit frequencies (probability maps) using batch processing.
|
|
699
|
+
|
|
700
|
+
Parameters
|
|
701
|
+
----------
|
|
702
|
+
sequences : RaggedData
|
|
703
|
+
Encoded nucleotide sequences.
|
|
704
|
+
strand : {"best", "+", "-", "both"}, optional
|
|
705
|
+
Strand to evaluate.
|
|
706
|
+
|
|
707
|
+
Returns
|
|
708
|
+
-------
|
|
709
|
+
RaggedData
|
|
710
|
+
Per-position hit frequencies.
|
|
711
|
+
"""
|
|
712
|
+
|
|
713
|
+
return scores_to_frequencies(self.scan(sequences, strand))
|
|
714
|
+
|
|
715
|
+
def clear_cache(self) -> None:
|
|
716
|
+
"""Clear all cached scores, frequencies and normalization range."""
|
|
717
|
+
self._freq_cache.clear()
|
|
718
|
+
self._threshold_table = None
|
|
719
|
+
if not getattr(self, "_pfm_is_required", False):
|
|
720
|
+
self._pfm = None
|
|
721
|
+
|
|
722
|
+
def save(self, filepath: str, clear_cache: bool = True) -> None:
|
|
723
|
+
"""Save the motif model to a file using joblib.
|
|
724
|
+
|
|
725
|
+
Parameters
|
|
726
|
+
----------
|
|
727
|
+
filepath : str
|
|
728
|
+
Full path to the output file (e.g., 'motifs/M1.pkl').
|
|
729
|
+
clear_cache : bool
|
|
730
|
+
If True, clears caches before saving to minimize file size.
|
|
731
|
+
"""
|
|
732
|
+
if clear_cache:
|
|
733
|
+
self.clear_cache()
|
|
734
|
+
|
|
735
|
+
# Create directory if it doesn't exist
|
|
736
|
+
os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)
|
|
737
|
+
|
|
738
|
+
joblib.dump(self, filepath)
|
|
739
|
+
|
|
740
|
+
@staticmethod
|
|
741
|
+
def load(filepath: str) -> MotifModel:
|
|
742
|
+
"""Load a motif model from a .pkl file.
|
|
743
|
+
|
|
744
|
+
Parameters
|
|
745
|
+
----------
|
|
746
|
+
filepath : str
|
|
747
|
+
Path to the .pkl file.
|
|
748
|
+
|
|
749
|
+
Returns
|
|
750
|
+
-------
|
|
751
|
+
MotifModel
|
|
752
|
+
Loaded motif model instance.
|
|
753
|
+
"""
|
|
754
|
+
if not os.path.exists(filepath):
|
|
755
|
+
raise FileNotFoundError(f"Motif file not found: {filepath}")
|
|
756
|
+
|
|
757
|
+
model = joblib.load(filepath)
|
|
758
|
+
return model
|
|
759
|
+
|
|
760
|
+
@property
|
|
761
|
+
def model_type(self) -> str:
|
|
762
|
+
"""Abstract property to return the type of model."""
|
|
763
|
+
raise NotImplementedError("Subclasses must implement model_type property")
|
|
764
|
+
|
|
765
|
+
def write(self, path: str) -> None:
|
|
766
|
+
"""Abstract method to write the motif to a file in its native format."""
|
|
767
|
+
raise NotImplementedError("Subclasses must implement write() method")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
class PwmMotif(MotifModel):
|
|
771
|
+
"""Position Weight Matrix motif model.
|
|
772
|
+
|
|
773
|
+
This class wraps a PWM (log odds scores) and provides efficient
|
|
774
|
+
scoring using the precompiled :func:`yamcot.functions.batch_all_scores`
|
|
775
|
+
Numba kernel. The input matrix is expected to have shape (5, L)
|
|
776
|
+
where the final row contains column minima. The scoring logic
|
|
777
|
+
treats any encoded nucleotide equal to 4 as an ambiguous 'N' and
|
|
778
|
+
assigns a zero contribution.
|
|
779
|
+
|
|
780
|
+
Parameters
|
|
781
|
+
----------
|
|
782
|
+
matrix : np.ndarray
|
|
783
|
+
PWM matrix with shape (5, L) where the 5th row contains column minima.
|
|
784
|
+
name : str
|
|
785
|
+
Name of the motif.
|
|
786
|
+
length : int
|
|
787
|
+
Length of the motif.
|
|
788
|
+
pfm : np.ndarray
|
|
789
|
+
Position Frequency Matrix of shape (4, L). Required attribute for PWM.
|
|
790
|
+
kmer : int, optional
|
|
791
|
+
Size of k-mer for scanning (default is 1).
|
|
792
|
+
"""
|
|
793
|
+
|
|
794
|
+
def __init__(
|
|
795
|
+
self,
|
|
796
|
+
matrix: np.ndarray,
|
|
797
|
+
name: str,
|
|
798
|
+
length: int,
|
|
799
|
+
pfm: np.ndarray,
|
|
800
|
+
kmer: int = 1,
|
|
801
|
+
) -> None:
|
|
802
|
+
super().__init__(
|
|
803
|
+
matrix=matrix,
|
|
804
|
+
name=name,
|
|
805
|
+
length=length,
|
|
806
|
+
)
|
|
807
|
+
self.kmer = kmer
|
|
808
|
+
# Assign to private attribute so property works correctly
|
|
809
|
+
self._pfm = pfm
|
|
810
|
+
self._pfm_is_required = True
|
|
811
|
+
|
|
812
|
+
def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
813
|
+
"""Score sequences with the PWM.
|
|
814
|
+
|
|
815
|
+
Parameters
|
|
816
|
+
----------
|
|
817
|
+
sequences : RaggedData
|
|
818
|
+
Encoded nucleotide sequences (int8).
|
|
819
|
+
strand : {"best", "+", "-", "both"}, optional
|
|
820
|
+
Strand selection mode. If None, uses self.strand_mode.
|
|
821
|
+
|
|
822
|
+
Returns
|
|
823
|
+
-------
|
|
824
|
+
RaggedData
|
|
825
|
+
Scanning results with scores for each position.
|
|
826
|
+
"""
|
|
827
|
+
strand = strand or self.strand_mode
|
|
828
|
+
matrix = self.matrix.astype(np.float32)
|
|
829
|
+
|
|
830
|
+
if strand == "+":
|
|
831
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
|
|
832
|
+
elif strand == "-":
|
|
833
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
|
|
834
|
+
elif strand == "best":
|
|
835
|
+
sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
|
|
836
|
+
sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
|
|
837
|
+
return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
|
|
838
|
+
else:
|
|
839
|
+
logger = logging.getLogger(__name__)
|
|
840
|
+
logger.error(f"Unknown strand mode: {strand}")
|
|
841
|
+
sys.exit(1)
|
|
842
|
+
|
|
843
|
+
@classmethod
|
|
844
|
+
def from_file(cls, path: str, index: int = 0, **kwargs) -> PwmMotif:
|
|
845
|
+
"""Create a PwmMotif from a file.
|
|
846
|
+
|
|
847
|
+
Supports MEME, ProSampler, and PFM formats.
|
|
848
|
+
|
|
849
|
+
Parameters
|
|
850
|
+
----------
|
|
851
|
+
path : str
|
|
852
|
+
Path to the motif file.
|
|
853
|
+
index : int, optional
|
|
854
|
+
Index of motif to read from file if multiple motifs are present (default 0).
|
|
855
|
+
**kwargs : dict
|
|
856
|
+
Additional arguments.
|
|
857
|
+
|
|
858
|
+
Returns
|
|
859
|
+
-------
|
|
860
|
+
PwmMotif
|
|
861
|
+
A PwmMotif object created from the file.
|
|
862
|
+
"""
|
|
863
|
+
# Determine file format based on extension
|
|
864
|
+
_, ext = os.path.splitext(path.lower())
|
|
865
|
+
|
|
866
|
+
if ext == ".pkl":
|
|
867
|
+
return joblib.load(path)
|
|
868
|
+
elif ext == ".meme":
|
|
869
|
+
# MEME format
|
|
870
|
+
matrix, info, _number_of_motifs = read_meme(path, index=index)
|
|
871
|
+
pfm = matrix
|
|
872
|
+
name, length = info
|
|
873
|
+
elif ext == ".pfm":
|
|
874
|
+
# PFM format
|
|
875
|
+
pwm, length, _minimum, _maximum = read_pfm(path)
|
|
876
|
+
# Extract PFM from PWM (first 4 rows, excluding the 5th row of minimums)
|
|
877
|
+
pfm = pwm[:4, :] # Get the original PFM from the extended PWM
|
|
878
|
+
name = os.path.splitext(os.path.basename(path))[0]
|
|
879
|
+
else:
|
|
880
|
+
logger = logging.getLogger(__name__)
|
|
881
|
+
logger.error(f"Wrong format pf PWM model: {path}")
|
|
882
|
+
sys.exit(1)
|
|
883
|
+
|
|
884
|
+
# Convert PFM to PWM
|
|
885
|
+
pwm = pfm_to_pwm(pfm)
|
|
886
|
+
# Add the 5th row for 'N' characters (minimum values at each position)
|
|
887
|
+
pwm_ext = np.concatenate((pwm, np.min(pwm, axis=0, keepdims=True)), axis=0)
|
|
888
|
+
return cls(matrix=pwm_ext, name=name, length=int(length), pfm=pfm)
|
|
889
|
+
|
|
890
|
+
@property
|
|
891
|
+
def model_type(self) -> str:
|
|
892
|
+
"""Return the type of model ('pwm')."""
|
|
893
|
+
return "pwm"
|
|
894
|
+
|
|
895
|
+
def write(self, path: str) -> None:
|
|
896
|
+
"""Write the PWM motif to a file in PFM format."""
|
|
897
|
+
# Create directory if it doesn't exist
|
|
898
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
899
|
+
self.write_pfm(path)
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
class BammMotif(MotifModel):
|
|
903
|
+
"""Bayesian Markov Model motif.
|
|
904
|
+
|
|
905
|
+
BaMMs extend PWMs by modelling dependencies between neighbouring
|
|
906
|
+
nucleotides. The representation and scanning procedure differ
|
|
907
|
+
substantially from PWMs and require specialised software. This
|
|
908
|
+
class provides an interface compatible with the :class:`MotifModel`
|
|
909
|
+
API but does not implement the scoring logic. Users must supply
|
|
910
|
+
an external scoring function via :meth:`scan`.
|
|
911
|
+
|
|
912
|
+
Parameters
|
|
913
|
+
----------
|
|
914
|
+
matrix : np.ndarray
|
|
915
|
+
BaMM matrix representation.
|
|
916
|
+
name : str
|
|
917
|
+
Name of the motif.
|
|
918
|
+
length : int
|
|
919
|
+
Length of the motif.
|
|
920
|
+
kmer : int, optional
|
|
921
|
+
Size of k-mer for scanning (default is 2).
|
|
922
|
+
"""
|
|
923
|
+
|
|
924
|
+
def __init__(
|
|
925
|
+
self,
|
|
926
|
+
matrix: np.ndarray,
|
|
927
|
+
name: str,
|
|
928
|
+
length: int,
|
|
929
|
+
kmer: int = 2,
|
|
930
|
+
) -> None:
|
|
931
|
+
super().__init__(
|
|
932
|
+
matrix=matrix,
|
|
933
|
+
name=name,
|
|
934
|
+
length=length,
|
|
935
|
+
)
|
|
936
|
+
self.kmer = kmer
|
|
937
|
+
|
|
938
|
+
def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
939
|
+
"""Score sequences with the BaMM.
|
|
940
|
+
|
|
941
|
+
Parameters
|
|
942
|
+
----------
|
|
943
|
+
sequences : RaggedData
|
|
944
|
+
Encoded nucleotide sequences (int8).
|
|
945
|
+
strand : {"best", "+", "-"}, optional
|
|
946
|
+
Strand selection mode. If None, uses self.strand_mode.
|
|
947
|
+
|
|
948
|
+
Returns
|
|
949
|
+
-------
|
|
950
|
+
RaggedData
|
|
951
|
+
Scanning results with scores for each position.
|
|
952
|
+
"""
|
|
953
|
+
strand = strand or self.strand_mode
|
|
954
|
+
|
|
955
|
+
matrix = self.matrix.astype(np.float32)
|
|
956
|
+
|
|
957
|
+
if strand == "+":
|
|
958
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False, with_context=True)
|
|
959
|
+
elif strand == "-":
|
|
960
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True, with_context=True)
|
|
961
|
+
elif strand == "best":
|
|
962
|
+
sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False, with_context=True)
|
|
963
|
+
sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True, with_context=True)
|
|
964
|
+
return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
|
|
965
|
+
else:
|
|
966
|
+
logger = logging.getLogger(__name__)
|
|
967
|
+
logger.error(f"Unknown strand mode: {strand}")
|
|
968
|
+
sys.exit(1)
|
|
969
|
+
|
|
970
|
+
@classmethod
|
|
971
|
+
def from_file(cls, path: str, bg_path: str | None = None, order: int = 2, **kwargs) -> "BammMotif":
|
|
972
|
+
"""Create a BammMotif from a file.
|
|
973
|
+
|
|
974
|
+
Parameters
|
|
975
|
+
----------
|
|
976
|
+
path : str
|
|
977
|
+
Path to the BaMM motif file (.ihbcp format or base path).
|
|
978
|
+
bg_path : str, optional
|
|
979
|
+
Path to the background BaMM file (.hbcp format). If not provided,
|
|
980
|
+
it attempts to find a background file in the same directory.
|
|
981
|
+
order : int, optional
|
|
982
|
+
Order of the BaMM model (default is 2).
|
|
983
|
+
**kwargs : dict
|
|
984
|
+
Additional arguments.
|
|
985
|
+
|
|
986
|
+
Returns
|
|
987
|
+
-------
|
|
988
|
+
BammMotif
|
|
989
|
+
A BammMotif object created from the file.
|
|
990
|
+
"""
|
|
991
|
+
# Handle case where path is provided without extension (as in pipeline.py)
|
|
992
|
+
if not path.endswith(".ihbcp") and not os.path.exists(path):
|
|
993
|
+
ihbcp_path = f"{path}.ihbcp"
|
|
994
|
+
hbcp_path = f"{path}.hbcp"
|
|
995
|
+
if os.path.exists(ihbcp_path):
|
|
996
|
+
path = ihbcp_path
|
|
997
|
+
if bg_path is None and os.path.exists(hbcp_path):
|
|
998
|
+
bg_path = hbcp_path
|
|
999
|
+
|
|
1000
|
+
# If no background path is provided, try to find it
|
|
1001
|
+
if bg_path is None:
|
|
1002
|
+
# Look for background file in the same directory
|
|
1003
|
+
dir_path = os.path.dirname(path)
|
|
1004
|
+
basename = os.path.basename(path)
|
|
1005
|
+
# Try to find background file by replacing extension or using common names
|
|
1006
|
+
possible_bg_names = ["bamm.hbcp", "background.hbcp", basename.replace(".ihbcp", ".hbcp")]
|
|
1007
|
+
for bg_name in possible_bg_names:
|
|
1008
|
+
possible_bg_path = os.path.join(dir_path, bg_name)
|
|
1009
|
+
if os.path.exists(possible_bg_path):
|
|
1010
|
+
bg_path = possible_bg_path
|
|
1011
|
+
break
|
|
1012
|
+
|
|
1013
|
+
if bg_path is None:
|
|
1014
|
+
raise ValueError(f"Background file not found for {path}. Please provide bg_path parameter.")
|
|
1015
|
+
|
|
1016
|
+
_, max_order, length = parse_file_content(path)
|
|
1017
|
+
if order > max_order:
|
|
1018
|
+
order = max_order
|
|
1019
|
+
|
|
1020
|
+
# Read the BaMM motif and background files
|
|
1021
|
+
matrix = read_bamm(path, bg_path, order)
|
|
1022
|
+
name = os.path.splitext(os.path.basename(path))[0]
|
|
1023
|
+
return cls(
|
|
1024
|
+
matrix=matrix,
|
|
1025
|
+
name=name,
|
|
1026
|
+
length=length,
|
|
1027
|
+
kmer=order + 1, # BaMM kmer is typically order + 1
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
@property
|
|
1031
|
+
def model_type(self) -> str:
|
|
1032
|
+
"""Return the type of model ('bamm')."""
|
|
1033
|
+
return "bamm"
|
|
1034
|
+
|
|
1035
|
+
def write(self, path: str) -> None:
|
|
1036
|
+
"""Write the BaMM motif to a file."""
|
|
1037
|
+
# BaMM writing functionality would go here
|
|
1038
|
+
# For now, we'll implement a basic version that saves as joblib
|
|
1039
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
1040
|
+
self.save(path)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
class SitegaMotif(MotifModel):
|
|
1044
|
+
"""SiteGA motif model.
|
|
1045
|
+
|
|
1046
|
+
SiteGA motifs are based on a linear probabilistic grammar and
|
|
1047
|
+
represent dinucleotide dependencies across multiple positions. The
|
|
1048
|
+
scoring logic can be derived from the original SiteGA
|
|
1049
|
+
implementation but is non‑trivial. This class acts as a
|
|
1050
|
+
placeholder for future SiteGA integration. A custom scanning
|
|
1051
|
+
function can be provided via :meth:`scan`.
|
|
1052
|
+
|
|
1053
|
+
Parameters
|
|
1054
|
+
----------
|
|
1055
|
+
matrix : np.ndarray
|
|
1056
|
+
SiteGA matrix representation.
|
|
1057
|
+
name : str
|
|
1058
|
+
Name of the motif.
|
|
1059
|
+
length : int
|
|
1060
|
+
Length of the motif.
|
|
1061
|
+
kmer : int, optional
|
|
1062
|
+
Size of k-mer for scanning (default is 2).
|
|
1063
|
+
"""
|
|
1064
|
+
|
|
1065
|
+
def __init__(
|
|
1066
|
+
self,
|
|
1067
|
+
matrix: np.ndarray,
|
|
1068
|
+
name: str,
|
|
1069
|
+
length: int,
|
|
1070
|
+
minimum: float = 0.0,
|
|
1071
|
+
maximum: float = 0.0,
|
|
1072
|
+
kmer: int = 2,
|
|
1073
|
+
) -> None:
|
|
1074
|
+
super().__init__(
|
|
1075
|
+
matrix=matrix,
|
|
1076
|
+
name=name,
|
|
1077
|
+
length=length,
|
|
1078
|
+
)
|
|
1079
|
+
self.minimum = minimum
|
|
1080
|
+
self.maximum = maximum
|
|
1081
|
+
self.kmer = kmer
|
|
1082
|
+
|
|
1083
|
+
def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
|
|
1084
|
+
"""Score sequences with SiteGA model.
|
|
1085
|
+
|
|
1086
|
+
Parameters
|
|
1087
|
+
----------
|
|
1088
|
+
sequences : RaggedData
|
|
1089
|
+
Encoded nucleotide sequences (int8).
|
|
1090
|
+
strand : {"best", "+", "-"}, optional
|
|
1091
|
+
Strand selection mode. If None, uses self.strand_mode.
|
|
1092
|
+
|
|
1093
|
+
Returns
|
|
1094
|
+
-------
|
|
1095
|
+
RaggedData
|
|
1096
|
+
Scanning results with scores for each position.
|
|
1097
|
+
"""
|
|
1098
|
+
strand = strand or self.strand_mode
|
|
1099
|
+
matrix = self.matrix.astype(np.float32)
|
|
1100
|
+
|
|
1101
|
+
if strand == "+":
|
|
1102
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
|
|
1103
|
+
elif strand == "-":
|
|
1104
|
+
return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
|
|
1105
|
+
elif strand == "best":
|
|
1106
|
+
sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
|
|
1107
|
+
sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
|
|
1108
|
+
return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
|
|
1109
|
+
else:
|
|
1110
|
+
logger = logging.getLogger(__name__)
|
|
1111
|
+
logger.error(f"Unknown strand mode: {strand}")
|
|
1112
|
+
sys.exit(1)
|
|
1113
|
+
|
|
1114
|
+
@classmethod
|
|
1115
|
+
def from_file(cls, path: str, **kwargs) -> SitegaMotif:
|
|
1116
|
+
"""Parse SiteGA output file to create a SitegaMotif object.
|
|
1117
|
+
|
|
1118
|
+
Parameters
|
|
1119
|
+
----------
|
|
1120
|
+
path : str
|
|
1121
|
+
Path to the SiteGA output file (typically ends with '.mat').
|
|
1122
|
+
**kwargs : dict
|
|
1123
|
+
Additional arguments.
|
|
1124
|
+
|
|
1125
|
+
Returns
|
|
1126
|
+
-------
|
|
1127
|
+
SitegaMotif
|
|
1128
|
+
A SitegaMotif object created from the parsed data.
|
|
1129
|
+
"""
|
|
1130
|
+
# Parse the SiteGA output file
|
|
1131
|
+
matrix, length, minimum, maximum = read_sitega(path)
|
|
1132
|
+
|
|
1133
|
+
# Extract motif name from the file path
|
|
1134
|
+
name = os.path.splitext(os.path.basename(path))[0]
|
|
1135
|
+
|
|
1136
|
+
# Create and return the SitegaMotif instance
|
|
1137
|
+
return cls(
|
|
1138
|
+
matrix=matrix,
|
|
1139
|
+
name=name,
|
|
1140
|
+
length=length,
|
|
1141
|
+
minimum=minimum,
|
|
1142
|
+
maximum=maximum,
|
|
1143
|
+
kmer=2, # SiteGA typically uses dinucleotide dependencies
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
@property
|
|
1147
|
+
def model_type(self) -> str:
|
|
1148
|
+
"""Return the type of model ('sitega')."""
|
|
1149
|
+
return "sitega"
|
|
1150
|
+
|
|
1151
|
+
def write(self, path: str) -> None:
|
|
1152
|
+
"""Write the SiteGA motif to a file in .mat format."""
|
|
1153
|
+
# Create directory if it doesn't exist
|
|
1154
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
1155
|
+
write_sitega(self, path)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
# Register subclasses for polymorphic loading
|
|
1159
|
+
MotifModel.register_subclass("pwm", PwmMotif)
|
|
1160
|
+
MotifModel.register_subclass("bamm", BammMotif)
|
|
1161
|
+
MotifModel.register_subclass("sitega", SitegaMotif)
|