crosshair-tool 0.0.95__cp310-cp310-macosx_11_0_arm64.whl → 0.0.97__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crosshair-tool might be problematic. Click here for more details.

Files changed (46) hide show
  1. _crosshair_tracers.cpython-310-darwin.so +0 -0
  2. crosshair/__init__.py +1 -1
  3. crosshair/_tracers_test.py +5 -5
  4. crosshair/codeconfig.py +3 -2
  5. crosshair/condition_parser.py +1 -0
  6. crosshair/condition_parser_test.py +0 -2
  7. crosshair/core.py +8 -9
  8. crosshair/core_test.py +2 -3
  9. crosshair/diff_behavior_test.py +0 -2
  10. crosshair/dynamic_typing.py +3 -3
  11. crosshair/enforce.py +1 -0
  12. crosshair/examples/check_examples_test.py +1 -0
  13. crosshair/fnutil.py +2 -3
  14. crosshair/fnutil_test.py +1 -4
  15. crosshair/fuzz_core_test.py +9 -1
  16. crosshair/libimpl/arraylib.py +1 -1
  17. crosshair/libimpl/builtinslib.py +77 -24
  18. crosshair/libimpl/builtinslib_ch_test.py +15 -5
  19. crosshair/libimpl/builtinslib_test.py +38 -1
  20. crosshair/libimpl/collectionslib_test.py +4 -4
  21. crosshair/libimpl/datetimelib.py +1 -3
  22. crosshair/libimpl/datetimelib_ch_test.py +5 -5
  23. crosshair/libimpl/encodings/_encutil.py +11 -6
  24. crosshair/libimpl/functoolslib.py +8 -2
  25. crosshair/libimpl/functoolslib_test.py +22 -6
  26. crosshair/libimpl/relib.py +1 -1
  27. crosshair/libimpl/unicodedatalib_test.py +3 -3
  28. crosshair/main.py +5 -3
  29. crosshair/opcode_intercept.py +45 -17
  30. crosshair/path_cover.py +5 -1
  31. crosshair/pathing_oracle.py +40 -3
  32. crosshair/pathing_oracle_test.py +21 -0
  33. crosshair/register_contract.py +1 -0
  34. crosshair/register_contract_test.py +2 -4
  35. crosshair/simplestructs.py +10 -8
  36. crosshair/statespace.py +74 -19
  37. crosshair/statespace_test.py +16 -0
  38. crosshair/tools/generate_demo_table.py +2 -2
  39. crosshair/tracers.py +8 -6
  40. crosshair/util.py +6 -6
  41. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/METADATA +4 -5
  42. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/RECORD +46 -45
  43. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/WHEEL +0 -0
  44. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/entry_points.txt +0 -0
  45. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/licenses/LICENSE +0 -0
  46. {crosshair_tool-0.0.95.dist-info → crosshair_tool-0.0.97.dist-info}/top_level.txt +0 -0
@@ -939,6 +939,31 @@ def test_str_replace_method() -> None:
939
939
  check_states(f, POST_FAIL)
940
940
 
941
941
 
942
+ def test_str_startswith(space) -> None:
943
+ symbolic_char = proxy_for_type(str, "x")
944
+ symbolic_empty = proxy_for_type(str, "y")
945
+ with ResumedTracing():
946
+ space.add(len(symbolic_char) == 1)
947
+ space.add(len(symbolic_empty) == 0)
948
+ assert symbolic_char.startswith(symbolic_empty)
949
+ assert symbolic_char.startswith(symbolic_char)
950
+ assert symbolic_char.startswith(("foo", symbolic_empty))
951
+ assert not symbolic_char.startswith(("foo", "bar"))
952
+ assert symbolic_char.startswith(("", "bar"))
953
+ assert symbolic_char.startswith("")
954
+ assert symbolic_char.startswith(symbolic_empty, 1)
955
+ assert symbolic_char.startswith(symbolic_empty, 1, 1)
956
+ assert str.startswith(symbolic_char, symbolic_empty)
957
+ assert "foo".startswith(symbolic_empty)
958
+ assert not "".startswith(symbolic_char)
959
+
960
+ # Yes, the empty string is findable off the left side but not the right
961
+ assert "x".startswith("", -10, -9)
962
+ assert symbolic_char.startswith(symbolic_empty, -10, -9)
963
+ assert not "x".startswith("", 9, 10)
964
+ assert not symbolic_char.startswith(symbolic_empty, 9, 10)
965
+
966
+
942
967
  @pytest.mark.demo
943
968
  def test_str_index_method() -> None:
944
969
  def f(a: str) -> int:
@@ -1425,7 +1450,7 @@ def test_str_lower():
1425
1450
 
1426
1451
 
1427
1452
  def test_str_title():
1428
- chr_lj = "\u01C9" # "lj"
1453
+ chr_lj = "\u01c9" # "lj"
1429
1454
  chr_Lj = "\u01c8" # "Lj" (different from "LJ", "\u01c7")
1430
1455
  with standalone_statespace:
1431
1456
  with NoTracing():
@@ -2624,6 +2649,7 @@ if sys.version_info >= (3, 9):
2624
2649
  def test_set_basic_fail() -> None:
2625
2650
  def f(a: Set[int], k: int) -> None:
2626
2651
  """
2652
+ pre: len(a) <= 2
2627
2653
  post[a]: k+1 in a
2628
2654
  """
2629
2655
  a.add(k)
@@ -3134,6 +3160,17 @@ def test_callable_as_bool() -> None:
3134
3160
  check_states(f, CONFIRMED)
3135
3161
 
3136
3162
 
3163
+ def test_callable_can_return_different_values(space) -> None:
3164
+ fn = proxy_for_type(Callable[[], int], "fn")
3165
+ with ResumedTracing():
3166
+ first_return = fn()
3167
+ second_return = fn()
3168
+ returns_are_equal = first_return == second_return
3169
+ returns_are_not_equal = first_return != second_return
3170
+ assert space.is_possible(returns_are_equal)
3171
+ assert space.is_possible(returns_are_not_equal)
3172
+
3173
+
3137
3174
  @pytest.mark.smoke
3138
3175
  def test_callable_repr() -> None:
3139
3176
  def f(f1: Callable[[int], int]) -> int:
@@ -3,7 +3,7 @@ import sys
3
3
  from collections import Counter, defaultdict, deque, namedtuple
4
4
  from copy import deepcopy
5
5
  from inspect import Parameter, Signature
6
- from typing import Counter, DefaultDict, Deque, NamedTuple, Tuple
6
+ from typing import Callable, Counter, DefaultDict, Deque, Dict, NamedTuple, Tuple
7
7
 
8
8
  import pytest
9
9
 
@@ -14,7 +14,7 @@ from crosshair.core import (
14
14
  realize,
15
15
  standalone_statespace,
16
16
  )
17
- from crosshair.libimpl.collectionslib import ListBasedDeque
17
+ from crosshair.libimpl.collectionslib import ListBasedDeque, PureDefaultDict
18
18
  from crosshair.statespace import CANNOT_CONFIRM, CONFIRMED, POST_FAIL, MessageType
19
19
  from crosshair.test_util import check_states
20
20
  from crosshair.tracers import NoTracing, ResumedTracing
@@ -246,10 +246,10 @@ def test_defaultdict_default_fail(test_list) -> None:
246
246
 
247
247
 
248
248
  def test_defaultdict_default_ok(test_list) -> None:
249
- def f(a: DefaultDict[int, int], k1: int, k2: int) -> DefaultDict[int, int]:
249
+ def f(a: DefaultDict[int, int], k: int) -> DefaultDict[int, int]:
250
250
  """
251
251
  pre: len(a) == 0 and a.default_factory is not None
252
- post: _[k1] == _[k2]
252
+ post: _[k] == _[k]
253
253
  """
254
254
  return a
255
255
 
@@ -627,9 +627,7 @@ class timedelta:
627
627
 
628
628
  def total_seconds(self):
629
629
  """Total seconds in the duration."""
630
- return (
631
- (self.days * 86400 + self.seconds) * 10**6 + self.microseconds
632
- ) / 10**6
630
+ return ((self.days * 86400 + self.seconds) * 10**6 + self.microseconds) / 10**6
633
631
 
634
632
  # Read-only field accessors
635
633
  @property
@@ -29,7 +29,7 @@ def check_datetimelib_lt(
29
29
  Tuple[timedelta, timedelta],
30
30
  Tuple[date, datetime],
31
31
  Tuple[datetime, datetime],
32
- ]
32
+ ],
33
33
  ) -> ResultComparison:
34
34
  """post: _"""
35
35
  return compare_results(operator.lt, *p)
@@ -40,7 +40,7 @@ def check_datetimelib_add(
40
40
  Tuple[timedelta, timedelta],
41
41
  Tuple[date, timedelta],
42
42
  Tuple[timedelta, datetime],
43
- ]
43
+ ],
44
44
  ) -> ResultComparison:
45
45
  """post: _"""
46
46
  return compare_results(operator.add, *p)
@@ -52,14 +52,14 @@ def check_datetimelib_subtract(
52
52
  Tuple[date, timedelta],
53
53
  Tuple[datetime, timedelta],
54
54
  Tuple[datetime, datetime],
55
- ]
55
+ ],
56
56
  ) -> ResultComparison:
57
57
  """post: _"""
58
58
  return compare_results(operator.sub, *p)
59
59
 
60
60
 
61
61
  def check_datetimelib_str(
62
- obj: Union[timedelta, timezone, date, time, datetime]
62
+ obj: Union[timedelta, timezone, date, time, datetime],
63
63
  ) -> ResultComparison:
64
64
  """post: _"""
65
65
  return compare_results(_invoker("__str__"), obj)
@@ -67,7 +67,7 @@ def check_datetimelib_str(
67
67
 
68
68
  def check_datetimelib_repr(
69
69
  # TODO: re-enable time, datetime repr checking after fixing in Python 3.11
70
- obj: Union[timedelta, timezone, date]
70
+ obj: Union[timedelta, timezone, date],
71
71
  ) -> ResultComparison:
72
72
  """post: _"""
73
73
  return compare_results(_invoker("__repr__"), obj)
@@ -25,6 +25,7 @@ class UnexpectedEndError(ChunkError):
25
25
  @dataclass
26
26
  class MidChunkError(ChunkError):
27
27
  _reason: str
28
+
28
29
  # _errlen: int = 1
29
30
  def reason(self) -> str:
30
31
  return self._reason
@@ -112,7 +113,7 @@ class StemEncoder:
112
113
  continue
113
114
  if errors == "replace":
114
115
  idx += 1
115
- parts.append("\uFFFD")
116
+ parts.append("\ufffd")
116
117
  continue
117
118
 
118
119
  # 2. Then fall back to native implementations if necessary:
@@ -132,24 +133,28 @@ class StemEncoder:
132
133
 
133
134
  def _getregentry(stem_encoder: Type[StemEncoder]):
134
135
  class StemIncrementalEncoder(codecs.BufferedIncrementalEncoder):
135
- def _buffer_encode(self, input: str, errors: str, final: bool) -> bytes:
136
+ def _buffer_encode(
137
+ self, input: str, errors: str, final: bool
138
+ ) -> Tuple[bytes, int]:
136
139
  enc_name = stem_encoder.encoding_name
137
140
  out, idx, err = stem_encoder._encode_chunk(input, 0)
138
141
  assert isinstance(out, bytes)
139
142
  if not err:
140
- return out
143
+ return (out, idx)
141
144
  if isinstance(err, UnexpectedEndError) or not final:
142
- return out
145
+ return (out, idx)
143
146
  exc = UnicodeEncodeError(enc_name, input, idx, idx + 1, err.reason())
144
147
  replacement, idx = codecs.lookup_error(errors)(exc)
145
148
  if isinstance(replacement, str):
146
149
  replacement = codecs.encode(replacement, enc_name)
147
- return out + replacement
150
+ return (out + replacement, idx)
148
151
 
149
152
  class StemIncrementalDecoder(codecs.BufferedIncrementalDecoder):
150
153
  def _buffer_decode(
151
- self, input: bytes, errors: str, final: bool
154
+ self, input: Buffer, errors: str, final: bool
152
155
  ) -> Tuple[str, int]:
156
+ if not isinstance(input, bytes):
157
+ input = memoryview(input).tobytes()
153
158
  enc_name = stem_encoder.encoding_name
154
159
  out, idx, err = stem_encoder._decode_chunk(input, 0)
155
160
  assert isinstance(out, str)
@@ -1,4 +1,4 @@
1
- from functools import _lru_cache_wrapper, partial, reduce
1
+ from functools import _lru_cache_wrapper, partial, reduce, update_wrapper, wraps
2
2
 
3
3
  from crosshair.core import register_patch
4
4
 
@@ -7,7 +7,13 @@ from crosshair.core import register_patch
7
7
 
8
8
  def _partial(func, *a1, **kw1):
9
9
  if callable(func):
10
- return partial(lambda *a2, **kw2: func(*a2, **kw2), *a1, **kw1)
10
+ # We make a do-nothing wrapper to ensure that the tracer has a crack
11
+ # at this function when it is called.
12
+ def wrapper(*a2, **kw2):
13
+ return func(*a2, **kw2)
14
+
15
+ update_wrapper(wrapper, func)
16
+ return partial(wrapper, *a1, **kw1)
11
17
  else:
12
18
  raise TypeError
13
19
 
@@ -1,20 +1,36 @@
1
1
  import functools
2
+ import inspect
2
3
 
3
4
  from crosshair.core import proxy_for_type, standalone_statespace
4
5
  from crosshair.libimpl.builtinslib import LazyIntSymbolicStr
5
- from crosshair.tracers import NoTracing
6
+ from crosshair.tracers import NoTracing, ResumedTracing
6
7
 
7
8
 
8
- def test_partial():
9
- with standalone_statespace as space:
10
- with NoTracing():
11
- abc = LazyIntSymbolicStr(list(map(ord, "abc")))
12
- xyz = LazyIntSymbolicStr(list(map(ord, "xyz")))
9
+ def test_partial(space):
10
+ abc = LazyIntSymbolicStr(list(map(ord, "abc")))
11
+ xyz = LazyIntSymbolicStr(list(map(ord, "xyz")))
12
+ with ResumedTracing():
13
13
  joiner = functools.partial(str.join, ",")
14
14
  ret = joiner([abc, xyz])
15
15
  assert ret == "abc,xyz"
16
16
 
17
17
 
18
+ def test_partial_is_interceptable(space):
19
+ x = proxy_for_type(str, "x")
20
+ y = proxy_for_type(str, "y")
21
+ with ResumedTracing():
22
+ joiner = functools.partial(str.startswith, x)
23
+ # Ensure we don't explode
24
+ list(map(joiner, ["foo", y]))
25
+
26
+
27
+ def test_partial_arg_is_inspectable(space):
28
+ with ResumedTracing():
29
+ joiner = functools.partial(str.join, ",")
30
+ assert isinstance(joiner, functools.partial)
31
+ assert inspect.getdoc(joiner.func) == inspect.getdoc(str.join)
32
+
33
+
18
34
  def test_reduce():
19
35
  with standalone_statespace as space:
20
36
  with NoTracing():
@@ -7,7 +7,7 @@ from unicodedata import category
7
7
  if sys.version_info < (3, 11):
8
8
  import sre_parse as re_parser
9
9
  else:
10
- import re._parser as re_parser
10
+ import re._parser as re_parser # type: ignore
11
11
 
12
12
  from sys import maxunicode
13
13
  from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
@@ -11,7 +11,7 @@ def test_numeric():
11
11
  with standalone_statespace as space:
12
12
  with NoTracing():
13
13
  fourstr = LazyIntSymbolicStr(list(map(ord, "4")))
14
- halfstr = LazyIntSymbolicStr(list(map(ord, "\u00BD"))) # (1/2 character)
14
+ halfstr = LazyIntSymbolicStr(list(map(ord, "\u00bd"))) # (1/2 character)
15
15
  four = unicodedata.numeric(fourstr)
16
16
  half = unicodedata.numeric(halfstr)
17
17
  assert type(four) is float
@@ -22,7 +22,7 @@ def test_numeric():
22
22
  def test_decimal():
23
23
  with standalone_statespace as space:
24
24
  with NoTracing():
25
- thai4 = LazyIntSymbolicStr(list(map(ord, "\u0E54"))) # (Thai numerial 4)
25
+ thai4 = LazyIntSymbolicStr(list(map(ord, "\u0e54"))) # (Thai numerial 4)
26
26
  super4 = LazyIntSymbolicStr(list(map(ord, "\u2074"))) # (superscript 4)
27
27
  four = unicodedata.decimal(thai4)
28
28
  assert type(four) is int
@@ -34,7 +34,7 @@ def test_decimal():
34
34
  def test_digit():
35
35
  with standalone_statespace as space:
36
36
  with NoTracing():
37
- thai4 = LazyIntSymbolicStr(list(map(ord, "\u0E54"))) # (Thai numerial 4)
37
+ thai4 = LazyIntSymbolicStr(list(map(ord, "\u0e54"))) # (Thai numerial 4)
38
38
  super4 = LazyIntSymbolicStr(list(map(ord, "\u2074"))) # (superscript 4)
39
39
  four = unicodedata.digit(thai4)
40
40
  assert type(four) is int
crosshair/main.py CHANGED
@@ -774,9 +774,11 @@ def cover(
774
774
  ctxfn,
775
775
  options,
776
776
  args.coverage_type,
777
- arg_formatter=format_boundargs_as_dictionary
778
- if example_output_format == ExampleOutputFormat.ARG_DICTIONARY
779
- else format_boundargs,
777
+ arg_formatter=(
778
+ format_boundargs_as_dictionary
779
+ if example_output_format == ExampleOutputFormat.ARG_DICTIONARY
780
+ else format_boundargs
781
+ ),
780
782
  )
781
783
  except NotDeterministic:
782
784
  print(
@@ -5,7 +5,9 @@ from collections import defaultdict
5
5
  from collections.abc import MutableMapping, Set
6
6
  from sys import version_info
7
7
  from types import CodeType, FrameType
8
- from typing import Any, Callable, Iterable, Mapping, Tuple, Union
8
+ from typing import Any, Callable, Iterable, List, Mapping, Tuple, Union
9
+
10
+ from z3 import ExprRef # type: ignore
9
11
 
10
12
  from crosshair.core import (
11
13
  ATOMIC_IMMUTABLE_TYPES,
@@ -53,6 +55,7 @@ TO_BOOL = dis.opmap.get("TO_BOOL", 256)
53
55
  IS_OP = dis.opmap.get("IS_OP", 256)
54
56
  BINARY_MODULO = dis.opmap.get("BINARY_MODULO", 256)
55
57
  BINARY_OP = dis.opmap.get("BINARY_OP", 256)
58
+ LOAD_COMMON_CONSTANT = dis.opmap.get("LOAD_COMMON_CONSTANT", 256)
56
59
 
57
60
 
58
61
  def frame_op_arg(frame):
@@ -83,13 +86,6 @@ class MultiSubscriptableContainer:
83
86
  if isinstance(container, Mapping):
84
87
  kv_pairs: Iterable[Tuple[Any, Any]] = container.items()
85
88
  else:
86
- in_bounds = space.smt_fork(
87
- z3Or(-len(container) <= key.var, key.var < len(container)),
88
- desc=f"index_in_bounds",
89
- probability_true=0.9,
90
- )
91
- if not in_bounds:
92
- raise IndexError
93
89
  kv_pairs = enumerate(container)
94
90
 
95
91
  values_by_type = defaultdict(list)
@@ -118,7 +114,7 @@ class MultiSubscriptableContainer:
118
114
  keys_by_value_id[id(cur_value)].append(cur_key)
119
115
  for value_type, cur_pairs in values_by_type.items():
120
116
  hypothetical_result = symbolic_for_pytype(value_type)(
121
- "item_at_" + space.uniq(), value_type
117
+ "item_" + space.uniq(), value_type
122
118
  )
123
119
  with ResumedTracing():
124
120
  condition_pairs = []
@@ -139,20 +135,17 @@ class MultiSubscriptableContainer:
139
135
  space.add(any([all(pair) for pair in condition_pairs]))
140
136
  return hypothetical_result
141
137
 
142
- for (value_id, value), probability_true in with_uniform_probabilities(
143
- values_by_id.items()
144
- ):
138
+ exprs_and_values: List[Tuple[ExprRef, object]] = []
139
+ for value_id, value in values_by_id.items():
145
140
  keys_for_value = keys_by_value_id[value_id]
146
141
  with ResumedTracing():
147
142
  is_match = any([key == k for k in keys_for_value])
148
143
  if isinstance(is_match, SymbolicBool):
149
- if space.smt_fork(
150
- is_match.var,
151
- probability_true=probability_true,
152
- ):
153
- return value
144
+ exprs_and_values.append((is_match.var, value))
154
145
  elif is_match:
155
146
  return value
147
+ if exprs_and_values:
148
+ return space.smt_fanout(exprs_and_values, desc="multi_subscript")
156
149
 
157
150
  if type(container) is dict:
158
151
  raise KeyError # ( f"Key {key} not found in dict")
@@ -160,6 +153,39 @@ class MultiSubscriptableContainer:
160
153
  raise IndexError # (f"Index {key} out of range for list/tuple of length {len(container)}")
161
154
 
162
155
 
156
+ class LoadCommonConstantInterceptor(TracingModule):
157
+ """
158
+ As of 3.14, the bytecode generation process generates optimizations
159
+ for builtins.any/all when invoked on a generator expression.
160
+ It essentially "inlines" the logic as bytecode.
161
+ We need to avoid this.
162
+ Before entering the optimized code path, it will check that the any/all
163
+ function is identity-equal to the original builtin, which is loaded using
164
+ the LOAD_COMMON_CONSTANT opcode.
165
+
166
+ This interceptor replaces that function with a proxy that functions
167
+ identically but is not identity-equal (so that we avoid the optimized
168
+ path),
169
+ """
170
+
171
+ opcodes_wanted = frozenset([LOAD_COMMON_CONSTANT])
172
+
173
+ def trace_op(self, frame, codeobj, codenum):
174
+ CONSTANT_BUILTIN_ALL = 3
175
+ CONSTANT_BUILTIN_ANY = 4
176
+ index = frame_op_arg(frame)
177
+
178
+ def post_op():
179
+ expected_fn = all if index == CONSTANT_BUILTIN_ALL else any
180
+ if CROSSHAIR_EXTRA_ASSERTS:
181
+ if frame_stack_read(frame, -1) is not expected_fn:
182
+ raise CrossHairInternal
183
+ frame_stack_write(frame, -1, lambda *a: expected_fn(*a))
184
+
185
+ if index == CONSTANT_BUILTIN_ALL or index == CONSTANT_BUILTIN_ANY:
186
+ COMPOSITE_TRACER.set_postop_callback(post_op, frame)
187
+
188
+
163
189
  class SymbolicSubscriptInterceptor(TracingModule):
164
190
  opcodes_wanted = frozenset([BINARY_SUBSCR, BINARY_OP])
165
191
 
@@ -562,6 +588,8 @@ def make_registrations():
562
588
  register_opcode_patch(SymbolicSliceInterceptor())
563
589
  if sys.version_info < (3, 9):
564
590
  register_opcode_patch(ComparisonInterceptForwarder())
591
+ if sys.version_info >= (3, 14):
592
+ register_opcode_patch(LoadCommonConstantInterceptor())
565
593
  register_opcode_patch(ContainmentInterceptor())
566
594
  register_opcode_patch(BuildStringInterceptor())
567
595
  register_opcode_patch(FormatValueInterceptor())
crosshair/path_cover.py CHANGED
@@ -133,7 +133,11 @@ def path_cover(
133
133
  selected: List[PathSummary] = []
134
134
  while paths:
135
135
  next_best = max(
136
- paths, key=lambda p: len(p.coverage.offsets_covered - opcodes_found)
136
+ paths,
137
+ key=lambda p: (
138
+ len(p.coverage.offsets_covered - opcodes_found), # high coverage
139
+ -len(p.formatted_args), # with small input size
140
+ ),
137
141
  )
138
142
  cur_offsets = next_best.coverage.offsets_covered
139
143
  if coverage_type == CoverageType.OPCODE:
@@ -11,9 +11,11 @@ from crosshair.statespace import (
11
11
  NodeLike,
12
12
  RootNode,
13
13
  SearchTreeNode,
14
+ StateSpace,
14
15
  WorstResultNode,
15
16
  )
16
17
  from crosshair.util import CrossHairInternal, debug, in_debug
18
+ from crosshair.z3util import z3And, z3Not, z3Or
17
19
 
18
20
  CodeLoc = Tuple[str, ...]
19
21
 
@@ -60,7 +62,8 @@ class CoveragePathingOracle(AbstractPathingOracle):
60
62
  # (even just a 10% change could be much larger than it would be otherwise)
61
63
  _delta_probabilities = {-1: 0.1, 0: 0.25, 1: 0.9}
62
64
 
63
- def pre_path_hook(self, root: RootNode) -> None:
65
+ def pre_path_hook(self, space: StateSpace) -> None:
66
+ root = space._root
64
67
  visits = self.visits
65
68
  _delta_probabilities = self._delta_probabilities
66
69
 
@@ -213,16 +216,50 @@ class PreferNegativeOracle(AbstractPathingOracle):
213
216
  return 0.25
214
217
 
215
218
 
219
+ class ConstrainedOracle(AbstractPathingOracle):
220
+ """
221
+ A pathing oracle that prefers to take a path that satisfies
222
+ explicitly provided constraints.
223
+ """
224
+
225
+ def __init__(self, inner_oracle: AbstractPathingOracle):
226
+ self.inner_oracle = inner_oracle
227
+ self.exprs: List[ExprRef] = []
228
+
229
+ def prefer(self, expr: ExprRef):
230
+ self.exprs.append(expr)
231
+
232
+ def pre_path_hook(self, space: StateSpace) -> None:
233
+ self.space = space
234
+ self.exprs = []
235
+ self.inner_oracle.pre_path_hook(space)
236
+
237
+ def post_path_hook(self, path: Sequence["SearchTreeNode"]) -> None:
238
+ self.inner_oracle.post_path_hook(path)
239
+
240
+ def decide(
241
+ self, root, node: "WorstResultNode", engine_probability: Optional[float]
242
+ ) -> float:
243
+ # We always run the inner oracle in case it's tracking something about the path.
244
+ default_probability = self.inner_oracle.decide(root, node, engine_probability)
245
+ if not self.space.is_possible(z3And(*[node.expr, *self.exprs])):
246
+ return 0.0
247
+ elif not self.space.is_possible(z3And(*[z3Not(node.expr), *self.exprs])):
248
+ return 1.0
249
+ else:
250
+ return default_probability
251
+
252
+
216
253
  class RotatingOracle(AbstractPathingOracle):
217
254
  def __init__(self, oracles: List[AbstractPathingOracle]):
218
255
  self.oracles = oracles
219
256
  self.index = -1
220
257
 
221
- def pre_path_hook(self, root: "RootNode") -> None:
258
+ def pre_path_hook(self, space: StateSpace) -> None:
222
259
  oracles = self.oracles
223
260
  self.index = (self.index + 1) % len(oracles)
224
261
  for oracle in oracles:
225
- oracle.pre_path_hook(root)
262
+ oracle.pre_path_hook(space)
226
263
 
227
264
  def post_path_hook(self, path: Sequence["SearchTreeNode"]) -> None:
228
265
  for oracle in self.oracles:
@@ -0,0 +1,21 @@
1
+ import random
2
+
3
+ import z3 # type: ignore
4
+
5
+ from crosshair.pathing_oracle import ConstrainedOracle, PreferNegativeOracle
6
+ from crosshair.statespace import RootNode, SimpleStateSpace, WorstResultNode
7
+
8
+
9
+ def test_constrained_oracle():
10
+ oracle = ConstrainedOracle(PreferNegativeOracle())
11
+ x = z3.Int("x")
12
+ root = RootNode()
13
+ space = SimpleStateSpace()
14
+ oracle.pre_path_hook(space)
15
+ oracle.prefer(x >= 7)
16
+ rand = random.Random()
17
+ assert oracle.decide(root, WorstResultNode(rand, x < 7, space.solver), None) == 0.0
18
+ assert oracle.decide(root, WorstResultNode(rand, x >= 3, space.solver), None) == 1.0
19
+ assert (
20
+ oracle.decide(root, WorstResultNode(rand, x == 7, space.solver), None) == 0.25
21
+ )
@@ -1,4 +1,5 @@
1
1
  """API for registering contracts for external libraries."""
2
+
2
3
  from dataclasses import dataclass
3
4
  from inspect import Parameter, Signature, getmodule, ismethod, signature
4
5
  from types import MethodDescriptorType, ModuleType, WrapperDescriptorType
@@ -99,12 +99,10 @@ def test_register_numpy_randint():
99
99
 
100
100
  def test_register_overload():
101
101
  @overload
102
- def overld(a: int) -> int:
103
- ...
102
+ def overld(a: int) -> int: ...
104
103
 
105
104
  @overload
106
- def overld(a: str) -> str:
107
- ...
105
+ def overld(a: str) -> str: ...
108
106
 
109
107
  def overld(a: Union[int, str]) -> Union[int, str]:
110
108
  if isinstance(a, int):
@@ -38,10 +38,10 @@ class MapBase(collections.abc.MutableMapping):
38
38
  return NotImplemented
39
39
  if len(self) != len(other):
40
40
  return False
41
- for (k, self_value) in self.items():
41
+ for k, self_value in self.items():
42
42
  found = False
43
43
  # We do a slow nested loop search because we don't want to hash the key.
44
- for (other_key, other_value) in other.items():
44
+ for other_key, other_value in other.items():
45
45
  if other_key != k:
46
46
  continue
47
47
  if self_value == other_value:
@@ -122,7 +122,7 @@ class SimpleDict(MapBase):
122
122
  def __getitem__(self, key, default=_MISSING):
123
123
  if not is_hashable(key):
124
124
  raise TypeError("unhashable type")
125
- for (k, v) in self.contents_:
125
+ for k, v in self.contents_:
126
126
  # Note that the identity check below is not just an optimization;
127
127
  # it is required to implement the semantics of NaN dict keys
128
128
  if k is key or k == key:
@@ -134,7 +134,7 @@ class SimpleDict(MapBase):
134
134
  def __setitem__(self, key, value):
135
135
  if not is_hashable(key):
136
136
  raise TypeError("unhashable type")
137
- for (i, (k, v)) in enumerate(self.contents_):
137
+ for i, (k, v) in enumerate(self.contents_):
138
138
  if k == key:
139
139
  self.contents_[i] = (k, value)
140
140
  return
@@ -143,7 +143,7 @@ class SimpleDict(MapBase):
143
143
  def __delitem__(self, key):
144
144
  if not is_hashable(key):
145
145
  raise TypeError("unhashable type")
146
- for (i, (k, v)) in enumerate(self.contents_):
146
+ for i, (k, v) in enumerate(self.contents_):
147
147
  if k == key:
148
148
  del self.contents_[i]
149
149
  return
@@ -493,9 +493,11 @@ class SequenceConcatenation(collections.abc.Sequence, SeqBase):
493
493
  return second.__getitem__(
494
494
  slice(
495
495
  i.start - firstlen,
496
- i.stop
497
- if i.stop is None or i.stop < 0
498
- else i.stop - firstlen,
496
+ (
497
+ i.stop
498
+ if i.stop is None or i.stop < 0
499
+ else i.stop - firstlen
500
+ ),
499
501
  i.step,
500
502
  )
501
503
  )