angr 9.2.111__py3-none-manylinux2014_x86_64.whl → 9.2.113__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (35) hide show
  1. angr/__init__.py +1 -1
  2. angr/analyses/cfg/cfg_base.py +4 -1
  3. angr/analyses/decompiler/condition_processor.py +9 -2
  4. angr/analyses/decompiler/optimization_passes/__init__.py +3 -1
  5. angr/analyses/decompiler/optimization_passes/const_prop_reverter.py +367 -0
  6. angr/analyses/decompiler/optimization_passes/deadblock_remover.py +1 -1
  7. angr/analyses/decompiler/optimization_passes/lowered_switch_simplifier.py +99 -12
  8. angr/analyses/decompiler/optimization_passes/optimization_pass.py +79 -9
  9. angr/analyses/decompiler/optimization_passes/return_duplicator_base.py +21 -0
  10. angr/analyses/decompiler/optimization_passes/return_duplicator_low.py +111 -9
  11. angr/analyses/decompiler/redundant_label_remover.py +17 -0
  12. angr/analyses/decompiler/seq_cf_structure_counter.py +37 -0
  13. angr/analyses/decompiler/structured_codegen/c.py +4 -5
  14. angr/analyses/decompiler/structuring/phoenix.py +3 -3
  15. angr/analyses/reaching_definitions/rd_state.py +2 -0
  16. angr/analyses/reaching_definitions/reaching_definitions.py +7 -0
  17. angr/angrdb/serializers/loader.py +91 -7
  18. angr/calling_conventions.py +11 -9
  19. angr/knowledge_plugins/key_definitions/live_definitions.py +5 -0
  20. angr/knowledge_plugins/propagations/states.py +3 -2
  21. angr/knowledge_plugins/variables/variable_manager.py +1 -1
  22. angr/procedures/stubs/ReturnUnconstrained.py +1 -2
  23. angr/procedures/stubs/syscall_stub.py +1 -2
  24. angr/sim_type.py +354 -136
  25. angr/state_plugins/debug_variables.py +2 -2
  26. angr/state_plugins/solver.py +5 -13
  27. angr/storage/memory_mixins/multi_value_merger_mixin.py +13 -3
  28. angr/utils/orderedset.py +70 -0
  29. angr/vaults.py +0 -1
  30. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/METADATA +6 -6
  31. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/RECORD +35 -32
  32. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/WHEEL +1 -1
  33. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/LICENSE +0 -0
  34. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/entry_points.txt +0 -0
  35. {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/top_level.txt +0 -0
angr/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # pylint: disable=wildcard-import
2
2
  # pylint: disable=wrong-import-position
3
3
 
4
- __version__ = "9.2.111"
4
+ __version__ = "9.2.113"
5
5
 
6
6
  if bytes is str:
7
7
  raise Exception(
@@ -7,7 +7,6 @@ import networkx
7
7
  from sortedcontainers import SortedDict
8
8
 
9
9
  import pyvex
10
- from claripy.utils.orderedset import OrderedSet
11
10
  from cle import ELF, PE, Blob, TLSObject, MachO, ExternObject, KernelObject, FunctionHintSource, Hex, Coff, SRec, XBE
12
11
  from cle.backends import NamedRegion
13
12
  import archinfo
@@ -34,6 +33,7 @@ from angr.codenode import HookNode, BlockNode
34
33
  from angr.engines.vex.lifter import VEX_IRSB_MAX_SIZE, VEX_IRSB_MAX_INST
35
34
  from angr.analyses import Analysis
36
35
  from angr.analyses.stack_pointer_tracker import StackPointerTracker
36
+ from angr.utils.orderedset import OrderedSet
37
37
  from .indirect_jump_resolvers.default_resolvers import default_indirect_jump_resolvers
38
38
 
39
39
  if TYPE_CHECKING:
@@ -746,6 +746,9 @@ class CFGBase(Analysis):
746
746
  memory_regions = []
747
747
 
748
748
  for b in binaries:
749
+ if not b.has_memory:
750
+ continue
751
+
749
752
  if isinstance(b, ELF):
750
753
  # If we have sections, we get result from sections
751
754
  sections = []
@@ -184,7 +184,12 @@ class ConditionProcessor:
184
184
  self.edge_conditions = edge_conditions
185
185
 
186
186
  def recover_reaching_conditions(
187
- self, region, graph=None, with_successors=False, case_entry_to_switch_head: dict[int, int] | None = None
187
+ self,
188
+ region,
189
+ graph=None,
190
+ with_successors=False,
191
+ case_entry_to_switch_head: dict[int, int] | None = None,
192
+ simplify_conditions: bool = True,
188
193
  ):
189
194
  """
190
195
  Recover the reaching conditions for each block in an acyclic graph. Note that we assume the graph that's passed
@@ -255,7 +260,9 @@ class ConditionProcessor:
255
260
  reaching_condition = claripy.Or(claripy.And(pred_condition, edge_condition), reaching_condition)
256
261
 
257
262
  if reaching_condition is not None:
258
- reaching_conditions[node] = self.simplify_condition(reaching_condition)
263
+ reaching_conditions[node] = (
264
+ self.simplify_condition(reaching_condition) if simplify_conditions else reaching_condition
265
+ )
259
266
 
260
267
  # My hypothesis: for nodes where two paths come together *and* those that cannot be further structured into
261
268
  # another if-else construct (we take the short-cut by testing if the operator is an "Or" after running our
@@ -28,6 +28,7 @@ from .code_motion import CodeMotionOptimization
28
28
  from .switch_default_case_duplicator import SwitchDefaultCaseDuplicator
29
29
  from .deadblock_remover import DeadblockRemover
30
30
  from .inlined_string_transformation_simplifier import InlinedStringTransformationSimplifier
31
+ from .const_prop_reverter import ConstPropOptReverter
31
32
 
32
33
  # order matters!
33
34
  _all_optimization_passes = [
@@ -47,7 +48,8 @@ _all_optimization_passes = [
47
48
  (ReturnDuplicatorHigh, True),
48
49
  (DeadblockRemover, True),
49
50
  (SwitchDefaultCaseDuplicator, True),
50
- (LoweredSwitchSimplifier, False),
51
+ (ConstPropOptReverter, True),
52
+ (LoweredSwitchSimplifier, True),
51
53
  (ReturnDuplicatorLow, True),
52
54
  (ReturnDeduplicator, True),
53
55
  (CodeMotionOptimization, True),
@@ -0,0 +1,367 @@
1
+ import logging
2
+ from collections.abc import Callable
3
+ import itertools
4
+
5
+ import networkx
6
+ import claripy
7
+ from ailment import Const
8
+ from ailment.block_walker import AILBlockWalkerBase
9
+ from ailment.statement import Call, Statement, ConditionalJump, Assignment, Store, Return
10
+ from ailment.expression import Convert, Register
11
+
12
+ from .optimization_pass import OptimizationPass, OptimizationPassStage
13
+ from ..utils import remove_labels, add_labels
14
+ from ....knowledge_plugins.key_definitions.atoms import MemoryLocation
15
+ from ....knowledge_plugins.key_definitions.constants import OP_BEFORE
16
+
17
+
18
+ _l = logging.getLogger(__name__)
19
+
20
+
21
+ class PairAILBlockWalker:
22
+ """
23
+ This AILBlockWalker will walk two blocks at a time and call a handler for each pair of statements that are
24
+ instances of the same type. This is useful for comparing two statements for similarity across blocks.
25
+ """
26
+
27
+ def __init__(self, graph: networkx.DiGraph, stmt_pair_handlers=None):
28
+ self.graph = graph
29
+
30
+ _default_stmt_handlers = {
31
+ Assignment: self._handle_Assignment_pair,
32
+ Call: self._handle_Call_pair,
33
+ Store: self._handle_Store_pair,
34
+ ConditionalJump: self._handle_ConditionalJump_pair,
35
+ Return: self._handle_Return_pair,
36
+ }
37
+
38
+ self.stmt_pair_handlers: dict[Statement, Callable] = (
39
+ stmt_pair_handlers if stmt_pair_handlers else _default_stmt_handlers
40
+ )
41
+
42
+ # pylint: disable=no-self-use
43
+ def _walk_block(self, block):
44
+ walked_objs = {Assignment: set(), Call: set(), Store: set(), ConditionalJump: set(), Return: set()}
45
+
46
+ # create a walker that will:
47
+ # 1. recursively expand a stmt with the default handler then,
48
+ # 2. record the stmt parts in the walked_objs dict with the overwritten handler
49
+ #
50
+ # CallExpressions are a special case that require a handler in expressions, since they are statements.
51
+ walker = AILBlockWalkerBase()
52
+ _default_stmt_handlers = {
53
+ Assignment: walker._handle_Assignment,
54
+ Call: walker._handle_Call,
55
+ Store: walker._handle_Store,
56
+ ConditionalJump: walker._handle_ConditionalJump,
57
+ Return: walker._handle_Return,
58
+ }
59
+
60
+ def _handle_ail_obj(stmt_idx, stmt, block_):
61
+ _default_stmt_handlers[type(stmt)](stmt_idx, stmt, block_)
62
+ walked_objs[type(stmt)].add(stmt)
63
+
64
+ # pylint: disable=unused-argument
65
+ def _handle_call_expr(expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block_):
66
+ walked_objs[Call].add(expr)
67
+
68
+ _stmt_handlers = {typ: _handle_ail_obj for typ in walked_objs}
69
+ walker.stmt_handlers = _stmt_handlers
70
+ walker.expr_handlers[Call] = _handle_call_expr
71
+
72
+ walker.walk(block)
73
+ return walked_objs
74
+
75
+ def walk(self):
76
+ for b0, b1 in itertools.combinations(self.graph.nodes, 2):
77
+ walked_obj_by_blk = {}
78
+
79
+ for blk in (b0, b1):
80
+ walked_obj_by_blk[blk] = self._walk_block(blk)
81
+
82
+ for typ, objs0 in walked_obj_by_blk[b0].items():
83
+ try:
84
+ handler = self.stmt_pair_handlers[typ]
85
+ except KeyError:
86
+ continue
87
+
88
+ if not objs0:
89
+ continue
90
+
91
+ objs1 = walked_obj_by_blk[b1][typ]
92
+ if not objs1:
93
+ continue
94
+
95
+ for o0 in objs0:
96
+ for o1 in objs1:
97
+ handler(o0, b0, o1, b1)
98
+
99
+ #
100
+ # default handlers
101
+ #
102
+
103
+ # pylint: disable=unused-argument,no-self-use
104
+ def _handle_Assignment_pair(self, obj0, blk0, obj1, blk1):
105
+ return
106
+
107
+ # pylint: disable=unused-argument,no-self-use
108
+ def _handle_Call_pair(self, obj0, blk0, obj1, blk1):
109
+ return
110
+
111
+ # pylint: disable=unused-argument,no-self-use
112
+ def _handle_Store_pair(self, obj0, blk0, obj1, blk1):
113
+ return
114
+
115
+ # pylint: disable=unused-argument,no-self-use
116
+ def _handle_ConditionalJump_pair(self, obj0, blk0, obj1, blk1):
117
+ return
118
+
119
+ # pylint: disable=unused-argument,no-self-use
120
+ def _handle_Return_pair(self, obj0, blk0, obj1, blk1):
121
+ return
122
+
123
+
124
+ class ConstPropOptReverter(OptimizationPass):
125
+ """
126
+ This optimization reverts the effects of constant propagation done by the compiler as discussed in the
127
+ USENIX 2024 paper SAILR. This optimization's main goal is to enable later optimizations that rely on
128
+ symbolic variables to be more effective. This optimization pass will convert two statements with a difference of
129
+ a const and a symbolic variable into two statements with the symbolic variables.
130
+
131
+ As an example:
132
+ x = 75
133
+ puts(x)
134
+ puts(75)
135
+
136
+ will be converted to:
137
+ x = 75
138
+ puts(x)
139
+ puts(x)
140
+ """
141
+
142
+ ARCHES = None
143
+ PLATFORMS = None
144
+ STAGE = OptimizationPassStage.DURING_REGION_IDENTIFICATION
145
+ NAME = "Revert Constant Propagation Optimizations"
146
+ DESCRIPTION = __doc__.strip()
147
+
148
+ def __init__(self, func, region_identifier=None, reaching_definitions=None, **kwargs):
149
+ self.ri = region_identifier
150
+ self.rd = reaching_definitions
151
+ super().__init__(func, **kwargs)
152
+
153
+ self._call_pair_targets = []
154
+ self.resolution = False
155
+ self.analyze()
156
+
157
+ def _check(self):
158
+ return True, {}
159
+
160
+ def _analyze(self, cache=None):
161
+ self.resolution = False
162
+ self.out_graph = remove_labels(self._graph)
163
+ # self.out_graph = self._graph
164
+
165
+ _pair_stmt_handlers = {
166
+ Call: self._handle_Call_pair,
167
+ Return: self._handle_Return_pair,
168
+ }
169
+
170
+ if self.out_graph is None:
171
+ return
172
+
173
+ walker = PairAILBlockWalker(self.out_graph, stmt_pair_handlers=_pair_stmt_handlers)
174
+ walker.walk()
175
+ if self._call_pair_targets:
176
+ self._analyze_call_pair_targets()
177
+
178
+ if not self.resolution:
179
+ self.out_graph = None
180
+ else:
181
+ self.out_graph = add_labels(self.out_graph)
182
+
183
+ def _analyze_call_pair_targets(self):
184
+ all_obs_points = []
185
+ for _, observation_points in self._call_pair_targets:
186
+ all_obs_points.extend(observation_points)
187
+
188
+ self.rd = self.project.analyses.ReachingDefinitions(subject=self._func, observation_points=all_obs_points)
189
+
190
+ for (call0, blk0, call1, blk1, arg_conflicts), _ in self._call_pair_targets:
191
+ # attempt to do constant resolution for each argument that differs
192
+ for i, args in arg_conflicts.items():
193
+ a0, a1 = args[:]
194
+ calls = {a0: call0, a1: call1}
195
+ blks = {call0: blk0, call1: blk1}
196
+
197
+ # we can only resolve two arguments where one is constant and one is symbolic
198
+ const_arg = None
199
+ sym_arg = None
200
+ for arg in calls:
201
+ if isinstance(arg, Const) and const_arg is None:
202
+ const_arg = arg
203
+ elif not isinstance(arg, Const) and sym_arg is None:
204
+ sym_arg = arg
205
+
206
+ if const_arg is None or sym_arg is None:
207
+ continue
208
+
209
+ unwrapped_sym_arg = sym_arg.operands[0] if isinstance(sym_arg, Convert) else sym_arg
210
+ try:
211
+ # TODO: make this support more than just Loads
212
+ # target must be a Load of a memory location
213
+ target_atom = MemoryLocation(unwrapped_sym_arg.addr.value, unwrapped_sym_arg.size, "Iend_LE")
214
+ const_state = self.rd.get_reaching_definitions_by_node(blks[calls[const_arg]].addr, OP_BEFORE)
215
+
216
+ state_load_vals = const_state.get_value_from_atom(target_atom)
217
+ except AttributeError:
218
+ continue
219
+ except KeyError:
220
+ continue
221
+
222
+ if not state_load_vals:
223
+ continue
224
+
225
+ state_vals = list(state_load_vals.values())
226
+ # the symbolic variable MUST resolve to only a single value
227
+ if len(state_vals) != 1:
228
+ continue
229
+
230
+ state_val = list(state_vals[0])[0]
231
+ if hasattr(state_val, "concrete") and state_val.concrete:
232
+ const_value = claripy.Solver().eval(state_val, 1)[0]
233
+ else:
234
+ continue
235
+
236
+ if not const_value == const_arg.value:
237
+ continue
238
+
239
+ _l.debug("Constant argument at position %d was resolved to symbolic arg %s", i, sym_arg)
240
+ const_call = calls[const_arg]
241
+ const_arg_i = const_call.args.index(const_arg)
242
+ const_call.args[const_arg_i] = sym_arg
243
+ self.resolution = True
244
+
245
+ #
246
+ # Handle Similar Returns
247
+ #
248
+
249
+ def _handle_Return_pair(self, obj0: Return, blk0: Return, obj1, blk1):
250
+ if obj0 is obj1:
251
+ return
252
+
253
+ rexp0, rexp1 = obj0.ret_exprs, obj1.ret_exprs
254
+ if rexp0 is None or rexp1 is None or len(rexp0) != len(rexp1):
255
+ return
256
+
257
+ conflicts = {
258
+ i: ret_exprs
259
+ for i, ret_exprs in enumerate(zip(rexp0, rexp1))
260
+ if hasattr(ret_exprs[0], "likes") and not ret_exprs[0].likes(ret_exprs[1])
261
+ }
262
+ # only single expr return is supported
263
+ if len(conflicts) != 1:
264
+ return
265
+
266
+ _, ret_exprs = list(conflicts.items())[0]
267
+ expr_to_blk = {ret_exprs[0]: blk0, ret_exprs[1]: blk1}
268
+ # find the expression that is symbolic
269
+ symb_expr, const_expr = None, None
270
+ for expr in ret_exprs:
271
+ unpacked_expr = expr
272
+ if isinstance(expr, Convert):
273
+ unpacked_expr = expr.operands[0]
274
+
275
+ if isinstance(unpacked_expr, Const):
276
+ const_expr = expr
277
+ elif isinstance(unpacked_expr, Call):
278
+ const_expr = expr
279
+ else:
280
+ symb_expr = expr
281
+
282
+ if symb_expr is None or const_expr is None:
283
+ return
284
+
285
+ # now we do specific cases for matching
286
+ if (
287
+ isinstance(symb_expr, Register)
288
+ and isinstance(const_expr, Call)
289
+ and isinstance(const_expr.ret_expr, Register)
290
+ ):
291
+ # Handles the following case
292
+ # B0:
293
+ # return foo(); // considered constant
294
+ # B1:
295
+ # return rax; // considered symbolic
296
+ #
297
+ # =>
298
+ #
299
+ # B0:
300
+ # rax = foo();
301
+ # return rax;
302
+ # B1:
303
+ # return rax;
304
+ #
305
+ # This is useful later for merging the return.
306
+ #
307
+ call_return_reg = const_expr.ret_expr
308
+ if symb_expr.likes(call_return_reg):
309
+ symb_return_stmt = expr_to_blk[symb_expr].statements[-1]
310
+ const_block = expr_to_blk[const_expr]
311
+
312
+ # rax = foo();
313
+ reg_assign = Assignment(None, symb_expr, const_expr, **const_expr.tags)
314
+
315
+ # construct new constant block
316
+ new_const_block = const_block.copy()
317
+ new_const_block.statements = new_const_block.statements[:-1] + [reg_assign] + [symb_return_stmt.copy()]
318
+ self._update_block(const_block, new_const_block)
319
+ self.resolution = True
320
+ else:
321
+ _l.debug("This case is not supported yet for Return de-propagation")
322
+
323
+ #
324
+ # Handle Similar Calls
325
+ #
326
+
327
+ def _handle_Call_pair(self, obj0: Call, blk0, obj1: Call, blk1):
328
+ if obj0 is obj1:
329
+ return
330
+
331
+ # verify both calls are calls to the same function
332
+ if (isinstance(obj0.target, str) or isinstance(obj1.target, str)) and obj0.target != obj1.target:
333
+ return
334
+ elif not obj0.target.likes(obj1.target):
335
+ return
336
+
337
+ call0, call1 = obj0, obj1
338
+ arg_conflicts = self.find_conflicting_call_args(call0, call1)
339
+ # if there is no conflict, then there is nothing to fix
340
+ if not arg_conflicts:
341
+ return
342
+
343
+ _l.debug(
344
+ "Found two calls at (%x, %x) that are similar. Attempting to resolve const args now...",
345
+ blk0.addr,
346
+ blk1.addr,
347
+ )
348
+
349
+ # destroy old ReachDefs, since we need a new one
350
+ observation_points = ("node", blk0.addr, OP_BEFORE), ("node", blk1.addr, OP_BEFORE)
351
+
352
+ # do full analysis after collecting all calls in _analyze
353
+ self._call_pair_targets.append(((call0, blk0, call1, blk1, arg_conflicts), observation_points))
354
+
355
+ @staticmethod
356
+ def find_conflicting_call_args(call0: Call, call1: Call):
357
+ if not call0.args or not call1.args:
358
+ return None
359
+
360
+ # TODO: update this to work for variable-arg functions
361
+ if len(call0.args) != len(call1.args):
362
+ return None
363
+
364
+ # zip args of call 0 and 1 conflict if they are not like each other
365
+ conflicts = {i: args for i, args in enumerate(zip(call0.args, call1.args)) if not args[0].likes(args[1])}
366
+
367
+ return conflicts
@@ -36,7 +36,7 @@ class DeadblockRemover(OptimizationPass):
36
36
  acyclic_graph = self._graph
37
37
  else:
38
38
  acyclic_graph = to_acyclic_graph(self._graph)
39
- cond_proc.recover_reaching_conditions(region=None, graph=acyclic_graph)
39
+ cond_proc.recover_reaching_conditions(region=None, graph=acyclic_graph, simplify_conditions=False)
40
40
 
41
41
  if not any(claripy.is_false(c) for c in cond_proc.reaching_conditions.values()):
42
42
  return False, None
@@ -11,7 +11,8 @@ from ailment.expression import Expression, BinaryOp, Const, Load
11
11
  from angr.utils.graph import GraphUtils
12
12
  from ..utils import first_nonlabel_statement, remove_last_statement
13
13
  from ..structuring.structurer_nodes import IncompleteSwitchCaseHeadStatement, SequenceNode, MultiNode
14
- from .optimization_pass import OptimizationPass, OptimizationPassStage, MultipleBlocksException
14
+ from .optimization_pass import OptimizationPassStage, MultipleBlocksException, StructuringOptimizationPass
15
+ from ..region_simplifiers.switch_cluster_simplifier import SwitchClusterFinder
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from ailment.expression import UnaryOp, Convert
@@ -130,15 +131,19 @@ class StableVarExprHasher(AILBlockWalkerBase):
130
131
  super()._handle_Convert(expr_idx, expr, stmt_idx, stmt, block)
131
132
 
132
133
 
133
- class LoweredSwitchSimplifier(OptimizationPass):
134
+ class LoweredSwitchSimplifier(StructuringOptimizationPass):
134
135
  """
135
- Recognize and simplify lowered switch-case constructs.
136
+ This optimization recognizes and reverts switch cases that have been lowered and possibly split into multiple
137
+ if-else statements. This optimization, discussed in the USENIX 2024 paper SAILR, aims to undo the compiler
138
+ optimization known as "Switch Lowering", present in both GCC and Clang. An in-depth discussion of this
139
+ optimization can be found in the paper or in our documentation of the optimization:
140
+ https://github.com/mahaloz/sailr-eval/issues/14#issue-2232616411
141
+
142
+ Note, this optimization does not occur in MSVC, which uses a different optimization strategy for switch cases.
143
+ As a hack for now, we only run this deoptimization on Linux binaries.
136
144
  """
137
145
 
138
- ARCHES = [
139
- "AMD64",
140
- ]
141
- PLATFORMS = ["linux", "windows"]
146
+ PLATFORMS = ["linux"]
142
147
  STAGE = OptimizationPassStage.DURING_REGION_IDENTIFICATION
143
148
  NAME = "Convert lowered switch-cases (if-else) to switch-cases"
144
149
  DESCRIPTION = (
@@ -147,12 +152,60 @@ class LoweredSwitchSimplifier(OptimizationPass):
147
152
  )
148
153
  STRUCTURING = ["phoenix"]
149
154
 
150
- def __init__(self, func, blocks_by_addr=None, blocks_by_addr_and_idx=None, graph=None, **kwargs):
155
+ def __init__(self, func, min_distinct_cases=2, **kwargs):
151
156
  super().__init__(
152
- func, blocks_by_addr=blocks_by_addr, blocks_by_addr_and_idx=blocks_by_addr_and_idx, graph=graph, **kwargs
157
+ func,
158
+ require_gotos=False,
159
+ prevent_new_gotos=False,
160
+ simplify_ail=False,
161
+ must_improve_rel_quality=True,
162
+ **kwargs,
153
163
  )
164
+
165
+ # this is the max number of cases that can be in a switch that can be converted to a
166
+ # if-tree (if the number of cases is greater than this, the switch will not be converted)
167
+ # https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/targhooks.cc#L1899
168
+ # TODO: add architecture specific values
169
+ default_case_values_threshold = 6
170
+ # NOTE: this means that there must be less than default_case_values for us to convert an if-tree to a switch
171
+ self._max_case_values = default_case_values_threshold
172
+
173
+ self._min_distinct_cases = min_distinct_cases
174
+
175
+ # used to determine if a switch-case construct is present in the code, useful for invalidating
176
+ # other heuristics that minimize false positives
177
+ self._switches_present_in_code = 0
178
+
154
179
  self.analyze()
155
180
 
181
+ @staticmethod
182
+ def _count_max_continuous_cases(cases: list[Case]) -> int:
183
+ if not cases: # Return 0 if the list is empty
184
+ return 0
185
+
186
+ max_len = 0
187
+ current_len = 1 # Start with 1 since a single number is a sequence of length 1
188
+ sorted_cases = sorted(cases, key=lambda c: c.value)
189
+ for i in range(1, len(sorted_cases)):
190
+ if sorted_cases[i].value == sorted_cases[i - 1].value + 1:
191
+ current_len += 1
192
+ else:
193
+ max_len = max(max_len, current_len)
194
+ current_len = 1
195
+
196
+ # Final check to include the last sequence
197
+ max_len = max(max_len, current_len)
198
+ return max_len
199
+
200
+ @staticmethod
201
+ def _count_distinct_cases(cases: list[Case]) -> int:
202
+ return len({case.target for case in cases})
203
+
204
+ def _analyze_simplified_region(self, region, initial=False):
205
+ super()._analyze_simplified_region(region, initial=initial)
206
+ finder = SwitchClusterFinder(region)
207
+ self._switches_present_in_code = len(finder.var2switches.values())
208
+
156
209
  def _check(self):
157
210
  # TODO: More filtering
158
211
  return True, None
@@ -161,7 +214,7 @@ class LoweredSwitchSimplifier(OptimizationPass):
161
214
  variablehash_to_cases = self._find_cascading_switch_variable_comparisons()
162
215
 
163
216
  if not variablehash_to_cases:
164
- return
217
+ return False
165
218
 
166
219
  graph_copy = networkx.DiGraph(self._graph)
167
220
  self.out_graph = graph_copy
@@ -169,7 +222,39 @@ class LoweredSwitchSimplifier(OptimizationPass):
169
222
 
170
223
  for _, caselists in variablehash_to_cases.items():
171
224
  for cases, redundant_nodes in caselists:
172
- original_nodes = [case.original_node for case in cases if case.value != "default"]
225
+ real_cases = [case for case in cases if case.value != "default"]
226
+ max_continuous_cases = self._count_max_continuous_cases(real_cases)
227
+
228
+ # There are a few rules used in most compilers about when to lower a switch that would otherwise
229
+ # be a jump table into either a series of if-trees or into series of bit tests.
230
+ #
231
+ # RULE 1: You only ever convert a Switch into if-stmts if there are less continuous cases
232
+ # then specified by the default_case_values_threshold, therefore we should never try to rever it
233
+ # if there is more or equal than that.
234
+ # https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/tree-switch-conversion.cc#L1406
235
+ if max_continuous_cases >= self._max_case_values:
236
+ _l.debug("Skipping switch-case conversion due to too many cases for %s", real_cases[0])
237
+ continue
238
+
239
+ # RULE 2: You only ever convert a Switch into if-stmts if at least one of the cases is not continuous.
240
+ # https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/tree-switch-conversion.cc#L1960
241
+ #
242
+ # However, we need to also consider the case where the cases we are looking at are currently a smaller
243
+ # cluster split off a non-continuous cluster. In this case, we should still convert it to a switch-case
244
+ # iff a switch-case construct is present in the code.
245
+ is_all_continuous = max_continuous_cases == len(real_cases)
246
+ if is_all_continuous and self._switches_present_in_code == 0:
247
+ _l.debug("Skipping switch-case conversion due to all cases being continuous for %s", real_cases[0])
248
+ continue
249
+
250
+ # RULE 3: It is not a real cluster if there are not enough distinct cases.
251
+ # A distinct case is a case that has a different body of code.
252
+ distinct_cases = self._count_distinct_cases(real_cases)
253
+ if distinct_cases < self._min_distinct_cases and self._switches_present_in_code == 0:
254
+ _l.debug("Skipping switch-case conversion due to too few distinct cases for %s", real_cases[0])
255
+ continue
256
+
257
+ original_nodes = [case.original_node for case in real_cases]
173
258
  original_head: Block = original_nodes[0]
174
259
  original_nodes = original_nodes[1:]
175
260
  existing_nodes_by_addr_and_idx = {(nn.addr, nn.idx): nn for nn in graph_copy}
@@ -221,7 +306,7 @@ class LoweredSwitchSimplifier(OptimizationPass):
221
306
  # would result in a successor node no longer being present in the graph
222
307
  if any(onode not in graph_copy for onode in original_nodes):
223
308
  self.out_graph = None
224
- return
309
+ return False
225
310
 
226
311
  # add edges between the head and case nodes
227
312
  for onode in original_nodes:
@@ -277,6 +362,8 @@ class LoweredSwitchSimplifier(OptimizationPass):
277
362
  else:
278
363
  graph_copy.add_edge(node_copy, succ)
279
364
 
365
+ return True
366
+
280
367
  def _find_cascading_switch_variable_comparisons(self):
281
368
  sorted_nodes = GraphUtils.quasi_topological_sort_nodes(self._graph)
282
369
  variable_comparisons = OrderedDict()