yamcot 1.0.0__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yamcot/models.py ADDED
@@ -0,0 +1,1161 @@
1
+ """
2
+ models
3
+ ======
4
+
5
+ This module defines abstractions for motif models. A motif is
6
+ represented internally as a numeric array and exposes methods for
7
+ scoring sequences and computing the best match score across all
8
+ positions and strand orientations. Concrete subclasses implement
9
+ different motif scoring models such as PWMs or Bayesian Markov
10
+ models (BaMMs).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import copy
16
+ import logging
17
+ import os
18
+ import sys
19
+ from dataclasses import dataclass, field
20
+ from typing import ClassVar, Dict, Literal, Optional
21
+
22
+ import joblib
23
+ import numpy as np
24
+ import pandas as pd
25
+
26
+ from yamcot.functions import batch_all_scores, pfm_to_pwm, scores_to_frequencies
27
+ from yamcot.io import parse_file_content, read_bamm, read_meme, read_pfm, read_sitega, write_sitega
28
+ from yamcot.ragged import RaggedData
29
+
30
+ StrandMode = Literal["best", "+", "-", "both"]
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RaggedScores:
35
+ """Container for ragged arrays of scores with variable-length sequences.
36
+
37
+ Attributes
38
+ ----------
39
+ values : np.ndarray
40
+ Float32 array of shape (n_seq, Lmax) containing the score values.
41
+ lengths : np.ndarray
42
+ Int32 array of shape (n_seq,) containing the length of each sequence.
43
+ """
44
+
45
+ values: np.ndarray # float32, shape (n_seq, Lmax)
46
+ lengths: np.ndarray # int32, shape (n_seq,)
47
+
48
+ @classmethod
49
+ def from_numba(cls, rs_numba: RaggedData) -> RaggedScores:
50
+ """Convert a RaggedData object from Numba to RaggedScores.
51
+
52
+ Parameters
53
+ ----------
54
+ rs_numba : RaggedData
55
+ Input RaggedData object from Numba computations.
56
+
57
+ Returns
58
+ -------
59
+ RaggedScores
60
+ Converted RaggedScores object.
61
+ """
62
+ n = rs_numba.num_sequences
63
+ lengths = np.zeros(n, dtype=np.int32)
64
+ for i in range(n):
65
+ lengths[i] = rs_numba.get_length(i)
66
+
67
+ l_max = int(lengths.max()) if n > 0 else 0
68
+ values = np.zeros((n, l_max), dtype=np.float32)
69
+ for i in range(n):
70
+ row = rs_numba.get_slice(i)
71
+ values[i, : len(row)] = row
72
+ return cls(values, lengths)
73
+
74
+
75
+ @dataclass
76
+ class MotifModel:
77
+ """
78
+ Abstract base class for motif models.
79
+
80
+ Subclasses must implement the :meth:`scan` method which returns
81
+ per-position scores for the forward and reverse complement of a
82
+ sequence. The :meth:`best_score` helper uses :meth:`scan` to
83
+ compute the maximal score across all positions and strands.
84
+
85
+ Attributes
86
+ ----------
87
+ matrix : np.ndarray
88
+ Numeric representation of the motif.
89
+ name : str
90
+ Human readable identifier of the motif.
91
+ length : int
92
+ Length of the motif (number of positions).
93
+ strand_mode : {"best", "+", "-"}
94
+ Default strand mode for scores/frequencies properties.
95
+ """
96
+
97
+ _registry: ClassVar[Dict[str, type[MotifModel]]] = {}
98
+
99
+ matrix: np.ndarray
100
+ name: str
101
+ length: int
102
+
103
+ strand_mode: StrandMode = field(default="best", init=True)
104
+
105
+ statistics: Optional[Dict[str, float]] = field(default=None, init=False)
106
+ _threshold_table: Optional[np.ndarray] = field(default=None, init=False)
107
+ _pfm: Optional[np.ndarray] = field(default=None, init=False)
108
+
109
+ _freq_cache: Dict[StrandMode, RaggedData] = field(
110
+ default_factory=dict,
111
+ init=False,
112
+ )
113
+
114
+ @classmethod
115
+ def register_subclass(cls, model_type: str, subclass: type[MotifModel]):
116
+ """Register a subclass for a specific model type."""
117
+ cls._registry[model_type.lower()] = subclass
118
+
119
+ # ------------------------------------------------------------------
120
+ # Thresholds
121
+ # ------------------------------------------------------------------
122
+ @property
123
+ def threshold_table(self) -> Optional[np.ndarray]:
124
+ """
125
+ Lazy property for threshold table.
126
+
127
+ Returns
128
+ -------
129
+ Optional[np.ndarray]
130
+ Precomputed threshold values for scoring.
131
+ """
132
+ if self._threshold_table is None:
133
+ logger = logging.getLogger(__name__)
134
+ logger.warning(f"There is no threshold table for model: {self.name}")
135
+ return self._threshold_table
136
+
137
+ def get_threshold_table(self, promoters) -> np.ndarray:
138
+ """Compute the threshold table based on the matrix.
139
+
140
+ Returns
141
+ -------
142
+ np.ndarray
143
+ The computed threshold table.
144
+ """
145
+ logger = logging.getLogger(__name__)
146
+ logger.info(f"Computing threshold table for model: {self.name}")
147
+
148
+ if self._threshold_table is None:
149
+ self._threshold_table = self._calculate_threshold_table(promoters)
150
+
151
+ return self._threshold_table
152
+
153
+ def _calculate_threshold_table(self, promoters: RaggedData) -> np.ndarray:
154
+ """Calculate the threshold table using all sites in all sequences.
155
+
156
+ Computes a lookup table mapping motif scores to false positive rates (FPR)
157
+ by evaluating the motif against all possible positions in the promoter sequences.
158
+
159
+ Parameters
160
+ ----------
161
+ promoters : RaggedData
162
+ Collection of promoter sequences to calculate thresholds against.
163
+
164
+ Returns
165
+ -------
166
+ np.ndarray, shape (N, 2)
167
+ Array where each row contains [score, -log10(fpr)], sorted by descending score.
168
+ """
169
+ # Get scores for all positions
170
+ ragged_scores = self.get_scores(promoters, strand=self.strand_mode)
171
+
172
+ # Flatten - collect all scores in one array directly from data (since there's no padding)
173
+ flat_scores = ragged_scores.data
174
+
175
+ if flat_scores.size == 0:
176
+ # Fallback: table with one "bad" value
177
+ return np.array([[0.0, 0.0]], dtype=np.float64)
178
+
179
+ # Sort in descending order
180
+ scores_sorted = np.sort(flat_scores)[::-1]
181
+ n_total = flat_scores.size
182
+
183
+ # Calculate FPR (-log10)
184
+ # FPR = rank / N_total
185
+ # Use unique to ensure identical scores have the same FPR (worst case)
186
+ unique_scores, inverse, counts = np.unique(scores_sorted, return_inverse=True, return_counts=True)
187
+
188
+ # np.unique sorts in ascending order. We need descending.
189
+ unique_scores = unique_scores[::-1] # Score desc
190
+ counts = counts[::-1]
191
+
192
+ # Cumulative sum (how many sites have score >= current)
193
+ cum_counts = np.cumsum(counts)
194
+
195
+ # FPR = cum_counts / n_total
196
+ fpr_values = cum_counts / n_total
197
+
198
+ # Log transform: -log10(FPR)
199
+ # Avoid log(0) (though here fpr > 0 always, since n >= 1)
200
+ log_fpr_values = -np.log10(fpr_values)
201
+
202
+ # Table: [Score, -log10(FPR)]
203
+ table = np.column_stack([unique_scores, log_fpr_values])
204
+
205
+ return table.astype(np.float64)
206
+
207
+ def _score_to_frequency(self, score: float) -> float:
208
+ """Return -log10(FPR) for a given score using the threshold table.
209
+
210
+ Parameters
211
+ ----------
212
+ score : float
213
+ Motif score to convert to frequency.
214
+
215
+ Returns
216
+ -------
217
+ float
218
+ The -log10(false positive rate) corresponding to the score.
219
+ """
220
+ if self._threshold_table is None:
221
+ return np.nan
222
+
223
+ scores_col = self._threshold_table[:, 0]
224
+ logfpr_col = self._threshold_table[:, 1] # <-- this is -log10(FPR)
225
+
226
+ if score >= scores_col[0]:
227
+ return float(logfpr_col[0])
228
+ if score <= scores_col[-1]:
229
+ return float(logfpr_col[-1])
230
+
231
+ idx = np.searchsorted(-scores_col, -score, side="left")
232
+ if idx >= len(logfpr_col):
233
+ return float(logfpr_col[-1])
234
+ return float(logfpr_col[idx])
235
+
236
+ def _frequency_to_score(self, frequency: float, background_data: Optional[RaggedData] = None) -> float:
237
+ """Convert frequency (FPR) to threshold score.
238
+
239
+ If background_data is provided, calculation is done across ALL positions (sites)
240
+ in these sequences using scores_to_frequencies.
241
+
242
+ Parameters
243
+ ----------
244
+ frequency : float
245
+ False positive rate to convert to a score threshold.
246
+ background_data : RaggedData, optional
247
+ Background sequences to calculate threshold from, if not using precomputed table.
248
+
249
+ Returns
250
+ -------
251
+ float
252
+ Score threshold corresponding to the given frequency.
253
+ """
254
+ if background_data is not None:
255
+ # 1. Get scores for ALL positions in background data
256
+ all_sites_scores = self.scan(background_data, strand="best")
257
+
258
+ if all_sites_scores.data.size == 0:
259
+ return 0.0
260
+
261
+ # 2. Use scores_to_frequencies to get -log10(FPR) for each position
262
+ # Function internally performs np.unique and np.cumsum on all .data values
263
+ freq_ragged = scores_to_frequencies(all_sites_scores)
264
+ log_p_values = freq_ragged.data
265
+
266
+ # 3. Determine target -log10(FPR) value
267
+ target_log_p = -np.log10(frequency) if frequency > 0 else 100.0
268
+
269
+ # 4. Find minimum score among those whose frequency <= target (i.e. -log_p >= target)
270
+ # This precisely corresponds to the _threshold_table construction logic
271
+ mask = log_p_values >= target_log_p
272
+ if not np.any(mask):
273
+ return float(all_sites_scores.data.min())
274
+
275
+ return float(all_sites_scores.data[mask].min())
276
+
277
+ # Fallback via precomputed table
278
+ if self._threshold_table is None:
279
+ raise ValueError("Threshold table not computed. Call get_threshold_table() or provide background_data.")
280
+
281
+ if frequency <= 0:
282
+ return float(self._threshold_table[0, 0]) # infinitely strict threshold -> maximum score
283
+
284
+ target_logfpr = -np.log10(frequency)
285
+
286
+ scores_col = self._threshold_table[:, 0]
287
+ logfpr_col = self._threshold_table[:, 1] # -log10(FPR)
288
+
289
+ mask = logfpr_col >= target_logfpr
290
+ if not np.any(mask):
291
+ # even the weakest score doesn't give such FPR
292
+ return float(scores_col[-1])
293
+
294
+ last_valid = np.where(mask)[0][-1]
295
+ return float(scores_col[last_valid])
296
+
297
+ # ------------------------------------------------------------------
298
+ # Scanning
299
+ # ------------------------------------------------------------------
300
+ def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
301
+ """
302
+ Perform batch scanning of sequences for motif matches.
303
+
304
+ Parameters
305
+ ----------
306
+ sequences : RaggedData
307
+ Encoded sequences (int8).
308
+ strand : {"+", "-", "best"}, optional
309
+ Strand selection mode:
310
+ - None: use self.strand_mode.
311
+ - "+": forward strand only.
312
+ - "-": reverse complement only.
313
+ - "best": maximum score between strands at each position.
314
+
315
+ Returns
316
+ -------
317
+ RaggedData
318
+ Scan results.
319
+ """
320
+ raise NotImplementedError("MotifModel subclasses must implement scan()")
321
+
322
+ @classmethod
323
+ def from_file(cls, path: str, **kwargs) -> "MotifModel":
324
+ """
325
+ Abstract class method to create a motif model from a file.
326
+
327
+ Parameters
328
+ ----------
329
+ path : str
330
+ Path to the motif file.
331
+ **kwargs : dict
332
+ Additional arguments for specific model types.
333
+
334
+ Returns
335
+ -------
336
+ MotifModel
337
+ A motif model instance.
338
+ """
339
+ raise NotImplementedError("Subclasses must implement from_file()")
340
+
341
+ @classmethod
342
+ def create_from_file(cls, path: str, model_type: str, **kwargs) -> "MotifModel":
343
+ """
344
+ Factory method to create a motif model from a file based on model_type.
345
+
346
+ Parameters
347
+ ----------
348
+ path : str
349
+ Path to the motif file.
350
+ model_type : str
351
+ Type of the model ('pwm', 'bamm', 'sitega').
352
+ **kwargs : dict
353
+ Additional arguments for specific model types.
354
+
355
+ Returns
356
+ -------
357
+ MotifModel
358
+ A motif model instance.
359
+ """
360
+ model_type = model_type.lower()
361
+ if model_type not in cls._registry:
362
+ raise ValueError(f"Unsupported model type: {model_type}. Registered types: {list(cls._registry.keys())}")
363
+
364
+ subclass = cls._registry[model_type]
365
+ return subclass.from_file(path, **kwargs)
366
+
367
+ def get_sites(
368
+ self,
369
+ sequences: RaggedData,
370
+ mode: str = "best",
371
+ fpr_threshold: Optional[float] = None,
372
+ ) -> pd.DataFrame:
373
+ """Find motif binding sites in sequences.
374
+
375
+ The method operates in two modes:
376
+ - "best": finds the single best site in each sequence
377
+ - "threshold": finds all sites with false positive rate ≤ fpr_threshold
378
+
379
+ Parameters
380
+ ----------
381
+ sequences : RaggedData
382
+ Encoded sequences to search for motif sites.
383
+ mode : {"best", "threshold"}, optional
384
+ Site finding mode (default "best").
385
+ fpr_threshold : float, optional
386
+ False positive rate threshold for "threshold" mode (e.g., 0.001).
387
+ Required for mode="threshold".
388
+
389
+ Returns
390
+ -------
391
+ pd.DataFrame
392
+ DataFrame with columns:
393
+ - seq_index: sequence index
394
+ - start: start position of site (0-based)
395
+ - end: end position of site (exclusive)
396
+ - strand: DNA strand ("+" or "-")
397
+ - score: recognition function value
398
+ - frequency: false positive rate (FPR) from threshold_table
399
+ - site: site as string representation (ACGT)
400
+ """
401
+ # Validate parameters
402
+ if mode not in ["best", "threshold"]:
403
+ raise ValueError(f"mode must be 'best' or 'threshold', got {mode!r}")
404
+ if mode == "threshold" and fpr_threshold is None:
405
+ raise ValueError("fpr_threshold is required for mode='threshold'")
406
+
407
+ # Determine threshold score
408
+ score_threshold = (
409
+ self._frequency_to_score(fpr_threshold) if mode == "threshold" and fpr_threshold is not None else None
410
+ )
411
+ if score_threshold is not None:
412
+ logger = logging.getLogger(__name__)
413
+ logger.info(f"FPR threshold: {fpr_threshold} → Score threshold: {score_threshold:.4f}")
414
+
415
+ # Helper function to add a single site
416
+ def add_site(seq_idx: int, seq: np.ndarray, pos: int, strand_idx: int, score: float):
417
+ """Add a site to results."""
418
+ if pos + self.length > len(seq):
419
+ return
420
+
421
+ site_seq = seq[pos : pos + self.length]
422
+ strand = "+" if strand_idx == 0 else "-"
423
+
424
+ # Reverse complement for minus strand
425
+ if strand_idx == 1:
426
+ site_seq = self._get_rc_sequence(site_seq)
427
+
428
+ results.append(
429
+ {
430
+ "seq_index": seq_idx,
431
+ "start": int(pos),
432
+ "end": int(pos + self.length),
433
+ "strand": strand,
434
+ "score": score,
435
+ "frequency": self._score_to_frequency(score),
436
+ "site": self._int_to_seq(site_seq),
437
+ }
438
+ )
439
+
440
+ # Collect results
441
+ results = []
442
+
443
+ # Batch scanning
444
+ s_fwd_ragged = self.scan(sequences, strand="+")
445
+ s_rev_ragged = self.scan(sequences, strand="-")
446
+ n_seq = sequences.num_sequences
447
+
448
+ for seq_idx in range(n_seq):
449
+ seq = sequences.get_slice(seq_idx)
450
+ s_fwd = s_fwd_ragged.get_slice(seq_idx)
451
+ s_rev = s_rev_ragged.get_slice(seq_idx)
452
+
453
+ if mode == "best":
454
+ f_max = s_fwd.max() if s_fwd.size > 0 else -1e9
455
+ r_max = s_rev.max() if s_rev.size > 0 else -1e9
456
+
457
+ if f_max >= r_max:
458
+ best_pos = int(np.argmax(s_fwd))
459
+ best_score = float(f_max)
460
+ add_site(seq_idx, seq, best_pos, 0, best_score)
461
+ else:
462
+ best_pos = int(np.argmax(s_rev))
463
+ best_score = float(r_max)
464
+ add_site(seq_idx, seq, best_pos, 1, best_score)
465
+
466
+ else: # mode == "threshold"
467
+ # Forward strand
468
+ f_pos = np.where(s_fwd >= score_threshold)[0]
469
+ for pos in f_pos:
470
+ add_site(seq_idx, seq, int(pos), 0, float(s_fwd[pos]))
471
+
472
+ # Reverse strand
473
+ r_pos = np.where(s_rev >= score_threshold)[0]
474
+ for pos in r_pos:
475
+ add_site(seq_idx, seq, int(pos), 1, float(s_rev[pos]))
476
+
477
+ # Create and sort DataFrame
478
+ df = pd.DataFrame(results)
479
+ if len(df) > 0:
480
+ df = df.sort_values(["seq_index", "score"], ascending=[True, False]).reset_index(drop=True)
481
+
482
+ logger = logging.getLogger(__name__)
483
+ logger.info(f"Found {len(df)} site(s) in {sequences.num_sequences} sequence(s)")
484
+ return df
485
+
486
+ @staticmethod
487
+ def _int_to_seq(seq_int: np.ndarray) -> str:
488
+ """Convert integer-encoded sequence to ACGT string.
489
+
490
+ Parameters
491
+ ----------
492
+ seq_int : np.ndarray
493
+ Integer-encoded sequence (0=A, 1=C, 2=G, 3=T, 4=N).
494
+
495
+ Returns
496
+ -------
497
+ str
498
+ Sequence as string.
499
+ """
500
+ decoder = np.array(["A", "C", "G", "T", "N"], dtype="U1")
501
+ safe_seq = np.clip(seq_int, 0, 4)
502
+ return "".join(decoder[safe_seq])
503
+
504
+ @staticmethod
505
+ def _get_rc_sequence(seq_int: np.ndarray) -> np.ndarray:
506
+ """Return reverse complement of sequence.
507
+
508
+ Parameters
509
+ ----------
510
+ seq_int : np.ndarray
511
+ Integer-encoded sequence.
512
+
513
+ Returns
514
+ -------
515
+ np.ndarray
516
+ Reverse complement of sequence.
517
+ """
518
+ RC_TABLE = np.array([3, 2, 1, 0, 4], dtype=np.int8)
519
+ return RC_TABLE[seq_int[::-1]]
520
+
521
+ def write_pfm(self, path: str) -> None:
522
+ """Write the motif to a PFM formatted file.
523
+
524
+ Parameters
525
+ ----------
526
+ path : str
527
+ Path of the output file.
528
+ """
529
+ if self.pfm is not None:
530
+ with open(path, "w") as fname:
531
+ header = f">{self.name}"
532
+ np.savetxt(
533
+ fname,
534
+ self.pfm[:4, :].T,
535
+ fmt="%.8f",
536
+ delimiter="\t",
537
+ newline="\n",
538
+ header=header,
539
+ footer="",
540
+ comments="",
541
+ encoding=None,
542
+ )
543
+
544
+ def write_dist(self, path: str) -> None:
545
+ """Write the threshold table of motif to a DIST formatted file.
546
+
547
+ Parameters
548
+ ----------
549
+ path : str
550
+ Path of the output file.
551
+ """
552
+ table = self.threshold_table
553
+ if table is None:
554
+ logger = logging.getLogger(__name__)
555
+ logger.error(f"Cannot write DIST file: threshold table not computed for {self.name}")
556
+ return
557
+
558
+ table = copy.deepcopy(table)
559
+ max_score = self.matrix.max(axis=0).sum()
560
+ min_score = self.matrix.min(axis=0).sum()
561
+
562
+ table[:, 0] = (table[:, 0] - min_score) / (max_score - min_score)
563
+ with open(path, "w") as fname:
564
+ np.savetxt(fname, table, fmt="%.18f", delimiter="\t", newline="\n", footer="", comments="", encoding=None)
565
+
566
+ @property
567
+ def pfm(self) -> Optional[np.ndarray]:
568
+ """
569
+ Lazy property for position frequency matrix.
570
+
571
+ Returns
572
+ -------
573
+ Optional[np.ndarray]
574
+ Cached PFM if available, otherwise None.
575
+ Use get_pfm() to compute and cache.
576
+ """
577
+ if self._pfm is None:
578
+ logger = logging.getLogger(__name__)
579
+ logger.warning(f"PFM not computed for model: {self.name}. Use get_pfm(sequences) to compute.")
580
+ return self._pfm
581
+
582
+ def get_pfm(
583
+ self,
584
+ sequences: RaggedData,
585
+ mode: str = "best",
586
+ fpr_threshold: Optional[float] = None,
587
+ top_fraction: Optional[float] = None,
588
+ pseudocount: float = 0.25,
589
+ force_recompute: bool = False,
590
+ ) -> np.ndarray:
591
+ """Construct Position Frequency Matrix (PFM) from binding sites.
592
+
593
+ Result is cached in the _pfm attribute for reuse.
594
+
595
+ Parameters
596
+ ----------
597
+ sequences : RaggedData
598
+ Encoded sequences to extract binding sites from.
599
+ mode : {"best", "threshold"}, optional
600
+ Site finding mode (default "best").
601
+ fpr_threshold : float, optional
602
+ Frequency threshold for "threshold" mode.
603
+ top_fraction : float, optional
604
+ Selects only top N% of sites by score.
605
+ pseudocount : float, optional
606
+ Pseudocount for smoothing (default 0.25).
607
+ force_recompute : bool, optional
608
+ If True, recomputes PFM even if it's already cached.
609
+
610
+ Returns
611
+ -------
612
+ np.ndarray
613
+ Normalized PFM of shape (4, motif_length).
614
+ """
615
+ # Return cached version if available
616
+ if self._pfm is not None and not force_recompute:
617
+ logger = logging.getLogger(__name__)
618
+ logger.info(f"Returning cached PFM for model: {self.name}")
619
+ return self._pfm
620
+
621
+ logger = logging.getLogger(__name__)
622
+ logger.info(f"Computing PFM for model: {self.name}")
623
+
624
+ # Get sites
625
+ sites_df = self.get_sites(sequences, mode=mode, fpr_threshold=fpr_threshold)
626
+ if len(sites_df) == 0:
627
+ raise ValueError("No sites found")
628
+
629
+ sites_df = sites_df.sort_values(by=["score"], axis=0, ascending=False)
630
+
631
+ # Select top N% if specified
632
+ if top_fraction is not None:
633
+ n_keep = max(1, int(len(sites_df) * top_fraction))
634
+ sites_df = sites_df.nlargest(n_keep, "score")
635
+ logger = logging.getLogger(__name__)
636
+ logger.info(f"Selected top {top_fraction * 100:.1f}%: {n_keep} sites")
637
+
638
+ # Initialize PFM with pseudocounts
639
+ pfm = np.full((4, self.length), pseudocount, dtype=np.float32)
640
+ nuc_map = {"A": 0, "C": 1, "G": 2, "T": 3}
641
+
642
+ # Fill counters
643
+ for site_str in sites_df["site"]:
644
+ for pos, nuc in enumerate(site_str):
645
+ if nuc in nuc_map:
646
+ pfm[nuc_map[nuc], pos] += 1.0
647
+
648
+ # Normalize to probabilities
649
+ pfm = pfm / pfm.sum(axis=0, keepdims=True)
650
+
651
+ # Cache the result
652
+ self._pfm = pfm
653
+
654
+ return pfm
655
+
656
+ @staticmethod
657
+ def _reduce_strand(seq_scores: np.ndarray, strand: StrandMode) -> np.ndarray:
658
+ """Transform (2, N) → (N,) depending on strand mode.
659
+
660
+ Parameters
661
+ ----------
662
+ seq_scores : np.ndarray
663
+ Scores with shape (2, N_positions).
664
+ strand : {"best", "+", "-"}
665
+
666
+ Returns
667
+ -------
668
+ np.ndarray
669
+ One-dimensional scores per position (N_positions,).
670
+ """
671
+ if strand == "best":
672
+ # maximum across both strands for each position
673
+ return np.max(seq_scores, axis=0)
674
+ if strand == "+":
675
+ return seq_scores[0]
676
+ if strand == "-":
677
+ return seq_scores[1]
678
+ raise ValueError(f"Unknown strand={strand!r}. Use '+', '-', or 'best'.")
679
+
680
+ def get_scores(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
681
+ """Calculate motif scores for each position in the sequences using batch processing.
682
+
683
+ Parameters
684
+ ----------
685
+ sequences : RaggedData
686
+ Encoded sequences (int8).
687
+ strand : {"best", "+", "-", "both"}, optional
688
+ Strand to score. Default is self.strand_mode.
689
+
690
+ Returns
691
+ -------
692
+ RaggedData
693
+ Ragged array of scores (float32).
694
+ """
695
+ return self.scan(sequences, strand=strand)
696
+
697
+ def get_frequencies(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
698
+ """Calculate per-position hit frequencies (probability maps) using batch processing.
699
+
700
+ Parameters
701
+ ----------
702
+ sequences : RaggedData
703
+ Encoded nucleotide sequences.
704
+ strand : {"best", "+", "-", "both"}, optional
705
+ Strand to evaluate.
706
+
707
+ Returns
708
+ -------
709
+ RaggedData
710
+ Per-position hit frequencies.
711
+ """
712
+
713
+ return scores_to_frequencies(self.scan(sequences, strand))
714
+
715
+ def clear_cache(self) -> None:
716
+ """Clear all cached scores, frequencies and normalization range."""
717
+ self._freq_cache.clear()
718
+ self._threshold_table = None
719
+ if not getattr(self, "_pfm_is_required", False):
720
+ self._pfm = None
721
+
722
+ def save(self, filepath: str, clear_cache: bool = True) -> None:
723
+ """Save the motif model to a file using joblib.
724
+
725
+ Parameters
726
+ ----------
727
+ filepath : str
728
+ Full path to the output file (e.g., 'motifs/M1.pkl').
729
+ clear_cache : bool
730
+ If True, clears caches before saving to minimize file size.
731
+ """
732
+ if clear_cache:
733
+ self.clear_cache()
734
+
735
+ # Create directory if it doesn't exist
736
+ os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)
737
+
738
+ joblib.dump(self, filepath)
739
+
740
+ @staticmethod
741
+ def load(filepath: str) -> MotifModel:
742
+ """Load a motif model from a .pkl file.
743
+
744
+ Parameters
745
+ ----------
746
+ filepath : str
747
+ Path to the .pkl file.
748
+
749
+ Returns
750
+ -------
751
+ MotifModel
752
+ Loaded motif model instance.
753
+ """
754
+ if not os.path.exists(filepath):
755
+ raise FileNotFoundError(f"Motif file not found: {filepath}")
756
+
757
+ model = joblib.load(filepath)
758
+ return model
759
+
760
+ @property
761
+ def model_type(self) -> str:
762
+ """Abstract property to return the type of model."""
763
+ raise NotImplementedError("Subclasses must implement model_type property")
764
+
765
+ def write(self, path: str) -> None:
766
+ """Abstract method to write the motif to a file in its native format."""
767
+ raise NotImplementedError("Subclasses must implement write() method")
768
+
769
+
770
+ class PwmMotif(MotifModel):
771
+ """Position Weight Matrix motif model.
772
+
773
+ This class wraps a PWM (log odds scores) and provides efficient
774
+ scoring using the precompiled :func:`yamcot.functions.batch_all_scores`
775
+ Numba kernel. The input matrix is expected to have shape (5, L)
776
+ where the final row contains column minima. The scoring logic
777
+ treats any encoded nucleotide equal to 4 as an ambiguous 'N' and
778
+ assigns a zero contribution.
779
+
780
+ Parameters
781
+ ----------
782
+ matrix : np.ndarray
783
+ PWM matrix with shape (5, L) where the 5th row contains column minima.
784
+ name : str
785
+ Name of the motif.
786
+ length : int
787
+ Length of the motif.
788
+ pfm : np.ndarray
789
+ Position Frequency Matrix of shape (4, L). Required attribute for PWM.
790
+ kmer : int, optional
791
+ Size of k-mer for scanning (default is 1).
792
+ """
793
+
794
+ def __init__(
795
+ self,
796
+ matrix: np.ndarray,
797
+ name: str,
798
+ length: int,
799
+ pfm: np.ndarray,
800
+ kmer: int = 1,
801
+ ) -> None:
802
+ super().__init__(
803
+ matrix=matrix,
804
+ name=name,
805
+ length=length,
806
+ )
807
+ self.kmer = kmer
808
+ # Assign to private attribute so property works correctly
809
+ self._pfm = pfm
810
+ self._pfm_is_required = True
811
+
812
+ def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
813
+ """Score sequences with the PWM.
814
+
815
+ Parameters
816
+ ----------
817
+ sequences : RaggedData
818
+ Encoded nucleotide sequences (int8).
819
+ strand : {"best", "+", "-", "both"}, optional
820
+ Strand selection mode. If None, uses self.strand_mode.
821
+
822
+ Returns
823
+ -------
824
+ RaggedData
825
+ Scanning results with scores for each position.
826
+ """
827
+ strand = strand or self.strand_mode
828
+ matrix = self.matrix.astype(np.float32)
829
+
830
+ if strand == "+":
831
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
832
+ elif strand == "-":
833
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
834
+ elif strand == "best":
835
+ sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
836
+ sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
837
+ return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
838
+ else:
839
+ logger = logging.getLogger(__name__)
840
+ logger.error(f"Unknown strand mode: {strand}")
841
+ sys.exit(1)
842
+
843
+ @classmethod
844
+ def from_file(cls, path: str, index: int = 0, **kwargs) -> PwmMotif:
845
+ """Create a PwmMotif from a file.
846
+
847
+ Supports MEME, ProSampler, and PFM formats.
848
+
849
+ Parameters
850
+ ----------
851
+ path : str
852
+ Path to the motif file.
853
+ index : int, optional
854
+ Index of motif to read from file if multiple motifs are present (default 0).
855
+ **kwargs : dict
856
+ Additional arguments.
857
+
858
+ Returns
859
+ -------
860
+ PwmMotif
861
+ A PwmMotif object created from the file.
862
+ """
863
+ # Determine file format based on extension
864
+ _, ext = os.path.splitext(path.lower())
865
+
866
+ if ext == ".pkl":
867
+ return joblib.load(path)
868
+ elif ext == ".meme":
869
+ # MEME format
870
+ matrix, info, _number_of_motifs = read_meme(path, index=index)
871
+ pfm = matrix
872
+ name, length = info
873
+ elif ext == ".pfm":
874
+ # PFM format
875
+ pwm, length, _minimum, _maximum = read_pfm(path)
876
+ # Extract PFM from PWM (first 4 rows, excluding the 5th row of minimums)
877
+ pfm = pwm[:4, :] # Get the original PFM from the extended PWM
878
+ name = os.path.splitext(os.path.basename(path))[0]
879
+ else:
880
+ logger = logging.getLogger(__name__)
881
+ logger.error(f"Wrong format pf PWM model: {path}")
882
+ sys.exit(1)
883
+
884
+ # Convert PFM to PWM
885
+ pwm = pfm_to_pwm(pfm)
886
+ # Add the 5th row for 'N' characters (minimum values at each position)
887
+ pwm_ext = np.concatenate((pwm, np.min(pwm, axis=0, keepdims=True)), axis=0)
888
+ return cls(matrix=pwm_ext, name=name, length=int(length), pfm=pfm)
889
+
890
+ @property
891
+ def model_type(self) -> str:
892
+ """Return the type of model ('pwm')."""
893
+ return "pwm"
894
+
895
+ def write(self, path: str) -> None:
896
+ """Write the PWM motif to a file in PFM format."""
897
+ # Create directory if it doesn't exist
898
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
899
+ self.write_pfm(path)
900
+
901
+
902
+ class BammMotif(MotifModel):
903
+ """Bayesian Markov Model motif.
904
+
905
+ BaMMs extend PWMs by modelling dependencies between neighbouring
906
+ nucleotides. The representation and scanning procedure differ
907
+ substantially from PWMs and require specialised software. This
908
+ class provides an interface compatible with the :class:`MotifModel`
909
+ API but does not implement the scoring logic. Users must supply
910
+ an external scoring function via :meth:`scan`.
911
+
912
+ Parameters
913
+ ----------
914
+ matrix : np.ndarray
915
+ BaMM matrix representation.
916
+ name : str
917
+ Name of the motif.
918
+ length : int
919
+ Length of the motif.
920
+ kmer : int, optional
921
+ Size of k-mer for scanning (default is 2).
922
+ """
923
+
924
+ def __init__(
925
+ self,
926
+ matrix: np.ndarray,
927
+ name: str,
928
+ length: int,
929
+ kmer: int = 2,
930
+ ) -> None:
931
+ super().__init__(
932
+ matrix=matrix,
933
+ name=name,
934
+ length=length,
935
+ )
936
+ self.kmer = kmer
937
+
938
+ def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
939
+ """Score sequences with the BaMM.
940
+
941
+ Parameters
942
+ ----------
943
+ sequences : RaggedData
944
+ Encoded nucleotide sequences (int8).
945
+ strand : {"best", "+", "-"}, optional
946
+ Strand selection mode. If None, uses self.strand_mode.
947
+
948
+ Returns
949
+ -------
950
+ RaggedData
951
+ Scanning results with scores for each position.
952
+ """
953
+ strand = strand or self.strand_mode
954
+
955
+ matrix = self.matrix.astype(np.float32)
956
+
957
+ if strand == "+":
958
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False, with_context=True)
959
+ elif strand == "-":
960
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True, with_context=True)
961
+ elif strand == "best":
962
+ sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False, with_context=True)
963
+ sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True, with_context=True)
964
+ return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
965
+ else:
966
+ logger = logging.getLogger(__name__)
967
+ logger.error(f"Unknown strand mode: {strand}")
968
+ sys.exit(1)
969
+
970
+ @classmethod
971
+ def from_file(cls, path: str, bg_path: str | None = None, order: int = 2, **kwargs) -> "BammMotif":
972
+ """Create a BammMotif from a file.
973
+
974
+ Parameters
975
+ ----------
976
+ path : str
977
+ Path to the BaMM motif file (.ihbcp format or base path).
978
+ bg_path : str, optional
979
+ Path to the background BaMM file (.hbcp format). If not provided,
980
+ it attempts to find a background file in the same directory.
981
+ order : int, optional
982
+ Order of the BaMM model (default is 2).
983
+ **kwargs : dict
984
+ Additional arguments.
985
+
986
+ Returns
987
+ -------
988
+ BammMotif
989
+ A BammMotif object created from the file.
990
+ """
991
+ # Handle case where path is provided without extension (as in pipeline.py)
992
+ if not path.endswith(".ihbcp") and not os.path.exists(path):
993
+ ihbcp_path = f"{path}.ihbcp"
994
+ hbcp_path = f"{path}.hbcp"
995
+ if os.path.exists(ihbcp_path):
996
+ path = ihbcp_path
997
+ if bg_path is None and os.path.exists(hbcp_path):
998
+ bg_path = hbcp_path
999
+
1000
+ # If no background path is provided, try to find it
1001
+ if bg_path is None:
1002
+ # Look for background file in the same directory
1003
+ dir_path = os.path.dirname(path)
1004
+ basename = os.path.basename(path)
1005
+ # Try to find background file by replacing extension or using common names
1006
+ possible_bg_names = ["bamm.hbcp", "background.hbcp", basename.replace(".ihbcp", ".hbcp")]
1007
+ for bg_name in possible_bg_names:
1008
+ possible_bg_path = os.path.join(dir_path, bg_name)
1009
+ if os.path.exists(possible_bg_path):
1010
+ bg_path = possible_bg_path
1011
+ break
1012
+
1013
+ if bg_path is None:
1014
+ raise ValueError(f"Background file not found for {path}. Please provide bg_path parameter.")
1015
+
1016
+ _, max_order, length = parse_file_content(path)
1017
+ if order > max_order:
1018
+ order = max_order
1019
+
1020
+ # Read the BaMM motif and background files
1021
+ matrix = read_bamm(path, bg_path, order)
1022
+ name = os.path.splitext(os.path.basename(path))[0]
1023
+ return cls(
1024
+ matrix=matrix,
1025
+ name=name,
1026
+ length=length,
1027
+ kmer=order + 1, # BaMM kmer is typically order + 1
1028
+ )
1029
+
1030
+ @property
1031
+ def model_type(self) -> str:
1032
+ """Return the type of model ('bamm')."""
1033
+ return "bamm"
1034
+
1035
+ def write(self, path: str) -> None:
1036
+ """Write the BaMM motif to a file."""
1037
+ # BaMM writing functionality would go here
1038
+ # For now, we'll implement a basic version that saves as joblib
1039
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
1040
+ self.save(path)
1041
+
1042
+
1043
+ class SitegaMotif(MotifModel):
1044
+ """SiteGA motif model.
1045
+
1046
+ SiteGA motifs are based on a linear probabilistic grammar and
1047
+ represent dinucleotide dependencies across multiple positions. The
1048
+ scoring logic can be derived from the original SiteGA
1049
+ implementation but is non‑trivial. This class acts as a
1050
+ placeholder for future SiteGA integration. A custom scanning
1051
+ function can be provided via :meth:`scan`.
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ matrix : np.ndarray
1056
+ SiteGA matrix representation.
1057
+ name : str
1058
+ Name of the motif.
1059
+ length : int
1060
+ Length of the motif.
1061
+ kmer : int, optional
1062
+ Size of k-mer for scanning (default is 2).
1063
+ """
1064
+
1065
+ def __init__(
1066
+ self,
1067
+ matrix: np.ndarray,
1068
+ name: str,
1069
+ length: int,
1070
+ minimum: float = 0.0,
1071
+ maximum: float = 0.0,
1072
+ kmer: int = 2,
1073
+ ) -> None:
1074
+ super().__init__(
1075
+ matrix=matrix,
1076
+ name=name,
1077
+ length=length,
1078
+ )
1079
+ self.minimum = minimum
1080
+ self.maximum = maximum
1081
+ self.kmer = kmer
1082
+
1083
+ def scan(self, sequences: RaggedData, strand: Optional[StrandMode] = None) -> RaggedData:
1084
+ """Score sequences with SiteGA model.
1085
+
1086
+ Parameters
1087
+ ----------
1088
+ sequences : RaggedData
1089
+ Encoded nucleotide sequences (int8).
1090
+ strand : {"best", "+", "-"}, optional
1091
+ Strand selection mode. If None, uses self.strand_mode.
1092
+
1093
+ Returns
1094
+ -------
1095
+ RaggedData
1096
+ Scanning results with scores for each position.
1097
+ """
1098
+ strand = strand or self.strand_mode
1099
+ matrix = self.matrix.astype(np.float32)
1100
+
1101
+ if strand == "+":
1102
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
1103
+ elif strand == "-":
1104
+ return batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
1105
+ elif strand == "best":
1106
+ sf = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=False)
1107
+ sr = batch_all_scores(sequences, matrix, kmer=self.kmer, is_revcomp=True)
1108
+ return RaggedData(np.maximum(sf.data, sr.data), sf.offsets)
1109
+ else:
1110
+ logger = logging.getLogger(__name__)
1111
+ logger.error(f"Unknown strand mode: {strand}")
1112
+ sys.exit(1)
1113
+
1114
+ @classmethod
1115
+ def from_file(cls, path: str, **kwargs) -> SitegaMotif:
1116
+ """Parse SiteGA output file to create a SitegaMotif object.
1117
+
1118
+ Parameters
1119
+ ----------
1120
+ path : str
1121
+ Path to the SiteGA output file (typically ends with '.mat').
1122
+ **kwargs : dict
1123
+ Additional arguments.
1124
+
1125
+ Returns
1126
+ -------
1127
+ SitegaMotif
1128
+ A SitegaMotif object created from the parsed data.
1129
+ """
1130
+ # Parse the SiteGA output file
1131
+ matrix, length, minimum, maximum = read_sitega(path)
1132
+
1133
+ # Extract motif name from the file path
1134
+ name = os.path.splitext(os.path.basename(path))[0]
1135
+
1136
+ # Create and return the SitegaMotif instance
1137
+ return cls(
1138
+ matrix=matrix,
1139
+ name=name,
1140
+ length=length,
1141
+ minimum=minimum,
1142
+ maximum=maximum,
1143
+ kmer=2, # SiteGA typically uses dinucleotide dependencies
1144
+ )
1145
+
1146
+ @property
1147
+ def model_type(self) -> str:
1148
+ """Return the type of model ('sitega')."""
1149
+ return "sitega"
1150
+
1151
+ def write(self, path: str) -> None:
1152
+ """Write the SiteGA motif to a file in .mat format."""
1153
+ # Create directory if it doesn't exist
1154
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
1155
+ write_sitega(self, path)
1156
+
1157
+
1158
+ # Register subclasses for polymorphic loading
1159
+ MotifModel.register_subclass("pwm", PwmMotif)
1160
+ MotifModel.register_subclass("bamm", BammMotif)
1161
+ MotifModel.register_subclass("sitega", SitegaMotif)