shrinkray 0.0.0__py3-none-any.whl → 25.12.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
shrinkray/passes/sat.py CHANGED
@@ -1,8 +1,17 @@
1
- from shrinkray.passes.definitions import Format, ParseError, ReductionPass
2
- from shrinkray.passes.patching import SetPatches, apply_patches
1
+ from collections import Counter, defaultdict
2
+ from collections.abc import Callable, Iterable, Iterator
3
+
4
+ from shrinkray.passes.definitions import (
5
+ DumpError,
6
+ Format,
7
+ ParseError,
8
+ ReductionPass,
9
+ )
10
+ from shrinkray.passes.patching import Conflict, SetPatches, apply_patches
3
11
  from shrinkray.passes.sequences import delete_elements
4
12
  from shrinkray.problem import ReductionProblem
5
13
 
14
+
6
15
  Clause = list[int]
7
16
  SAT = list[Clause]
8
17
 
@@ -17,7 +26,7 @@ class _DimacsCNF(Format[bytes, SAT]):
17
26
  contents = input.decode("utf-8")
18
27
  except UnicodeDecodeError as e:
19
28
  raise ParseError(*e.args)
20
- clauses = []
29
+ clauses: SAT = []
21
30
  for line in contents.splitlines():
22
31
  line = line.strip()
23
32
  if line.startswith("c"):
@@ -27,7 +36,7 @@ class _DimacsCNF(Format[bytes, SAT]):
27
36
  if not line.strip():
28
37
  continue
29
38
  try:
30
- clause = list(map(int, line.strip().split()))
39
+ clause: Clause = list(map(int, line.strip().split()))
31
40
  except ValueError as e:
32
41
  raise ParseError(*e.args)
33
42
  if clause[-1] != 0:
@@ -39,6 +48,8 @@ class _DimacsCNF(Format[bytes, SAT]):
39
48
  return clauses
40
49
 
41
50
  def dumps(self, input: SAT) -> bytes:
51
+ if not input or not all(input):
52
+ raise DumpError(input)
42
53
  n_variables = max(abs(literal) for clause in input for literal in clause)
43
54
 
44
55
  parts = [f"p cnf {n_variables} {len(input)}"]
@@ -52,71 +63,48 @@ class _DimacsCNF(Format[bytes, SAT]):
52
63
  DimacsCNF = _DimacsCNF()
53
64
 
54
65
 
55
- async def renumber_variables(problem: ReductionProblem[SAT]):
56
- renumbering = {}
57
-
58
- def renumber(l):
59
- if l < 0:
60
- return -renumber(-l)
61
- try:
62
- return renumbering[l]
63
- except KeyError:
64
- pass
65
- result = len(renumbering) + 1
66
- renumbering[l] = result
67
- return result
68
-
69
- renumbered = [
70
- [renumber(literal) for literal in clause]
71
- for clause in problem.current_test_case
72
- ]
66
+ async def flip_literal_signs(problem: ReductionProblem[SAT]):
67
+ """Make negative literals positive.
73
68
 
74
- await problem.is_interesting(renumbered)
69
+ Tries to replace negative literals (-x) with positive ones (x).
70
+ This normalizes the formula toward using positive literals only.
71
+ """
75
72
 
73
+ def flip_terms(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
74
+ result = list(map(list, sat))
75
+ for i, j in terms:
76
+ result[i][j] = abs(result[i][j])
77
+ return result
76
78
 
77
- async def flip_literal_signs(problem: ReductionProblem[SAT]):
78
- seen_variables = set()
79
- target = problem.current_test_case
80
- for i in range(len(target)):
81
- for j, v in enumerate(target[i]):
82
- if abs(v) not in seen_variables and v < 0:
83
- attempt = []
84
- for clause in target:
85
- new_clause = []
86
- for literal in clause:
87
- if abs(literal) == abs(v):
88
- new_clause.append(-literal)
89
- else:
90
- new_clause.append(literal)
91
- attempt.append(new_clause)
92
- if await problem.is_interesting(attempt):
93
- target = attempt
94
- seen_variables.add(abs(v))
95
-
96
-
97
- async def remove_redundant_clauses(problem: ReductionProblem[SAT]):
98
- attempt = []
99
- seen = set()
100
- for clause in problem.current_test_case:
101
- if len(set(map(abs, clause))) < len(set(clause)):
102
- continue
103
- key = tuple(clause)
104
- if key in seen:
105
- continue
106
- seen.add(key)
107
- attempt.append(clause)
108
- await problem.is_interesting(attempt)
79
+ await apply_patches(
80
+ problem,
81
+ SetPatches(flip_terms),
82
+ [
83
+ frozenset({(i, j)})
84
+ for i, clause in enumerate(problem.current_test_case)
85
+ for j, v in enumerate(clause)
86
+ if v < 0
87
+ ],
88
+ )
89
+ await unit_propagate(problem)
109
90
 
110
91
 
111
92
  def literals_in(sat: SAT) -> frozenset[int]:
112
93
  return frozenset({literal for clause in sat for literal in clause})
113
94
 
114
95
 
115
- async def delete_literals(problem: ReductionProblem[SAT]):
96
+ async def delete_literals(problem: ReductionProblem[SAT]) -> None:
97
+ """Remove entire literals from the formula.
98
+
99
+ Tries to remove all occurrences of a literal (both positive and
100
+ negative forms) from all clauses. Clauses that become empty are
101
+ removed entirely.
102
+ """
103
+
116
104
  def remove_literals(literals: frozenset[int], sat: SAT) -> SAT:
117
- result = []
105
+ result: SAT = []
118
106
  for clause in sat:
119
- new_clause = [v for v in clause if v not in literals]
107
+ new_clause: Clause = [v for v in clause if v not in literals]
120
108
  if new_clause:
121
109
  result.append(new_clause)
122
110
  return result
@@ -128,49 +116,475 @@ async def delete_literals(problem: ReductionProblem[SAT]):
128
116
  )
129
117
 
130
118
 
131
- async def merge_variables(problem: ReductionProblem[SAT]):
132
- i = 0
133
- j = 1
134
- while True:
135
- variables = sorted({abs(l) for c in problem.current_test_case for l in c})
136
- if j >= len(variables):
119
+ async def delete_single_terms(problem: ReductionProblem[SAT]) -> None:
120
+ """Remove individual literal occurrences from specific clauses.
121
+
122
+ Unlike delete_literals (which removes a literal everywhere), this
123
+ tries removing literals from individual positions, allowing different
124
+ clauses to keep or lose the same literal independently.
125
+ """
126
+
127
+ def remove_terms(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
128
+ result: list[list[int]] = [list(c) for c in sat]
129
+ grouped: defaultdict[int, set[int]] = defaultdict(set)
130
+ for i, j in terms:
131
+ grouped[i].add(j)
132
+ for i, js in grouped.items():
133
+ for j in sorted(js, reverse=True):
134
+ del result[i][j]
135
+ return [c for c in result if c]
136
+
137
+ await apply_patches(
138
+ problem,
139
+ SetPatches(remove_terms),
140
+ [
141
+ frozenset({(i, j)})
142
+ for i, clause in enumerate(problem.current_test_case)
143
+ for j in range(len(clause))
144
+ ],
145
+ )
146
+ await unit_propagate(problem)
147
+
148
+
149
+ async def renumber_variables(problem: ReductionProblem[SAT]) -> None:
150
+ """Renumber variables to use smaller indices.
151
+
152
+ Tries to replace variable numbers with smaller ones (1, 2, 3, etc.)
153
+ to minimize the variable indices used. This normalizes the formula
154
+ toward using the smallest possible variable numbers.
155
+ """
156
+ variables = sorted(
157
+ {abs(lit) for clause in problem.current_test_case for lit in clause}
158
+ )
159
+
160
+ def renumber(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
161
+ renumbering: dict[int, int] = {}
162
+ for i, j in sorted(terms):
163
+ if j not in renumbering:
164
+ renumbering[j] = i
165
+ result: SAT = []
166
+ for clause in sat:
167
+ new_clause: Clause = sorted(
168
+ set(
169
+ [
170
+ (renumbering[lit] if lit > 0 else -renumbering[-lit])
171
+ if abs(lit) in renumbering
172
+ else lit
173
+ for lit in clause
174
+ ]
175
+ )
176
+ )
177
+ if len(set(map(abs, new_clause))) == len(new_clause):
178
+ result.append(new_clause)
179
+ return result
180
+
181
+ ideal_number: dict[int, int] = {v: i for i, v in enumerate(variables, 1)}
182
+ backup_number: dict[int, int] = {}
183
+ used = set(variables)
184
+ i = 1
185
+ for v in variables:
186
+ while i in used and i <= v:
137
187
  i += 1
138
- j = i + 1
139
- if j >= len(variables):
188
+ if i < v:
189
+ backup_number[v] = i
190
+
191
+ await apply_patches(
192
+ problem,
193
+ SetPatches(renumber),
194
+ [
195
+ frozenset({(u, v)})
196
+ for v in variables
197
+ for u in {
198
+ 1,
199
+ 2,
200
+ 3,
201
+ 4,
202
+ 5,
203
+ v // 3,
204
+ v // 2,
205
+ v - 3,
206
+ v - 2,
207
+ v - 1,
208
+ ideal_number[v],
209
+ backup_number.get(v, v),
210
+ }
211
+ if 0 < u < v
212
+ ],
213
+ )
214
+ await unit_propagate(problem)
215
+
216
+
217
+ class UnionFind[T]:
218
+ table: dict[T, T]
219
+ key: Callable[[T], object] | None
220
+ generation: int
221
+ representatives: int
222
+
223
+ def __init__(
224
+ self,
225
+ initial_merges: Iterable[tuple[T, T]] = (),
226
+ key: Callable[[T], object] | None = None,
227
+ ) -> None:
228
+ self.table = {}
229
+ self.key = key
230
+ self.generation = 0
231
+ self.representatives = 0
232
+ for k, v in initial_merges:
233
+ self.merge(k, v)
234
+
235
+ def components(self) -> list[list[T]]:
236
+ groupings: defaultdict[T, list[T]] = defaultdict(list)
237
+ for k in list(self.table):
238
+ groupings[self.find(k)].append(k)
239
+ return list(groupings.values())
240
+
241
+ def find(self, value: T) -> T:
242
+ try:
243
+ if self.table[value] == value:
244
+ return value
245
+ except KeyError:
246
+ self.representatives += 1
247
+ self.table[value] = value
248
+ return value
249
+
250
+ trail: list[T] = []
251
+ while value != self.table[value]:
252
+ trail.append(value)
253
+ value = self.table[value]
254
+ for t in trail:
255
+ self.table[t] = value
256
+ return value
257
+
258
+ def merge(self, left: T, right: T) -> None:
259
+ if left == right:
260
+ return
261
+ left = self.find(left)
262
+ right = self.find(right)
263
+ if left == right:
264
+ return
265
+ self.representatives -= 1
266
+ self.generation += 1
267
+ left, right = sorted((left, right), key=self.key) # type: ignore[arg-type]
268
+ self.table[right] = left
269
+
270
+ def merge_all(self, values: list[T]) -> None:
271
+ if len(values) > 1:
272
+ sorted_values: list[T] = sorted(values, key=self.key) # type: ignore[arg-type]
273
+ a: T = sorted_values[0] # type: ignore[reportUnknownVariableType]
274
+ for b in sorted_values[1:]: # type: ignore[reportUnknownVariableType]
275
+ self.merge(a, b) # type: ignore[reportUnknownArgumentType]
276
+
277
+ def __repr__(self) -> str:
278
+ return "%s(%d components)" % (
279
+ type(self).__name__,
280
+ len(self.components()),
281
+ )
282
+
283
+
284
+ class BooleanEquivalence(UnionFind[int]):
285
+ table: "NegatingMap" # type: ignore[reportIncompatibleVariableOverride]
286
+
287
+ def __init__(self, initial_merges: Iterable[tuple[int, int]] = ()) -> None:
288
+ super().__init__(initial_merges, key=abs)
289
+ self.table = NegatingMap() # pyright: ignore[reportIncompatibleVariableOverride]
290
+
291
+ def find(self, value: int) -> int:
292
+ if not value:
293
+ raise ValueError("Invalid variable %r" % (value,))
294
+ return super().find(value)
295
+
296
+ def merge(self, left: int, right: int) -> None:
297
+ if left == right:
298
+ return
299
+ left2 = self.find(left)
300
+ right2 = self.find(right)
301
+ if left2 == right2:
140
302
  return
303
+ if left2 == -right2:
304
+ raise Inconsistent(
305
+ "Attempted to merge %d (=%d) with %d (=%d)"
306
+ % (left, left2, right, right2)
307
+ )
308
+ super().merge(left, right)
309
+
310
+
311
+ class NegatingMap:
312
+ _data: dict[int, int]
313
+
314
+ def __init__(self) -> None:
315
+ self._data = {}
316
+
317
+ def __repr__(self) -> str:
318
+ m: dict[int, int] = {}
319
+ for k, v in self._data.items():
320
+ m[k] = v
321
+ m[-k] = -v
322
+ return repr(m)
323
+
324
+ def __iter__(self) -> Iterator[int]:
325
+ yield from self._data.keys()
326
+ for k in self._data.keys():
327
+ yield -k
328
+
329
+ def __getitem__(self, key: int) -> int:
330
+ assert key != 0
331
+ if key < 0:
332
+ return -self._data[-key]
333
+ else:
334
+ return self._data[key]
335
+
336
+ def __setitem__(self, key: int, value: int) -> None:
337
+ assert key != 0
338
+ assert value != 0
339
+ if key < 0:
340
+ self._data[-key] = -value
341
+ else:
342
+ self._data[key] = value
343
+
344
+
345
+ async def merge_literals(problem: ReductionProblem[SAT]) -> None:
346
+ """Merge pairs of literals into single variables.
347
+
348
+ Tries to identify pairs of literals that can be treated as equivalent
349
+ (or negations of each other) and replaces them with a single variable.
350
+ This reduces the number of distinct variables in the formula.
351
+ """
352
+
353
+ def apply_merges(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
354
+ uf = BooleanEquivalence()
355
+ try:
356
+ for u, v in terms:
357
+ uf.merge(u, v)
358
+ except Inconsistent:
359
+ raise Conflict()
360
+
361
+ result: set[frozenset[int]] = set()
362
+ for clause in sat:
363
+ new_clause = frozenset(map(uf.find, clause))
364
+ result.add(new_clause)
365
+ return sorted([sorted(clause, key=abs) for clause in result], key=len)
366
+
367
+ await apply_patches(
368
+ problem,
369
+ SetPatches(apply_merges),
370
+ [
371
+ frozenset({(u, -v)})
372
+ for clause in problem.current_test_case
373
+ for u in clause
374
+ for v in clause
375
+ if u != v
376
+ ],
377
+ )
378
+ await unit_propagate(problem)
379
+
380
+
381
+ async def pass_to_component(problem: ReductionProblem[SAT]) -> None:
382
+ """Try to reduce to a single connected component.
383
+
384
+ If the formula can be split into independent components (clauses that
385
+ share no variables), tries each component individually to see if any
386
+ single component is sufficient to maintain interestingness.
387
+ """
388
+ groups: UnionFind[int] = UnionFind()
389
+ clauses = problem.current_test_case
390
+ for clause in clauses:
391
+ groups.merge_all(list(map(abs, clause)))
392
+ partitions: defaultdict[int, SAT] = defaultdict(list)
393
+ for clause in clauses:
394
+ partitions[groups.find(abs(clause[0]))].append(clause)
395
+ if len(partitions) > 1:
396
+ for p in sorted(partitions.values(), key=len):
397
+ await problem.is_interesting(p)
398
+
399
+
400
+ async def sort_clauses(problem: ReductionProblem[SAT]) -> None:
401
+ """Sort clauses and literals into canonical order.
402
+
403
+ Sorts literals within each clause and sorts clauses themselves.
404
+ This normalizes the formula representation for consistent output.
405
+ """
406
+ await problem.is_interesting(sorted(map(sorted, problem.current_test_case)))
141
407
 
142
- target = variables[i]
143
- to_replace = variables[j]
144
-
145
- new_clauses = []
146
- for c in problem.current_test_case:
147
- c = set(c)
148
- if to_replace in c:
149
- c.discard(to_replace)
150
- c.add(target)
151
- if -to_replace in c:
152
- c.discard(-to_replace)
153
- c.add(-target)
154
- if len(set(map(abs, c))) < len(c):
408
+
409
+ class Inconsistent(Exception):
410
+ pass
411
+
412
+
413
+ class UnitPropagator:
414
+ __clauses: list[tuple[int, ...]]
415
+ __clause_counts: Counter[int]
416
+ __watches: defaultdict[int, frozenset[int]]
417
+ __watched_by: list[frozenset[int]]
418
+ units: set[int]
419
+ forced_variables: set[int]
420
+ __dirty: set[int]
421
+
422
+ def __init__(self, clauses: Iterable[Iterable[int]]) -> None:
423
+ self.__clauses = [tuple(c) for c in clauses]
424
+ self.__clause_counts = Counter()
425
+ for clause in self.__clauses:
426
+ for literal in clause:
427
+ self.__clause_counts[abs(literal)] += 1
428
+ self.__watches = defaultdict(frozenset)
429
+ self.__watched_by = [frozenset() for _ in self.__clauses]
430
+
431
+ self.units = set()
432
+ self.forced_variables = set()
433
+ self.__dirty = set(range(len(self.__clauses)))
434
+ self.__clean_dirty_clauses()
435
+
436
+ def __enqueue_unit(self, unit: int) -> None:
437
+ # Invariant: unit should not already be in self.units because satisfied
438
+ # clauses are skipped at line 424 before we try to enqueue their units.
439
+ assert unit not in self.units, f"unit {unit} already enqueued"
440
+ # Invariant: -unit should not be in self.units because we only add
441
+ # literals to watched_by if their negation is not in units (line 438).
442
+ assert -unit not in self.units, (
443
+ f"Tried to add {unit} as a unit but {-unit} is already a unit"
444
+ )
445
+ self.units.add(unit)
446
+ self.forced_variables.add(abs(unit))
447
+ self.__dirty.update(self.__watches.pop(-unit, ()))
448
+
449
+ def __clean_dirty_clauses(self) -> None:
450
+ iters = 0
451
+ while self.__dirty:
452
+ iters += 1
453
+ assert iters <= 10**6
454
+ dirty = self.__dirty
455
+ self.__dirty = set()
456
+
457
+ for i in dirty:
458
+ clause = self.__clauses[i]
459
+ if not clause:
460
+ raise Inconsistent("Clauses contain an empty clause")
461
+ if any(literal in self.units for literal in clause):
462
+ for literal in self.__watched_by[i]:
463
+ if literal in self.__watches:
464
+ self.__watches[literal] -= {i}
465
+ self.__watched_by[i] = frozenset()
466
+ for literal in clause:
467
+ self.__clause_counts[abs(literal)] -= 1
468
+ else:
469
+ for literal in list(self.__watched_by[i]):
470
+ if -literal in self.units:
471
+ self.__watched_by[i] -= {literal}
472
+ for literal in clause:
473
+ if len(self.__watched_by[i]) == 2:
474
+ break
475
+ if -literal not in self.units:
476
+ self.__watches[literal] |= {i}
477
+ self.__watched_by[i] |= {literal}
478
+ if len(self.__watched_by[i]) == 0:
479
+ raise Inconsistent(
480
+ f"Clause {' '.join(map(str, clause))} can no longer be satisfied"
481
+ )
482
+ elif len(self.__watched_by[i]) == 1:
483
+ self.__enqueue_unit(*self.__watched_by[i])
484
+
485
+ def propagated_clauses(self) -> SAT:
486
+ results: set[tuple[int, ...]] = set()
487
+ neg_units = {-v for v in self.units}
488
+ for clause in self.__clauses:
489
+ if any(literal in self.units for literal in clause):
155
490
  continue
156
- new_clauses.append(sorted(c))
491
+ if not neg_units.isdisjoint(clause):
492
+ clause = tuple(sorted(set(clause) - neg_units))
493
+ results.add(clause)
494
+ return [[literal] for literal in self.units] + [
495
+ list(c)
496
+ for c in sorted(results, key=lambda c: (len(c), list(map(abs, c)), c))
497
+ ]
498
+
499
+
500
+ async def unit_propagate(problem: ReductionProblem[SAT]) -> None:
501
+ """Apply unit propagation to simplify the formula.
502
+
503
+ Finds unit clauses (single-literal clauses) and propagates their
504
+ implications: removes satisfied clauses and removes the negated
505
+ literal from other clauses. This is a standard SAT preprocessing step.
506
+ """
507
+ try:
508
+ propagated = UnitPropagator(problem.current_test_case).propagated_clauses()
509
+ except Inconsistent:
510
+ # Clauses are unsatisfiable, nothing to propagate
511
+ return
512
+ if not await problem.is_interesting([c for c in propagated if len(c) > 1]):
513
+ await problem.is_interesting(propagated)
514
+
515
+
516
+ async def force_literals(problem: ReductionProblem[SAT]) -> None:
517
+ """Try forcing each literal to a specific value.
518
+
519
+ For each literal in the formula, tries adding it as a unit clause
520
+ and propagating. If the result is interesting, the formula is
521
+ simplified by that forced assignment.
522
+ """
523
+ literals = literals_in(problem.current_test_case)
524
+ for lit in literals:
525
+ try:
526
+ await problem.is_interesting(
527
+ UnitPropagator(problem.current_test_case + [[lit]]).propagated_clauses()
528
+ )
529
+ except Inconsistent:
530
+ pass
157
531
 
158
- assert new_clauses != problem.current_test_case
159
- await problem.is_interesting(new_clauses)
160
- if new_clauses != problem.current_test_case:
161
- j += 1
162
532
 
533
+ async def combine_clauses(problem: ReductionProblem[SAT]) -> None:
534
+ """Merge pairs of clauses into single clauses.
535
+
536
+ Tries to combine clauses that share literals, creating a single
537
+ clause containing all literals from both. This reduces clause count
538
+ while potentially creating larger but fewer clauses.
539
+ """
540
+
541
+ def apply_merges(terms: frozenset[tuple[int, int]], sat: SAT) -> SAT:
542
+ uf: UnionFind[int] = UnionFind()
543
+ for u, v in terms:
544
+ uf.merge(u, v)
545
+
546
+ result: list[Clause | None] = [list(c) for c in sat]
547
+ for c in uf.components():
548
+ # Note: len(c) == 1 can't occur because every element in uf
549
+ # came from a merge(u, v) pair where u != v, so all
550
+ # components have size >= 2.
551
+ combined: Clause = sorted({lit for i in c for lit in sat[i]}, key=abs)
552
+ for i in c:
553
+ result[i] = None
554
+ if len(combined) == len(set(map(abs, combined))):
555
+ result.append(combined)
556
+ return [clause for clause in result if clause is not None]
557
+
558
+ by_literal: defaultdict[int, list[int]] = defaultdict(list)
559
+ for i, clause in enumerate(problem.current_test_case):
560
+ for lit in clause:
561
+ by_literal[lit].append(i)
163
562
 
164
- async def sort_clauses(problem: ReductionProblem[SAT]):
165
- await problem.is_interesting(sorted(map(sorted, problem.current_test_case)))
563
+ await apply_patches(
564
+ problem,
565
+ SetPatches(apply_merges),
566
+ [
567
+ frozenset({(i, j)})
568
+ for group in by_literal.values()
569
+ for i in group
570
+ for j in group
571
+ if i != j
572
+ ]
573
+ + [frozenset({(i, i + 1)}) for i in range(len(problem.current_test_case) - 1)],
574
+ )
575
+ await unit_propagate(problem)
166
576
 
167
577
 
168
578
  SAT_PASSES: list[ReductionPass[SAT]] = [
169
579
  sort_clauses,
170
- renumber_variables,
171
- flip_literal_signs,
172
- remove_redundant_clauses,
173
- delete_elements,
580
+ force_literals,
581
+ pass_to_component,
582
+ unit_propagate,
174
583
  delete_literals,
175
- merge_variables,
584
+ delete_single_terms,
585
+ delete_elements,
586
+ flip_literal_signs,
587
+ combine_clauses,
588
+ merge_literals,
589
+ renumber_variables,
176
590
  ]