crosshair-tool 0.0.56__cp39-cp39-macosx_11_0_arm64.whl → 0.0.100__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. _crosshair_tracers.cpython-39-darwin.so +0 -0
  2. crosshair/__init__.py +1 -1
  3. crosshair/_mark_stacks.h +51 -24
  4. crosshair/_tracers.h +9 -5
  5. crosshair/_tracers_test.py +19 -9
  6. crosshair/auditwall.py +9 -8
  7. crosshair/auditwall_test.py +31 -19
  8. crosshair/codeconfig.py +3 -2
  9. crosshair/condition_parser.py +17 -133
  10. crosshair/condition_parser_test.py +54 -96
  11. crosshair/conftest.py +1 -1
  12. crosshair/copyext.py +91 -22
  13. crosshair/copyext_test.py +33 -0
  14. crosshair/core.py +259 -203
  15. crosshair/core_and_libs.py +20 -0
  16. crosshair/core_regestered_types_test.py +82 -0
  17. crosshair/core_test.py +693 -664
  18. crosshair/diff_behavior.py +76 -21
  19. crosshair/diff_behavior_test.py +132 -23
  20. crosshair/dynamic_typing.py +128 -18
  21. crosshair/dynamic_typing_test.py +91 -4
  22. crosshair/enforce.py +1 -6
  23. crosshair/enforce_test.py +15 -23
  24. crosshair/examples/check_examples_test.py +2 -1
  25. crosshair/fnutil.py +2 -3
  26. crosshair/fnutil_test.py +0 -7
  27. crosshair/fuzz_core_test.py +70 -83
  28. crosshair/libimpl/arraylib.py +10 -7
  29. crosshair/libimpl/binascii_ch_test.py +30 -0
  30. crosshair/libimpl/binascii_test.py +67 -0
  31. crosshair/libimpl/binasciilib.py +150 -0
  32. crosshair/libimpl/bisectlib_test.py +5 -5
  33. crosshair/libimpl/builtinslib.py +1002 -682
  34. crosshair/libimpl/builtinslib_ch_test.py +108 -30
  35. crosshair/libimpl/builtinslib_test.py +431 -143
  36. crosshair/libimpl/codecslib.py +22 -2
  37. crosshair/libimpl/codecslib_test.py +41 -9
  38. crosshair/libimpl/collectionslib.py +44 -8
  39. crosshair/libimpl/collectionslib_test.py +108 -20
  40. crosshair/libimpl/copylib.py +1 -1
  41. crosshair/libimpl/copylib_test.py +18 -0
  42. crosshair/libimpl/datetimelib.py +84 -67
  43. crosshair/libimpl/datetimelib_ch_test.py +12 -7
  44. crosshair/libimpl/datetimelib_test.py +5 -6
  45. crosshair/libimpl/decimallib.py +5257 -0
  46. crosshair/libimpl/decimallib_ch_test.py +78 -0
  47. crosshair/libimpl/decimallib_test.py +76 -0
  48. crosshair/libimpl/encodings/_encutil.py +21 -11
  49. crosshair/libimpl/fractionlib.py +16 -0
  50. crosshair/libimpl/fractionlib_test.py +80 -0
  51. crosshair/libimpl/functoolslib.py +19 -7
  52. crosshair/libimpl/functoolslib_test.py +22 -6
  53. crosshair/libimpl/hashliblib.py +30 -0
  54. crosshair/libimpl/hashliblib_test.py +18 -0
  55. crosshair/libimpl/heapqlib.py +32 -5
  56. crosshair/libimpl/heapqlib_test.py +15 -12
  57. crosshair/libimpl/iolib.py +7 -4
  58. crosshair/libimpl/ipaddresslib.py +8 -0
  59. crosshair/libimpl/itertoolslib_test.py +1 -1
  60. crosshair/libimpl/mathlib.py +165 -2
  61. crosshair/libimpl/mathlib_ch_test.py +44 -0
  62. crosshair/libimpl/mathlib_test.py +59 -16
  63. crosshair/libimpl/oslib.py +7 -0
  64. crosshair/libimpl/pathliblib_test.py +10 -0
  65. crosshair/libimpl/randomlib.py +1 -0
  66. crosshair/libimpl/randomlib_test.py +6 -4
  67. crosshair/libimpl/relib.py +180 -59
  68. crosshair/libimpl/relib_ch_test.py +26 -2
  69. crosshair/libimpl/relib_test.py +77 -14
  70. crosshair/libimpl/timelib.py +35 -13
  71. crosshair/libimpl/timelib_test.py +13 -3
  72. crosshair/libimpl/typeslib.py +15 -0
  73. crosshair/libimpl/typeslib_test.py +36 -0
  74. crosshair/libimpl/unicodedatalib_test.py +3 -3
  75. crosshair/libimpl/weakreflib.py +13 -0
  76. crosshair/libimpl/weakreflib_test.py +69 -0
  77. crosshair/libimpl/zliblib.py +15 -0
  78. crosshair/libimpl/zliblib_test.py +13 -0
  79. crosshair/lsp_server.py +21 -10
  80. crosshair/main.py +48 -28
  81. crosshair/main_test.py +59 -14
  82. crosshair/objectproxy.py +39 -14
  83. crosshair/objectproxy_test.py +27 -13
  84. crosshair/opcode_intercept.py +212 -24
  85. crosshair/opcode_intercept_test.py +172 -18
  86. crosshair/options.py +0 -1
  87. crosshair/patch_equivalence_test.py +5 -21
  88. crosshair/path_cover.py +7 -5
  89. crosshair/path_search.py +6 -4
  90. crosshair/path_search_test.py +1 -2
  91. crosshair/pathing_oracle.py +53 -10
  92. crosshair/pathing_oracle_test.py +21 -0
  93. crosshair/pure_importer_test.py +5 -21
  94. crosshair/register_contract.py +16 -6
  95. crosshair/register_contract_test.py +2 -14
  96. crosshair/simplestructs.py +154 -85
  97. crosshair/simplestructs_test.py +16 -2
  98. crosshair/smtlib.py +24 -0
  99. crosshair/smtlib_test.py +14 -0
  100. crosshair/statespace.py +319 -196
  101. crosshair/statespace_test.py +45 -0
  102. crosshair/stubs_parser.py +0 -2
  103. crosshair/test_util.py +87 -25
  104. crosshair/test_util_test.py +26 -0
  105. crosshair/tools/check_init_and_setup_coincide.py +0 -3
  106. crosshair/tools/generate_demo_table.py +2 -2
  107. crosshair/tracers.py +141 -49
  108. crosshair/type_repo.py +11 -4
  109. crosshair/unicode_categories.py +1 -0
  110. crosshair/util.py +158 -76
  111. crosshair/util_test.py +13 -20
  112. crosshair/watcher.py +4 -4
  113. crosshair/z3util.py +1 -1
  114. {crosshair_tool-0.0.56.dist-info → crosshair_tool-0.0.100.dist-info}/METADATA +45 -36
  115. crosshair_tool-0.0.100.dist-info/RECORD +176 -0
  116. {crosshair_tool-0.0.56.dist-info → crosshair_tool-0.0.100.dist-info}/WHEEL +2 -1
  117. crosshair/examples/hypothesis/__init__.py +0 -2
  118. crosshair/examples/hypothesis/bugs_detected/simple_strategies.py +0 -74
  119. crosshair_tool-0.0.56.dist-info/RECORD +0 -152
  120. /crosshair/{examples/hypothesis/bugs_detected/__init__.py → py.typed} +0 -0
  121. {crosshair_tool-0.0.56.dist-info → crosshair_tool-0.0.100.dist-info}/entry_points.txt +0 -0
  122. {crosshair_tool-0.0.56.dist-info → crosshair_tool-0.0.100.dist-info/licenses}/LICENSE +0 -0
  123. {crosshair_tool-0.0.56.dist-info → crosshair_tool-0.0.100.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,25 @@
1
+ import operator
1
2
  import re
2
3
  import sys
4
+ from array import array
5
+ from unicodedata import category
3
6
 
4
7
  if sys.version_info < (3, 11):
5
8
  import sre_parse as re_parser
6
9
  else:
7
- import re._parser as re_parser
10
+ import re._parser as re_parser # type: ignore
11
+
8
12
  from sys import maxunicode
9
- from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
13
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
10
14
 
11
15
  import z3 # type: ignore
12
16
 
13
17
  from crosshair.core import deep_realize, realize, register_patch, with_realized_args
14
- from crosshair.libimpl.builtinslib import AnySymbolicStr, SymbolicInt
18
+ from crosshair.libimpl.builtinslib import AnySymbolicStr, BytesLike, SymbolicInt
15
19
  from crosshair.statespace import context_statespace
16
20
  from crosshair.tracers import NoTracing, ResumedTracing, is_tracing
17
21
  from crosshair.unicode_categories import CharMask, get_unicode_categories
18
- from crosshair.util import CrosshairInternal, debug, is_iterable
22
+ from crosshair.util import CrossHairInternal, CrossHairValue, debug, is_iterable
19
23
 
20
24
  ANY = re_parser.ANY
21
25
  ASSERT = re_parser.ASSERT
@@ -51,6 +55,8 @@ class ReUnhandled(Exception):
51
55
  pass
52
56
 
53
57
 
58
+ _ALL_BYTES_TYPES = (bytes, bytearray, memoryview, array)
59
+ _STR_AND_BYTES_TYPES = (str, *_ALL_BYTES_TYPES)
54
60
  _NO_CHAR = CharMask([])
55
61
  _ANY_CHAR = CharMask([(0, maxunicode + 1)])
56
62
  _ANY_NON_NEWLINE_CHAR = _ANY_CHAR.subtract(CharMask([ord("\n")]))
@@ -74,8 +80,41 @@ _UNICODE_WHITESPACE_CHAR = _ASCII_WHITESPACE_CHAR.union(
74
80
  )
75
81
  )
76
82
 
83
+ _CASEABLE_CHARS = None
84
+
85
+
86
+ def caseable_chars():
87
+ global _CASEABLE_CHARS
88
+ if _CASEABLE_CHARS is None:
89
+ codepoints = []
90
+ for i in range(sys.maxunicode + 1):
91
+ ch = chr(i)
92
+ # Exclude the (large) "Other Letter" group that doesn't caseswap:
93
+ if category(ch) in ("Lo"):
94
+ assert ch.casefold() == ch
95
+ else:
96
+ codepoints.append(ch)
97
+
98
+ _CASEABLE_CHARS = "".join(codepoints)
99
+ return _CASEABLE_CHARS
100
+
101
+
102
+ _UNICODE_IGNORECASE_MASKS: Dict[int, CharMask] = {} # codepoint -> CharMask
103
+
104
+
105
+ def unicode_ignorecase_mask(cp: int) -> CharMask:
106
+ mask = _UNICODE_IGNORECASE_MASKS.get(cp)
107
+ if mask is None:
108
+ chars = caseable_chars()
109
+ matches = re.compile(chr(cp), re.IGNORECASE).findall(chars)
110
+ mask = CharMask([ord(c) for c in matches])
111
+ _UNICODE_IGNORECASE_MASKS[cp] = mask
112
+ return mask
113
+
77
114
 
78
- def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMask]:
115
+ def single_char_mask(
116
+ parsed: Tuple[object, Any], flags: int, ord=ord, chr=chr
117
+ ) -> Optional[CharMask]:
79
118
  """
80
119
  Compute a CharMask from a parsed regex.
81
120
 
@@ -88,10 +127,7 @@ def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMas
88
127
  isascii = re.ASCII & flags
89
128
  if op in (LITERAL, NOT_LITERAL):
90
129
  if re.IGNORECASE & flags:
91
- # NOTE: I *think* IGNORECASE does not do "fancy" case matching like the
92
- # casefold() builtin.
93
- # TODO: This fails on 1-to-many case transformations
94
- ret = CharMask([ord(chr(arg).lower()), ord(chr(arg).upper())])
130
+ ret = unicode_ignorecase_mask(arg)
95
131
  else:
96
132
  ret = CharMask([arg])
97
133
  if op is NOT_LITERAL:
@@ -101,6 +137,7 @@ def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMas
101
137
  if re.IGNORECASE & flags:
102
138
  ret = CharMask(
103
139
  [
140
+ # TODO: among other issues, this doesn't handle multi-codepoint caseswaps:
104
141
  (ord(chr(lo).lower()), ord(chr(hi).lower()) + 1),
105
142
  (ord(chr(lo).upper()), ord(chr(hi).upper()) + 1),
106
143
  ]
@@ -113,7 +150,7 @@ def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMas
113
150
  if negate:
114
151
  arg = arg[1:]
115
152
  for term in arg:
116
- submask = single_char_mask(term, flags)
153
+ submask = single_char_mask(term, flags, ord=ord, chr=chr)
117
154
  if submask is None:
118
155
  raise ReUnhandled("IN contains non-single-char expression")
119
156
  ret = ret.union(submask)
@@ -137,6 +174,7 @@ def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMas
137
174
  else:
138
175
  raise ReUnhandled("Unsupported category: ", arg)
139
176
  elif op is ANY and arg is None:
177
+ # TODO: test dot under ascii mode (seems like we should fall through to the re.ASCII check below)
140
178
  return _ANY_CHAR if re.DOTALL & flags else _ANY_NON_NEWLINE_CHAR
141
179
  else:
142
180
  return None
@@ -149,6 +187,13 @@ def single_char_mask(parsed: Tuple[object, Any], flags: int) -> Optional[CharMas
149
187
  Span = Tuple[int, Union[int, SymbolicInt]]
150
188
 
151
189
 
190
+ def _traced_binop(a, op, b):
191
+ if isinstance(a, CrossHairValue) or isinstance(b, CrossHairValue):
192
+ with ResumedTracing():
193
+ return op(a, b)
194
+ return op(a, b)
195
+
196
+
152
197
  class _MatchPart:
153
198
  def __init__(self, groups: List[Optional[Span]]):
154
199
  self._groups = groups
@@ -158,11 +203,21 @@ class _MatchPart:
158
203
  assert span is not None
159
204
  return span
160
205
 
206
+ def _clamp_all_spans(self, start, end):
207
+ groups = self._groups
208
+ for idx, span in enumerate(groups):
209
+ if span is not None:
210
+ (span_start, span_end) = span
211
+ with ResumedTracing():
212
+ if span_start == span_end:
213
+ if span_start < start:
214
+ groups[idx] = (start, start)
215
+ if span_start > end:
216
+ groups[idx] = (end, end)
217
+
161
218
  def isempty(self):
162
- for (start, end) in self._groups:
163
- if end > start:
164
- return False
165
- return True
219
+ (start, end) = self._groups[0]
220
+ return _traced_binop(end, operator.le, start)
166
221
 
167
222
  def __bool__(self):
168
223
  return True
@@ -194,8 +249,7 @@ class _MatchPart:
194
249
  return self._groups[group]
195
250
 
196
251
 
197
- _BACKREF_RE = re.compile(
198
- r"""
252
+ _BACKREF_RE_SOURCE = rb"""
199
253
  (?P<prefix> .*?)
200
254
  \\
201
255
  (?:
@@ -206,8 +260,10 @@ _BACKREF_RE = re.compile(
206
260
  g\< (?P<namedother> .* ) \>
207
261
  )
208
262
  (?P<suffix> .*)
209
- """,
210
- re.VERBOSE | re.MULTILINE,
263
+ """
264
+ _BACKREF_BYTES_RE = re.compile(_BACKREF_RE_SOURCE, re.VERBOSE | re.MULTILINE)
265
+ _BACKREF_STR_RE = re.compile(
266
+ str(_BACKREF_RE_SOURCE, "ascii"), re.VERBOSE | re.MULTILINE
211
267
  )
212
268
 
213
269
 
@@ -236,14 +292,14 @@ class _Match(_MatchPart):
236
292
  if idx in _idx_to_name:
237
293
  self.lastgroup = _idx_to_name[idx]
238
294
 
239
- def __ch_deep_realize__(self):
295
+ def __ch_deep_realize__(self, memo):
240
296
  # We cannot manually create realistic Match instances.
241
297
  # Realize our contents - it's better than nothing
242
298
  return _Match(
243
- deep_realize(self._groups),
299
+ deep_realize(self._groups, memo),
244
300
  realize(self.pos),
245
301
  realize(self.endpos),
246
- deep_realize(self.re),
302
+ deep_realize(self.re, memo),
247
303
  realize(self.string),
248
304
  )
249
305
 
@@ -251,11 +307,10 @@ class _Match(_MatchPart):
251
307
  return self.group(idx)
252
308
 
253
309
  def expand(self, template):
254
- if not isinstance(template, str):
255
- raise TypeError
310
+ backref_re = _BACKREF_STR_RE if isinstance(template, str) else _BACKREF_BYTES_RE
256
311
  with NoTracing():
257
312
  template = realize(template) # Usually this is a literal string
258
- match = _BACKREF_RE.fullmatch(template)
313
+ match = backref_re.fullmatch(template)
259
314
  if match is None:
260
315
  return template
261
316
  prefix, num, namednum, named, _, suffix = match.groups()
@@ -338,6 +393,8 @@ def _internal_match_patterns(
338
393
  string: AnySymbolicStr,
339
394
  offset: int,
340
395
  allow_empty: bool = True,
396
+ ord=ord,
397
+ chr=chr,
341
398
  ) -> Optional[_MatchPart]:
342
399
  """
343
400
  >>> import sre_parse
@@ -361,7 +418,13 @@ def _internal_match_patterns(
361
418
  def continue_matching(prefix):
362
419
  sub_allow_empty = allow_empty if prefix.isempty() else True
363
420
  suffix = _internal_match_patterns(
364
- top_patterns[1:], flags, string, prefix.end(), sub_allow_empty
421
+ top_patterns[1:],
422
+ flags,
423
+ string,
424
+ prefix.end(),
425
+ sub_allow_empty,
426
+ ord=ord,
427
+ chr=chr,
365
428
  )
366
429
  if suffix is None:
367
430
  return None
@@ -371,19 +434,23 @@ def _internal_match_patterns(
371
434
  # Seems like this casues nondeterminism due to a global LRU cache used by the typing module.
372
435
  def fork_on(expr, sz):
373
436
  if space.smt_fork(expr):
374
- return continue_matching(_MatchPart([(offset, offset + sz)]))
437
+ return continue_matching(
438
+ _MatchPart([(offset, _traced_binop(offset, operator.add, sz))])
439
+ )
375
440
  else:
376
441
  return None
377
442
 
378
- mask = single_char_mask(pattern, flags)
443
+ mask = single_char_mask(pattern, flags, ord=ord, chr=chr)
379
444
  if mask is not None:
380
445
  with ResumedTracing():
381
- if len(string) <= offset:
446
+ if any([offset < 0, offset >= len(string)]):
382
447
  return None
383
448
  char = ord(string[offset])
384
449
  if isinstance(char, int): # Concrete int? Just check it!
385
450
  if mask.covers(char):
386
- return continue_matching(_MatchPart([(offset, offset + 1)]))
451
+ return continue_matching(
452
+ _MatchPart([(offset, _traced_binop(offset, operator.add, 1))])
453
+ )
387
454
  else:
388
455
  return None
389
456
  smt_ch = SymbolicInt._coerce_to_smt_sort(char)
@@ -398,7 +465,7 @@ def _internal_match_patterns(
398
465
  overall_match = _MatchPart([(offset, offset)])
399
466
  while reps < min_repeat:
400
467
  submatch = _internal_match_patterns(
401
- subpattern, flags, string, overall_match.end(), True
468
+ subpattern, flags, string, overall_match.end(), True, ord=ord, chr=chr
402
469
  )
403
470
  if submatch is None:
404
471
  return None
@@ -423,7 +490,13 @@ def _internal_match_patterns(
423
490
  )
424
491
  remainder_allow_empty = allow_empty or not overall_match.isempty()
425
492
  remainder_match = _internal_match_patterns(
426
- remaining_matcher, flags, string, overall_match.end(), remainder_allow_empty
493
+ remaining_matcher,
494
+ flags,
495
+ string,
496
+ overall_match.end(),
497
+ remainder_allow_empty,
498
+ ord=ord,
499
+ chr=chr,
427
500
  )
428
501
  if remainder_match is not None:
429
502
  return overall_match._add_match(remainder_match)
@@ -438,7 +511,7 @@ def _internal_match_patterns(
438
511
  branches = arg[1]
439
512
  first_path = list(branches[0]) + list(top_patterns)[1:]
440
513
  submatch = _internal_match_patterns(
441
- first_path, flags, string, offset, allow_empty
514
+ first_path, flags, string, offset, allow_empty, ord=ord, chr=chr
442
515
  )
443
516
  if submatch is not None:
444
517
  return submatch
@@ -451,6 +524,8 @@ def _internal_match_patterns(
451
524
  string,
452
525
  offset,
453
526
  allow_empty,
527
+ ord=ord,
528
+ chr=chr,
454
529
  )
455
530
  elif op is AT:
456
531
  if arg in (AT_BEGINNING, AT_BEGINNING_STRING):
@@ -500,7 +575,9 @@ def _internal_match_patterns(
500
575
  (direction_int, subpattern) = arg
501
576
  positive_look = op == ASSERT
502
577
  if direction_int == 1:
503
- matched = _internal_match_patterns(subpattern, flags, string, offset, True)
578
+ matched = _internal_match_patterns(
579
+ subpattern, flags, string, offset, True, ord=ord, chr=chr
580
+ )
504
581
  else:
505
582
  assert direction_int == -1
506
583
  minwidth, maxwidth = subpattern.getwidth()
@@ -509,11 +586,13 @@ def _internal_match_patterns(
509
586
  rewound = offset - minwidth
510
587
  if rewound < 0:
511
588
  return None
512
- matched = _internal_match_patterns(subpattern, flags, string, rewound, True)
589
+ matched = _internal_match_patterns(
590
+ subpattern, flags, string, rewound, True, ord=ord, chr=chr
591
+ )
513
592
  if bool(matched) != bool(positive_look):
514
593
  return None
515
594
  return _internal_match_patterns(
516
- top_patterns[1:], flags, string, offset, allow_empty
595
+ top_patterns[1:], flags, string, offset, allow_empty, ord=ord, chr=chr
517
596
  )
518
597
  elif op is SUBPATTERN:
519
598
  (groupnum, _a, _b, subpatterns) = arg
@@ -524,7 +603,9 @@ def _internal_match_patterns(
524
603
  + [(_END_GROUP_MARKER, (groupnum, offset))]
525
604
  + list(top_patterns)[1:]
526
605
  )
527
- return _internal_match_patterns(new_top, flags, string, offset, allow_empty)
606
+ return _internal_match_patterns(
607
+ new_top, flags, string, offset, allow_empty, ord=ord, chr=chr
608
+ )
528
609
  elif op is _END_GROUP_MARKER:
529
610
  (group_num, begin) = arg
530
611
  match = continue_matching(_MatchPart([(offset, offset)]))
@@ -539,21 +620,33 @@ def _internal_match_patterns(
539
620
 
540
621
  def _match_pattern(
541
622
  compiled_regex: re.Pattern,
542
- orig_str: AnySymbolicStr,
623
+ orig_str: Union[AnySymbolicStr, BytesLike],
543
624
  pos: int,
544
625
  endpos: Optional[int] = None,
545
626
  subpattern: Optional[List] = None,
546
627
  allow_empty=True,
628
+ ord=ord,
629
+ chr=chr,
547
630
  ) -> Optional[_Match]:
548
631
  assert not is_tracing()
549
632
  if subpattern is None:
550
633
  subpattern = cast(List, parse(compiled_regex.pattern, compiled_regex.flags))
551
- trimmed_str = orig_str[:endpos]
634
+ with ResumedTracing():
635
+ trimmed_str = orig_str[:endpos]
552
636
  matchpart = _internal_match_patterns(
553
- subpattern, compiled_regex.flags, trimmed_str, pos, allow_empty
637
+ subpattern,
638
+ compiled_regex.flags,
639
+ trimmed_str,
640
+ pos,
641
+ allow_empty,
642
+ ord=ord,
643
+ chr=chr,
554
644
  )
555
645
  if matchpart is None:
556
646
  return None
647
+ match_start, match_end = matchpart._fullspan()
648
+ if _traced_binop(match_start, operator.eq, match_end):
649
+ matchpart._clamp_all_spans(0, len(orig_str))
557
650
  return _Match(matchpart._groups, pos, endpos, compiled_regex, orig_str)
558
651
 
559
652
 
@@ -564,8 +657,23 @@ def _compile(*a):
564
657
  return re._compile(*deep_realize(a))
565
658
 
566
659
 
660
+ def _check_str_or_bytes(patt: re.Pattern, obj: Any):
661
+ if not isinstance(patt, re.Pattern):
662
+ raise TypeError # TODO: e.g. "descriptor 'search' for 're.Pattern' objects doesn't apply to a 'str' object"
663
+ if not isinstance(obj, _STR_AND_BYTES_TYPES):
664
+ raise TypeError(f"expected string or bytes-like object, got '{type(obj)}'")
665
+ if isinstance(patt.pattern, str):
666
+ if isinstance(obj, str):
667
+ return (chr, ord)
668
+ raise TypeError("cannot use a bytes pattern on a string-like object")
669
+ else:
670
+ if isinstance(obj, _ALL_BYTES_TYPES):
671
+ return (lambda i: bytes([i]), lambda i: i)
672
+ raise TypeError("cannot use a string pattern on a bytes-like object")
673
+
674
+
567
675
  def _finditer_symbolic(
568
- patt: re.Pattern, string: AnySymbolicStr, pos: int, endpos: int
676
+ patt: re.Pattern, string: AnySymbolicStr, pos: int, endpos: int, chr=chr, ord=ord
569
677
  ) -> Iterable[_Match]:
570
678
  last_match_was_empty = False
571
679
  while True:
@@ -573,7 +681,9 @@ def _finditer_symbolic(
573
681
  if pos > endpos:
574
682
  break
575
683
  allow_empty = not last_match_was_empty
576
- match = _match_pattern(patt, string, pos, endpos, allow_empty=allow_empty)
684
+ match = _match_pattern(
685
+ patt, string, pos, endpos, allow_empty=allow_empty, chr=chr, ord=ord
686
+ )
577
687
  last_match_was_empty = False
578
688
  if not match:
579
689
  pos += 1
@@ -582,7 +692,7 @@ def _finditer_symbolic(
582
692
  with NoTracing():
583
693
  if match.start() == match.end():
584
694
  if not allow_empty:
585
- raise CrosshairInternal("Unexpected empty match")
695
+ raise CrossHairInternal("Unexpected empty match")
586
696
  last_match_was_empty = True
587
697
  else:
588
698
  pos = match.end()
@@ -590,12 +700,11 @@ def _finditer_symbolic(
590
700
 
591
701
  def _finditer(
592
702
  self: re.Pattern,
593
- string: Union[str, AnySymbolicStr],
703
+ string: Union[str, AnySymbolicStr, bytes],
594
704
  pos: int = 0,
595
705
  endpos: Optional[int] = None,
596
706
  ) -> Iterable[Union[re.Match, _Match]]:
597
- if not isinstance(string, str):
598
- raise TypeError
707
+ chr, ord = _check_str_or_bytes(self, string)
599
708
  if not isinstance(pos, int):
600
709
  raise TypeError
601
710
  if not (endpos is None or isinstance(endpos, int)):
@@ -607,7 +716,9 @@ def _finditer(
607
716
  pos, endpos, _ = slice(pos, endpos, 1).indices(realize(strlen))
608
717
  with ResumedTracing():
609
718
  try:
610
- yield from _finditer_symbolic(self, string, pos, endpos)
719
+ yield from _finditer_symbolic(
720
+ self, string, pos, endpos, chr=chr, ord=ord
721
+ )
611
722
  return
612
723
  except ReUnhandled as e:
613
724
  debug("Unsupported symbolic regex", self.pattern, e)
@@ -617,13 +728,19 @@ def _finditer(
617
728
  yield from re.Pattern.finditer(self, realize(string), pos, endpos)
618
729
 
619
730
 
620
- def _fullmatch(self, string: Union[str, AnySymbolicStr], pos=0, endpos=None):
731
+ def _fullmatch(
732
+ self: re.Pattern, string: Union[str, AnySymbolicStr, bytes], pos=0, endpos=None
733
+ ):
621
734
  with NoTracing():
622
- if isinstance(string, AnySymbolicStr):
735
+ if isinstance(string, (AnySymbolicStr, BytesLike)):
736
+ with ResumedTracing():
737
+ chr, ord = _check_str_or_bytes(self, string)
623
738
  try:
624
739
  compiled = cast(List, parse(self.pattern, self.flags))
625
740
  compiled.append((AT, AT_END_STRING))
626
- return _match_pattern(self, string, pos, endpos, compiled)
741
+ return _match_pattern(
742
+ self, string, pos, endpos, compiled, chr=chr, ord=ord
743
+ )
627
744
  except ReUnhandled as e:
628
745
  debug("Unsupported symbolic regex", self.pattern, e)
629
746
  if endpos is None:
@@ -636,9 +753,11 @@ def _match(
636
753
  self, string: Union[str, AnySymbolicStr], pos=0, endpos=None
637
754
  ) -> Union[None, re.Match, _Match]:
638
755
  with NoTracing():
639
- if isinstance(string, AnySymbolicStr):
756
+ if isinstance(string, (AnySymbolicStr, BytesLike)):
757
+ with ResumedTracing():
758
+ chr, ord = _check_str_or_bytes(self, string)
640
759
  try:
641
- return _match_pattern(self, string, pos, endpos)
760
+ return _match_pattern(self, string, pos, endpos, chr=chr, ord=ord)
642
761
  except ReUnhandled as e:
643
762
  debug("Unsupported symbolic regex", self.pattern, e)
644
763
  if endpos is None:
@@ -648,10 +767,12 @@ def _match(
648
767
 
649
768
 
650
769
  def _search(
651
- self, string: Union[str, AnySymbolicStr], pos: int = 0, endpos: Optional[int] = None
770
+ self: re.Pattern,
771
+ string: Union[str, AnySymbolicStr, bytes],
772
+ pos: int = 0,
773
+ endpos: Optional[int] = None,
652
774
  ) -> Union[None, re.Match, _Match]:
653
- if not isinstance(string, str):
654
- raise TypeError
775
+ chr, ord = _check_str_or_bytes(self, string)
655
776
  if not isinstance(pos, int):
656
777
  raise TypeError
657
778
  if not (endpos is None or isinstance(endpos, int)):
@@ -659,11 +780,11 @@ def _search(
659
780
  pos, endpos = realize(pos), realize(endpos)
660
781
  mylen = string.__len__()
661
782
  with NoTracing():
662
- if isinstance(string, AnySymbolicStr):
783
+ if isinstance(string, (AnySymbolicStr, BytesLike)):
663
784
  pos, endpos, _ = slice(pos, endpos, 1).indices(realize(mylen))
664
785
  try:
665
786
  while pos < endpos:
666
- match = _match_pattern(self, string, pos, endpos)
787
+ match = _match_pattern(self, string, pos, endpos, chr=chr, ord=ord)
667
788
  if match:
668
789
  return match
669
790
  pos += 1
@@ -686,7 +807,8 @@ def _subn(
686
807
  ) -> Tuple[str, int]:
687
808
  if not isinstance(self, re.Pattern):
688
809
  raise TypeError
689
- if isinstance(repl, str):
810
+ if isinstance(repl, _STR_AND_BYTES_TYPES):
811
+ _check_str_or_bytes(self, repl)
690
812
 
691
813
  def replfn(m):
692
814
  return m.expand(repl)
@@ -695,8 +817,7 @@ def _subn(
695
817
  replfn = repl
696
818
  else:
697
819
  raise TypeError
698
- if not isinstance(string, str):
699
- raise TypeError
820
+ _check_str_or_bytes(self, string)
700
821
  if not isinstance(count, int):
701
822
  raise TypeError
702
823
  match = self.search(string)
@@ -12,7 +12,10 @@ from crosshair.test_util import ResultComparison, compare_results
12
12
  def groups(match: Optional[re.Match]) -> Optional[Sequence]:
13
13
  if match is None:
14
14
  return None
15
- return match.groups(), match.start(), match.end()
15
+ return [
16
+ (match.start(i), match.end(i), match.group(i))
17
+ for i in range(len(match.groups()) + 1)
18
+ ]
16
19
 
17
20
 
18
21
  def check_inverted_categories(text: str, flags: int) -> ResultComparison:
@@ -45,6 +48,14 @@ def check_match_with_sliced_string(text: str) -> ResultComparison:
45
48
  return compare_results(lambda t: groups(re.match(r"^[ab]{2}\Z", t)), text[1:])
46
49
 
47
50
 
51
+ def check_match_with_offsets(text: str, start: int, end: int) -> ResultComparison:
52
+ """post: _"""
53
+ # return compare_results(lambda t: groups(re.compile(r"a").match(t, start, end)), text)
54
+ return compare_results(
55
+ lambda t: groups(re.compile(r"(a*)(a*)").match(t, start, end)), text
56
+ )
57
+
58
+
48
59
  def check_findall(text: str, flags: int) -> ResultComparison:
49
60
  """post: _"""
50
61
  return compare_results(lambda t, f: re.findall("aa", t, f), text, flags)
@@ -111,7 +122,7 @@ def check_search_anchored_end(text: str, flags: int) -> ResultComparison:
111
122
 
112
123
  def check_subn(text: str, flags: int) -> ResultComparison:
113
124
  """post: _"""
114
- return compare_results(lambda t, f: re.subn("aa", "ba", t, f), text, flags)
125
+ return compare_results(lambda t, f: re.subn("aa", "ba", t, flags=f), text, flags)
115
126
 
116
127
 
117
128
  def check_lookahead(text: str) -> ResultComparison:
@@ -134,6 +145,19 @@ def check_negative_lookbehind(text: str) -> ResultComparison:
134
145
  return compare_results(lambda t: groups(re.search(".(?<!b)", t)), text)
135
146
 
136
147
 
148
+ # Bytes-based regexes
149
+
150
+
151
+ def check_subn_bytes(text: bytes, flags: int) -> ResultComparison:
152
+ """post: _"""
153
+ return compare_results(lambda t, f: re.subn(b"a", b"b", t, flags=f), text, flags)
154
+
155
+
156
+ def check_findall_bytes(text: bytes, flags: int) -> ResultComparison:
157
+ """post: _"""
158
+ return compare_results(lambda t, f: re.findall("aa", t, f), text, flags)
159
+
160
+
137
161
  # This is the only real test definition.
138
162
  # It runs crosshair on each of the "check" functions defined above.
139
163
  @pytest.mark.parametrize("fn_name", [fn for fn in dir() if fn.startswith("check_")])