angr 9.2.139__py3-none-manylinux2014_x86_64.whl → 9.2.141__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (87) hide show
  1. angr/__init__.py +1 -1
  2. angr/analyses/calling_convention/calling_convention.py +136 -53
  3. angr/analyses/calling_convention/fact_collector.py +44 -18
  4. angr/analyses/calling_convention/utils.py +3 -1
  5. angr/analyses/cfg/cfg_base.py +13 -0
  6. angr/analyses/cfg/cfg_fast.py +11 -0
  7. angr/analyses/cfg/indirect_jump_resolvers/jumptable.py +9 -8
  8. angr/analyses/decompiler/ail_simplifier.py +115 -72
  9. angr/analyses/decompiler/callsite_maker.py +24 -11
  10. angr/analyses/decompiler/clinic.py +78 -43
  11. angr/analyses/decompiler/decompiler.py +18 -7
  12. angr/analyses/decompiler/expression_narrower.py +1 -1
  13. angr/analyses/decompiler/optimization_passes/const_prop_reverter.py +8 -7
  14. angr/analyses/decompiler/optimization_passes/duplication_reverter/duplication_reverter.py +3 -1
  15. angr/analyses/decompiler/optimization_passes/flip_boolean_cmp.py +21 -2
  16. angr/analyses/decompiler/optimization_passes/ite_region_converter.py +21 -13
  17. angr/analyses/decompiler/optimization_passes/lowered_switch_simplifier.py +84 -15
  18. angr/analyses/decompiler/optimization_passes/optimization_pass.py +92 -11
  19. angr/analyses/decompiler/optimization_passes/return_duplicator_base.py +53 -9
  20. angr/analyses/decompiler/peephole_optimizations/eager_eval.py +44 -7
  21. angr/analyses/decompiler/region_identifier.py +6 -4
  22. angr/analyses/decompiler/region_simplifiers/expr_folding.py +287 -122
  23. angr/analyses/decompiler/region_simplifiers/region_simplifier.py +31 -13
  24. angr/analyses/decompiler/ssailification/rewriting.py +23 -15
  25. angr/analyses/decompiler/ssailification/rewriting_engine.py +105 -24
  26. angr/analyses/decompiler/ssailification/ssailification.py +22 -14
  27. angr/analyses/decompiler/structured_codegen/c.py +73 -137
  28. angr/analyses/decompiler/structuring/dream.py +22 -18
  29. angr/analyses/decompiler/structuring/phoenix.py +158 -41
  30. angr/analyses/decompiler/structuring/recursive_structurer.py +1 -0
  31. angr/analyses/decompiler/structuring/structurer_base.py +37 -10
  32. angr/analyses/decompiler/structuring/structurer_nodes.py +4 -1
  33. angr/analyses/decompiler/utils.py +106 -21
  34. angr/analyses/deobfuscator/api_obf_finder.py +8 -5
  35. angr/analyses/deobfuscator/api_obf_type2_finder.py +18 -10
  36. angr/analyses/deobfuscator/string_obf_finder.py +105 -18
  37. angr/analyses/forward_analysis/forward_analysis.py +1 -1
  38. angr/analyses/propagator/top_checker_mixin.py +6 -6
  39. angr/analyses/reaching_definitions/__init__.py +2 -1
  40. angr/analyses/reaching_definitions/dep_graph.py +1 -12
  41. angr/analyses/reaching_definitions/engine_vex.py +36 -31
  42. angr/analyses/reaching_definitions/function_handler.py +15 -2
  43. angr/analyses/reaching_definitions/rd_state.py +1 -37
  44. angr/analyses/reaching_definitions/reaching_definitions.py +13 -24
  45. angr/analyses/s_propagator.py +6 -41
  46. angr/analyses/s_reaching_definitions/s_rda_model.py +7 -1
  47. angr/analyses/s_reaching_definitions/s_rda_view.py +43 -25
  48. angr/analyses/stack_pointer_tracker.py +36 -22
  49. angr/analyses/typehoon/simple_solver.py +45 -7
  50. angr/analyses/typehoon/typeconsts.py +18 -5
  51. angr/analyses/variable_recovery/engine_ail.py +1 -1
  52. angr/analyses/variable_recovery/engine_base.py +7 -5
  53. angr/analyses/variable_recovery/engine_vex.py +20 -4
  54. angr/block.py +69 -107
  55. angr/callable.py +14 -7
  56. angr/calling_conventions.py +30 -11
  57. angr/distributed/__init__.py +1 -1
  58. angr/engines/__init__.py +7 -8
  59. angr/engines/engine.py +1 -120
  60. angr/engines/failure.py +2 -2
  61. angr/engines/hook.py +2 -2
  62. angr/engines/light/engine.py +2 -2
  63. angr/engines/pcode/engine.py +2 -14
  64. angr/engines/procedure.py +2 -2
  65. angr/engines/soot/engine.py +2 -2
  66. angr/engines/soot/statements/switch.py +1 -1
  67. angr/engines/successors.py +124 -11
  68. angr/engines/syscall.py +2 -2
  69. angr/engines/unicorn.py +3 -3
  70. angr/engines/vex/heavy/heavy.py +3 -15
  71. angr/factory.py +12 -22
  72. angr/knowledge_plugins/key_definitions/atoms.py +8 -4
  73. angr/knowledge_plugins/key_definitions/live_definitions.py +41 -103
  74. angr/knowledge_plugins/variables/variable_manager.py +7 -5
  75. angr/sim_type.py +19 -17
  76. angr/simos/simos.py +3 -1
  77. angr/state_plugins/plugin.py +19 -4
  78. angr/storage/memory_mixins/memory_mixin.py +1 -1
  79. angr/storage/memory_mixins/paged_memory/pages/multi_values.py +10 -5
  80. angr/utils/ssa/__init__.py +119 -4
  81. angr/utils/types.py +48 -0
  82. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/METADATA +6 -6
  83. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/RECORD +87 -86
  84. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/LICENSE +0 -0
  85. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/WHEEL +0 -0
  86. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/entry_points.txt +0 -0
  87. {angr-9.2.139.dist-info → angr-9.2.141.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ from angr.calling_conventions import SimCC
16
16
  from angr.sim_type import SimTypeFunction
17
17
  from angr.knowledge_plugins.key_definitions.definition import Definition
18
18
  from angr.knowledge_plugins.functions import Function
19
- from angr.analyses.reaching_definitions.dep_graph import FunctionCallRelationships
20
19
  from angr.code_location import CodeLocation, ExternalCodeLocation
21
20
  from angr.knowledge_plugins.key_definitions.constants import ObservationPointType
22
21
  from angr import SIM_LIBRARIES, SIM_TYPE_COLLECTIONS
@@ -246,7 +245,7 @@ class FunctionCallDataUnwrapped(FunctionCallData):
246
245
  @staticmethod
247
246
  @wraps
248
247
  def decorate(
249
- f: Callable[[FunctionHandler, ReachingDefinitionsState, FunctionCallDataUnwrapped], None]
248
+ f: Callable[[FunctionHandler, ReachingDefinitionsState, FunctionCallDataUnwrapped], None],
250
249
  ) -> Callable[[FunctionHandler, ReachingDefinitionsState, FunctionCallData], None]:
251
250
  """
252
251
  Decorate a function handler method with this to make it take a FunctionCallDataUnwrapped instead of a
@@ -263,6 +262,20 @@ def _mk_wrapper(func, iself):
263
262
  return lambda *args, **kwargs: func(iself, *args, **kwargs)
264
263
 
265
264
 
265
+ @dataclass
266
+ class FunctionCallRelationships:
267
+ """
268
+ Produced by the function handler, provides associated callsite info and function input/output definitions.
269
+ """
270
+
271
+ callsite: CodeLocation
272
+ target: int | None
273
+ args_defns: list[set[Definition]]
274
+ other_input_defns: set[Definition]
275
+ ret_defns: set[Definition]
276
+ other_output_defns: set[Definition]
277
+
278
+
266
279
  # pylint: disable=unused-argument, no-self-use
267
280
  class FunctionHandler:
268
281
  """
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Any, TYPE_CHECKING, cast, overload
2
+ from typing import Any, TYPE_CHECKING, overload
3
3
  from collections.abc import Iterable, Iterator
4
4
  import logging
5
5
  from typing_extensions import Self
@@ -7,7 +7,6 @@ from typing_extensions import Self
7
7
  import archinfo
8
8
  import claripy
9
9
 
10
- from angr.misc.ux import deprecated
11
10
  from angr.knowledge_plugins.key_definitions.environment import Environment
12
11
  from angr.knowledge_plugins.key_definitions.tag import Tag
13
12
  from angr.knowledge_plugins.key_definitions.heap_address import HeapAddress
@@ -541,41 +540,6 @@ class ReachingDefinitionsState:
541
540
  self.all_definitions = set()
542
541
  self.live_definitions.reset_uses()
543
542
 
544
- @deprecated("deref")
545
- def pointer_to_atoms(self, pointer: MultiValues, size: int, endness: str) -> set[MemoryLocation]:
546
- """
547
- Given a MultiValues, return the set of atoms that loading or storing to the pointer with that value
548
- could define or use.
549
- """
550
- result = set()
551
- for vs in pointer.values():
552
- for value in vs:
553
- atom = self.pointer_to_atom(value, size, endness)
554
- if atom is not None:
555
- result.add(atom)
556
-
557
- return result
558
-
559
- @deprecated("deref")
560
- def pointer_to_atom(self, value: claripy.ast.BV, size: int, endness: str) -> MemoryLocation | None:
561
- if self.is_top(value):
562
- return None
563
-
564
- stack_offset = self.get_stack_offset(value)
565
- if stack_offset is not None:
566
- addr = SpOffset(len(value), stack_offset)
567
- else:
568
- heap_offset = self.get_heap_offset(value)
569
- if heap_offset is not None:
570
- addr = HeapAddress(heap_offset)
571
- elif value.op == "BVV":
572
- addr = cast(int, value.args[0])
573
- else:
574
- # cannot resolve
575
- return None
576
-
577
- return MemoryLocation(addr, size, endness)
578
-
579
543
  @overload
580
544
  def deref(
581
545
  self,
@@ -16,7 +16,6 @@ from angr.knowledge_plugins.functions import Function
16
16
  from angr.knowledge_plugins.key_definitions import ReachingDefinitionsModel, LiveDefinitions
17
17
  from angr.knowledge_plugins.key_definitions.constants import OP_BEFORE, OP_AFTER, ObservationPointType, ObservationPoint
18
18
  from angr.code_location import CodeLocation, ExternalCodeLocation
19
- from angr.misc.ux import deprecated
20
19
  from angr.analyses.forward_analysis.visitors.graph import NodeType
21
20
  from angr.analyses.analysis import Analysis
22
21
  from .engine_ail import SimEngineRDAIL
@@ -50,13 +49,13 @@ class ReachingDefinitionsAnalysis(
50
49
 
51
50
  def __init__(
52
51
  self,
53
- subject: Subject | ailment.Block | Block | Function | str = None,
52
+ subject: Subject | ailment.Block | Block | Function | str,
54
53
  func_graph=None,
55
54
  max_iterations=30,
56
55
  track_tmps=False,
57
56
  track_consts=True,
58
57
  observation_points: Iterable[ObservationPoint] | None = None,
59
- init_state: ReachingDefinitionsState = None,
58
+ init_state: ReachingDefinitionsState | None = None,
60
59
  init_context=None,
61
60
  state_initializer: RDAStateInitializer | None = None,
62
61
  cc=None,
@@ -242,10 +241,6 @@ class ReachingDefinitionsAnalysis(
242
241
  def visited_blocks(self):
243
242
  return self._visited_blocks
244
243
 
245
- @deprecated(replacement="get_reaching_definitions_by_insn")
246
- def get_reaching_definitions(self, ins_addr, op_type):
247
- return self.get_reaching_definitions_by_insn(ins_addr, op_type)
248
-
249
244
  def get_reaching_definitions_by_insn(self, ins_addr, op_type):
250
245
  key = "insn", ins_addr, op_type
251
246
  if key not in self.observed_results:
@@ -280,29 +275,22 @@ class ReachingDefinitionsAnalysis(
280
275
  :param node_idx: ID of the node. Used in AIL to differentiate blocks with the same address.
281
276
  """
282
277
 
283
- key = None
284
-
278
+ key: ObservationPoint | None = None
285
279
  observe = False
286
280
 
287
281
  if self._observe_all:
288
282
  observe = True
289
- key: ObservationPoint = (
290
- ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
291
- )
283
+ key = ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
292
284
  elif self._observation_points is not None:
293
- key: ObservationPoint = (
294
- ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
295
- )
285
+ key = ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
296
286
  if key in self._observation_points:
297
287
  observe = True
298
288
  elif self._observe_callback is not None:
299
289
  observe = self._observe_callback("node", addr=node_addr, state=state, op_type=op_type, node_idx=node_idx)
300
290
  if observe:
301
- key: ObservationPoint = (
302
- ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
303
- )
291
+ key = ("node", node_addr, op_type) if node_idx is None else ("node", (node_addr, node_idx), op_type)
304
292
 
305
- if observe:
293
+ if observe and key:
306
294
  self.observed_results[key] = state.live_definitions
307
295
 
308
296
  def insn_observe(
@@ -321,14 +309,14 @@ class ReachingDefinitionsAnalysis(
321
309
  :param op_type: Type of the observation point. Must be one of the following: OP_BEORE, OP_AFTER.
322
310
  """
323
311
 
324
- key = None
312
+ key: ObservationPoint | None = None
325
313
  observe = False
326
314
 
327
315
  if self._observe_all:
328
316
  observe = True
329
- key: ObservationPoint = ("insn", insn_addr, op_type)
317
+ key = ("insn", insn_addr, op_type)
330
318
  elif self._observation_points is not None:
331
- key: ObservationPoint = ("insn", insn_addr, op_type)
319
+ key = ("insn", insn_addr, op_type)
332
320
  if key in self._observation_points:
333
321
  observe = True
334
322
  elif self._observe_callback is not None:
@@ -336,9 +324,9 @@ class ReachingDefinitionsAnalysis(
336
324
  "insn", addr=insn_addr, stmt=stmt, block=block, state=state, op_type=op_type
337
325
  )
338
326
  if observe:
339
- key: ObservationPoint = ("insn", insn_addr, op_type)
327
+ key = ("insn", insn_addr, op_type)
340
328
 
341
- if not observe:
329
+ if not (observe and key):
342
330
  return
343
331
 
344
332
  if isinstance(stmt, pyvex.stmt.IRStmt):
@@ -533,6 +521,7 @@ class ReachingDefinitionsAnalysis(
533
521
  ]
534
522
  if node.addr == self.subject.content.addr:
535
523
  node_parents += [ExternalCodeLocation()]
524
+ assert block is not None
536
525
  self.model.at_new_block(
537
526
  CodeLocation(block.addr, 0, block_idx=block.idx if isinstance(block, ailment.Block) else None),
538
527
  node_parents,
@@ -35,6 +35,7 @@ from angr.utils.ssa import (
35
35
  get_tmp_uselocs,
36
36
  get_tmp_deflocs,
37
37
  phi_assignment_get_src,
38
+ has_store_stmt_in_between_stmts,
38
39
  )
39
40
 
40
41
 
@@ -186,6 +187,8 @@ class SPropagatorAnalysis(Analysis):
186
187
 
187
188
  # function mode only
188
189
  if self.mode == "function":
190
+ assert self.func_graph is not None
191
+
189
192
  for vvar, defloc in vvar_deflocs.items():
190
193
  if vvar.varid not in vvar_uselocs:
191
194
  continue
@@ -213,7 +216,7 @@ class SPropagatorAnalysis(Analysis):
213
216
  # }
214
217
  can_replace = True
215
218
  for _, vvar_useloc in vvar_uselocs[vvar.varid]:
216
- if self.has_store_stmt_in_between(blocks, defloc, vvar_useloc):
219
+ if has_store_stmt_in_between_stmts(self.func_graph, blocks, defloc, vvar_useloc):
217
220
  can_replace = False
218
221
 
219
222
  if can_replace:
@@ -241,8 +244,8 @@ class SPropagatorAnalysis(Analysis):
241
244
  if vvar.was_reg or vvar.was_parameter:
242
245
  if len(vvar_uselocs[vvar.varid]) == 1:
243
246
  vvar_used, vvar_useloc = next(iter(vvar_uselocs[vvar.varid]))
244
- if is_const_vvar_load_assignment(stmt) and not self.has_store_stmt_in_between(
245
- blocks, defloc, vvar_useloc
247
+ if is_const_vvar_load_assignment(stmt) and not has_store_stmt_in_between_stmts(
248
+ self.func_graph, blocks, defloc, vvar_useloc
246
249
  ):
247
250
  # we can propagate this load because there is no store between its def and use
248
251
  replacements[vvar_useloc][vvar_used] = stmt.src
@@ -462,44 +465,6 @@ class SPropagatorAnalysis(Analysis):
462
465
 
463
466
  return False
464
467
 
465
- def has_store_stmt_in_between(
466
- self, blocks: dict[tuple[int, int | None], Block], defloc: CodeLocation, useloc: CodeLocation
467
- ) -> bool:
468
- assert defloc.block_addr is not None
469
- assert defloc.stmt_idx is not None
470
- assert useloc.block_addr is not None
471
- assert useloc.stmt_idx is not None
472
- assert self.func_graph is not None
473
-
474
- use_block = blocks[(useloc.block_addr, useloc.block_idx)]
475
- def_block = blocks[(defloc.block_addr, defloc.block_idx)]
476
-
477
- # traverse the graph, go from use_block until we reach def_block, and look for Store statements
478
- seen = {use_block}
479
- queue = [use_block]
480
- while queue:
481
- block = queue.pop(0)
482
-
483
- starting_stmt_idx, ending_stmt_idx = 0, len(block.statements)
484
- if block is def_block:
485
- starting_stmt_idx = defloc.stmt_idx + 1
486
- if block is use_block:
487
- ending_stmt_idx = useloc.stmt_idx
488
-
489
- for i in range(starting_stmt_idx, ending_stmt_idx):
490
- if isinstance(block.statements[i], Store):
491
- return True
492
-
493
- if block is def_block:
494
- continue
495
-
496
- for pred in self.func_graph.predecessors(block):
497
- if pred not in seen:
498
- seen.add(pred)
499
- queue.append(pred)
500
-
501
- return False
502
-
503
468
  @staticmethod
504
469
  def is_vvar_used_for_addr_loading_switch_case(uselocs: set[CodeLocation], blocks) -> bool:
505
470
  """
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
4
  from collections.abc import Generator
5
- from typing import Any
5
+ from typing import Any, Literal, overload
6
6
 
7
7
  from ailment.expression import VirtualVariable, Tmp
8
8
 
@@ -48,6 +48,12 @@ class SRDAModel:
48
48
  s.add(Definition(tmp_atom, CodeLocation(block_loc.block_addr, stmt_idx, block_idx=block_loc.block_idx)))
49
49
  return s
50
50
 
51
+ @overload
52
+ def get_uses_by_location(self, loc: CodeLocation, exprs: Literal[True]) -> set[tuple[Definition, Any | None]]: ...
53
+
54
+ @overload
55
+ def get_uses_by_location(self, loc: CodeLocation, exprs: Literal[False] = ...) -> set[Definition]: ...
56
+
51
57
  def get_uses_by_location(
52
58
  self, loc: CodeLocation, exprs: bool = False
53
59
  ) -> set[Definition] | set[tuple[Definition, Any | None]]:
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from collections.abc import Callable
4
5
  from collections import defaultdict
5
6
 
7
+ from ailment import Block
6
8
  from ailment.statement import Statement, Assignment, Call, Label
7
9
  from ailment.expression import VirtualVariable, VirtualVariableCategory, Expression
8
10
 
@@ -22,7 +24,7 @@ class RegVVarPredicate:
22
24
  Implements a predicate that is used in get_reg_vvar_by_stmt_idx and get_reg_vvar_by_insn.
23
25
  """
24
26
 
25
- def __init__(self, reg_offset: int, vvars: set[VirtualVariable], arch):
27
+ def __init__(self, reg_offset: int, vvars: list[VirtualVariable], arch):
26
28
  self.reg_offset = reg_offset
27
29
  self.vvars = vvars
28
30
  self.arch = arch
@@ -47,7 +49,8 @@ class RegVVarPredicate:
47
49
  and stmt.dst.was_reg
48
50
  and stmt.dst.reg_offset == self.reg_offset
49
51
  ):
50
- self.vvars.add(stmt.dst)
52
+ if stmt.dst not in self.vvars:
53
+ self.vvars.append(stmt.dst)
51
54
  return True
52
55
  if isinstance(stmt, Call):
53
56
  if (
@@ -55,7 +58,8 @@ class RegVVarPredicate:
55
58
  and stmt.ret_expr.was_reg
56
59
  and stmt.ret_expr.reg_offset == self.reg_offset
57
60
  ):
58
- self.vvars.add(stmt.ret_expr)
61
+ if stmt.ret_expr not in self.vvars:
62
+ self.vvars.append(stmt.ret_expr)
59
63
  return True
60
64
  # is it clobbered maybe?
61
65
  clobbered_regs = self._get_call_clobbered_regs(stmt)
@@ -69,7 +73,7 @@ class StackVVarPredicate:
69
73
  Implements a predicate that is used in get_stack_vvar_by_stmt_idx and get_stack_vvar_by_insn.
70
74
  """
71
75
 
72
- def __init__(self, stack_offset: int, size: int, vvars: set[VirtualVariable]):
76
+ def __init__(self, stack_offset: int, size: int, vvars: list[VirtualVariable]):
73
77
  self.stack_offset = stack_offset
74
78
  self.size = size
75
79
  self.vvars = vvars
@@ -82,7 +86,8 @@ class StackVVarPredicate:
82
86
  and stmt.dst.stack_offset <= self.stack_offset < stmt.dst.stack_offset + stmt.dst.size
83
87
  and stmt.dst.stack_offset <= self.stack_offset + self.size <= stmt.dst.stack_offset + stmt.dst.size
84
88
  ):
85
- self.vvars.add(stmt.dst)
89
+ if stmt.dst not in self.vvars:
90
+ self.vvars.append(stmt.dst)
86
91
  return True
87
92
  return False
88
93
 
@@ -96,7 +101,13 @@ class SRDAView:
96
101
  self.model = model
97
102
 
98
103
  def _get_vvar_by_stmt(
99
- self, block_addr: int, block_idx: int | None, stmt_idx: int, op_type: ObservationPointType, predicate
104
+ self,
105
+ block_addr: int,
106
+ block_idx: int | None,
107
+ stmt_idx: int,
108
+ op_type: ObservationPointType,
109
+ predicate: Callable,
110
+ consecutive: bool = False,
100
111
  ):
101
112
  # find the starting block
102
113
  for block in self.model.func_graph:
@@ -107,7 +118,10 @@ class SRDAView:
107
118
  return
108
119
 
109
120
  traversed = set()
110
- queue = [(the_block, stmt_idx if op_type == ObservationPointType.OP_BEFORE else stmt_idx + 1)]
121
+ queue: list[tuple[Block, int | None]] = [
122
+ (the_block, stmt_idx if op_type == ObservationPointType.OP_BEFORE else stmt_idx + 1)
123
+ ]
124
+ predicate_returned_true = False
111
125
  while queue:
112
126
  block, start_stmt_idx = queue.pop(0)
113
127
  traversed.add(block)
@@ -115,7 +129,8 @@ class SRDAView:
115
129
  stmts = block.statements[:start_stmt_idx] if start_stmt_idx is not None else block.statements
116
130
 
117
131
  for stmt in reversed(stmts):
118
- should_break = predicate(stmt)
132
+ r = predicate(stmt)
133
+ should_break = (predicate_returned_true and r is False) if consecutive else r
119
134
  if should_break:
120
135
  break
121
136
  else:
@@ -129,7 +144,7 @@ class SRDAView:
129
144
  self, reg_offset: int, block_addr: int, block_idx: int | None, stmt_idx: int, op_type: ObservationPointType
130
145
  ) -> VirtualVariable | None:
131
146
  reg_offset = get_reg_offset_base(reg_offset, self.model.arch)
132
- vvars = set()
147
+ vvars = []
133
148
  predicater = RegVVarPredicate(reg_offset, vvars, self.model.arch)
134
149
  self._get_vvar_by_stmt(block_addr, block_idx, stmt_idx, op_type, predicater.predicate)
135
150
 
@@ -137,14 +152,14 @@ class SRDAView:
137
152
  # not found - check function arguments
138
153
  for func_arg in self.model.func_args:
139
154
  if isinstance(func_arg, VirtualVariable):
140
- func_arg_category = func_arg.oident[0]
155
+ func_arg_category = func_arg.parameter_category
141
156
  if func_arg_category == VirtualVariableCategory.REGISTER:
142
- func_arg_regoff = func_arg.oident[1]
157
+ func_arg_regoff = func_arg.parameter_reg_offset
143
158
  if func_arg_regoff == reg_offset:
144
- vvars.add(func_arg)
159
+ vvars.append(func_arg)
145
160
 
146
161
  assert len(vvars) <= 1
147
- return next(iter(vvars), None)
162
+ return vvars[0] if vvars else None
148
163
 
149
164
  def get_stack_vvar_by_stmt( # pylint: disable=too-many-positional-arguments
150
165
  self,
@@ -155,21 +170,24 @@ class SRDAView:
155
170
  stmt_idx: int,
156
171
  op_type: ObservationPointType,
157
172
  ) -> VirtualVariable | None:
158
- vvars = set()
173
+ vvars = []
159
174
  predicater = StackVVarPredicate(stack_offset, size, vvars)
160
- self._get_vvar_by_stmt(block_addr, block_idx, stmt_idx, op_type, predicater.predicate)
175
+ self._get_vvar_by_stmt(block_addr, block_idx, stmt_idx, op_type, predicater.predicate, consecutive=True)
161
176
 
162
177
  if not vvars:
163
178
  # not found - check function arguments
164
179
  for func_arg in self.model.func_args:
165
180
  if isinstance(func_arg, VirtualVariable):
166
- func_arg_category = func_arg.oident[0]
181
+ func_arg_category = func_arg.parameter_category
167
182
  if func_arg_category == VirtualVariableCategory.STACK:
168
- func_arg_stackoff = func_arg.oident[1]
183
+ func_arg_stackoff = func_arg.oident[1] # type: ignore
169
184
  if func_arg_stackoff == stack_offset and func_arg.size == size:
170
- vvars.add(func_arg)
171
- assert len(vvars) <= 1
172
- return next(iter(vvars), None)
185
+ vvars.append(func_arg)
186
+ # there might be multiple vvars; we prioritize the one whose size fits the best
187
+ for v in vvars:
188
+ if v.stack_offset == stack_offset and v.size == size:
189
+ return v
190
+ return vvars[0] if vvars else None
173
191
 
174
192
  def _get_vvar_by_insn(self, addr: int, op_type: ObservationPointType, predicate, block_idx: int | None = None):
175
193
  # find the starting block
@@ -202,23 +220,23 @@ class SRDAView:
202
220
  self, reg_offset: int, addr: int, op_type: ObservationPointType, block_idx: int | None = None
203
221
  ) -> VirtualVariable | None:
204
222
  reg_offset = get_reg_offset_base(reg_offset, self.model.arch)
205
- vvars = set()
223
+ vvars = []
206
224
  predicater = RegVVarPredicate(reg_offset, vvars, self.model.arch)
207
225
 
208
226
  self._get_vvar_by_insn(addr, op_type, predicater.predicate, block_idx=block_idx)
209
227
 
210
228
  assert len(vvars) <= 1
211
- return next(iter(vvars), None)
229
+ return vvars[0] if vvars else None
212
230
 
213
231
  def get_stack_vvar_by_insn( # pylint: disable=too-many-positional-arguments
214
232
  self, stack_offset: int, size: int, addr: int, op_type: ObservationPointType, block_idx: int | None = None
215
233
  ) -> VirtualVariable | None:
216
- vvars = set()
234
+ vvars = []
217
235
  predicater = StackVVarPredicate(stack_offset, size, vvars)
218
236
  self._get_vvar_by_insn(addr, op_type, predicater.predicate, block_idx=block_idx)
219
237
 
220
238
  assert len(vvars) <= 1
221
- return next(iter(vvars), None)
239
+ return vvars[0] if vvars else None
222
240
 
223
241
  def get_vvar_value(self, vvar: VirtualVariable) -> Expression | None:
224
242
  if vvar not in self.model.all_vvar_definitions:
@@ -227,7 +245,7 @@ class SRDAView:
227
245
 
228
246
  for block in self.model.func_graph:
229
247
  if block.addr == codeloc.block_addr and block.idx == codeloc.block_idx:
230
- if codeloc.stmt_idx < len(block.statements):
248
+ if codeloc.stmt_idx is not None and codeloc.stmt_idx < len(block.statements):
231
249
  stmt = block.statements[codeloc.stmt_idx]
232
250
  if isinstance(stmt, Assignment) and stmt.dst.likes(vvar):
233
251
  return stmt.src
@@ -22,6 +22,7 @@ try:
22
22
  from angr.engines import pcode
23
23
  except ImportError:
24
24
  pypcode = None
25
+ pcode = None
25
26
 
26
27
  if TYPE_CHECKING:
27
28
  from angr.block import Block
@@ -93,6 +94,11 @@ class Register:
93
94
  return self.offset == other.offset
94
95
  return False
95
96
 
97
+ def __add__(self, other) -> OffsetVal:
98
+ if type(other) is Constant:
99
+ return OffsetVal(self, other.val)
100
+ raise CouldNotResolveException
101
+
96
102
  def __repr__(self):
97
103
  return str(self.offset)
98
104
 
@@ -232,6 +238,7 @@ class StackPointerTrackerState:
232
238
  def give_up_on_memory_tracking(self):
233
239
  self.memory = {}
234
240
  self.is_tracking_memory = False
241
+ return self
235
242
 
236
243
  def store(self, addr, val):
237
244
  # strong update
@@ -370,7 +377,8 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
370
377
  self._mem_merge_cache = {}
371
378
 
372
379
  if initial_reg_values:
373
- self._reg_value_at_block_start[func.addr if func is not None else block.addr] = initial_reg_values
380
+ block_start_addr = func.addr if func is not None else block.addr # type: ignore
381
+ self._reg_value_at_block_start[block_start_addr] = initial_reg_values
374
382
 
375
383
  _l.debug("Running on function %r", self._func)
376
384
  self._analyze()
@@ -461,9 +469,13 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
461
469
  return any(self.inconsistent_for(r) for r in self.reg_offsets)
462
470
 
463
471
  def inconsistent_for(self, reg):
472
+ if self._func is None:
473
+ raise ValueError("inconsistent_for() is only supported in function mode")
464
474
  return any(self.offset_after_block(endpoint.addr, reg) is TOP for endpoint in self._func.endpoints)
465
475
 
466
476
  def offsets_for(self, reg):
477
+ if self._func is None:
478
+ raise ValueError("offsets_for() is only supported in function mode")
467
479
  return [
468
480
  o for block in self._func.blocks if (o := self.offset_after_block(block.addr, reg)) not in (TOP, BOTTOM)
469
481
  ]
@@ -481,7 +493,7 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
481
493
  def _post_analysis(self):
482
494
  pass
483
495
 
484
- def _get_register(self, offset):
496
+ def _get_register(self, offset) -> Register:
485
497
  name = self.project.arch.register_names[offset]
486
498
  size = self.project.arch.registers[name][1]
487
499
  return Register(offset, size * self.project.arch.byte_width)
@@ -557,7 +569,7 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
557
569
  output_state = state.freeze()
558
570
  return None, output_state
559
571
 
560
- def _process_vex_irsb(self, node, vex_block: pyvex.IRSB, state: StackPointerTrackerState) -> int:
572
+ def _process_vex_irsb(self, node, vex_block: pyvex.IRSB, state: StackPointerTrackerState) -> int | None:
561
573
  tmps = {}
562
574
  curr_stmt_start_addr = None
563
575
 
@@ -704,21 +716,16 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
704
716
  if callees:
705
717
  if len(callees) == 1:
706
718
  callee = callees[0]
719
+ track_rax = False
720
+ if (
721
+ (callee.info.get("is_rust_probestack", False) and self.project.arch.name == "AMD64")
722
+ or (callee.info.get("is_alloca_probe", False) and self.project.arch.name == "AMD64")
723
+ or callee.name == "__chkstk"
724
+ ):
725
+ # sp = sp - rax right after returning from the call
726
+ track_rax = True
707
727
 
708
- if callee.info.get("is_rust_probestack", False) is True and self.project.arch.name == "AMD64":
709
- # special-case for rust_probestack: sp = sp - rax right after returning from the call, so we
710
- # need to keep track of rax
711
- for stmt in reversed(vex_block.statements):
712
- if (
713
- isinstance(stmt, pyvex.IRStmt.Put)
714
- and stmt.offset == self.project.arch.registers["rax"][0]
715
- and isinstance(stmt.data, pyvex.IRExpr.Const)
716
- ):
717
- state.put(stmt.offset, Constant(stmt.data.con.value), force=True)
718
- break
719
- elif callee.name == "__chkstk":
720
- # special-case for __chkstk: sp = sp - rax right after returning from the call, so we need to
721
- # keep track of rax
728
+ if track_rax:
722
729
  for stmt in reversed(vex_block.statements):
723
730
  if (
724
731
  isinstance(stmt, pyvex.IRStmt.Put)
@@ -737,18 +744,20 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
737
744
  # found callee clean-up cases...
738
745
  try:
739
746
  v = state.get(self.project.arch.sp_offset)
747
+ incremented = None
740
748
  if v is BOTTOM:
741
749
  incremented = BOTTOM
742
750
  elif callee_cleanups[0].prototype is not None:
743
751
  num_args = len(callee_cleanups[0].prototype.args)
744
752
  incremented = v + Constant(self.project.arch.bytes * num_args)
745
- state.put(self.project.arch.sp_offset, incremented)
753
+ if incremented is not None:
754
+ state.put(self.project.arch.sp_offset, incremented)
746
755
  except CouldNotResolveException:
747
756
  pass
748
757
 
749
758
  return curr_stmt_start_addr
750
759
 
751
- def _process_pcode_irsb(self, node, pcode_irsb: pcode.lifter.IRSB, state: StackPointerTrackerState) -> int:
760
+ def _process_pcode_irsb(self, node, pcode_irsb: pcode.lifter.IRSB, state: StackPointerTrackerState) -> int | None:
752
761
  unique = {}
753
762
  curr_stmt_start_addr = None
754
763
 
@@ -830,18 +839,20 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
830
839
  # found callee clean-up cases...
831
840
  try:
832
841
  v = state.get(self.project.arch.sp_offset)
842
+ incremented = None
833
843
  if v is BOTTOM:
834
844
  incremented = BOTTOM
835
845
  elif callee_cleanups[0].prototype is not None:
836
846
  num_args = len(callee_cleanups[0].prototype.args)
837
847
  incremented = v + Constant(self.project.arch.bytes * num_args)
838
- state.put(self.project.arch.sp_offset, incremented)
848
+ if incremented is not None:
849
+ state.put(self.project.arch.sp_offset, incremented)
839
850
  except CouldNotResolveException:
840
851
  pass
841
852
 
842
853
  return curr_stmt_start_addr
843
854
 
844
- def _widen_states(self, *states):
855
+ def _widen_states(self, *states: FrozenStackPointerTrackerState):
845
856
  assert len(states) == 2
846
857
  merged, _ = self._merge_states(None, *states)
847
858
  if len(merged.memory) > 5:
@@ -849,13 +860,16 @@ class StackPointerTracker(Analysis, ForwardAnalysis):
849
860
  merged = merged.unfreeze().give_up_on_memory_tracking().freeze()
850
861
  return merged
851
862
 
852
- def _merge_states(self, node, *states: StackPointerTrackerState):
863
+ def _merge_states(self, node, *states: FrozenStackPointerTrackerState):
853
864
  merged_state = states[0]
854
865
  for other in states[1:]:
855
866
  merged_state = merged_state.merge(other, node.addr, self._reg_merge_cache, self._mem_merge_cache)
856
867
  return merged_state, merged_state == states[0]
857
868
 
858
869
  def _find_callees(self, node) -> list[Function]:
870
+ if self._func is None:
871
+ raise ValueError("find_callees() is only supported in function mode")
872
+
859
873
  callees: list[Function] = []
860
874
  for _, dst, data in self._func.transition_graph.out_edges(node, data=True):
861
875
  if data.get("type") == "call" and isinstance(dst, Function):