angr 9.2.126__py3-none-macosx_11_0_arm64.whl → 9.2.128__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (44) hide show
  1. angr/__init__.py +1 -1
  2. angr/analyses/analysis.py +8 -2
  3. angr/analyses/cfg/cfg_fast.py +12 -1
  4. angr/analyses/decompiler/clinic.py +23 -2
  5. angr/analyses/decompiler/condition_processor.py +5 -7
  6. angr/analyses/decompiler/decompilation_cache.py +4 -0
  7. angr/analyses/decompiler/decompiler.py +36 -7
  8. angr/analyses/decompiler/dephication/graph_vvar_mapping.py +1 -2
  9. angr/analyses/decompiler/graph_region.py +3 -6
  10. angr/analyses/decompiler/label_collector.py +32 -0
  11. angr/analyses/decompiler/optimization_passes/__init__.py +3 -0
  12. angr/analyses/decompiler/optimization_passes/optimization_pass.py +6 -3
  13. angr/analyses/decompiler/optimization_passes/switch_default_case_duplicator.py +41 -3
  14. angr/analyses/decompiler/optimization_passes/switch_reused_entry_rewriter.py +102 -0
  15. angr/analyses/decompiler/presets/basic.py +2 -0
  16. angr/analyses/decompiler/presets/fast.py +2 -0
  17. angr/analyses/decompiler/presets/full.py +2 -0
  18. angr/analyses/decompiler/region_identifier.py +8 -8
  19. angr/analyses/decompiler/ssailification/traversal.py +1 -0
  20. angr/analyses/decompiler/ssailification/traversal_engine.py +15 -0
  21. angr/analyses/decompiler/structured_codegen/c.py +0 -3
  22. angr/analyses/decompiler/structured_codegen/dwarf_import.py +4 -1
  23. angr/analyses/decompiler/structuring/phoenix.py +131 -31
  24. angr/analyses/decompiler/structuring/recursive_structurer.py +3 -1
  25. angr/analyses/decompiler/structuring/structurer_base.py +33 -1
  26. angr/analyses/reaching_definitions/function_handler_library/string.py +2 -2
  27. angr/analyses/s_liveness.py +3 -3
  28. angr/analyses/s_propagator.py +74 -3
  29. angr/angrdb/models.py +2 -1
  30. angr/angrdb/serializers/kb.py +3 -3
  31. angr/angrdb/serializers/structured_code.py +5 -3
  32. angr/calling_conventions.py +1 -1
  33. angr/knowledge_base.py +1 -1
  34. angr/knowledge_plugins/__init__.py +0 -2
  35. angr/knowledge_plugins/structured_code.py +1 -1
  36. angr/lib/angr_native.dylib +0 -0
  37. angr/utils/ssa/__init__.py +8 -3
  38. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/METADATA +6 -6
  39. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/RECORD +43 -42
  40. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/WHEEL +1 -1
  41. angr/knowledge_plugins/decompilation.py +0 -45
  42. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/LICENSE +0 -0
  43. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/entry_points.txt +0 -0
  44. {angr-9.2.126.dist-info → angr-9.2.128.dist-info}/top_level.txt +0 -0
angr/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # pylint: disable=wrong-import-position
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "9.2.126"
5
+ __version__ = "9.2.128"
6
6
 
7
7
  if bytes is str:
8
8
  raise Exception(
angr/analyses/analysis.py CHANGED
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, TypeVar, Generic, cast
9
9
  from collections.abc import Callable
10
10
  from types import NoneType
11
11
  from itertools import chain
12
+ from traceback import format_exception
12
13
 
13
14
  import logging
14
15
  import time
@@ -75,6 +76,11 @@ class AnalysisLogEntry:
75
76
 
76
77
  self.message = message
77
78
 
79
+ def format(self) -> str:
80
+ if self.exc_traceback is None:
81
+ return self.message
82
+ return "\n".join((*format_exception(self.exc_type, self.exc_value, self.exc_traceback), "", self.message))
83
+
78
84
  def __getstate__(self):
79
85
  return (
80
86
  str(self.__dict__.get("exc_type")),
@@ -281,8 +287,8 @@ class Analysis:
281
287
  kb: KnowledgeBase
282
288
  _fail_fast: bool
283
289
  _name: str
284
- errors = []
285
- named_errors = defaultdict(list)
290
+ errors: list[AnalysisLogEntry] = []
291
+ named_errors: defaultdict[str, list[AnalysisLogEntry]] = defaultdict(list)
286
292
  _progress_callback = None
287
293
  _show_progressbar = False
288
294
  _progressbar = None
@@ -3309,7 +3309,18 @@ class CFGFast(ForwardAnalysis[CFGNode, CFGNode, CFGJob, int], CFGBase): # pylin
3309
3309
  # this is not a no-op block. Determine where nop instructions terminate.
3310
3310
  insns = block.capstone.insns
3311
3311
  if insns:
3312
- nop_length = self._get_nop_length(insns)
3312
+ if (
3313
+ self.project.simos is not None
3314
+ and self.project.simos.name == "Win32"
3315
+ and insns[0].mnemonic == "mov"
3316
+ ):
3317
+ op0, op1 = insns[0].operands
3318
+ if op0.type == 1 and op1.type == 1 and op0.reg == op1.reg:
3319
+ # hot-patch points on Windows DLLs
3320
+ # https://devblogs.microsoft.com/oldnewthing/20110921-00/?p=9583
3321
+ nop_length = None
3322
+ else:
3323
+ nop_length = self._get_nop_length(insns)
3313
3324
 
3314
3325
  if nop_length is None or nop_length <= 0:
3315
3326
  continue
@@ -113,6 +113,8 @@ class Clinic(Analysis):
113
113
  vvar_id_start: int = 0,
114
114
  optimization_scratch: dict[str, Any] | None = None,
115
115
  desired_variables: set[str] | None = None,
116
+ force_loop_single_exit: bool = True,
117
+ complete_successors: bool = False,
116
118
  ):
117
119
  if not func.normalized and mode == ClinicMode.DECOMPILE:
118
120
  raise ValueError("Decompilation must work on normalized function graphs.")
@@ -157,6 +159,8 @@ class Clinic(Analysis):
157
159
  self._inlined_counts = {} if inlined_counts is None else inlined_counts
158
160
  self._inlining_parents = inlining_parents or ()
159
161
  self._desired_variables = desired_variables
162
+ self._force_loop_single_exit = force_loop_single_exit
163
+ self._complete_successors = complete_successors
160
164
 
161
165
  self._register_save_areas_removed: bool = False
162
166
 
@@ -333,6 +337,7 @@ class Clinic(Analysis):
333
337
  optimization_passes=[StackCanarySimplifier],
334
338
  sp_shift=self._max_stack_depth,
335
339
  vvar_id_start=self.vvar_id_start,
340
+ fail_fast=self._fail_fast,
336
341
  )
337
342
  self.vvar_id_start = callee_clinic.vvar_id_start + 1
338
343
  self._max_stack_depth = callee_clinic._max_stack_depth
@@ -787,7 +792,7 @@ class Clinic(Analysis):
787
792
 
788
793
  # case 2: the callee is a SimProcedure
789
794
  if target_func.is_simprocedure:
790
- cc = self.project.analyses.CallingConvention(target_func)
795
+ cc = self.project.analyses.CallingConvention(target_func, fail_fast=self._fail_fast)
791
796
  if cc.cc is not None and cc.prototype is not None:
792
797
  target_func.calling_convention = cc.cc
793
798
  target_func.prototype = cc.prototype
@@ -795,7 +800,7 @@ class Clinic(Analysis):
795
800
 
796
801
  # case 3: the callee is a PLT function
797
802
  if target_func.is_plt:
798
- cc = self.project.analyses.CallingConvention(target_func)
803
+ cc = self.project.analyses.CallingConvention(target_func, fail_fast=self._fail_fast)
799
804
  if cc.cc is not None and cc.prototype is not None:
800
805
  target_func.calling_convention = cc.cc
801
806
  target_func.prototype = cc.prototype
@@ -834,6 +839,7 @@ class Clinic(Analysis):
834
839
  callsite_block_addr=callsite.addr,
835
840
  callsite_insn_addr=callsite_ins_addr,
836
841
  func_graph=func_graph,
842
+ fail_fast=self._fail_fast,
837
843
  )
838
844
 
839
845
  if cc.cc is not None and cc.prototype is not None:
@@ -864,6 +870,7 @@ class Clinic(Analysis):
864
870
  # finally, recover the calling convention of the current function
865
871
  if self.function.prototype is None or self.function.calling_convention is None:
866
872
  self.project.analyses.CompleteCallingConventions(
873
+ fail_fast=self._fail_fast,
867
874
  recover_variables=True,
868
875
  prioritize_func_addrs=[self.function.addr],
869
876
  skip_other_funcs=True,
@@ -896,6 +903,7 @@ class Clinic(Analysis):
896
903
  spt = self.project.analyses.StackPointerTracker(
897
904
  self.function,
898
905
  regs,
906
+ fail_fast=self._fail_fast,
899
907
  track_memory=self._sp_tracker_track_memory,
900
908
  cross_insn_opt=False,
901
909
  initial_reg_values=initial_reg_values,
@@ -1043,6 +1051,7 @@ class Clinic(Analysis):
1043
1051
  reg_name=self.project.arch.translate_register_name(
1044
1052
  ret_reg_offset, size=self.project.arch.bits
1045
1053
  ),
1054
+ **target.tags,
1046
1055
  )
1047
1056
  call_stmt = ailment.Stmt.Call(
1048
1057
  None,
@@ -1129,6 +1138,7 @@ class Clinic(Analysis):
1129
1138
  simp = self.project.analyses.AILBlockSimplifier(
1130
1139
  ail_block,
1131
1140
  self.function.addr,
1141
+ fail_fast=self._fail_fast,
1132
1142
  remove_dead_memdefs=remove_dead_memdefs,
1133
1143
  stack_pointer_tracker=stack_pointer_tracker,
1134
1144
  peephole_optimizations=self.peephole_optimizations,
@@ -1200,6 +1210,7 @@ class Clinic(Analysis):
1200
1210
 
1201
1211
  simp = self.project.analyses.AILSimplifier(
1202
1212
  self.function,
1213
+ fail_fast=self._fail_fast,
1203
1214
  func_graph=ail_graph,
1204
1215
  remove_dead_memdefs=remove_dead_memdefs,
1205
1216
  unify_variables=unify_variables,
@@ -1258,6 +1269,8 @@ class Clinic(Analysis):
1258
1269
  vvar_id_start=self.vvar_id_start,
1259
1270
  entry_node_addr=self.entry_node_addr,
1260
1271
  scratch=self.optimization_scratch,
1272
+ force_loop_single_exit=self._force_loop_single_exit,
1273
+ complete_successors=self._complete_successors,
1261
1274
  **kwargs,
1262
1275
  )
1263
1276
  if a.out_graph:
@@ -1341,6 +1354,7 @@ class Clinic(Analysis):
1341
1354
  ssailification = self.project.analyses.Ssailification(
1342
1355
  self.function,
1343
1356
  ail_graph,
1357
+ fail_fast=self._fail_fast,
1344
1358
  entry=next(iter(bb for bb in ail_graph if (bb.addr, bb.idx) == self.entry_node_addr)),
1345
1359
  ail_manager=self._ail_manager,
1346
1360
  ssa_stackvars=False,
@@ -1354,6 +1368,7 @@ class Clinic(Analysis):
1354
1368
  ssailification = self.project.analyses.Ssailification(
1355
1369
  self.function,
1356
1370
  ail_graph,
1371
+ fail_fast=self._fail_fast,
1357
1372
  entry=next(iter(bb for bb in ail_graph if (bb.addr, bb.idx) == self.entry_node_addr)),
1358
1373
  ail_manager=self._ail_manager,
1359
1374
  ssa_tmps=True,
@@ -1368,6 +1383,7 @@ class Clinic(Analysis):
1368
1383
  dephication = self.project.analyses.GraphDephicationVVarMapping(
1369
1384
  self.function,
1370
1385
  ail_graph,
1386
+ fail_fast=self._fail_fast,
1371
1387
  entry=next(iter(bb for bb in ail_graph if (bb.addr, bb.idx) == self.entry_node_addr)),
1372
1388
  vvar_id_start=self.vvar_id_start,
1373
1389
  )
@@ -1420,6 +1436,7 @@ class Clinic(Analysis):
1420
1436
  rd = self.project.analyses.SReachingDefinitions(
1421
1437
  subject=self.function,
1422
1438
  func_graph=ail_graph,
1439
+ fail_fast=self._fail_fast,
1423
1440
  # use_callee_saved_regs_at_return=not self._register_save_areas_removed, FIXME
1424
1441
  )
1425
1442
 
@@ -1430,6 +1447,7 @@ class Clinic(Analysis):
1430
1447
  def _handler(block):
1431
1448
  csm = self.project.analyses.AILCallSiteMaker(
1432
1449
  block,
1450
+ fail_fast=self._fail_fast,
1433
1451
  reaching_definitions=rd,
1434
1452
  stack_pointer_tracker=stack_pointer_tracker,
1435
1453
  ail_manager=self._ail_manager,
@@ -1443,6 +1461,7 @@ class Clinic(Analysis):
1443
1461
  simp = self.project.analyses.AILBlockSimplifier(
1444
1462
  ail_block,
1445
1463
  self.function.addr,
1464
+ fail_fast=self._fail_fast,
1446
1465
  stack_pointer_tracker=stack_pointer_tracker,
1447
1466
  peephole_optimizations=self.peephole_optimizations,
1448
1467
  )
@@ -1526,6 +1545,7 @@ class Clinic(Analysis):
1526
1545
  tmp_kb.functions = self.kb.functions
1527
1546
  vr = self.project.analyses.VariableRecoveryFast(
1528
1547
  self.function, # pylint:disable=unused-variable
1548
+ fail_fast=self._fail_fast,
1529
1549
  func_graph=ail_graph,
1530
1550
  kb=tmp_kb,
1531
1551
  track_sp=False,
@@ -1558,6 +1578,7 @@ class Clinic(Analysis):
1558
1578
  vr.type_constraints,
1559
1579
  vr.func_typevar,
1560
1580
  kb=tmp_kb,
1581
+ fail_fast=self._fail_fast,
1561
1582
  var_mapping=vr.var_to_typevars,
1562
1583
  must_struct=must_struct,
1563
1584
  ground_truth=groundtruth,
@@ -1113,17 +1113,15 @@ class ConditionProcessor:
1113
1113
  r1_with: claripy.ast.Bool,
1114
1114
  ) -> claripy.ast.Bool:
1115
1115
  if ast.op == "And":
1116
- return ast.make_like(
1117
- "And", (ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
1116
+ return claripy.And(
1117
+ *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
1118
1118
  )
1119
1119
  if ast.op == "Or":
1120
- return ast.make_like(
1121
- "Or", (ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
1120
+ return claripy.Or(
1121
+ *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
1122
1122
  )
1123
1123
  if ast.op == "Not":
1124
- return ast.make_like(
1125
- "Not", (ConditionProcessor._replace_term_in_ast(ast.args[0], r0, r0_with, r1, r1_with),)
1126
- )
1124
+ return claripy.Not(ConditionProcessor._replace_term_in_ast(ast.args[0], r0, r0_with, r1, r1_with))
1127
1125
  if ast is r0:
1128
1126
  return r0_with
1129
1127
  if ast is r1:
@@ -23,6 +23,7 @@ class DecompilationCache:
23
23
  "clinic",
24
24
  "ite_exprs",
25
25
  "binop_operators",
26
+ "errors",
26
27
  )
27
28
 
28
29
  def __init__(self, addr):
@@ -35,7 +36,10 @@ class DecompilationCache:
35
36
  self.clinic: Clinic | None = None
36
37
  self.ite_exprs: set[tuple[int, Any]] | None = None
37
38
  self.binop_operators: dict[OpDescriptor, str] | None = None
39
+ self.errors: list[str] = []
38
40
 
39
41
  @property
40
42
  def local_types(self):
43
+ if self.clinic is None or self.clinic.variable_kb is None:
44
+ return None
41
45
  return self.clinic.variable_kb.variables[self.addr].types
@@ -139,7 +139,22 @@ class Decompiler(Analysis):
139
139
  self.expr_collapse_depth = expr_collapse_depth
140
140
 
141
141
  if decompile:
142
- self._decompile()
142
+ with self._resilience():
143
+ self._decompile()
144
+ if self.errors:
145
+ if (self.func.addr, self._flavor) not in self.kb.decompilations:
146
+ self.kb.decompilations[(self.func.addr, self._flavor)] = DecompilationCache(self.func.addr)
147
+ for error in self.errors:
148
+ self.kb.decompilations[(self.func.addr, self._flavor)].errors.append(error.format())
149
+ with self._resilience():
150
+ l.info("Decompilation failed for %s. Switching to basic preset and trying again.")
151
+ if preset != DECOMPILATION_PRESETS["basic"]:
152
+ self._optimization_passes = DECOMPILATION_PRESETS["basic"].get_optimization_passes(
153
+ self.project.arch, self.project.simos.name
154
+ )
155
+ self._decompile()
156
+ for error in self.errors:
157
+ self.kb.decompilations[(self.func.addr, self._flavor)].errors.append(error.format())
143
158
 
144
159
  def _can_use_decompilation_cache(self, cache: DecompilationCache) -> bool:
145
160
  a, b = self._cache_parameters, cache.parameters
@@ -155,7 +170,7 @@ class Decompiler(Analysis):
155
170
 
156
171
  if self._cache_parameters is not None:
157
172
  try:
158
- cache = self.kb.decompilations[self.func.addr]
173
+ cache = self.kb.decompilations[(self.func.addr, self._flavor)]
159
174
  if not self._can_use_decompilation_cache(cache):
160
175
  cache = None
161
176
  except KeyError:
@@ -220,6 +235,7 @@ class Decompiler(Analysis):
220
235
  clinic = self.project.analyses.Clinic(
221
236
  self.func,
222
237
  kb=self.kb,
238
+ fail_fast=self._fail_fast,
223
239
  variable_kb=variable_kb,
224
240
  reset_variable_names=reset_variable_names,
225
241
  optimization_passes=self._optimization_passes,
@@ -233,6 +249,8 @@ class Decompiler(Analysis):
233
249
  inline_functions=self._inline_functions,
234
250
  desired_variables=self._desired_variables,
235
251
  optimization_scratch=self._optimization_scratch,
252
+ force_loop_single_exit=self._force_loop_single_exit,
253
+ complete_successors=self._complete_successors,
236
254
  **self.options_to_params(self.options_by_class["clinic"]),
237
255
  )
238
256
  else:
@@ -293,7 +311,7 @@ class Decompiler(Analysis):
293
311
  self._update_progress(75.0, text="Structuring code")
294
312
 
295
313
  # structure it
296
- rs = self.project.analyses[RecursiveStructurer].prep(kb=self.kb)(
314
+ rs = self.project.analyses[RecursiveStructurer].prep(kb=self.kb, fail_fast=self._fail_fast)(
297
315
  ri.region,
298
316
  cond_proc=cond_proc,
299
317
  func=self.func,
@@ -306,6 +324,7 @@ class Decompiler(Analysis):
306
324
  self.func,
307
325
  rs.result,
308
326
  kb=self.kb,
327
+ fail_fast=self._fail_fast,
309
328
  variable_kb=clinic.variable_kb,
310
329
  **self.options_to_params(self.options_by_class["region_simplifier"]),
311
330
  )
@@ -330,6 +349,7 @@ class Decompiler(Analysis):
330
349
  flavor=self._flavor,
331
350
  func_args=clinic.arg_list,
332
351
  kb=self.kb,
352
+ fail_fast=self._fail_fast,
333
353
  variable_kb=clinic.variable_kb,
334
354
  expr_comments=old_codegen.expr_comments if old_codegen is not None else None,
335
355
  stmt_comments=old_codegen.stmt_comments if old_codegen is not None else None,
@@ -347,10 +367,10 @@ class Decompiler(Analysis):
347
367
  self.cache.codegen = codegen
348
368
  self.cache.clinic = self.clinic
349
369
 
350
- self.kb.decompilations[self.func.addr] = self.cache
370
+ self.kb.decompilations[(self.func.addr, self._flavor)] = self.cache
351
371
 
352
372
  def _recover_regions(self, graph: networkx.DiGraph, condition_processor, update_graph: bool = True):
353
- return self.project.analyses[RegionIdentifier].prep(kb=self.kb)(
373
+ return self.project.analyses[RegionIdentifier].prep(kb=self.kb, fail_fast=self._fail_fast)(
354
374
  self.func,
355
375
  graph=graph,
356
376
  cond_proc=condition_processor,
@@ -403,6 +423,8 @@ class Decompiler(Analysis):
403
423
  reaching_definitions=reaching_definitions,
404
424
  entry_node_addr=self.clinic.entry_node_addr,
405
425
  scratch=self._optimization_scratch,
426
+ force_loop_single_exit=self._force_loop_single_exit,
427
+ complete_successors=self._complete_successors,
406
428
  **kwargs,
407
429
  )
408
430
 
@@ -463,6 +485,8 @@ class Decompiler(Analysis):
463
485
  vvar_id_start=self.vvar_id_start,
464
486
  entry_node_addr=self.clinic.entry_node_addr,
465
487
  scratch=self._optimization_scratch,
488
+ force_loop_single_exit=self._force_loop_single_exit,
489
+ complete_successors=self._complete_successors,
466
490
  **kwargs,
467
491
  )
468
492
 
@@ -546,6 +570,7 @@ class Decompiler(Analysis):
546
570
  type_constraints,
547
571
  func_typevar,
548
572
  kb=var_kb,
573
+ fail_fast=self._fail_fast,
549
574
  var_mapping=var_to_typevar,
550
575
  must_struct=must_struct,
551
576
  ground_truth=groundtruth,
@@ -613,11 +638,15 @@ class Decompiler(Analysis):
613
638
  )
614
639
 
615
640
  def _transform_graph_from_ssa(self, ail_graph: networkx.DiGraph) -> networkx.DiGraph:
616
- dephication = self.project.analyses.GraphDephication(self.func, ail_graph, rewrite=True)
641
+ dephication = self.project.analyses.GraphDephication(
642
+ self.func, ail_graph, rewrite=True, kb=self.kb, fail_fast=self._fail_fast
643
+ )
617
644
  return dephication.output
618
645
 
619
646
  def _transform_seqnode_from_ssa(self, seq_node: SequenceNode) -> SequenceNode:
620
- dephication = self.project.analyses.SeqNodeDephication(self.func, seq_node, rewrite=True)
647
+ dephication = self.project.analyses.SeqNodeDephication(
648
+ self.func, seq_node, rewrite=True, kb=self.kb, fail_fast=self._fail_fast
649
+ )
621
650
  return dephication.output
622
651
 
623
652
  @staticmethod
@@ -290,8 +290,7 @@ class GraphDephicationVVarMapping(Analysis): # pylint:disable=abstract-method
290
290
  for stmt in dst_block.statements:
291
291
  if isinstance(stmt, Label):
292
292
  continue
293
- r, _ = is_phi_assignment(stmt)
294
- if r:
293
+ if is_phi_assignment(stmt):
295
294
  for src_, vvar in stmt.src.src_and_vvars:
296
295
  if src_ == src and vvar is not None and vvar.varid == vvar_id:
297
296
  return True
@@ -382,12 +382,9 @@ class GraphRegion:
382
382
  if src in graph:
383
383
  graph.add_edge(src, dst)
384
384
  else:
385
- # it may happen that the dst node does not exist in sub_graph
386
- # fallback
387
- l.info("Node dst is not found in sub_graph. Enter the fall back logic.")
388
- for src in sub_graph.nodes:
389
- if sub_graph.out_degree[src] == 0:
390
- graph.add_edge(src, dst)
385
+ # it may happen that the dst node no longer exists in sub_graph or its successors
386
+ # this is because we have deemed that the dst node is no longer a valid successor for sub_graph
387
+ pass
391
388
 
392
389
  graph.add_nodes_from(sub_graph_nodes)
393
390
  graph.add_edges_from(sub_graph_edges)
@@ -0,0 +1,32 @@
1
+ # pylint:disable=unused-argument
2
+ from __future__ import annotations
3
+ from collections import defaultdict
4
+
5
+ import ailment
6
+
7
+ from .sequence_walker import SequenceWalker
8
+
9
+
10
+ class LabelCollector:
11
+ """
12
+ Collect all labels.
13
+ """
14
+
15
+ def __init__(self, node):
16
+ self.root = node
17
+ self.labels: defaultdict[str, list[tuple[int, int | None]]] = defaultdict(list)
18
+
19
+ handlers = {
20
+ ailment.Block: self._handle_Block,
21
+ }
22
+ self._walker = SequenceWalker(handlers=handlers)
23
+ self._walker.walk(self.root)
24
+
25
+ #
26
+ # Handlers
27
+ #
28
+
29
+ def _handle_Block(self, block: ailment.Block, **kwargs):
30
+ for stmt in block.statements:
31
+ if isinstance(stmt, ailment.Stmt.Label):
32
+ self.labels[stmt.name].append((block.addr, block.idx))
@@ -32,6 +32,7 @@ from .inlined_string_transformation_simplifier import InlinedStringTransformatio
32
32
  from .const_prop_reverter import ConstPropOptReverter
33
33
  from .call_stmt_rewriter import CallStatementRewriter
34
34
  from .duplication_reverter import DuplicationReverter
35
+ from .switch_reused_entry_rewriter import SwitchReusedEntryRewriter
35
36
 
36
37
  if TYPE_CHECKING:
37
38
  from angr.analyses.decompiler.presets import DecompilationPreset
@@ -55,6 +56,7 @@ ALL_OPTIMIZATION_PASSES = [
55
56
  ReturnDuplicatorHigh,
56
57
  DeadblockRemover,
57
58
  SwitchDefaultCaseDuplicator,
59
+ SwitchReusedEntryRewriter,
58
60
  ConstPropOptReverter,
59
61
  DuplicationReverter,
60
62
  LoweredSwitchSimplifier,
@@ -129,6 +131,7 @@ __all__ = (
129
131
  "CrossJumpReverter",
130
132
  "CodeMotionOptimization",
131
133
  "SwitchDefaultCaseDuplicator",
134
+ "SwitchReusedEntryRewriter",
132
135
  "DeadblockRemover",
133
136
  "InlinedStringTransformationSimplifier",
134
137
  "ConstPropOptReverter",
@@ -118,6 +118,8 @@ class OptimizationPass(BaseOptimizationPass):
118
118
  vvar_id_start=None,
119
119
  entry_node_addr=None,
120
120
  scratch: dict[str, Any] | None = None,
121
+ force_loop_single_exit: bool = True,
122
+ complete_successors: bool = False,
121
123
  **kwargs,
122
124
  ):
123
125
  super().__init__(func)
@@ -134,6 +136,8 @@ class OptimizationPass(BaseOptimizationPass):
134
136
  self.entry_node_addr: tuple[int, int | None] = (
135
137
  entry_node_addr if entry_node_addr is not None else (func.addr, None)
136
138
  )
139
+ self._force_loop_single_exit = force_loop_single_exit
140
+ self._complete_successors = complete_successors
137
141
 
138
142
  # output
139
143
  self.out_graph: networkx.DiGraph | None = None
@@ -255,9 +259,8 @@ class OptimizationPass(BaseOptimizationPass):
255
259
  graph=graph,
256
260
  cond_proc=condition_processor or ConditionProcessor(self.project.arch),
257
261
  update_graph=update_graph,
258
- # TODO: find a way to pass Phoenix/DREAM options here (see decompiler.py for correct use)
259
- force_loop_single_exit=True,
260
- complete_successors=False,
262
+ force_loop_single_exit=self._force_loop_single_exit,
263
+ complete_successors=self._complete_successors,
261
264
  entry_node_addr=self.entry_node_addr,
262
265
  )
263
266
 
@@ -1,10 +1,15 @@
1
1
  # pylint:disable=too-many-boolean-expressions
2
2
  from __future__ import annotations
3
3
  from itertools import count
4
+ from collections import defaultdict
4
5
  import logging
5
6
 
6
7
  import networkx
7
8
 
9
+ from ailment.block import Block
10
+ from ailment.statement import Jump
11
+ from ailment.expression import Const
12
+
8
13
  from angr.knowledge_plugins.cfg import IndirectJumpType
9
14
  from .optimization_pass import OptimizationPass, OptimizationPassStage
10
15
 
@@ -29,17 +34,19 @@ class SwitchDefaultCaseDuplicator(OptimizationPass):
29
34
 
30
35
  ARCHES = None
31
36
  PLATFORMS = None
32
- STAGE = OptimizationPassStage.BEFORE_REGION_IDENTIFICATION
37
+ STAGE = OptimizationPassStage.AFTER_AIL_GRAPH_CREATION
33
38
  NAME = "Duplicate default-case nodes to undo default-case node reuse caused by compiler code deduplication"
34
39
  DESCRIPTION = __doc__.strip()
35
40
 
36
41
  def __init__(self, func, **kwargs):
37
42
  super().__init__(func, **kwargs)
38
43
 
39
- self.node_idx = count(start=0)
44
+ self.node_idx = count(start=self._scratch.get("node_idx", 0))
40
45
 
41
46
  self.analyze()
42
47
 
48
+ self._scratch["node_idx"] = next(self.node_idx)
49
+
43
50
  def _check(self):
44
51
  jumptables = self.kb.cfgs.get_most_accurate().jump_tables
45
52
  switch_jump_block_addrs = {
@@ -77,8 +84,39 @@ class SwitchDefaultCaseDuplicator(OptimizationPass):
77
84
  out_graph = None
78
85
  duplicated_default_addrs: set[int] = set()
79
86
 
87
+ default_addr_count = defaultdict(int)
88
+ goto_rewritten_default_addrs = set()
89
+ for _, _, default_addr in default_case_node_addrs:
90
+ default_addr_count[default_addr] += 1
91
+ for default_addr, cnt in default_addr_count.items():
92
+ if cnt > 1:
93
+ # rewrite all of them into gotos
94
+ default_node = self._get_block(default_addr)
95
+ for switch_head_addr in sorted((sa for sa, _, da in default_case_node_addrs if da == default_addr)):
96
+ switch_head_node = self._get_block(switch_head_addr)
97
+ goto_stmt = Jump(
98
+ None,
99
+ Const(None, None, default_addr, self.project.arch.bits, ins_addr=default_addr),
100
+ target_idx=None, # I'm assuming the ID of the default node is None here
101
+ ins_addr=default_addr,
102
+ )
103
+ goto_node = Block(
104
+ default_addr,
105
+ 0,
106
+ statements=[goto_stmt],
107
+ idx=next(self.node_idx),
108
+ )
109
+
110
+ if out_graph is None:
111
+ out_graph = self._graph
112
+ out_graph.remove_edge(switch_head_node, default_node)
113
+ out_graph.add_edge(switch_head_node, goto_node)
114
+ out_graph.add_edge(goto_node, default_node)
115
+
116
+ goto_rewritten_default_addrs.add(default_addr)
117
+
80
118
  for switch_head_addr, jump_node_addr, default_addr in default_case_node_addrs:
81
- if default_addr in duplicated_default_addrs:
119
+ if default_addr in duplicated_default_addrs or default_addr in goto_rewritten_default_addrs:
82
120
  continue
83
121
 
84
122
  default_case_node = self._func.get_node(default_addr)