pyseqalignment 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyseqalign/__init__.py +14 -0
- pyseqalign/core/__init__.py +12 -0
- pyseqalign/core/alignment.py +67 -0
- pyseqalign/core/needleman_wunsch.py +122 -0
- pyseqalign/core/smith_waterman.py +173 -0
- pyseqalign/learning/__init__.py +20 -0
- pyseqalign/learning/aleph.py +212 -0
- pyseqalign/learning/aleph_files/__init__.py +0 -0
- pyseqalign/learning/aleph_files/aleph_swi_ak.pl +10420 -0
- pyseqalign/learning/base.py +68 -0
- pyseqalign/learning/popper.py +215 -0
- pyseqalign/learning/task_builder.py +213 -0
- pyseqalign/prolog/__init__.py +5 -0
- pyseqalign/prolog/engine.py +102 -0
- pyseqalign/prolog/knowledge/__init__.py +0 -0
- pyseqalign/prolog/knowledge/amino_acids.pl +53 -0
- pyseqalign/prolog/knowledge/blosum50.pl +800 -0
- pyseqalign/prolog/knowledge/defaults.pl +15 -0
- pyseqalign/prolog/knowledge/distances.pl +119 -0
- pyseqalign/scoring/__init__.py +11 -0
- pyseqalign/scoring/distance.py +100 -0
- pyseqalign/scoring/matrices.py +362 -0
- pyseqalign/scoring/matrix_data/BLOSUM100 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM50 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM60 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM62 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM70 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM80 +31 -0
- pyseqalign/scoring/matrix_data/BLOSUM90 +31 -0
- pyseqalign/scoring/matrix_data/PAM150 +34 -0
- pyseqalign/scoring/matrix_data/PAM200 +34 -0
- pyseqalign/scoring/matrix_data/PAM250 +34 -0
- pyseqalign/scoring/matrix_data/PAM50 +34 -0
- pyseqalign/scoring/matrix_data/__init__.py +0 -0
- pyseqalign/utils/__init__.py +9 -0
- pyseqalign/utils/helpers.py +47 -0
- pyseqalignment-0.1.0.dist-info/METADATA +317 -0
- pyseqalignment-0.1.0.dist-info/RECORD +41 -0
- pyseqalignment-0.1.0.dist-info/WHEEL +5 -0
- pyseqalignment-0.1.0.dist-info/licenses/LICENSE +21 -0
- pyseqalignment-0.1.0.dist-info/top_level.txt +1 -0
pyseqalign/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""pySeqAlign -- Sequence alignment with Prolog-style distance functions and ILP learning."""
|
|
2
|
+
|
|
3
|
+
from pyseqalign.core.alignment import AlignmentResult, LocalAlignmentResult
|
|
4
|
+
from pyseqalign.core.needleman_wunsch import NeedlemanWunsch
|
|
5
|
+
from pyseqalign.core.smith_waterman import SmithWaterman
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"SmithWaterman",
|
|
11
|
+
"NeedlemanWunsch",
|
|
12
|
+
"AlignmentResult",
|
|
13
|
+
"LocalAlignmentResult",
|
|
14
|
+
]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Core alignment algorithms."""
|
|
2
|
+
|
|
3
|
+
from pyseqalign.core.alignment import AlignmentResult, LocalAlignmentResult
|
|
4
|
+
from pyseqalign.core.needleman_wunsch import NeedlemanWunsch
|
|
5
|
+
from pyseqalign.core.smith_waterman import SmithWaterman
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"SmithWaterman",
|
|
9
|
+
"NeedlemanWunsch",
|
|
10
|
+
"AlignmentResult",
|
|
11
|
+
"LocalAlignmentResult",
|
|
12
|
+
]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Data structures for alignment results."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class AlignmentResult:
|
|
10
|
+
"""Result of a global (Needleman-Wunsch) alignment.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
query: Aligned query sequence (with gaps represented as 0).
|
|
14
|
+
target: Aligned target sequence (with gaps represented as 0).
|
|
15
|
+
score: Alignment score.
|
|
16
|
+
length: Length of the alignment.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
query: list[int]
|
|
20
|
+
target: list[int]
|
|
21
|
+
score: float
|
|
22
|
+
length: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class LocalAlignmentResult:
|
|
27
|
+
"""Result of a single local (Smith-Waterman) alignment.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
query_path: Indices along the query sequence in the alignment.
|
|
31
|
+
target_path: Indices along the target sequence in the alignment.
|
|
32
|
+
start_query: Start position in the query.
|
|
33
|
+
start_target: Start position in the target.
|
|
34
|
+
end_query: End position in the query.
|
|
35
|
+
end_target: End position in the target.
|
|
36
|
+
length: Length of the alignment path.
|
|
37
|
+
score: Alignment score.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
query_path: list[int]
|
|
41
|
+
target_path: list[int]
|
|
42
|
+
start_query: int
|
|
43
|
+
start_target: int
|
|
44
|
+
end_query: int
|
|
45
|
+
end_target: int
|
|
46
|
+
length: int
|
|
47
|
+
score: float
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class KLocalAlignmentResults:
|
|
52
|
+
"""Container for k non-overlapping local alignments.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
alignments: List of local alignment results, sorted by score descending.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
alignments: list[LocalAlignmentResult] = field(default_factory=list)
|
|
59
|
+
|
|
60
|
+
def __len__(self) -> int:
|
|
61
|
+
return len(self.alignments)
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, index: int) -> LocalAlignmentResult:
|
|
64
|
+
return self.alignments[index]
|
|
65
|
+
|
|
66
|
+
def __iter__(self):
|
|
67
|
+
return iter(self.alignments)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Needleman-Wunsch global sequence alignment.
|
|
2
|
+
|
|
3
|
+
Translated from the legacy C implementation in pyAlign.c.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from pyseqalign.core.alignment import AlignmentResult
|
|
9
|
+
from pyseqalign.core.smith_waterman import ScoringFunction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NeedlemanWunsch:
|
|
13
|
+
"""Needleman-Wunsch global alignment.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
scoring: A scoring function implementing the ``ScoringFunction`` protocol.
|
|
17
|
+
gap_penalty: Cost applied when introducing a gap. The scoring function is
|
|
18
|
+
called with element ID ``0`` to represent a gap character.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, scoring: ScoringFunction, gap_penalty: float | None = None) -> None:
|
|
22
|
+
self.scoring = scoring
|
|
23
|
+
self._explicit_gap_penalty = gap_penalty
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def gap_penalty(self) -> float:
|
|
27
|
+
"""Return gap cost -- derived from scoring(0,0) when not set explicitly."""
|
|
28
|
+
if self._explicit_gap_penalty is not None:
|
|
29
|
+
return self._explicit_gap_penalty
|
|
30
|
+
return self.scoring.score(0, 0)
|
|
31
|
+
|
|
32
|
+
def align(self, seq1: list[int], seq2: list[int]) -> AlignmentResult:
|
|
33
|
+
"""Compute the optimal global alignment of *seq1* and *seq2*.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
seq1: First input sequence (list of integer element IDs).
|
|
37
|
+
seq2: Second input sequence.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
An ``AlignmentResult`` with aligned sequences, score, and length.
|
|
41
|
+
"""
|
|
42
|
+
rows = len(seq1) + 1
|
|
43
|
+
cols = len(seq2) + 1
|
|
44
|
+
|
|
45
|
+
gap = self.gap_penalty
|
|
46
|
+
|
|
47
|
+
# Initialise F-matrix.
|
|
48
|
+
f_matrix = [[0.0] * cols for _ in range(rows)]
|
|
49
|
+
tb_matrix = [[-1.0] * cols for _ in range(rows)]
|
|
50
|
+
|
|
51
|
+
# Fill border gaps.
|
|
52
|
+
for i in range(1, rows):
|
|
53
|
+
f_matrix[i][0] = gap * i
|
|
54
|
+
for j in range(1, cols):
|
|
55
|
+
f_matrix[0][j] = gap * j
|
|
56
|
+
|
|
57
|
+
# Fill matrices.
|
|
58
|
+
for i in range(1, rows):
|
|
59
|
+
for j in range(1, cols):
|
|
60
|
+
match = f_matrix[i - 1][j - 1] + self.scoring.score(seq1[i - 1], seq2[j - 1])
|
|
61
|
+
delete = f_matrix[i - 1][j] + self.scoring.score(seq1[i - 1], 0)
|
|
62
|
+
insert = f_matrix[i][j - 1] + self.scoring.score(0, seq2[j - 1])
|
|
63
|
+
|
|
64
|
+
choices = [match, delete, insert]
|
|
65
|
+
best = _argmax(choices)
|
|
66
|
+
|
|
67
|
+
f_matrix[i][j] = choices[best]
|
|
68
|
+
tb_matrix[i][j] = float(best)
|
|
69
|
+
|
|
70
|
+
score = f_matrix[rows - 1][cols - 1]
|
|
71
|
+
|
|
72
|
+
# Traceback.
|
|
73
|
+
align1: list[int] = []
|
|
74
|
+
align2: list[int] = []
|
|
75
|
+
i = rows - 1
|
|
76
|
+
j = cols - 1
|
|
77
|
+
|
|
78
|
+
while i > 0 and j > 0:
|
|
79
|
+
if tb_matrix[i][j] == 0.0:
|
|
80
|
+
# Diagonal -- match/mismatch.
|
|
81
|
+
i -= 1
|
|
82
|
+
j -= 1
|
|
83
|
+
align1.append(seq1[i])
|
|
84
|
+
align2.append(seq2[j])
|
|
85
|
+
elif tb_matrix[i][j] == 1.0:
|
|
86
|
+
# Up -- gap in seq2.
|
|
87
|
+
i -= 1
|
|
88
|
+
align1.append(seq1[i])
|
|
89
|
+
align2.append(0)
|
|
90
|
+
else:
|
|
91
|
+
# Left -- gap in seq1.
|
|
92
|
+
j -= 1
|
|
93
|
+
align1.append(0)
|
|
94
|
+
align2.append(seq2[j])
|
|
95
|
+
|
|
96
|
+
while i > 0:
|
|
97
|
+
i -= 1
|
|
98
|
+
align1.append(seq1[i])
|
|
99
|
+
align2.append(0)
|
|
100
|
+
|
|
101
|
+
while j > 0:
|
|
102
|
+
j -= 1
|
|
103
|
+
align1.append(0)
|
|
104
|
+
align2.append(seq2[j])
|
|
105
|
+
|
|
106
|
+
align1.reverse()
|
|
107
|
+
align2.reverse()
|
|
108
|
+
|
|
109
|
+
return AlignmentResult(
|
|
110
|
+
query=align1,
|
|
111
|
+
target=align2,
|
|
112
|
+
score=score,
|
|
113
|
+
length=len(align1),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _argmax(values: list[float]) -> int:
|
|
118
|
+
best = 0
|
|
119
|
+
for i in range(1, len(values)):
|
|
120
|
+
if values[i] > values[best]:
|
|
121
|
+
best = i
|
|
122
|
+
return best
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Smith-Waterman local sequence alignment.
|
|
2
|
+
|
|
3
|
+
Translated from the legacy C implementation in swAlign.c.
|
|
4
|
+
Computes the k best non-overlapping local alignments between two sequences.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Protocol
|
|
10
|
+
|
|
11
|
+
from pyseqalign.core.alignment import KLocalAlignmentResults, LocalAlignmentResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ScoringFunction(Protocol):
|
|
15
|
+
"""Protocol for scoring/distance functions used by alignment algorithms."""
|
|
16
|
+
|
|
17
|
+
def score(self, a: int, b: int) -> float:
|
|
18
|
+
"""Return the similarity score between elements *a* and *b*."""
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SmithWaterman:
|
|
23
|
+
"""Smith-Waterman local alignment.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
scoring: A scoring function implementing the ``ScoringFunction`` protocol.
|
|
27
|
+
gap_penalty: Cost applied when introducing a gap (should be positive;
|
|
28
|
+
it is subtracted internally).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, scoring: ScoringFunction, gap_penalty: float = 8.0) -> None:
|
|
32
|
+
self.scoring = scoring
|
|
33
|
+
self.gap_penalty = gap_penalty
|
|
34
|
+
|
|
35
|
+
# ------------------------------------------------------------------
|
|
36
|
+
# Public API
|
|
37
|
+
# ------------------------------------------------------------------
|
|
38
|
+
|
|
39
|
+
def align(
|
|
40
|
+
self,
|
|
41
|
+
seq1: list[int],
|
|
42
|
+
seq2: list[int],
|
|
43
|
+
k: int = 1,
|
|
44
|
+
cutoff: float = 0.0,
|
|
45
|
+
min_score: float = 2.0,
|
|
46
|
+
) -> KLocalAlignmentResults:
|
|
47
|
+
"""Compute up to *k* best non-overlapping local alignments.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
seq1: First input sequence (list of integer element IDs).
|
|
51
|
+
seq2: Second input sequence.
|
|
52
|
+
k: Maximum number of non-overlapping alignments to return.
|
|
53
|
+
cutoff: Minimum cell value to keep in the F-matrix (default 0 for SW).
|
|
54
|
+
min_score: Cells with score above this are considered trace start candidates.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A ``KLocalAlignmentResults`` containing up to *k* alignments sorted
|
|
58
|
+
by score descending.
|
|
59
|
+
"""
|
|
60
|
+
if k == 0:
|
|
61
|
+
return KLocalAlignmentResults()
|
|
62
|
+
|
|
63
|
+
rows = len(seq1) + 1
|
|
64
|
+
cols = len(seq2) + 1
|
|
65
|
+
|
|
66
|
+
# Initialise F-matrix and traceback matrix.
|
|
67
|
+
f_matrix = [[0.0] * cols for _ in range(rows)]
|
|
68
|
+
traceback = [[(-10, -10)] * cols for _ in range(rows)]
|
|
69
|
+
|
|
70
|
+
# Fill the matrices and collect high-scoring cells.
|
|
71
|
+
max_traces: list[tuple[int, int]] = []
|
|
72
|
+
d = self.gap_penalty
|
|
73
|
+
|
|
74
|
+
for i in range(1, rows):
|
|
75
|
+
for j in range(1, cols):
|
|
76
|
+
match = f_matrix[i - 1][j - 1] + self.scoring.score(seq1[i - 1], seq2[j - 1])
|
|
77
|
+
delete = f_matrix[i - 1][j] - d
|
|
78
|
+
insert = f_matrix[i][j - 1] - d
|
|
79
|
+
|
|
80
|
+
choices = [cutoff, match, delete, insert]
|
|
81
|
+
best_idx = _argmax(choices)
|
|
82
|
+
f_matrix[i][j] = choices[best_idx]
|
|
83
|
+
|
|
84
|
+
if best_idx == 1:
|
|
85
|
+
traceback[i][j] = (i - 1, j - 1)
|
|
86
|
+
elif best_idx == 2:
|
|
87
|
+
traceback[i][j] = (i - 1, j)
|
|
88
|
+
elif best_idx == 3:
|
|
89
|
+
traceback[i][j] = (i, j - 1)
|
|
90
|
+
|
|
91
|
+
if choices[best_idx] > min_score:
|
|
92
|
+
max_traces.append((i, j))
|
|
93
|
+
|
|
94
|
+
# Generate all candidate traces (sorted by score descending).
|
|
95
|
+
candidates = self._generate_traces(f_matrix, traceback, max_traces, rows, cols)
|
|
96
|
+
|
|
97
|
+
# Select up to k non-overlapping alignments.
|
|
98
|
+
selected: list[LocalAlignmentResult] = []
|
|
99
|
+
for candidate in candidates:
|
|
100
|
+
if len(selected) >= k:
|
|
101
|
+
break
|
|
102
|
+
if not any(self._overlaps(s, candidate) for s in selected):
|
|
103
|
+
selected.append(candidate)
|
|
104
|
+
|
|
105
|
+
return KLocalAlignmentResults(alignments=selected)
|
|
106
|
+
|
|
107
|
+
# ------------------------------------------------------------------
|
|
108
|
+
# Internal helpers
|
|
109
|
+
# ------------------------------------------------------------------
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _generate_traces(
|
|
113
|
+
f_matrix: list[list[float]],
|
|
114
|
+
traceback: list[list[tuple[int, int]]],
|
|
115
|
+
max_traces: list[tuple[int, int]],
|
|
116
|
+
rows: int,
|
|
117
|
+
cols: int,
|
|
118
|
+
) -> list[LocalAlignmentResult]:
|
|
119
|
+
"""Traceback from each high-scoring cell to produce alignment candidates."""
|
|
120
|
+
results: list[LocalAlignmentResult] = []
|
|
121
|
+
|
|
122
|
+
for end_i, end_j in max_traces:
|
|
123
|
+
path_a: list[int] = []
|
|
124
|
+
path_b: list[int] = []
|
|
125
|
+
score = 0.0
|
|
126
|
+
|
|
127
|
+
ci, cj = end_i, end_j
|
|
128
|
+
while traceback[ci][cj] != (-10, -10):
|
|
129
|
+
path_a.append(ci)
|
|
130
|
+
path_b.append(cj)
|
|
131
|
+
score += f_matrix[ci][cj]
|
|
132
|
+
ci, cj = traceback[ci][cj]
|
|
133
|
+
|
|
134
|
+
path_a.append(ci)
|
|
135
|
+
path_b.append(cj)
|
|
136
|
+
|
|
137
|
+
# Reverse to get start-to-end order.
|
|
138
|
+
path_a.reverse()
|
|
139
|
+
path_b.reverse()
|
|
140
|
+
|
|
141
|
+
length = len(path_a)
|
|
142
|
+
results.append(
|
|
143
|
+
LocalAlignmentResult(
|
|
144
|
+
query_path=path_a,
|
|
145
|
+
target_path=path_b,
|
|
146
|
+
start_query=ci,
|
|
147
|
+
start_target=cj,
|
|
148
|
+
end_query=end_i,
|
|
149
|
+
end_target=end_j,
|
|
150
|
+
length=length,
|
|
151
|
+
score=score,
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Sort by score descending.
|
|
156
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
157
|
+
return results
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _overlaps(a: LocalAlignmentResult, b: LocalAlignmentResult) -> bool:
|
|
161
|
+
"""Check whether two local alignments share any (i, j) cell."""
|
|
162
|
+
cells_a = set(zip(a.query_path, a.target_path))
|
|
163
|
+
cells_b = set(zip(b.query_path, b.target_path))
|
|
164
|
+
return bool(cells_a & cells_b)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _argmax(values: list[float]) -> int:
|
|
168
|
+
"""Return the index of the maximum value."""
|
|
169
|
+
best = 0
|
|
170
|
+
for i in range(1, len(values)):
|
|
171
|
+
if values[i] > values[best]:
|
|
172
|
+
best = i
|
|
173
|
+
return best
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Inductive Logic Programming (ILP) backends for learning alignment rules.
|
|
2
|
+
|
|
3
|
+
This subpackage provides a common interface for learning scoring functions and
|
|
4
|
+
alignment rules from example alignments. Two backends are supported:
|
|
5
|
+
|
|
6
|
+
- **Aleph** -- the classic ILP system (Srinivasan, 2001) via SWI-Prolog.
|
|
7
|
+
Ported from the legacy pySeqAlign code.
|
|
8
|
+
- **Popper** -- a modern ILP system (Cropper & Morel, 2021) that learns from
|
|
9
|
+
failures using ASP/SAT solvers. Recommended for new projects.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from pyseqalign.learning.base import ILPLearner, ILPTask, LearnedProgram
|
|
13
|
+
from pyseqalign.learning.task_builder import AlignmentTaskBuilder
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ILPTask",
|
|
17
|
+
"LearnedProgram",
|
|
18
|
+
"ILPLearner",
|
|
19
|
+
"AlignmentTaskBuilder",
|
|
20
|
+
]
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Aleph ILP backend.
|
|
2
|
+
|
|
3
|
+
Runs the Aleph ILP system (Srinivasan, 2001) via SWI-Prolog to learn
|
|
4
|
+
Prolog clauses from alignment examples.
|
|
5
|
+
|
|
6
|
+
Requires SWI-Prolog installed on the system and accessible via ``pyswip``
|
|
7
|
+
or the ``swipl`` command.
|
|
8
|
+
|
|
9
|
+
The bundled ``aleph_swi_ak.pl`` file is a SWI-Prolog compatible version
|
|
10
|
+
of Aleph 5, originally ported from YAP Prolog in the legacy pySeqAlign
|
|
11
|
+
codebase.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import subprocess
|
|
17
|
+
import tempfile
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from pyseqalign.learning.base import ILPTask, LearnedProgram
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AlephLearner:
|
|
24
|
+
"""Aleph ILP backend.
|
|
25
|
+
|
|
26
|
+
Uses SWI-Prolog to run Aleph's ``induce/1`` on the provided task.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
aleph_path: Path to the ``aleph_swi_ak.pl`` file. Defaults to the
|
|
30
|
+
bundled version shipped with pyseqalign.
|
|
31
|
+
swipl_cmd: Command to invoke SWI-Prolog (default ``"swipl"``).
|
|
32
|
+
induce_mode: Aleph induction mode. One of ``"induce"``,
|
|
33
|
+
``"induce_max"``, ``"induce_cover"``, ``"induce_tree"``,
|
|
34
|
+
``"induce_features"``, ``"induce_constraints"``,
|
|
35
|
+
``"induce_incremental"`` (default ``"induce"``).
|
|
36
|
+
timeout: Maximum seconds for the SWI-Prolog process (default 300).
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
VALID_MODES = {
|
|
40
|
+
"induce",
|
|
41
|
+
"induce_max",
|
|
42
|
+
"induce_cover",
|
|
43
|
+
"induce_tree",
|
|
44
|
+
"induce_features",
|
|
45
|
+
"induce_constraints",
|
|
46
|
+
"induce_incremental",
|
|
47
|
+
"induce_theory",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
aleph_path: str | Path | None = None,
|
|
53
|
+
swipl_cmd: str = "swipl",
|
|
54
|
+
induce_mode: str = "induce",
|
|
55
|
+
timeout: int = 300,
|
|
56
|
+
) -> None:
|
|
57
|
+
if aleph_path is None:
|
|
58
|
+
aleph_path = Path(__file__).parent / "aleph_files" / "aleph_swi_ak.pl"
|
|
59
|
+
self.aleph_path = Path(aleph_path)
|
|
60
|
+
self.swipl_cmd = swipl_cmd
|
|
61
|
+
self.timeout = timeout
|
|
62
|
+
|
|
63
|
+
if induce_mode not in self.VALID_MODES:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Unknown induce_mode '{induce_mode}'. "
|
|
66
|
+
f"Valid modes: {sorted(self.VALID_MODES)}"
|
|
67
|
+
)
|
|
68
|
+
self.induce_mode = induce_mode
|
|
69
|
+
|
|
70
|
+
def learn(self, task: ILPTask) -> LearnedProgram:
|
|
71
|
+
"""Run Aleph on the given task.
|
|
72
|
+
|
|
73
|
+
Writes the task to temporary files, invokes SWI-Prolog with Aleph,
|
|
74
|
+
and parses the output for learned clauses.
|
|
75
|
+
"""
|
|
76
|
+
work_dir = task.work_dir or Path(tempfile.mkdtemp(prefix="pyseqalign_aleph_"))
|
|
77
|
+
work_dir = Path(work_dir)
|
|
78
|
+
work_dir.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
# Write Aleph-format files.
|
|
81
|
+
bk_lines = []
|
|
82
|
+
for k, v in task.settings.items():
|
|
83
|
+
bk_lines.append(f":- set({k},{v}).")
|
|
84
|
+
bk_lines.extend(task.bias)
|
|
85
|
+
bk_lines.append("")
|
|
86
|
+
bk_lines.extend(task.background)
|
|
87
|
+
(work_dir / "task.b").write_text("\n".join(bk_lines) + "\n")
|
|
88
|
+
(work_dir / "task.f").write_text("\n".join(task.positive) + "\n")
|
|
89
|
+
(work_dir / "task.n").write_text("\n".join(task.negative) + "\n")
|
|
90
|
+
|
|
91
|
+
# Construct SWI-Prolog script.
|
|
92
|
+
aleph_abs = self.aleph_path.resolve()
|
|
93
|
+
task_abs = (work_dir / "task").resolve()
|
|
94
|
+
result_abs = (work_dir / "result.pl").resolve()
|
|
95
|
+
|
|
96
|
+
script = (
|
|
97
|
+
f":- consult('{aleph_abs}').\n"
|
|
98
|
+
f":- read_all('{task_abs}').\n"
|
|
99
|
+
f":- {self.induce_mode}.\n"
|
|
100
|
+
f":- write_rules('{result_abs}').\n"
|
|
101
|
+
f":- halt.\n"
|
|
102
|
+
)
|
|
103
|
+
script_path = work_dir / "run_aleph.pl"
|
|
104
|
+
script_path.write_text(script)
|
|
105
|
+
|
|
106
|
+
# Run SWI-Prolog.
|
|
107
|
+
try:
|
|
108
|
+
result = subprocess.run(
|
|
109
|
+
[self.swipl_cmd, "-s", str(script_path)],
|
|
110
|
+
capture_output=True,
|
|
111
|
+
text=True,
|
|
112
|
+
timeout=self.timeout,
|
|
113
|
+
cwd=str(work_dir),
|
|
114
|
+
)
|
|
115
|
+
raw_output = result.stdout + result.stderr
|
|
116
|
+
except FileNotFoundError:
|
|
117
|
+
raise RuntimeError(
|
|
118
|
+
f"SWI-Prolog not found at '{self.swipl_cmd}'. "
|
|
119
|
+
"Install SWI-Prolog or set swipl_cmd to the correct path."
|
|
120
|
+
)
|
|
121
|
+
except subprocess.TimeoutExpired:
|
|
122
|
+
return LearnedProgram(
|
|
123
|
+
raw_output=f"Aleph timed out after {self.timeout}s",
|
|
124
|
+
stats={"timeout": True},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Parse results.
|
|
128
|
+
clauses = self._parse_output(raw_output, result_abs)
|
|
129
|
+
|
|
130
|
+
return LearnedProgram(
|
|
131
|
+
clauses=clauses,
|
|
132
|
+
score=self._extract_score(raw_output),
|
|
133
|
+
stats=self._extract_stats(raw_output),
|
|
134
|
+
raw_output=raw_output,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# ------------------------------------------------------------------
|
|
138
|
+
# Parsing helpers
|
|
139
|
+
# ------------------------------------------------------------------
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _parse_output(raw_output: str, result_file: Path) -> list[str]:
|
|
143
|
+
"""Extract learned clauses from Aleph output."""
|
|
144
|
+
clauses: list[str] = []
|
|
145
|
+
|
|
146
|
+
# Try reading the written rules file first.
|
|
147
|
+
if result_file.exists():
|
|
148
|
+
text = result_file.read_text()
|
|
149
|
+
for line in text.strip().splitlines():
|
|
150
|
+
line = line.strip()
|
|
151
|
+
if line and not line.startswith("%"):
|
|
152
|
+
clauses.append(line)
|
|
153
|
+
if clauses:
|
|
154
|
+
return clauses
|
|
155
|
+
|
|
156
|
+
# Fall back to parsing stdout for [Rule N] blocks.
|
|
157
|
+
in_rule = False
|
|
158
|
+
current: list[str] = []
|
|
159
|
+
for line in raw_output.splitlines():
|
|
160
|
+
if "[Rule" in line:
|
|
161
|
+
in_rule = True
|
|
162
|
+
current = []
|
|
163
|
+
continue
|
|
164
|
+
if in_rule:
|
|
165
|
+
stripped = line.strip()
|
|
166
|
+
if stripped == "":
|
|
167
|
+
if current:
|
|
168
|
+
clauses.append(" ".join(current))
|
|
169
|
+
current = []
|
|
170
|
+
in_rule = False
|
|
171
|
+
else:
|
|
172
|
+
current.append(stripped)
|
|
173
|
+
if current:
|
|
174
|
+
clauses.append(" ".join(current))
|
|
175
|
+
|
|
176
|
+
return clauses
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def _extract_score(raw_output: str) -> float:
|
|
180
|
+
"""Extract accuracy or coverage score from Aleph output."""
|
|
181
|
+
for line in raw_output.splitlines():
|
|
182
|
+
if "Accuracy" in line or "accuracy" in line:
|
|
183
|
+
parts = line.split()
|
|
184
|
+
for p in parts:
|
|
185
|
+
try:
|
|
186
|
+
return float(p.strip("()%,"))
|
|
187
|
+
except ValueError:
|
|
188
|
+
continue
|
|
189
|
+
return 0.0
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def _extract_stats(raw_output: str) -> dict[str, object]:
|
|
193
|
+
"""Extract statistics from Aleph output."""
|
|
194
|
+
stats: dict[str, object] = {}
|
|
195
|
+
for line in raw_output.splitlines():
|
|
196
|
+
if "clauses constructed" in line.lower():
|
|
197
|
+
parts = line.split()
|
|
198
|
+
for p in parts:
|
|
199
|
+
try:
|
|
200
|
+
stats["clauses_constructed"] = int(p)
|
|
201
|
+
break
|
|
202
|
+
except ValueError:
|
|
203
|
+
continue
|
|
204
|
+
if "nodes explored" in line.lower() or "nodes visited" in line.lower():
|
|
205
|
+
parts = line.split()
|
|
206
|
+
for p in parts:
|
|
207
|
+
try:
|
|
208
|
+
stats["nodes_explored"] = int(p)
|
|
209
|
+
break
|
|
210
|
+
except ValueError:
|
|
211
|
+
continue
|
|
212
|
+
return stats
|
|
File without changes
|