angr 9.2.142__py3-none-manylinux2014_aarch64.whl → 9.2.143__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (28) hide show
  1. angr/__init__.py +1 -1
  2. angr/analyses/calling_convention/calling_convention.py +9 -9
  3. angr/analyses/calling_convention/fact_collector.py +31 -9
  4. angr/analyses/cfg/indirect_jump_resolvers/const_resolver.py +12 -1
  5. angr/analyses/cfg/indirect_jump_resolvers/jumptable.py +4 -1
  6. angr/analyses/complete_calling_conventions.py +18 -5
  7. angr/analyses/decompiler/ail_simplifier.py +90 -65
  8. angr/analyses/decompiler/optimization_passes/condition_constprop.py +49 -14
  9. angr/analyses/decompiler/optimization_passes/ite_region_converter.py +8 -0
  10. angr/analyses/decompiler/peephole_optimizations/simplify_pc_relative_loads.py +15 -1
  11. angr/analyses/decompiler/sequence_walker.py +8 -0
  12. angr/analyses/decompiler/utils.py +13 -0
  13. angr/analyses/s_propagator.py +40 -29
  14. angr/analyses/s_reaching_definitions/s_rda_model.py +45 -36
  15. angr/analyses/s_reaching_definitions/s_rda_view.py +6 -3
  16. angr/analyses/s_reaching_definitions/s_reaching_definitions.py +21 -21
  17. angr/analyses/variable_recovery/engine_ail.py +6 -6
  18. angr/calling_conventions.py +18 -8
  19. angr/procedures/definitions/linux_kernel.py +5 -0
  20. angr/utils/doms.py +40 -33
  21. angr/utils/ssa/__init__.py +21 -14
  22. angr/utils/ssa/vvar_uses_collector.py +2 -2
  23. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/METADATA +6 -6
  24. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/RECORD +28 -28
  25. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/LICENSE +0 -0
  26. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/WHEEL +0 -0
  27. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/entry_points.txt +0 -0
  28. {angr-9.2.142.dist-info → angr-9.2.143.dist-info}/top_level.txt +0 -0
angr/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # pylint: disable=wrong-import-position
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "9.2.142"
5
+ __version__ = "9.2.143"
6
6
 
7
7
  if bytes is str:
8
8
  raise Exception(
@@ -220,9 +220,9 @@ class CallingConventionAnalysis(Analysis):
220
220
  self.prototype = prototype # type: ignore
221
221
  return
222
222
  if self._function.is_plt:
223
- r = self._analyze_plt()
224
- if r is not None:
225
- self.cc, self.prototype = r
223
+ r_plt = self._analyze_plt()
224
+ if r_plt is not None:
225
+ self.cc, self.prototype, self.prototype_libname = r_plt
226
226
  return
227
227
 
228
228
  r = self._analyze_function()
@@ -278,11 +278,11 @@ class CallingConventionAnalysis(Analysis):
278
278
  self.cc = cc
279
279
  self.prototype = prototype
280
280
 
281
- def _analyze_plt(self) -> tuple[SimCC, SimTypeFunction | None] | None:
281
+ def _analyze_plt(self) -> tuple[SimCC, SimTypeFunction | None, str | None] | None:
282
282
  """
283
283
  Get the calling convention for a PLT stub.
284
284
 
285
- :return: A calling convention.
285
+ :return: A calling convention, the function type, as well as the library name if available.
286
286
  """
287
287
  assert self._function is not None
288
288
 
@@ -326,11 +326,11 @@ class CallingConventionAnalysis(Analysis):
326
326
  # we only take the prototype from the SimProcedure if
327
327
  # - the SimProcedure is a function
328
328
  # - the prototype of the SimProcedure is not guessed
329
- return cc, hooker.prototype
329
+ return cc, hooker.prototype, hooker.library_name
330
330
  if real_func.prototype is not None:
331
- return cc, real_func.prototype
331
+ return cc, real_func.prototype, real_func.prototype_libname
332
332
  else:
333
- return cc, real_func.prototype
333
+ return cc, real_func.prototype, real_func.prototype_libname
334
334
 
335
335
  if self.analyze_callsites:
336
336
  # determine the calling convention by analyzing its callsites
@@ -344,7 +344,7 @@ class CallingConventionAnalysis(Analysis):
344
344
  prototype = self._adjust_prototype(
345
345
  prototype, callsite_facts, update_arguments=UpdateArgumentsOption.AlwaysUpdate
346
346
  )
347
- return cc, prototype
347
+ return cc, prototype, None
348
348
 
349
349
  return None
350
350
 
@@ -1,10 +1,11 @@
1
1
  # pylint:disable=too-many-boolean-expressions
2
2
  from __future__ import annotations
3
- from typing import Any
3
+ from typing import Any, TYPE_CHECKING
4
4
 
5
5
  import pyvex
6
6
  import claripy
7
7
 
8
+ from angr import SIM_LIBRARIES, SIM_TYPE_COLLECTIONS
8
9
  from angr.utils.bits import s2u, u2s
9
10
  from angr.block import Block
10
11
  from angr.analyses.analysis import Analysis
@@ -13,9 +14,12 @@ from angr.knowledge_plugins.functions import Function
13
14
  from angr.codenode import BlockNode, HookNode
14
15
  from angr.engines.light import SimEngineNostmtVEX, SimEngineLight, SpOffset, RegisterOffset
15
16
  from angr.calling_conventions import SimRegArg, SimStackArg, default_cc
16
- from angr.sim_type import SimTypeBottom
17
+ from angr.sim_type import SimTypeBottom, dereference_simtype, SimTypeFunction
17
18
  from .utils import is_sane_register_variable
18
19
 
20
+ if TYPE_CHECKING:
21
+ from angr.codenode import CodeNode
22
+
19
23
 
20
24
  class FactCollectorState:
21
25
  """
@@ -224,9 +228,12 @@ class FactCollector(Analysis):
224
228
  callee_restored_regs = self._analyze_endpoints_for_restored_regs()
225
229
  self._determine_input_args(end_states, callee_restored_regs)
226
230
 
227
- def _analyze_startpoint(self):
231
+ def _analyze_startpoint(self) -> list[FactCollectorState]:
228
232
  func_graph = self.function.transition_graph
229
233
  startpoint = self.function.startpoint
234
+ if startpoint is None:
235
+ return []
236
+
230
237
  bp_as_gpr = self.function.info.get("bp_as_gpr", False)
231
238
  engine = SimEngineFactCollectorVEX(self.project, bp_as_gpr)
232
239
  init_state = FactCollectorState()
@@ -235,9 +242,9 @@ class FactCollector(Analysis):
235
242
  init_state.bp_value = init_state.sp_value
236
243
 
237
244
  traversed = set()
238
- queue: list[tuple[int, FactCollectorState, BlockNode | HookNode | Function, BlockNode | HookNode | None]] = [
239
- (0, init_state, startpoint, None)
240
- ]
245
+ queue: list[
246
+ tuple[int, FactCollectorState, CodeNode | BlockNode | HookNode | Function, BlockNode | HookNode | None]
247
+ ] = [(0, init_state, startpoint, None)]
241
248
  end_states: list[FactCollectorState] = []
242
249
  while queue:
243
250
  depth, state, node, retnode = queue.pop(0)
@@ -398,9 +405,24 @@ class FactCollector(Analysis):
398
405
  and not isinstance(func_succ.prototype.returnty, SimTypeBottom)
399
406
  ):
400
407
  # assume the function overwrites the return variable
401
- returnty_size = func_succ.prototype.returnty.with_arch(self.project.arch).size
402
- assert returnty_size is not None
403
- retval_size = returnty_size // self.project.arch.byte_width
408
+ proto = func_succ.prototype
409
+ if func_succ.prototype_libname is not None:
410
+ # we need to deref the prototype in case it uses SimTypeRef internally
411
+ type_collections = []
412
+ prototype_lib = SIM_LIBRARIES[func_succ.prototype_libname]
413
+ if prototype_lib.type_collection_names:
414
+ for typelib_name in prototype_lib.type_collection_names:
415
+ type_collections.append(SIM_TYPE_COLLECTIONS[typelib_name])
416
+ proto = dereference_simtype(proto, type_collections)
417
+
418
+ assert isinstance(proto, SimTypeFunction) and proto.returnty is not None
419
+ returnty_size = proto.returnty.with_arch(self.project.arch).size
420
+ if returnty_size is None:
421
+ # it may be None if somehow we cannot resolve a SimTypeRef; we fall back to the full
422
+ # machine word size
423
+ retval_size = self.project.arch.bytes
424
+ else:
425
+ retval_size = returnty_size // self.project.arch.byte_width
404
426
  retval_sizes.append(retval_size)
405
427
  continue
406
428
 
@@ -43,11 +43,22 @@ class ConstantResolver(IndirectJumpResolver):
43
43
  be resolved to a constant value. This resolver must be run after all other more specific resolvers.
44
44
  """
45
45
 
46
- def __init__(self, project):
46
+ def __init__(self, project, max_func_nodes: int = 512):
47
47
  super().__init__(project, timeless=False)
48
+ self.max_func_nodes = max_func_nodes
48
49
 
49
50
  def filter(self, cfg, addr, func_addr, block, jumpkind):
51
+ if not cfg.functions.contains_addr(func_addr):
52
+ # the function does not exist
53
+ return False
54
+
55
+ # for performance, we don't run constant resolver if the function is too large
56
+ func = cfg.functions.get_by_addr(func_addr)
57
+ if len(func.block_addrs_set) > self.max_func_nodes:
58
+ return False
59
+
50
60
  # we support both an indirect call and jump since the value can be resolved
61
+
51
62
  return jumpkind in {"Ijk_Boring", "Ijk_Call"}
52
63
 
53
64
  def resolve( # pylint:disable=unused-argument
@@ -6,6 +6,7 @@ from collections.abc import Sequence
6
6
  from collections import defaultdict, OrderedDict
7
7
  import logging
8
8
  import functools
9
+ import contextlib
9
10
 
10
11
  import pyvex
11
12
  import claripy
@@ -1798,7 +1799,9 @@ class JumpTableResolver(IndirectJumpResolver):
1798
1799
  # swap the two tmps
1799
1800
  jump_base_addr.tmp, jump_base_addr.tmp_1 = jump_base_addr.tmp_1, jump_base_addr.tmp
1800
1801
  # Load the concrete base address
1801
- jump_base_addr.base_addr = state.solver.eval(state.scratch.temps[jump_base_addr.tmp_1])
1802
+ with contextlib.suppress(SimError):
1803
+ # silently eat the claripy exception
1804
+ jump_base_addr.base_addr = state.solver.eval(state.scratch.temps[jump_base_addr.tmp_1])
1802
1805
  else:
1803
1806
  # We do not support the cases where the base address involves more than one addition.
1804
1807
  # One such case exists in libc-2.27.so shipped with Ubuntu x86 where esi is used as the address of the
@@ -63,7 +63,7 @@ class CompleteCallingConventionsAnalysis(Analysis):
63
63
  max_function_size: int | None = None,
64
64
  workers: int = 0,
65
65
  cc_callback: Callable | None = None,
66
- prioritize_func_addrs: Iterable[int] | None = None,
66
+ prioritize_func_addrs: list[int] | set[int] | None = None,
67
67
  skip_other_funcs: bool = False,
68
68
  auto_start: bool = True,
69
69
  func_graphs: dict[int, networkx.DiGraph] | None = None,
@@ -130,9 +130,20 @@ class CompleteCallingConventionsAnalysis(Analysis):
130
130
  Infer calling conventions for all functions in the current project.
131
131
  """
132
132
 
133
- # get an ordering of functions based on the call graph
134
- # note that the call graph is a multi-digraph. we convert it to a digraph to speed up topological sort
135
- directed_callgraph = networkx.DiGraph(self.kb.functions.callgraph)
133
+ # special case: if both _prioritize_func_addrs and _skip_other_funcs are set, we only need to sort part of
134
+ # the call graph; even better, if there is only one function set, we don't need to sort the call graph at all!
135
+ if self._prioritize_func_addrs and self._skip_other_funcs:
136
+ if len(self._prioritize_func_addrs) == 1:
137
+ self._func_addrs = list(self._prioritize_func_addrs)
138
+ self._total_funcs = 1
139
+ return
140
+ directed_callgraph = networkx.DiGraph(self.kb.functions.callgraph)
141
+ directed_callgraph = directed_callgraph.subgraph(self._prioritize_func_addrs)
142
+ else:
143
+ # get an ordering of functions based on the call graph
144
+ # note that the call graph is a multi-digraph. we convert it to a digraph to speed up topological sort
145
+ directed_callgraph = networkx.DiGraph(self.kb.functions.callgraph)
146
+ assert isinstance(directed_callgraph, networkx.DiGraph)
136
147
  sorted_funcs = GraphUtils.quasi_topological_sort_nodes(directed_callgraph)
137
148
 
138
149
  total_funcs = 0
@@ -148,7 +159,7 @@ class CompleteCallingConventionsAnalysis(Analysis):
148
159
  continue
149
160
 
150
161
  if self._max_function_size is not None:
151
- func_size = sum(block.size for block in func.blocks)
162
+ func_size = sum(block.size for block in func.blocks if block.size is not None)
152
163
  if func_size > self._max_function_size:
153
164
  _l.info(
154
165
  "Skipping variable recovery for %r since its size (%d) is greater than the cutoff "
@@ -189,6 +200,7 @@ class CompleteCallingConventionsAnalysis(Analysis):
189
200
 
190
201
  def work(self):
191
202
  total_funcs = self._total_funcs
203
+ assert total_funcs is not None
192
204
  if self._workers == 0:
193
205
  self._update_progress(0)
194
206
  for idx, func_addr in enumerate(self._func_addrs):
@@ -211,6 +223,7 @@ class CompleteCallingConventionsAnalysis(Analysis):
211
223
  self._finish_progress()
212
224
 
213
225
  else:
226
+ assert self._remaining_funcs is not None and self._func_queue is not None
214
227
  self._remaining_funcs.value = len(self._func_addrs)
215
228
 
216
229
  # generate a call tree (obviously, it's acyclic)
@@ -36,6 +36,7 @@ from angr.knowledge_plugins.key_definitions.definition import Definition
36
36
  from angr.knowledge_plugins.key_definitions.constants import OP_BEFORE
37
37
  from angr.errors import AngrRuntimeError
38
38
  from angr.analyses import Analysis, AnalysesHub
39
+ from angr.utils.timing import timethis
39
40
  from .ailgraph_walker import AILGraphWalker
40
41
  from .expression_narrower import ExprNarrowingInfo, NarrowingInfoExtractor, ExpressionNarrower
41
42
  from .block_simplifier import BlockSimplifier
@@ -202,6 +203,7 @@ class AILSimplifier(Analysis):
202
203
  AILGraphWalker(self.func_graph, _handler, replace_nodes=True).walk()
203
204
  self.blocks = {}
204
205
 
206
+ @timethis
205
207
  def _compute_reaching_definitions(self) -> SRDAModel:
206
208
  # Computing reaching definitions or return the cached one
207
209
  if self._reaching_definitions is not None:
@@ -217,6 +219,7 @@ class AILSimplifier(Analysis):
217
219
  self._reaching_definitions = rd
218
220
  return rd
219
221
 
222
+ @timethis
220
223
  def _compute_propagation(self) -> SPropagatorAnalysis:
221
224
  # Propagate expressions or return the existing result
222
225
  if self._propagator is not None:
@@ -233,6 +236,7 @@ class AILSimplifier(Analysis):
233
236
  self._propagator_dead_vvar_ids = prop.dead_vvar_ids
234
237
  return prop
235
238
 
239
+ @timethis
236
240
  def _compute_equivalence(self) -> set[Equivalence]:
237
241
  equivalence = set()
238
242
  for block in self.func_graph:
@@ -281,6 +285,7 @@ class AILSimplifier(Analysis):
281
285
  # Expression narrowing
282
286
  #
283
287
 
288
+ @timethis
284
289
  def _narrow_exprs(self) -> bool:
285
290
  """
286
291
  A register may be used with full width even when only the lower bytes are really needed. This results in the
@@ -511,9 +516,9 @@ class AILSimplifier(Analysis):
511
516
  atom = atom_queue.pop(0)
512
517
  seen.add(atom)
513
518
 
514
- use_and_exprs = rd.get_vvar_uses_with_expr(atom)
519
+ expr_and_uses = rd.all_vvar_uses[atom.varid]
515
520
 
516
- for loc, expr in use_and_exprs:
521
+ for expr, loc in set(expr_and_uses):
517
522
  old_block = block_dict.get((loc.block_addr, loc.block_idx), None)
518
523
  if old_block is None:
519
524
  # missing a block for whatever reason
@@ -532,6 +537,7 @@ class AILSimplifier(Analysis):
532
537
  )
533
538
  if new_atom not in seen:
534
539
  atom_queue.append(new_atom)
540
+ seen.add(new_atom)
535
541
  else:
536
542
  result.append((atom, loc, expr))
537
543
  return result, phi_vars
@@ -659,6 +665,7 @@ class AILSimplifier(Analysis):
659
665
  # Unifying local variables
660
666
  #
661
667
 
668
+ @timethis
662
669
  def _unify_local_variables(self) -> bool:
663
670
  """
664
671
  Find variables that are definitely equivalent and then eliminate unnecessary copies.
@@ -822,14 +829,14 @@ class AILSimplifier(Analysis):
822
829
  continue
823
830
 
824
831
  # find all its uses
825
- all_arg_copy_var_uses: set[tuple[CodeLocation, Any]] = set(
826
- rd.get_vvar_uses_with_expr(arg_copy_def.atom)
832
+ all_arg_copy_var_uses: set[tuple[Any, CodeLocation]] = rd.get_vvar_uses_with_expr(
833
+ arg_copy_def.atom
827
834
  )
828
835
  all_uses_with_def = set()
829
836
 
830
837
  should_abort = False
831
838
  for use in all_arg_copy_var_uses:
832
- used_expr = use[1]
839
+ used_expr = use[0]
833
840
  if used_expr is not None and used_expr.size != arg_copy_def.size:
834
841
  should_abort = True
835
842
  break
@@ -924,15 +931,19 @@ class AILSimplifier(Analysis):
924
931
 
925
932
  # find all uses of this definition
926
933
  # we make a copy of the set since we may touch the set (uses) when replacing expressions
927
- all_uses: set[tuple[CodeLocation, Any]] = set(rd.get_vvar_uses_with_expr(to_replace_def.atom))
934
+ all_uses: set[tuple[Any, CodeLocation]] = set(rd.all_vvar_uses[to_replace_def.atom.varid])
928
935
  # make sure none of these uses are phi nodes (depends on more than one def)
929
936
  all_uses_with_unique_def = set()
930
- for use_and_expr in all_uses:
931
- use_loc, used_expr = use_and_expr
937
+ for expr_and_use in all_uses:
938
+ used_expr, use_loc = expr_and_use
932
939
  defs_and_exprs = rd.get_uses_by_location(use_loc, exprs=True)
933
- filtered_defs = {def_ for def_, expr_ in defs_and_exprs if expr_ == used_expr}
940
+ filtered_defs = {
941
+ def_
942
+ for def_, expr_ in defs_and_exprs
943
+ if expr_ is not None and used_expr is not None and expr_.varid == used_expr.varid
944
+ }
934
945
  if len(filtered_defs) == 1:
935
- all_uses_with_unique_def.add(use_and_expr)
946
+ all_uses_with_unique_def.add(expr_and_use)
936
947
  else:
937
948
  # optimization: break early
938
949
  break
@@ -947,7 +958,7 @@ class AILSimplifier(Analysis):
947
958
 
948
959
  if not (isinstance(replace_with, VirtualVariable) and replace_with.was_parameter):
949
960
  assignment_ctr = 0
950
- all_use_locs = {use_loc for use_loc, _ in all_uses}
961
+ all_use_locs = {use_loc for _, use_loc in all_uses}
951
962
  for use_loc in all_use_locs:
952
963
  if use_loc == eq.codeloc:
953
964
  continue
@@ -960,17 +971,17 @@ class AILSimplifier(Analysis):
960
971
  if assignment_ctr > 1:
961
972
  continue
962
973
 
963
- all_uses_with_def = {(to_replace_def, use_and_expr) for use_and_expr in all_uses}
974
+ all_uses_with_def = {(to_replace_def, expr_and_use) for expr_and_use in all_uses}
964
975
 
965
976
  remove_initial_assignment = False # expression folding will take care of it
966
977
 
967
978
  assert replace_with is not None
968
979
 
969
- if any(not isinstance(use_and_expr[1], VirtualVariable) for _, use_and_expr in all_uses_with_def):
980
+ if any(not isinstance(expr_and_use[0], VirtualVariable) for _, expr_and_use in all_uses_with_def):
970
981
  # if any of the uses are phi assignments, we skip
971
982
  used_in_phi_assignment = False
972
- for _, use_and_expr in all_uses_with_def:
973
- u = use_and_expr[0]
983
+ for _, expr_and_use in all_uses_with_def:
984
+ u = expr_and_use[1]
974
985
  assert u.block_addr is not None
975
986
  assert u.stmt_idx is not None
976
987
  block = addr_and_idx_to_block[(u.block_addr, u.block_idx)]
@@ -983,8 +994,8 @@ class AILSimplifier(Analysis):
983
994
 
984
995
  # ensure the uses we consider are all after the eq location
985
996
  filtered_all_uses_with_def = []
986
- for def_, use_and_expr in all_uses_with_def:
987
- u = use_and_expr[0]
997
+ for def_, expr_and_use in all_uses_with_def:
998
+ u = expr_and_use[1]
988
999
  if (
989
1000
  u.block_addr == eq.codeloc.block_addr
990
1001
  and u.block_idx == eq.codeloc.block_idx
@@ -992,7 +1003,7 @@ class AILSimplifier(Analysis):
992
1003
  ):
993
1004
  # this use happens before the assignment - ignore it
994
1005
  continue
995
- filtered_all_uses_with_def.append((def_, use_and_expr))
1006
+ filtered_all_uses_with_def.append((def_, expr_and_use))
996
1007
  all_uses_with_def = filtered_all_uses_with_def
997
1008
 
998
1009
  if not all_uses_with_def:
@@ -1004,8 +1015,8 @@ class AILSimplifier(Analysis):
1004
1015
 
1005
1016
  # replace all uses
1006
1017
  all_uses_replaced = True
1007
- for def_, use_and_expr in all_uses_with_def:
1008
- u, used_expr = use_and_expr
1018
+ for def_, expr_and_use in all_uses_with_def:
1019
+ used_expr, u = expr_and_use
1009
1020
 
1010
1021
  use_expr_defns = []
1011
1022
  for d in rd.get_uses_by_location(u):
@@ -1110,6 +1121,7 @@ class AILSimplifier(Analysis):
1110
1121
  walker.walk_statement(stmt)
1111
1122
  return len(walker.temps) > 0
1112
1123
 
1124
+ @timethis
1113
1125
  def _fold_call_exprs(self) -> bool:
1114
1126
  """
1115
1127
  Fold a call expression (statement) into other statements if the return value of the call expression (statement)
@@ -1183,11 +1195,11 @@ class AILSimplifier(Analysis):
1183
1195
  assert the_def.codeloc.block_addr is not None
1184
1196
  assert the_def.codeloc.stmt_idx is not None
1185
1197
 
1186
- all_uses: set[tuple[CodeLocation, Any]] = set(rd.get_vvar_uses_with_expr(the_def.atom))
1198
+ all_uses: set[tuple[Any, CodeLocation]] = rd.get_vvar_uses_with_expr(the_def.atom)
1187
1199
 
1188
1200
  if len(all_uses) != 1:
1189
1201
  continue
1190
- u, used_expr = next(iter(all_uses))
1202
+ used_expr, u = next(iter(all_uses))
1191
1203
  if used_expr is None:
1192
1204
  continue
1193
1205
  assert u.block_addr is not None
@@ -1314,6 +1326,7 @@ class AILSimplifier(Analysis):
1314
1326
 
1315
1327
  return False, None
1316
1328
 
1329
+ @timethis
1317
1330
  def _iteratively_remove_dead_assignments(self) -> bool:
1318
1331
  anything_removed = False
1319
1332
  while True:
@@ -1323,6 +1336,7 @@ class AILSimplifier(Analysis):
1323
1336
  self._rebuild_func_graph()
1324
1337
  self._clear_cache()
1325
1338
 
1339
+ @timethis
1326
1340
  def _remove_dead_assignments(self) -> bool:
1327
1341
 
1328
1342
  # keeping tracking of statements to remove and statements (as well as dead vvars) to keep allows us to handle
@@ -1330,7 +1344,7 @@ class AILSimplifier(Analysis):
1330
1344
  # value and the floating-point return value.
1331
1345
  stmts_to_remove_per_block: dict[tuple[int, int | None], set[int]] = defaultdict(set)
1332
1346
  stmts_to_keep_per_block: dict[tuple[int, int | None], set[int]] = defaultdict(set)
1333
- dead_vvar_ids: set[int] = set()
1347
+ dead_vvar_ids: set[int] = self._removed_vvar_ids.copy()
1334
1348
  dead_vvar_codelocs: set[CodeLocation] = set()
1335
1349
  blocks: dict[tuple[int, int | None], Block] = {
1336
1350
  (node.addr, node.idx): self.blocks.get(node, node) for node in self.func_graph.nodes()
@@ -1343,36 +1357,43 @@ class AILSimplifier(Analysis):
1343
1357
  stackarg_offsets = (
1344
1358
  {(tpl[1] & mask) for tpl in self._stack_arg_offsets} if self._stack_arg_offsets is not None else None
1345
1359
  )
1360
+
1346
1361
  while True:
1347
1362
  new_dead_vars_found = False
1348
- for vvar, codeloc in rd.all_vvar_definitions.items():
1349
- if vvar.varid in dead_vvar_ids:
1363
+
1364
+ # traverse all virtual variable definitions
1365
+ for vvar_id, codeloc in rd.all_vvar_definitions.items():
1366
+ if vvar_id in dead_vvar_ids:
1350
1367
  continue
1351
- if vvar.varid in self._propagator_dead_vvar_ids:
1368
+ uses = None
1369
+ if vvar_id in self._propagator_dead_vvar_ids:
1352
1370
  # we are definitely removing this variable if it has no uses
1353
- uses = rd.all_vvar_uses[vvar]
1354
- elif vvar.was_stack:
1355
- if not self._remove_dead_memdefs:
1356
- if rd.is_phi_vvar_id(vvar.varid):
1357
- # we always remove unused phi variables
1358
- pass
1359
- elif vvar.varid in self._secondary_stackvars:
1360
- # secondary stack variables are potentially removable
1361
- pass
1362
- elif stackarg_offsets is not None:
1363
- # we always remove definitions for stack arguments
1364
- assert vvar.stack_offset is not None
1365
- if (vvar.stack_offset & mask) not in stackarg_offsets:
1371
+ uses = rd.all_vvar_uses[vvar_id]
1372
+
1373
+ if uses is None:
1374
+ vvar = rd.varid_to_vvar[vvar_id]
1375
+ if vvar.was_stack:
1376
+ if not self._remove_dead_memdefs:
1377
+ if rd.is_phi_vvar_id(vvar_id):
1378
+ # we always remove unused phi variables
1379
+ pass
1380
+ elif vvar_id in self._secondary_stackvars:
1381
+ # secondary stack variables are potentially removable
1382
+ pass
1383
+ elif stackarg_offsets is not None:
1384
+ # we always remove definitions for stack arguments
1385
+ assert vvar.stack_offset is not None
1386
+ if (vvar.stack_offset & mask) not in stackarg_offsets:
1387
+ continue
1388
+ else:
1366
1389
  continue
1367
- else:
1368
- continue
1369
- uses = rd.all_vvar_uses[vvar]
1390
+ uses = rd.all_vvar_uses[vvar_id]
1370
1391
 
1371
- elif vvar.was_tmp or vvar.was_reg or vvar.was_parameter:
1372
- uses = rd.all_vvar_uses[vvar]
1392
+ elif vvar.was_tmp or vvar.was_reg or vvar.was_parameter:
1393
+ uses = rd.all_vvar_uses[vvar_id]
1373
1394
 
1374
- else:
1375
- uses = set()
1395
+ else:
1396
+ uses = set()
1376
1397
 
1377
1398
  # remove uses where vvars are going to be removed
1378
1399
  filtered_uses_count = 0
@@ -1385,7 +1406,7 @@ class AILSimplifier(Analysis):
1385
1406
 
1386
1407
  if filtered_uses_count == 0:
1387
1408
  new_dead_vars_found = True
1388
- dead_vvar_ids.add(vvar.varid)
1409
+ dead_vvar_ids.add(vvar_id)
1389
1410
  dead_vvar_codelocs.add(codeloc)
1390
1411
  if not isinstance(codeloc, ExternalCodeLocation):
1391
1412
  assert codeloc.block_addr is not None
@@ -1403,30 +1424,29 @@ class AILSimplifier(Analysis):
1403
1424
  break
1404
1425
 
1405
1426
  # find all phi variables that rely on variables that no longer exist
1406
- all_removed_var_ids = self._removed_vvar_ids.copy()
1407
1427
  removed_vvar_ids = self._removed_vvar_ids
1408
1428
  while True:
1409
1429
  new_removed_vvar_ids = set()
1410
1430
  for phi_varid, phi_use_varids in rd.phivarid_to_varids.items():
1411
- if phi_varid not in all_removed_var_ids and any(
1412
- vvarid in removed_vvar_ids for vvarid in phi_use_varids
1413
- ):
1414
- loc = rd.all_vvar_definitions[rd.varid_to_vvar[phi_varid]]
1431
+ if phi_varid not in dead_vvar_ids and any(vvarid in removed_vvar_ids for vvarid in phi_use_varids):
1432
+ loc = rd.all_vvar_definitions[phi_varid]
1415
1433
  assert loc.block_addr is not None and loc.stmt_idx is not None
1416
- stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx)
1417
- new_removed_vvar_ids.add(phi_varid)
1418
- all_removed_var_ids.add(phi_varid)
1434
+ if loc.stmt_idx not in stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)]:
1435
+ stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx)
1436
+ new_removed_vvar_ids.add(phi_varid)
1437
+ dead_vvar_ids.add(phi_varid)
1419
1438
  if not new_removed_vvar_ids:
1420
1439
  break
1421
1440
  removed_vvar_ids = new_removed_vvar_ids
1422
1441
 
1423
1442
  # find all phi variables that are only ever used by other phi variables
1424
- redundant_phi_and_dirty_varids = self._find_cyclic_dependent_phis_and_dirty_vvars(rd)
1443
+ redundant_phi_and_dirty_varids = self._find_cyclic_dependent_phis_and_dirty_vvars(rd, dead_vvar_ids)
1425
1444
  for varid in redundant_phi_and_dirty_varids:
1426
- loc = rd.all_vvar_definitions[rd.varid_to_vvar[varid]]
1445
+ loc = rd.all_vvar_definitions[varid]
1427
1446
  assert loc.block_addr is not None and loc.stmt_idx is not None
1428
- stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx)
1429
- stmts_to_keep_per_block[(loc.block_addr, loc.block_idx)].discard(loc.stmt_idx)
1447
+ if loc.stmt_idx not in stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)]:
1448
+ stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx)
1449
+ stmts_to_keep_per_block[(loc.block_addr, loc.block_idx)].discard(loc.stmt_idx)
1430
1450
 
1431
1451
  for codeloc in self._calls_to_remove | self._assignments_to_remove:
1432
1452
  # this call can be removed. make sure it exists in stmts_to_remove_per_block
@@ -1481,6 +1501,7 @@ class AILSimplifier(Analysis):
1481
1501
  if self._statement_has_call_exprs(stmt):
1482
1502
  if codeloc in self._calls_to_remove:
1483
1503
  # it has a call and must be removed
1504
+ self._calls_to_remove.discard(codeloc)
1484
1505
  simplified = True
1485
1506
  continue
1486
1507
  if isinstance(stmt, Assignment) and isinstance(stmt.dst, VirtualVariable):
@@ -1538,9 +1559,8 @@ class AILSimplifier(Analysis):
1538
1559
  :return: The set of vvar use atoms.
1539
1560
  """
1540
1561
 
1541
- vvar = rd.varid_to_vvar[vvar_id]
1542
1562
  used_by: set[int | None] = set()
1543
- for used_vvar, loc in rd.all_vvar_uses[vvar]:
1563
+ for used_vvar, loc in rd.all_vvar_uses[vvar_id]:
1544
1564
  if used_vvar is None:
1545
1565
  # no explicit reference
1546
1566
  used_by.add(None)
@@ -1553,7 +1573,7 @@ class AILSimplifier(Analysis):
1553
1573
  used_by.add(None)
1554
1574
  return used_by
1555
1575
 
1556
- def _find_cyclic_dependent_phis_and_dirty_vvars(self, rd: SRDAModel) -> set[int]:
1576
+ def _find_cyclic_dependent_phis_and_dirty_vvars(self, rd: SRDAModel, dead_vvar_ids: set[int]) -> set[int]:
1557
1577
  blocks_dict: dict[tuple[int, int | None], Block] = {(bb.addr, bb.idx): bb for bb in self.func_graph}
1558
1578
 
1559
1579
  # find dirty vvars and vexccall vvars
@@ -1568,16 +1588,21 @@ class AILSimplifier(Analysis):
1568
1588
  ):
1569
1589
  dirty_vvar_ids.add(stmt.dst.varid)
1570
1590
 
1571
- phi_and_dirty_vvar_ids = rd.phi_vvar_ids | dirty_vvar_ids
1591
+ phi_and_dirty_vvar_ids = (rd.phi_vvar_ids | dirty_vvar_ids).difference(dead_vvar_ids)
1572
1592
 
1573
1593
  vvar_used_by: dict[int, set[int | None]] = defaultdict(set)
1574
1594
  for var_id in phi_and_dirty_vvar_ids:
1575
1595
  if var_id in rd.phivarid_to_varids:
1576
1596
  for used_by_varid in rd.phivarid_to_varids[var_id]:
1597
+ if used_by_varid in dead_vvar_ids:
1598
+ # this variable no longer exists
1599
+ continue
1577
1600
  if used_by_varid not in vvar_used_by:
1578
- vvar_used_by[used_by_varid] |= self._get_vvar_used_by(used_by_varid, rd, blocks_dict)
1601
+ vvar_used_by[used_by_varid] |= self._get_vvar_used_by(
1602
+ used_by_varid, rd, blocks_dict
1603
+ ).difference(dead_vvar_ids)
1579
1604
  vvar_used_by[used_by_varid].add(var_id) # probably unnecessary
1580
- vvar_used_by[var_id] |= self._get_vvar_used_by(var_id, rd, blocks_dict)
1605
+ vvar_used_by[var_id] |= self._get_vvar_used_by(var_id, rd, blocks_dict).difference(dead_vvar_ids)
1581
1606
 
1582
1607
  g = networkx.DiGraph()
1583
1608
  dummy_vvar_id = -1