shrinkray 0.0.0__py3-none-any.whl → 25.12.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shrinkray/__main__.py +130 -960
- shrinkray/cli.py +70 -0
- shrinkray/display.py +75 -0
- shrinkray/formatting.py +108 -0
- shrinkray/passes/bytes.py +217 -10
- shrinkray/passes/clangdelta.py +47 -17
- shrinkray/passes/definitions.py +84 -4
- shrinkray/passes/genericlanguages.py +61 -7
- shrinkray/passes/json.py +6 -0
- shrinkray/passes/patching.py +65 -57
- shrinkray/passes/python.py +66 -23
- shrinkray/passes/sat.py +505 -91
- shrinkray/passes/sequences.py +26 -6
- shrinkray/problem.py +206 -27
- shrinkray/process.py +49 -0
- shrinkray/reducer.py +187 -25
- shrinkray/state.py +599 -0
- shrinkray/subprocess/__init__.py +24 -0
- shrinkray/subprocess/client.py +253 -0
- shrinkray/subprocess/protocol.py +190 -0
- shrinkray/subprocess/worker.py +491 -0
- shrinkray/tui.py +915 -0
- shrinkray/ui.py +72 -0
- shrinkray/work.py +34 -6
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/METADATA +44 -27
- shrinkray-25.12.26.0.dist-info/RECORD +33 -0
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/WHEEL +2 -1
- shrinkray-25.12.26.0.dist-info/entry_points.txt +3 -0
- shrinkray-25.12.26.0.dist-info/top_level.txt +1 -0
- shrinkray/learning.py +0 -221
- shrinkray-0.0.0.dist-info/RECORD +0 -22
- shrinkray-0.0.0.dist-info/entry_points.txt +0 -3
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info/licenses}/LICENSE +0 -0
shrinkray/passes/sat.py
CHANGED
|
@@ -1,8 +1,17 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from collections import Counter, defaultdict
|
|
2
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
3
|
+
|
|
4
|
+
from shrinkray.passes.definitions import (
|
|
5
|
+
DumpError,
|
|
6
|
+
Format,
|
|
7
|
+
ParseError,
|
|
8
|
+
ReductionPass,
|
|
9
|
+
)
|
|
10
|
+
from shrinkray.passes.patching import Conflict, SetPatches, apply_patches
|
|
3
11
|
from shrinkray.passes.sequences import delete_elements
|
|
4
12
|
from shrinkray.problem import ReductionProblem
|
|
5
13
|
|
|
14
|
+
|
|
6
15
|
Clause = list[int]
|
|
7
16
|
SAT = list[Clause]
|
|
8
17
|
|
|
@@ -17,7 +26,7 @@ class _DimacsCNF(Format[bytes, SAT]):
|
|
|
17
26
|
contents = input.decode("utf-8")
|
|
18
27
|
except UnicodeDecodeError as e:
|
|
19
28
|
raise ParseError(*e.args)
|
|
20
|
-
clauses = []
|
|
29
|
+
clauses: SAT = []
|
|
21
30
|
for line in contents.splitlines():
|
|
22
31
|
line = line.strip()
|
|
23
32
|
if line.startswith("c"):
|
|
@@ -27,7 +36,7 @@ class _DimacsCNF(Format[bytes, SAT]):
|
|
|
27
36
|
if not line.strip():
|
|
28
37
|
continue
|
|
29
38
|
try:
|
|
30
|
-
clause = list(map(int, line.strip().split()))
|
|
39
|
+
clause: Clause = list(map(int, line.strip().split()))
|
|
31
40
|
except ValueError as e:
|
|
32
41
|
raise ParseError(*e.args)
|
|
33
42
|
if clause[-1] != 0:
|
|
@@ -39,6 +48,8 @@ class _DimacsCNF(Format[bytes, SAT]):
|
|
|
39
48
|
return clauses
|
|
40
49
|
|
|
41
50
|
def dumps(self, input: SAT) -> bytes:
|
|
51
|
+
if not input or not all(input):
|
|
52
|
+
raise DumpError(input)
|
|
42
53
|
n_variables = max(abs(literal) for clause in input for literal in clause)
|
|
43
54
|
|
|
44
55
|
parts = [f"p cnf {n_variables} {len(input)}"]
|
|
@@ -52,71 +63,48 @@ class _DimacsCNF(Format[bytes, SAT]):
|
|
|
52
63
|
DimacsCNF = _DimacsCNF()
|
|
53
64
|
|
|
54
65
|
|
|
55
|
-
async def
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def renumber(l):
|
|
59
|
-
if l < 0:
|
|
60
|
-
return -renumber(-l)
|
|
61
|
-
try:
|
|
62
|
-
return renumbering[l]
|
|
63
|
-
except KeyError:
|
|
64
|
-
pass
|
|
65
|
-
result = len(renumbering) + 1
|
|
66
|
-
renumbering[l] = result
|
|
67
|
-
return result
|
|
68
|
-
|
|
69
|
-
renumbered = [
|
|
70
|
-
[renumber(literal) for literal in clause]
|
|
71
|
-
for clause in problem.current_test_case
|
|
72
|
-
]
|
|
66
|
+
async def flip_literal_signs(problem: ReductionProblem[SAT]):
|
|
67
|
+
"""Make negative literals positive.
|
|
73
68
|
|
|
74
|
-
|
|
69
|
+
Tries to replace negative literals (-x) with positive ones (x).
|
|
70
|
+
This normalizes the formula toward using positive literals only.
|
|
71
|
+
"""
|
|
75
72
|
|
|
73
|
+
def flip_terms(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
|
|
74
|
+
result = list(map(list, sat))
|
|
75
|
+
for i, j in terms:
|
|
76
|
+
result[i][j] = abs(result[i][j])
|
|
77
|
+
return result
|
|
76
78
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
new_clause.append(-literal)
|
|
89
|
-
else:
|
|
90
|
-
new_clause.append(literal)
|
|
91
|
-
attempt.append(new_clause)
|
|
92
|
-
if await problem.is_interesting(attempt):
|
|
93
|
-
target = attempt
|
|
94
|
-
seen_variables.add(abs(v))
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
async def remove_redundant_clauses(problem: ReductionProblem[SAT]):
|
|
98
|
-
attempt = []
|
|
99
|
-
seen = set()
|
|
100
|
-
for clause in problem.current_test_case:
|
|
101
|
-
if len(set(map(abs, clause))) < len(set(clause)):
|
|
102
|
-
continue
|
|
103
|
-
key = tuple(clause)
|
|
104
|
-
if key in seen:
|
|
105
|
-
continue
|
|
106
|
-
seen.add(key)
|
|
107
|
-
attempt.append(clause)
|
|
108
|
-
await problem.is_interesting(attempt)
|
|
79
|
+
await apply_patches(
|
|
80
|
+
problem,
|
|
81
|
+
SetPatches(flip_terms),
|
|
82
|
+
[
|
|
83
|
+
frozenset({(i, j)})
|
|
84
|
+
for i, clause in enumerate(problem.current_test_case)
|
|
85
|
+
for j, v in enumerate(clause)
|
|
86
|
+
if v < 0
|
|
87
|
+
],
|
|
88
|
+
)
|
|
89
|
+
await unit_propagate(problem)
|
|
109
90
|
|
|
110
91
|
|
|
111
92
|
def literals_in(sat: SAT) -> frozenset[int]:
|
|
112
93
|
return frozenset({literal for clause in sat for literal in clause})
|
|
113
94
|
|
|
114
95
|
|
|
115
|
-
async def delete_literals(problem: ReductionProblem[SAT]):
|
|
96
|
+
async def delete_literals(problem: ReductionProblem[SAT]) -> None:
|
|
97
|
+
"""Remove entire literals from the formula.
|
|
98
|
+
|
|
99
|
+
Tries to remove all occurrences of a literal (both positive and
|
|
100
|
+
negative forms) from all clauses. Clauses that become empty are
|
|
101
|
+
removed entirely.
|
|
102
|
+
"""
|
|
103
|
+
|
|
116
104
|
def remove_literals(literals: frozenset[int], sat: SAT) -> SAT:
|
|
117
|
-
result = []
|
|
105
|
+
result: SAT = []
|
|
118
106
|
for clause in sat:
|
|
119
|
-
new_clause = [v for v in clause if v not in literals]
|
|
107
|
+
new_clause: Clause = [v for v in clause if v not in literals]
|
|
120
108
|
if new_clause:
|
|
121
109
|
result.append(new_clause)
|
|
122
110
|
return result
|
|
@@ -128,49 +116,475 @@ async def delete_literals(problem: ReductionProblem[SAT]):
|
|
|
128
116
|
)
|
|
129
117
|
|
|
130
118
|
|
|
131
|
-
async def
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
119
|
+
async def delete_single_terms(problem: ReductionProblem[SAT]) -> None:
|
|
120
|
+
"""Remove individual literal occurrences from specific clauses.
|
|
121
|
+
|
|
122
|
+
Unlike delete_literals (which removes a literal everywhere), this
|
|
123
|
+
tries removing literals from individual positions, allowing different
|
|
124
|
+
clauses to keep or lose the same literal independently.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def remove_terms(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
|
|
128
|
+
result: list[list[int]] = [list(c) for c in sat]
|
|
129
|
+
grouped: defaultdict[int, set[int]] = defaultdict(set)
|
|
130
|
+
for i, j in terms:
|
|
131
|
+
grouped[i].add(j)
|
|
132
|
+
for i, js in grouped.items():
|
|
133
|
+
for j in sorted(js, reverse=True):
|
|
134
|
+
del result[i][j]
|
|
135
|
+
return [c for c in result if c]
|
|
136
|
+
|
|
137
|
+
await apply_patches(
|
|
138
|
+
problem,
|
|
139
|
+
SetPatches(remove_terms),
|
|
140
|
+
[
|
|
141
|
+
frozenset({(i, j)})
|
|
142
|
+
for i, clause in enumerate(problem.current_test_case)
|
|
143
|
+
for j in range(len(clause))
|
|
144
|
+
],
|
|
145
|
+
)
|
|
146
|
+
await unit_propagate(problem)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def renumber_variables(problem: ReductionProblem[SAT]) -> None:
|
|
150
|
+
"""Renumber variables to use smaller indices.
|
|
151
|
+
|
|
152
|
+
Tries to replace variable numbers with smaller ones (1, 2, 3, etc.)
|
|
153
|
+
to minimize the variable indices used. This normalizes the formula
|
|
154
|
+
toward using the smallest possible variable numbers.
|
|
155
|
+
"""
|
|
156
|
+
variables = sorted(
|
|
157
|
+
{abs(lit) for clause in problem.current_test_case for lit in clause}
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def renumber(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
|
|
161
|
+
renumbering: dict[int, int] = {}
|
|
162
|
+
for i, j in sorted(terms):
|
|
163
|
+
if j not in renumbering:
|
|
164
|
+
renumbering[j] = i
|
|
165
|
+
result: SAT = []
|
|
166
|
+
for clause in sat:
|
|
167
|
+
new_clause: Clause = sorted(
|
|
168
|
+
set(
|
|
169
|
+
[
|
|
170
|
+
(renumbering[lit] if lit > 0 else -renumbering[-lit])
|
|
171
|
+
if abs(lit) in renumbering
|
|
172
|
+
else lit
|
|
173
|
+
for lit in clause
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if len(set(map(abs, new_clause))) == len(new_clause):
|
|
178
|
+
result.append(new_clause)
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
ideal_number: dict[int, int] = {v: i for i, v in enumerate(variables, 1)}
|
|
182
|
+
backup_number: dict[int, int] = {}
|
|
183
|
+
used = set(variables)
|
|
184
|
+
i = 1
|
|
185
|
+
for v in variables:
|
|
186
|
+
while i in used and i <= v:
|
|
137
187
|
i += 1
|
|
138
|
-
|
|
139
|
-
|
|
188
|
+
if i < v:
|
|
189
|
+
backup_number[v] = i
|
|
190
|
+
|
|
191
|
+
await apply_patches(
|
|
192
|
+
problem,
|
|
193
|
+
SetPatches(renumber),
|
|
194
|
+
[
|
|
195
|
+
frozenset({(u, v)})
|
|
196
|
+
for v in variables
|
|
197
|
+
for u in {
|
|
198
|
+
1,
|
|
199
|
+
2,
|
|
200
|
+
3,
|
|
201
|
+
4,
|
|
202
|
+
5,
|
|
203
|
+
v // 3,
|
|
204
|
+
v // 2,
|
|
205
|
+
v - 3,
|
|
206
|
+
v - 2,
|
|
207
|
+
v - 1,
|
|
208
|
+
ideal_number[v],
|
|
209
|
+
backup_number.get(v, v),
|
|
210
|
+
}
|
|
211
|
+
if 0 < u < v
|
|
212
|
+
],
|
|
213
|
+
)
|
|
214
|
+
await unit_propagate(problem)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class UnionFind[T]:
|
|
218
|
+
table: dict[T, T]
|
|
219
|
+
key: Callable[[T], object] | None
|
|
220
|
+
generation: int
|
|
221
|
+
representatives: int
|
|
222
|
+
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
initial_merges: Iterable[tuple[T, T]] = (),
|
|
226
|
+
key: Callable[[T], object] | None = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
self.table = {}
|
|
229
|
+
self.key = key
|
|
230
|
+
self.generation = 0
|
|
231
|
+
self.representatives = 0
|
|
232
|
+
for k, v in initial_merges:
|
|
233
|
+
self.merge(k, v)
|
|
234
|
+
|
|
235
|
+
def components(self) -> list[list[T]]:
|
|
236
|
+
groupings: defaultdict[T, list[T]] = defaultdict(list)
|
|
237
|
+
for k in list(self.table):
|
|
238
|
+
groupings[self.find(k)].append(k)
|
|
239
|
+
return list(groupings.values())
|
|
240
|
+
|
|
241
|
+
def find(self, value: T) -> T:
|
|
242
|
+
try:
|
|
243
|
+
if self.table[value] == value:
|
|
244
|
+
return value
|
|
245
|
+
except KeyError:
|
|
246
|
+
self.representatives += 1
|
|
247
|
+
self.table[value] = value
|
|
248
|
+
return value
|
|
249
|
+
|
|
250
|
+
trail: list[T] = []
|
|
251
|
+
while value != self.table[value]:
|
|
252
|
+
trail.append(value)
|
|
253
|
+
value = self.table[value]
|
|
254
|
+
for t in trail:
|
|
255
|
+
self.table[t] = value
|
|
256
|
+
return value
|
|
257
|
+
|
|
258
|
+
def merge(self, left: T, right: T) -> None:
|
|
259
|
+
if left == right:
|
|
260
|
+
return
|
|
261
|
+
left = self.find(left)
|
|
262
|
+
right = self.find(right)
|
|
263
|
+
if left == right:
|
|
264
|
+
return
|
|
265
|
+
self.representatives -= 1
|
|
266
|
+
self.generation += 1
|
|
267
|
+
left, right = sorted((left, right), key=self.key) # type: ignore[arg-type]
|
|
268
|
+
self.table[right] = left
|
|
269
|
+
|
|
270
|
+
def merge_all(self, values: list[T]) -> None:
|
|
271
|
+
if len(values) > 1:
|
|
272
|
+
sorted_values: list[T] = sorted(values, key=self.key) # type: ignore[arg-type]
|
|
273
|
+
a: T = sorted_values[0] # type: ignore[reportUnknownVariableType]
|
|
274
|
+
for b in sorted_values[1:]: # type: ignore[reportUnknownVariableType]
|
|
275
|
+
self.merge(a, b) # type: ignore[reportUnknownArgumentType]
|
|
276
|
+
|
|
277
|
+
def __repr__(self) -> str:
|
|
278
|
+
return "%s(%d components)" % (
|
|
279
|
+
type(self).__name__,
|
|
280
|
+
len(self.components()),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class BooleanEquivalence(UnionFind[int]):
|
|
285
|
+
table: "NegatingMap" # type: ignore[reportIncompatibleVariableOverride]
|
|
286
|
+
|
|
287
|
+
def __init__(self, initial_merges: Iterable[tuple[int, int]] = ()) -> None:
|
|
288
|
+
super().__init__(initial_merges, key=abs)
|
|
289
|
+
self.table = NegatingMap() # pyright: ignore[reportIncompatibleVariableOverride]
|
|
290
|
+
|
|
291
|
+
def find(self, value: int) -> int:
|
|
292
|
+
if not value:
|
|
293
|
+
raise ValueError("Invalid variable %r" % (value,))
|
|
294
|
+
return super().find(value)
|
|
295
|
+
|
|
296
|
+
def merge(self, left: int, right: int) -> None:
|
|
297
|
+
if left == right:
|
|
298
|
+
return
|
|
299
|
+
left2 = self.find(left)
|
|
300
|
+
right2 = self.find(right)
|
|
301
|
+
if left2 == right2:
|
|
140
302
|
return
|
|
303
|
+
if left2 == -right2:
|
|
304
|
+
raise Inconsistent(
|
|
305
|
+
"Attempted to merge %d (=%d) with %d (=%d)"
|
|
306
|
+
% (left, left2, right, right2)
|
|
307
|
+
)
|
|
308
|
+
super().merge(left, right)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class NegatingMap:
|
|
312
|
+
_data: dict[int, int]
|
|
313
|
+
|
|
314
|
+
def __init__(self) -> None:
|
|
315
|
+
self._data = {}
|
|
316
|
+
|
|
317
|
+
def __repr__(self) -> str:
|
|
318
|
+
m: dict[int, int] = {}
|
|
319
|
+
for k, v in self._data.items():
|
|
320
|
+
m[k] = v
|
|
321
|
+
m[-k] = -v
|
|
322
|
+
return repr(m)
|
|
323
|
+
|
|
324
|
+
def __iter__(self) -> Iterator[int]:
|
|
325
|
+
yield from self._data.keys()
|
|
326
|
+
for k in self._data.keys():
|
|
327
|
+
yield -k
|
|
328
|
+
|
|
329
|
+
def __getitem__(self, key: int) -> int:
|
|
330
|
+
assert key != 0
|
|
331
|
+
if key < 0:
|
|
332
|
+
return -self._data[-key]
|
|
333
|
+
else:
|
|
334
|
+
return self._data[key]
|
|
335
|
+
|
|
336
|
+
def __setitem__(self, key: int, value: int) -> None:
|
|
337
|
+
assert key != 0
|
|
338
|
+
assert value != 0
|
|
339
|
+
if key < 0:
|
|
340
|
+
self._data[-key] = -value
|
|
341
|
+
else:
|
|
342
|
+
self._data[key] = value
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
async def merge_literals(problem: ReductionProblem[SAT]) -> None:
|
|
346
|
+
"""Merge pairs of literals into single variables.
|
|
347
|
+
|
|
348
|
+
Tries to identify pairs of literals that can be treated as equivalent
|
|
349
|
+
(or negations of each other) and replaces them with a single variable.
|
|
350
|
+
This reduces the number of distinct variables in the formula.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def apply_merges(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
|
|
354
|
+
uf = BooleanEquivalence()
|
|
355
|
+
try:
|
|
356
|
+
for u, v in terms:
|
|
357
|
+
uf.merge(u, v)
|
|
358
|
+
except Inconsistent:
|
|
359
|
+
raise Conflict()
|
|
360
|
+
|
|
361
|
+
result: set[frozenset[int]] = set()
|
|
362
|
+
for clause in sat:
|
|
363
|
+
new_clause = frozenset(map(uf.find, clause))
|
|
364
|
+
result.add(new_clause)
|
|
365
|
+
return sorted([sorted(clause, key=abs) for clause in result], key=len)
|
|
366
|
+
|
|
367
|
+
await apply_patches(
|
|
368
|
+
problem,
|
|
369
|
+
SetPatches(apply_merges),
|
|
370
|
+
[
|
|
371
|
+
frozenset({(u, -v)})
|
|
372
|
+
for clause in problem.current_test_case
|
|
373
|
+
for u in clause
|
|
374
|
+
for v in clause
|
|
375
|
+
if u != v
|
|
376
|
+
],
|
|
377
|
+
)
|
|
378
|
+
await unit_propagate(problem)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
async def pass_to_component(problem: ReductionProblem[SAT]) -> None:
|
|
382
|
+
"""Try to reduce to a single connected component.
|
|
383
|
+
|
|
384
|
+
If the formula can be split into independent components (clauses that
|
|
385
|
+
share no variables), tries each component individually to see if any
|
|
386
|
+
single component is sufficient to maintain interestingness.
|
|
387
|
+
"""
|
|
388
|
+
groups: UnionFind[int] = UnionFind()
|
|
389
|
+
clauses = problem.current_test_case
|
|
390
|
+
for clause in clauses:
|
|
391
|
+
groups.merge_all(list(map(abs, clause)))
|
|
392
|
+
partitions: defaultdict[int, SAT] = defaultdict(list)
|
|
393
|
+
for clause in clauses:
|
|
394
|
+
partitions[groups.find(abs(clause[0]))].append(clause)
|
|
395
|
+
if len(partitions) > 1:
|
|
396
|
+
for p in sorted(partitions.values(), key=len):
|
|
397
|
+
await problem.is_interesting(p)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
async def sort_clauses(problem: ReductionProblem[SAT]) -> None:
|
|
401
|
+
"""Sort clauses and literals into canonical order.
|
|
402
|
+
|
|
403
|
+
Sorts literals within each clause and sorts clauses themselves.
|
|
404
|
+
This normalizes the formula representation for consistent output.
|
|
405
|
+
"""
|
|
406
|
+
await problem.is_interesting(sorted(map(sorted, problem.current_test_case)))
|
|
141
407
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
408
|
+
|
|
409
|
+
class Inconsistent(Exception):
|
|
410
|
+
pass
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class UnitPropagator:
|
|
414
|
+
__clauses: list[tuple[int, ...]]
|
|
415
|
+
__clause_counts: Counter[int]
|
|
416
|
+
__watches: defaultdict[int, frozenset[int]]
|
|
417
|
+
__watched_by: list[frozenset[int]]
|
|
418
|
+
units: set[int]
|
|
419
|
+
forced_variables: set[int]
|
|
420
|
+
__dirty: set[int]
|
|
421
|
+
|
|
422
|
+
def __init__(self, clauses: Iterable[Iterable[int]]) -> None:
|
|
423
|
+
self.__clauses = [tuple(c) for c in clauses]
|
|
424
|
+
self.__clause_counts = Counter()
|
|
425
|
+
for clause in self.__clauses:
|
|
426
|
+
for literal in clause:
|
|
427
|
+
self.__clause_counts[abs(literal)] += 1
|
|
428
|
+
self.__watches = defaultdict(frozenset)
|
|
429
|
+
self.__watched_by = [frozenset() for _ in self.__clauses]
|
|
430
|
+
|
|
431
|
+
self.units = set()
|
|
432
|
+
self.forced_variables = set()
|
|
433
|
+
self.__dirty = set(range(len(self.__clauses)))
|
|
434
|
+
self.__clean_dirty_clauses()
|
|
435
|
+
|
|
436
|
+
def __enqueue_unit(self, unit: int) -> None:
|
|
437
|
+
# Invariant: unit should not already be in self.units because satisfied
|
|
438
|
+
# clauses are skipped at line 424 before we try to enqueue their units.
|
|
439
|
+
assert unit not in self.units, f"unit {unit} already enqueued"
|
|
440
|
+
# Invariant: -unit should not be in self.units because we only add
|
|
441
|
+
# literals to watched_by if their negation is not in units (line 438).
|
|
442
|
+
assert -unit not in self.units, (
|
|
443
|
+
f"Tried to add {unit} as a unit but {-unit} is already a unit"
|
|
444
|
+
)
|
|
445
|
+
self.units.add(unit)
|
|
446
|
+
self.forced_variables.add(abs(unit))
|
|
447
|
+
self.__dirty.update(self.__watches.pop(-unit, ()))
|
|
448
|
+
|
|
449
|
+
def __clean_dirty_clauses(self) -> None:
|
|
450
|
+
iters = 0
|
|
451
|
+
while self.__dirty:
|
|
452
|
+
iters += 1
|
|
453
|
+
assert iters <= 10**6
|
|
454
|
+
dirty = self.__dirty
|
|
455
|
+
self.__dirty = set()
|
|
456
|
+
|
|
457
|
+
for i in dirty:
|
|
458
|
+
clause = self.__clauses[i]
|
|
459
|
+
if not clause:
|
|
460
|
+
raise Inconsistent("Clauses contain an empty clause")
|
|
461
|
+
if any(literal in self.units for literal in clause):
|
|
462
|
+
for literal in self.__watched_by[i]:
|
|
463
|
+
if literal in self.__watches:
|
|
464
|
+
self.__watches[literal] -= {i}
|
|
465
|
+
self.__watched_by[i] = frozenset()
|
|
466
|
+
for literal in clause:
|
|
467
|
+
self.__clause_counts[abs(literal)] -= 1
|
|
468
|
+
else:
|
|
469
|
+
for literal in list(self.__watched_by[i]):
|
|
470
|
+
if -literal in self.units:
|
|
471
|
+
self.__watched_by[i] -= {literal}
|
|
472
|
+
for literal in clause:
|
|
473
|
+
if len(self.__watched_by[i]) == 2:
|
|
474
|
+
break
|
|
475
|
+
if -literal not in self.units:
|
|
476
|
+
self.__watches[literal] |= {i}
|
|
477
|
+
self.__watched_by[i] |= {literal}
|
|
478
|
+
if len(self.__watched_by[i]) == 0:
|
|
479
|
+
raise Inconsistent(
|
|
480
|
+
f"Clause {' '.join(map(str, clause))} can no longer be satisfied"
|
|
481
|
+
)
|
|
482
|
+
elif len(self.__watched_by[i]) == 1:
|
|
483
|
+
self.__enqueue_unit(*self.__watched_by[i])
|
|
484
|
+
|
|
485
|
+
def propagated_clauses(self) -> SAT:
|
|
486
|
+
results: set[tuple[int, ...]] = set()
|
|
487
|
+
neg_units = {-v for v in self.units}
|
|
488
|
+
for clause in self.__clauses:
|
|
489
|
+
if any(literal in self.units for literal in clause):
|
|
155
490
|
continue
|
|
156
|
-
|
|
491
|
+
if not neg_units.isdisjoint(clause):
|
|
492
|
+
clause = tuple(sorted(set(clause) - neg_units))
|
|
493
|
+
results.add(clause)
|
|
494
|
+
return [[literal] for literal in self.units] + [
|
|
495
|
+
list(c)
|
|
496
|
+
for c in sorted(results, key=lambda c: (len(c), list(map(abs, c)), c))
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
async def unit_propagate(problem: ReductionProblem[SAT]) -> None:
|
|
501
|
+
"""Apply unit propagation to simplify the formula.
|
|
502
|
+
|
|
503
|
+
Finds unit clauses (single-literal clauses) and propagates their
|
|
504
|
+
implications: removes satisfied clauses and removes the negated
|
|
505
|
+
literal from other clauses. This is a standard SAT preprocessing step.
|
|
506
|
+
"""
|
|
507
|
+
try:
|
|
508
|
+
propagated = UnitPropagator(problem.current_test_case).propagated_clauses()
|
|
509
|
+
except Inconsistent:
|
|
510
|
+
# Clauses are unsatisfiable, nothing to propagate
|
|
511
|
+
return
|
|
512
|
+
if not await problem.is_interesting([c for c in propagated if len(c) > 1]):
|
|
513
|
+
await problem.is_interesting(propagated)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
async def force_literals(problem: ReductionProblem[SAT]) -> None:
|
|
517
|
+
"""Try forcing each literal to a specific value.
|
|
518
|
+
|
|
519
|
+
For each literal in the formula, tries adding it as a unit clause
|
|
520
|
+
and propagating. If the result is interesting, the formula is
|
|
521
|
+
simplified by that forced assignment.
|
|
522
|
+
"""
|
|
523
|
+
literals = literals_in(problem.current_test_case)
|
|
524
|
+
for lit in literals:
|
|
525
|
+
try:
|
|
526
|
+
await problem.is_interesting(
|
|
527
|
+
UnitPropagator(problem.current_test_case + [[lit]]).propagated_clauses()
|
|
528
|
+
)
|
|
529
|
+
except Inconsistent:
|
|
530
|
+
pass
|
|
157
531
|
|
|
158
|
-
assert new_clauses != problem.current_test_case
|
|
159
|
-
await problem.is_interesting(new_clauses)
|
|
160
|
-
if new_clauses != problem.current_test_case:
|
|
161
|
-
j += 1
|
|
162
532
|
|
|
533
|
+
async def combine_clauses(problem: ReductionProblem[SAT]) -> None:
|
|
534
|
+
"""Merge pairs of clauses into single clauses.
|
|
535
|
+
|
|
536
|
+
Tries to combine clauses that share literals, creating a single
|
|
537
|
+
clause containing all literals from both. This reduces clause count
|
|
538
|
+
while potentially creating larger but fewer clauses.
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def apply_merges(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
|
|
542
|
+
uf: UnionFind[int] = UnionFind()
|
|
543
|
+
for u, v in terms:
|
|
544
|
+
uf.merge(u, v)
|
|
545
|
+
|
|
546
|
+
result: list[Clause | None] = [list(c) for c in sat]
|
|
547
|
+
for c in uf.components():
|
|
548
|
+
# Note: len(c) == 1 can't occur because every element in uf
|
|
549
|
+
# came from a merge(u, v) pair where u != v, so all
|
|
550
|
+
# components have size >= 2.
|
|
551
|
+
combined: Clause = sorted({lit for i in c for lit in sat[i]}, key=abs)
|
|
552
|
+
for i in c:
|
|
553
|
+
result[i] = None
|
|
554
|
+
if len(combined) == len(set(map(abs, combined))):
|
|
555
|
+
result.append(combined)
|
|
556
|
+
return [clause for clause in result if clause is not None]
|
|
557
|
+
|
|
558
|
+
by_literal: defaultdict[int, list[int]] = defaultdict(list)
|
|
559
|
+
for i, clause in enumerate(problem.current_test_case):
|
|
560
|
+
for lit in clause:
|
|
561
|
+
by_literal[lit].append(i)
|
|
163
562
|
|
|
164
|
-
|
|
165
|
-
|
|
563
|
+
await apply_patches(
|
|
564
|
+
problem,
|
|
565
|
+
SetPatches(apply_merges),
|
|
566
|
+
[
|
|
567
|
+
frozenset({(i, j)})
|
|
568
|
+
for group in by_literal.values()
|
|
569
|
+
for i in group
|
|
570
|
+
for j in group
|
|
571
|
+
if i != j
|
|
572
|
+
]
|
|
573
|
+
+ [frozenset({(i, i + 1)}) for i in range(len(problem.current_test_case) - 1)],
|
|
574
|
+
)
|
|
575
|
+
await unit_propagate(problem)
|
|
166
576
|
|
|
167
577
|
|
|
168
578
|
SAT_PASSES: list[ReductionPass[SAT]] = [
|
|
169
579
|
sort_clauses,
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
delete_elements,
|
|
580
|
+
force_literals,
|
|
581
|
+
pass_to_component,
|
|
582
|
+
unit_propagate,
|
|
174
583
|
delete_literals,
|
|
175
|
-
|
|
584
|
+
delete_single_terms,
|
|
585
|
+
delete_elements,
|
|
586
|
+
flip_literal_signs,
|
|
587
|
+
combine_clauses,
|
|
588
|
+
merge_literals,
|
|
589
|
+
renumber_variables,
|
|
176
590
|
]
|