angr 9.2.111__py3-none-macosx_10_9_x86_64.whl → 9.2.113__py3-none-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of angr might be problematic. Click here for more details.
- angr/__init__.py +1 -1
- angr/analyses/cfg/cfg_base.py +4 -1
- angr/analyses/decompiler/condition_processor.py +9 -2
- angr/analyses/decompiler/optimization_passes/__init__.py +3 -1
- angr/analyses/decompiler/optimization_passes/const_prop_reverter.py +367 -0
- angr/analyses/decompiler/optimization_passes/deadblock_remover.py +1 -1
- angr/analyses/decompiler/optimization_passes/lowered_switch_simplifier.py +99 -12
- angr/analyses/decompiler/optimization_passes/optimization_pass.py +79 -9
- angr/analyses/decompiler/optimization_passes/return_duplicator_base.py +21 -0
- angr/analyses/decompiler/optimization_passes/return_duplicator_low.py +111 -9
- angr/analyses/decompiler/redundant_label_remover.py +17 -0
- angr/analyses/decompiler/seq_cf_structure_counter.py +37 -0
- angr/analyses/decompiler/structured_codegen/c.py +4 -5
- angr/analyses/decompiler/structuring/phoenix.py +3 -3
- angr/analyses/reaching_definitions/rd_state.py +2 -0
- angr/analyses/reaching_definitions/reaching_definitions.py +7 -0
- angr/angrdb/serializers/loader.py +91 -7
- angr/calling_conventions.py +11 -9
- angr/knowledge_plugins/key_definitions/live_definitions.py +5 -0
- angr/knowledge_plugins/propagations/states.py +3 -2
- angr/knowledge_plugins/variables/variable_manager.py +1 -1
- angr/lib/angr_native.dylib +0 -0
- angr/procedures/stubs/ReturnUnconstrained.py +1 -2
- angr/procedures/stubs/syscall_stub.py +1 -2
- angr/sim_type.py +354 -136
- angr/state_plugins/debug_variables.py +2 -2
- angr/state_plugins/solver.py +5 -13
- angr/storage/memory_mixins/multi_value_merger_mixin.py +13 -3
- angr/utils/orderedset.py +70 -0
- angr/vaults.py +0 -1
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/METADATA +6 -6
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/RECORD +36 -33
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/WHEEL +1 -1
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/LICENSE +0 -0
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/entry_points.txt +0 -0
- {angr-9.2.111.dist-info → angr-9.2.113.dist-info}/top_level.txt +0 -0
angr/__init__.py
CHANGED
angr/analyses/cfg/cfg_base.py
CHANGED
|
@@ -7,7 +7,6 @@ import networkx
|
|
|
7
7
|
from sortedcontainers import SortedDict
|
|
8
8
|
|
|
9
9
|
import pyvex
|
|
10
|
-
from claripy.utils.orderedset import OrderedSet
|
|
11
10
|
from cle import ELF, PE, Blob, TLSObject, MachO, ExternObject, KernelObject, FunctionHintSource, Hex, Coff, SRec, XBE
|
|
12
11
|
from cle.backends import NamedRegion
|
|
13
12
|
import archinfo
|
|
@@ -34,6 +33,7 @@ from angr.codenode import HookNode, BlockNode
|
|
|
34
33
|
from angr.engines.vex.lifter import VEX_IRSB_MAX_SIZE, VEX_IRSB_MAX_INST
|
|
35
34
|
from angr.analyses import Analysis
|
|
36
35
|
from angr.analyses.stack_pointer_tracker import StackPointerTracker
|
|
36
|
+
from angr.utils.orderedset import OrderedSet
|
|
37
37
|
from .indirect_jump_resolvers.default_resolvers import default_indirect_jump_resolvers
|
|
38
38
|
|
|
39
39
|
if TYPE_CHECKING:
|
|
@@ -746,6 +746,9 @@ class CFGBase(Analysis):
|
|
|
746
746
|
memory_regions = []
|
|
747
747
|
|
|
748
748
|
for b in binaries:
|
|
749
|
+
if not b.has_memory:
|
|
750
|
+
continue
|
|
751
|
+
|
|
749
752
|
if isinstance(b, ELF):
|
|
750
753
|
# If we have sections, we get result from sections
|
|
751
754
|
sections = []
|
|
@@ -184,7 +184,12 @@ class ConditionProcessor:
|
|
|
184
184
|
self.edge_conditions = edge_conditions
|
|
185
185
|
|
|
186
186
|
def recover_reaching_conditions(
|
|
187
|
-
self,
|
|
187
|
+
self,
|
|
188
|
+
region,
|
|
189
|
+
graph=None,
|
|
190
|
+
with_successors=False,
|
|
191
|
+
case_entry_to_switch_head: dict[int, int] | None = None,
|
|
192
|
+
simplify_conditions: bool = True,
|
|
188
193
|
):
|
|
189
194
|
"""
|
|
190
195
|
Recover the reaching conditions for each block in an acyclic graph. Note that we assume the graph that's passed
|
|
@@ -255,7 +260,9 @@ class ConditionProcessor:
|
|
|
255
260
|
reaching_condition = claripy.Or(claripy.And(pred_condition, edge_condition), reaching_condition)
|
|
256
261
|
|
|
257
262
|
if reaching_condition is not None:
|
|
258
|
-
reaching_conditions[node] =
|
|
263
|
+
reaching_conditions[node] = (
|
|
264
|
+
self.simplify_condition(reaching_condition) if simplify_conditions else reaching_condition
|
|
265
|
+
)
|
|
259
266
|
|
|
260
267
|
# My hypothesis: for nodes where two paths come together *and* those that cannot be further structured into
|
|
261
268
|
# another if-else construct (we take the short-cut by testing if the operator is an "Or" after running our
|
|
@@ -28,6 +28,7 @@ from .code_motion import CodeMotionOptimization
|
|
|
28
28
|
from .switch_default_case_duplicator import SwitchDefaultCaseDuplicator
|
|
29
29
|
from .deadblock_remover import DeadblockRemover
|
|
30
30
|
from .inlined_string_transformation_simplifier import InlinedStringTransformationSimplifier
|
|
31
|
+
from .const_prop_reverter import ConstPropOptReverter
|
|
31
32
|
|
|
32
33
|
# order matters!
|
|
33
34
|
_all_optimization_passes = [
|
|
@@ -47,7 +48,8 @@ _all_optimization_passes = [
|
|
|
47
48
|
(ReturnDuplicatorHigh, True),
|
|
48
49
|
(DeadblockRemover, True),
|
|
49
50
|
(SwitchDefaultCaseDuplicator, True),
|
|
50
|
-
(
|
|
51
|
+
(ConstPropOptReverter, True),
|
|
52
|
+
(LoweredSwitchSimplifier, True),
|
|
51
53
|
(ReturnDuplicatorLow, True),
|
|
52
54
|
(ReturnDeduplicator, True),
|
|
53
55
|
(CodeMotionOptimization, True),
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import itertools
|
|
4
|
+
|
|
5
|
+
import networkx
|
|
6
|
+
import claripy
|
|
7
|
+
from ailment import Const
|
|
8
|
+
from ailment.block_walker import AILBlockWalkerBase
|
|
9
|
+
from ailment.statement import Call, Statement, ConditionalJump, Assignment, Store, Return
|
|
10
|
+
from ailment.expression import Convert, Register
|
|
11
|
+
|
|
12
|
+
from .optimization_pass import OptimizationPass, OptimizationPassStage
|
|
13
|
+
from ..utils import remove_labels, add_labels
|
|
14
|
+
from ....knowledge_plugins.key_definitions.atoms import MemoryLocation
|
|
15
|
+
from ....knowledge_plugins.key_definitions.constants import OP_BEFORE
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_l = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PairAILBlockWalker:
|
|
22
|
+
"""
|
|
23
|
+
This AILBlockWalker will walk two blocks at a time and call a handler for each pair of statements that are
|
|
24
|
+
instances of the same type. This is useful for comparing two statements for similarity across blocks.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, graph: networkx.DiGraph, stmt_pair_handlers=None):
|
|
28
|
+
self.graph = graph
|
|
29
|
+
|
|
30
|
+
_default_stmt_handlers = {
|
|
31
|
+
Assignment: self._handle_Assignment_pair,
|
|
32
|
+
Call: self._handle_Call_pair,
|
|
33
|
+
Store: self._handle_Store_pair,
|
|
34
|
+
ConditionalJump: self._handle_ConditionalJump_pair,
|
|
35
|
+
Return: self._handle_Return_pair,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
self.stmt_pair_handlers: dict[Statement, Callable] = (
|
|
39
|
+
stmt_pair_handlers if stmt_pair_handlers else _default_stmt_handlers
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# pylint: disable=no-self-use
|
|
43
|
+
def _walk_block(self, block):
|
|
44
|
+
walked_objs = {Assignment: set(), Call: set(), Store: set(), ConditionalJump: set(), Return: set()}
|
|
45
|
+
|
|
46
|
+
# create a walker that will:
|
|
47
|
+
# 1. recursively expand a stmt with the default handler then,
|
|
48
|
+
# 2. record the stmt parts in the walked_objs dict with the overwritten handler
|
|
49
|
+
#
|
|
50
|
+
# CallExpressions are a special case that require a handler in expressions, since they are statements.
|
|
51
|
+
walker = AILBlockWalkerBase()
|
|
52
|
+
_default_stmt_handlers = {
|
|
53
|
+
Assignment: walker._handle_Assignment,
|
|
54
|
+
Call: walker._handle_Call,
|
|
55
|
+
Store: walker._handle_Store,
|
|
56
|
+
ConditionalJump: walker._handle_ConditionalJump,
|
|
57
|
+
Return: walker._handle_Return,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def _handle_ail_obj(stmt_idx, stmt, block_):
|
|
61
|
+
_default_stmt_handlers[type(stmt)](stmt_idx, stmt, block_)
|
|
62
|
+
walked_objs[type(stmt)].add(stmt)
|
|
63
|
+
|
|
64
|
+
# pylint: disable=unused-argument
|
|
65
|
+
def _handle_call_expr(expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block_):
|
|
66
|
+
walked_objs[Call].add(expr)
|
|
67
|
+
|
|
68
|
+
_stmt_handlers = {typ: _handle_ail_obj for typ in walked_objs}
|
|
69
|
+
walker.stmt_handlers = _stmt_handlers
|
|
70
|
+
walker.expr_handlers[Call] = _handle_call_expr
|
|
71
|
+
|
|
72
|
+
walker.walk(block)
|
|
73
|
+
return walked_objs
|
|
74
|
+
|
|
75
|
+
def walk(self):
|
|
76
|
+
for b0, b1 in itertools.combinations(self.graph.nodes, 2):
|
|
77
|
+
walked_obj_by_blk = {}
|
|
78
|
+
|
|
79
|
+
for blk in (b0, b1):
|
|
80
|
+
walked_obj_by_blk[blk] = self._walk_block(blk)
|
|
81
|
+
|
|
82
|
+
for typ, objs0 in walked_obj_by_blk[b0].items():
|
|
83
|
+
try:
|
|
84
|
+
handler = self.stmt_pair_handlers[typ]
|
|
85
|
+
except KeyError:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
if not objs0:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
objs1 = walked_obj_by_blk[b1][typ]
|
|
92
|
+
if not objs1:
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
for o0 in objs0:
|
|
96
|
+
for o1 in objs1:
|
|
97
|
+
handler(o0, b0, o1, b1)
|
|
98
|
+
|
|
99
|
+
#
|
|
100
|
+
# default handlers
|
|
101
|
+
#
|
|
102
|
+
|
|
103
|
+
# pylint: disable=unused-argument,no-self-use
|
|
104
|
+
def _handle_Assignment_pair(self, obj0, blk0, obj1, blk1):
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
# pylint: disable=unused-argument,no-self-use
|
|
108
|
+
def _handle_Call_pair(self, obj0, blk0, obj1, blk1):
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
# pylint: disable=unused-argument,no-self-use
|
|
112
|
+
def _handle_Store_pair(self, obj0, blk0, obj1, blk1):
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
# pylint: disable=unused-argument,no-self-use
|
|
116
|
+
def _handle_ConditionalJump_pair(self, obj0, blk0, obj1, blk1):
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
# pylint: disable=unused-argument,no-self-use
|
|
120
|
+
def _handle_Return_pair(self, obj0, blk0, obj1, blk1):
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ConstPropOptReverter(OptimizationPass):
|
|
125
|
+
"""
|
|
126
|
+
This optimization reverts the effects of constant propagation done by the compiler as discussed in the
|
|
127
|
+
USENIX 2024 paper SAILR. This optimization's main goal is to enable later optimizations that rely on
|
|
128
|
+
symbolic variables to be more effective. This optimization pass will convert two statements with a difference of
|
|
129
|
+
a const and a symbolic variable into two statements with the symbolic variables.
|
|
130
|
+
|
|
131
|
+
As an example:
|
|
132
|
+
x = 75
|
|
133
|
+
puts(x)
|
|
134
|
+
puts(75)
|
|
135
|
+
|
|
136
|
+
will be converted to:
|
|
137
|
+
x = 75
|
|
138
|
+
puts(x)
|
|
139
|
+
puts(x)
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
ARCHES = None
|
|
143
|
+
PLATFORMS = None
|
|
144
|
+
STAGE = OptimizationPassStage.DURING_REGION_IDENTIFICATION
|
|
145
|
+
NAME = "Revert Constant Propagation Optimizations"
|
|
146
|
+
DESCRIPTION = __doc__.strip()
|
|
147
|
+
|
|
148
|
+
def __init__(self, func, region_identifier=None, reaching_definitions=None, **kwargs):
|
|
149
|
+
self.ri = region_identifier
|
|
150
|
+
self.rd = reaching_definitions
|
|
151
|
+
super().__init__(func, **kwargs)
|
|
152
|
+
|
|
153
|
+
self._call_pair_targets = []
|
|
154
|
+
self.resolution = False
|
|
155
|
+
self.analyze()
|
|
156
|
+
|
|
157
|
+
def _check(self):
|
|
158
|
+
return True, {}
|
|
159
|
+
|
|
160
|
+
def _analyze(self, cache=None):
|
|
161
|
+
self.resolution = False
|
|
162
|
+
self.out_graph = remove_labels(self._graph)
|
|
163
|
+
# self.out_graph = self._graph
|
|
164
|
+
|
|
165
|
+
_pair_stmt_handlers = {
|
|
166
|
+
Call: self._handle_Call_pair,
|
|
167
|
+
Return: self._handle_Return_pair,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if self.out_graph is None:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
walker = PairAILBlockWalker(self.out_graph, stmt_pair_handlers=_pair_stmt_handlers)
|
|
174
|
+
walker.walk()
|
|
175
|
+
if self._call_pair_targets:
|
|
176
|
+
self._analyze_call_pair_targets()
|
|
177
|
+
|
|
178
|
+
if not self.resolution:
|
|
179
|
+
self.out_graph = None
|
|
180
|
+
else:
|
|
181
|
+
self.out_graph = add_labels(self.out_graph)
|
|
182
|
+
|
|
183
|
+
def _analyze_call_pair_targets(self):
|
|
184
|
+
all_obs_points = []
|
|
185
|
+
for _, observation_points in self._call_pair_targets:
|
|
186
|
+
all_obs_points.extend(observation_points)
|
|
187
|
+
|
|
188
|
+
self.rd = self.project.analyses.ReachingDefinitions(subject=self._func, observation_points=all_obs_points)
|
|
189
|
+
|
|
190
|
+
for (call0, blk0, call1, blk1, arg_conflicts), _ in self._call_pair_targets:
|
|
191
|
+
# attempt to do constant resolution for each argument that differs
|
|
192
|
+
for i, args in arg_conflicts.items():
|
|
193
|
+
a0, a1 = args[:]
|
|
194
|
+
calls = {a0: call0, a1: call1}
|
|
195
|
+
blks = {call0: blk0, call1: blk1}
|
|
196
|
+
|
|
197
|
+
# we can only resolve two arguments where one is constant and one is symbolic
|
|
198
|
+
const_arg = None
|
|
199
|
+
sym_arg = None
|
|
200
|
+
for arg in calls:
|
|
201
|
+
if isinstance(arg, Const) and const_arg is None:
|
|
202
|
+
const_arg = arg
|
|
203
|
+
elif not isinstance(arg, Const) and sym_arg is None:
|
|
204
|
+
sym_arg = arg
|
|
205
|
+
|
|
206
|
+
if const_arg is None or sym_arg is None:
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
unwrapped_sym_arg = sym_arg.operands[0] if isinstance(sym_arg, Convert) else sym_arg
|
|
210
|
+
try:
|
|
211
|
+
# TODO: make this support more than just Loads
|
|
212
|
+
# target must be a Load of a memory location
|
|
213
|
+
target_atom = MemoryLocation(unwrapped_sym_arg.addr.value, unwrapped_sym_arg.size, "Iend_LE")
|
|
214
|
+
const_state = self.rd.get_reaching_definitions_by_node(blks[calls[const_arg]].addr, OP_BEFORE)
|
|
215
|
+
|
|
216
|
+
state_load_vals = const_state.get_value_from_atom(target_atom)
|
|
217
|
+
except AttributeError:
|
|
218
|
+
continue
|
|
219
|
+
except KeyError:
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
if not state_load_vals:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
state_vals = list(state_load_vals.values())
|
|
226
|
+
# the symbolic variable MUST resolve to only a single value
|
|
227
|
+
if len(state_vals) != 1:
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
state_val = list(state_vals[0])[0]
|
|
231
|
+
if hasattr(state_val, "concrete") and state_val.concrete:
|
|
232
|
+
const_value = claripy.Solver().eval(state_val, 1)[0]
|
|
233
|
+
else:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
if not const_value == const_arg.value:
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
_l.debug("Constant argument at position %d was resolved to symbolic arg %s", i, sym_arg)
|
|
240
|
+
const_call = calls[const_arg]
|
|
241
|
+
const_arg_i = const_call.args.index(const_arg)
|
|
242
|
+
const_call.args[const_arg_i] = sym_arg
|
|
243
|
+
self.resolution = True
|
|
244
|
+
|
|
245
|
+
#
|
|
246
|
+
# Handle Similar Returns
|
|
247
|
+
#
|
|
248
|
+
|
|
249
|
+
def _handle_Return_pair(self, obj0: Return, blk0: Return, obj1, blk1):
|
|
250
|
+
if obj0 is obj1:
|
|
251
|
+
return
|
|
252
|
+
|
|
253
|
+
rexp0, rexp1 = obj0.ret_exprs, obj1.ret_exprs
|
|
254
|
+
if rexp0 is None or rexp1 is None or len(rexp0) != len(rexp1):
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
conflicts = {
|
|
258
|
+
i: ret_exprs
|
|
259
|
+
for i, ret_exprs in enumerate(zip(rexp0, rexp1))
|
|
260
|
+
if hasattr(ret_exprs[0], "likes") and not ret_exprs[0].likes(ret_exprs[1])
|
|
261
|
+
}
|
|
262
|
+
# only single expr return is supported
|
|
263
|
+
if len(conflicts) != 1:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
_, ret_exprs = list(conflicts.items())[0]
|
|
267
|
+
expr_to_blk = {ret_exprs[0]: blk0, ret_exprs[1]: blk1}
|
|
268
|
+
# find the expression that is symbolic
|
|
269
|
+
symb_expr, const_expr = None, None
|
|
270
|
+
for expr in ret_exprs:
|
|
271
|
+
unpacked_expr = expr
|
|
272
|
+
if isinstance(expr, Convert):
|
|
273
|
+
unpacked_expr = expr.operands[0]
|
|
274
|
+
|
|
275
|
+
if isinstance(unpacked_expr, Const):
|
|
276
|
+
const_expr = expr
|
|
277
|
+
elif isinstance(unpacked_expr, Call):
|
|
278
|
+
const_expr = expr
|
|
279
|
+
else:
|
|
280
|
+
symb_expr = expr
|
|
281
|
+
|
|
282
|
+
if symb_expr is None or const_expr is None:
|
|
283
|
+
return
|
|
284
|
+
|
|
285
|
+
# now we do specific cases for matching
|
|
286
|
+
if (
|
|
287
|
+
isinstance(symb_expr, Register)
|
|
288
|
+
and isinstance(const_expr, Call)
|
|
289
|
+
and isinstance(const_expr.ret_expr, Register)
|
|
290
|
+
):
|
|
291
|
+
# Handles the following case
|
|
292
|
+
# B0:
|
|
293
|
+
# return foo(); // considered constant
|
|
294
|
+
# B1:
|
|
295
|
+
# return rax; // considered symbolic
|
|
296
|
+
#
|
|
297
|
+
# =>
|
|
298
|
+
#
|
|
299
|
+
# B0:
|
|
300
|
+
# rax = foo();
|
|
301
|
+
# return rax;
|
|
302
|
+
# B1:
|
|
303
|
+
# return rax;
|
|
304
|
+
#
|
|
305
|
+
# This is useful later for merging the return.
|
|
306
|
+
#
|
|
307
|
+
call_return_reg = const_expr.ret_expr
|
|
308
|
+
if symb_expr.likes(call_return_reg):
|
|
309
|
+
symb_return_stmt = expr_to_blk[symb_expr].statements[-1]
|
|
310
|
+
const_block = expr_to_blk[const_expr]
|
|
311
|
+
|
|
312
|
+
# rax = foo();
|
|
313
|
+
reg_assign = Assignment(None, symb_expr, const_expr, **const_expr.tags)
|
|
314
|
+
|
|
315
|
+
# construct new constant block
|
|
316
|
+
new_const_block = const_block.copy()
|
|
317
|
+
new_const_block.statements = new_const_block.statements[:-1] + [reg_assign] + [symb_return_stmt.copy()]
|
|
318
|
+
self._update_block(const_block, new_const_block)
|
|
319
|
+
self.resolution = True
|
|
320
|
+
else:
|
|
321
|
+
_l.debug("This case is not supported yet for Return de-propagation")
|
|
322
|
+
|
|
323
|
+
#
|
|
324
|
+
# Handle Similar Calls
|
|
325
|
+
#
|
|
326
|
+
|
|
327
|
+
def _handle_Call_pair(self, obj0: Call, blk0, obj1: Call, blk1):
|
|
328
|
+
if obj0 is obj1:
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
# verify both calls are calls to the same function
|
|
332
|
+
if (isinstance(obj0.target, str) or isinstance(obj1.target, str)) and obj0.target != obj1.target:
|
|
333
|
+
return
|
|
334
|
+
elif not obj0.target.likes(obj1.target):
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
call0, call1 = obj0, obj1
|
|
338
|
+
arg_conflicts = self.find_conflicting_call_args(call0, call1)
|
|
339
|
+
# if there is no conflict, then there is nothing to fix
|
|
340
|
+
if not arg_conflicts:
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
_l.debug(
|
|
344
|
+
"Found two calls at (%x, %x) that are similar. Attempting to resolve const args now...",
|
|
345
|
+
blk0.addr,
|
|
346
|
+
blk1.addr,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# destroy old ReachDefs, since we need a new one
|
|
350
|
+
observation_points = ("node", blk0.addr, OP_BEFORE), ("node", blk1.addr, OP_BEFORE)
|
|
351
|
+
|
|
352
|
+
# do full analysis after collecting all calls in _analyze
|
|
353
|
+
self._call_pair_targets.append(((call0, blk0, call1, blk1, arg_conflicts), observation_points))
|
|
354
|
+
|
|
355
|
+
@staticmethod
|
|
356
|
+
def find_conflicting_call_args(call0: Call, call1: Call):
|
|
357
|
+
if not call0.args or not call1.args:
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
# TODO: update this to work for variable-arg functions
|
|
361
|
+
if len(call0.args) != len(call1.args):
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
# zip args of call 0 and 1 conflict if they are not like each other
|
|
365
|
+
conflicts = {i: args for i, args in enumerate(zip(call0.args, call1.args)) if not args[0].likes(args[1])}
|
|
366
|
+
|
|
367
|
+
return conflicts
|
|
@@ -36,7 +36,7 @@ class DeadblockRemover(OptimizationPass):
|
|
|
36
36
|
acyclic_graph = self._graph
|
|
37
37
|
else:
|
|
38
38
|
acyclic_graph = to_acyclic_graph(self._graph)
|
|
39
|
-
cond_proc.recover_reaching_conditions(region=None, graph=acyclic_graph)
|
|
39
|
+
cond_proc.recover_reaching_conditions(region=None, graph=acyclic_graph, simplify_conditions=False)
|
|
40
40
|
|
|
41
41
|
if not any(claripy.is_false(c) for c in cond_proc.reaching_conditions.values()):
|
|
42
42
|
return False, None
|
|
@@ -11,7 +11,8 @@ from ailment.expression import Expression, BinaryOp, Const, Load
|
|
|
11
11
|
from angr.utils.graph import GraphUtils
|
|
12
12
|
from ..utils import first_nonlabel_statement, remove_last_statement
|
|
13
13
|
from ..structuring.structurer_nodes import IncompleteSwitchCaseHeadStatement, SequenceNode, MultiNode
|
|
14
|
-
from .optimization_pass import
|
|
14
|
+
from .optimization_pass import OptimizationPassStage, MultipleBlocksException, StructuringOptimizationPass
|
|
15
|
+
from ..region_simplifiers.switch_cluster_simplifier import SwitchClusterFinder
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from ailment.expression import UnaryOp, Convert
|
|
@@ -130,15 +131,19 @@ class StableVarExprHasher(AILBlockWalkerBase):
|
|
|
130
131
|
super()._handle_Convert(expr_idx, expr, stmt_idx, stmt, block)
|
|
131
132
|
|
|
132
133
|
|
|
133
|
-
class LoweredSwitchSimplifier(
|
|
134
|
+
class LoweredSwitchSimplifier(StructuringOptimizationPass):
|
|
134
135
|
"""
|
|
135
|
-
|
|
136
|
+
This optimization recognizes and reverts switch cases that have been lowered and possibly split into multiple
|
|
137
|
+
if-else statements. This optimization, discussed in the USENIX 2024 paper SAILR, aims to undo the compiler
|
|
138
|
+
optimization known as "Switch Lowering", present in both GCC and Clang. An in-depth discussion of this
|
|
139
|
+
optimization can be found in the paper or in our documentation of the optimization:
|
|
140
|
+
https://github.com/mahaloz/sailr-eval/issues/14#issue-2232616411
|
|
141
|
+
|
|
142
|
+
Note, this optimization does not occur in MSVC, which uses a different optimization strategy for switch cases.
|
|
143
|
+
As a hack for now, we only run this deoptimization on Linux binaries.
|
|
136
144
|
"""
|
|
137
145
|
|
|
138
|
-
|
|
139
|
-
"AMD64",
|
|
140
|
-
]
|
|
141
|
-
PLATFORMS = ["linux", "windows"]
|
|
146
|
+
PLATFORMS = ["linux"]
|
|
142
147
|
STAGE = OptimizationPassStage.DURING_REGION_IDENTIFICATION
|
|
143
148
|
NAME = "Convert lowered switch-cases (if-else) to switch-cases"
|
|
144
149
|
DESCRIPTION = (
|
|
@@ -147,12 +152,60 @@ class LoweredSwitchSimplifier(OptimizationPass):
|
|
|
147
152
|
)
|
|
148
153
|
STRUCTURING = ["phoenix"]
|
|
149
154
|
|
|
150
|
-
def __init__(self, func,
|
|
155
|
+
def __init__(self, func, min_distinct_cases=2, **kwargs):
|
|
151
156
|
super().__init__(
|
|
152
|
-
func,
|
|
157
|
+
func,
|
|
158
|
+
require_gotos=False,
|
|
159
|
+
prevent_new_gotos=False,
|
|
160
|
+
simplify_ail=False,
|
|
161
|
+
must_improve_rel_quality=True,
|
|
162
|
+
**kwargs,
|
|
153
163
|
)
|
|
164
|
+
|
|
165
|
+
# this is the max number of cases that can be in a switch that can be converted to a
|
|
166
|
+
# if-tree (if the number of cases is greater than this, the switch will not be converted)
|
|
167
|
+
# https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/targhooks.cc#L1899
|
|
168
|
+
# TODO: add architecture specific values
|
|
169
|
+
default_case_values_threshold = 6
|
|
170
|
+
# NOTE: this means that there must be less than default_case_values for us to convert an if-tree to a switch
|
|
171
|
+
self._max_case_values = default_case_values_threshold
|
|
172
|
+
|
|
173
|
+
self._min_distinct_cases = min_distinct_cases
|
|
174
|
+
|
|
175
|
+
# used to determine if a switch-case construct is present in the code, useful for invalidating
|
|
176
|
+
# other heuristics that minimize false positives
|
|
177
|
+
self._switches_present_in_code = 0
|
|
178
|
+
|
|
154
179
|
self.analyze()
|
|
155
180
|
|
|
181
|
+
@staticmethod
|
|
182
|
+
def _count_max_continuous_cases(cases: list[Case]) -> int:
|
|
183
|
+
if not cases: # Return 0 if the list is empty
|
|
184
|
+
return 0
|
|
185
|
+
|
|
186
|
+
max_len = 0
|
|
187
|
+
current_len = 1 # Start with 1 since a single number is a sequence of length 1
|
|
188
|
+
sorted_cases = sorted(cases, key=lambda c: c.value)
|
|
189
|
+
for i in range(1, len(sorted_cases)):
|
|
190
|
+
if sorted_cases[i].value == sorted_cases[i - 1].value + 1:
|
|
191
|
+
current_len += 1
|
|
192
|
+
else:
|
|
193
|
+
max_len = max(max_len, current_len)
|
|
194
|
+
current_len = 1
|
|
195
|
+
|
|
196
|
+
# Final check to include the last sequence
|
|
197
|
+
max_len = max(max_len, current_len)
|
|
198
|
+
return max_len
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _count_distinct_cases(cases: list[Case]) -> int:
|
|
202
|
+
return len({case.target for case in cases})
|
|
203
|
+
|
|
204
|
+
def _analyze_simplified_region(self, region, initial=False):
|
|
205
|
+
super()._analyze_simplified_region(region, initial=initial)
|
|
206
|
+
finder = SwitchClusterFinder(region)
|
|
207
|
+
self._switches_present_in_code = len(finder.var2switches.values())
|
|
208
|
+
|
|
156
209
|
def _check(self):
|
|
157
210
|
# TODO: More filtering
|
|
158
211
|
return True, None
|
|
@@ -161,7 +214,7 @@ class LoweredSwitchSimplifier(OptimizationPass):
|
|
|
161
214
|
variablehash_to_cases = self._find_cascading_switch_variable_comparisons()
|
|
162
215
|
|
|
163
216
|
if not variablehash_to_cases:
|
|
164
|
-
return
|
|
217
|
+
return False
|
|
165
218
|
|
|
166
219
|
graph_copy = networkx.DiGraph(self._graph)
|
|
167
220
|
self.out_graph = graph_copy
|
|
@@ -169,7 +222,39 @@ class LoweredSwitchSimplifier(OptimizationPass):
|
|
|
169
222
|
|
|
170
223
|
for _, caselists in variablehash_to_cases.items():
|
|
171
224
|
for cases, redundant_nodes in caselists:
|
|
172
|
-
|
|
225
|
+
real_cases = [case for case in cases if case.value != "default"]
|
|
226
|
+
max_continuous_cases = self._count_max_continuous_cases(real_cases)
|
|
227
|
+
|
|
228
|
+
# There are a few rules used in most compilers about when to lower a switch that would otherwise
|
|
229
|
+
# be a jump table into either a series of if-trees or into series of bit tests.
|
|
230
|
+
#
|
|
231
|
+
# RULE 1: You only ever convert a Switch into if-stmts if there are less continuous cases
|
|
232
|
+
# then specified by the default_case_values_threshold, therefore we should never try to rever it
|
|
233
|
+
# if there is more or equal than that.
|
|
234
|
+
# https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/tree-switch-conversion.cc#L1406
|
|
235
|
+
if max_continuous_cases >= self._max_case_values:
|
|
236
|
+
_l.debug("Skipping switch-case conversion due to too many cases for %s", real_cases[0])
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
# RULE 2: You only ever convert a Switch into if-stmts if at least one of the cases is not continuous.
|
|
240
|
+
# https://github.com/gcc-mirror/gcc/blob/f9a60d575f02822852aa22513c636be38f9c63ea/gcc/tree-switch-conversion.cc#L1960
|
|
241
|
+
#
|
|
242
|
+
# However, we need to also consider the case where the cases we are looking at are currently a smaller
|
|
243
|
+
# cluster split off a non-continuous cluster. In this case, we should still convert it to a switch-case
|
|
244
|
+
# iff a switch-case construct is present in the code.
|
|
245
|
+
is_all_continuous = max_continuous_cases == len(real_cases)
|
|
246
|
+
if is_all_continuous and self._switches_present_in_code == 0:
|
|
247
|
+
_l.debug("Skipping switch-case conversion due to all cases being continuous for %s", real_cases[0])
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
# RULE 3: It is not a real cluster if there are not enough distinct cases.
|
|
251
|
+
# A distinct case is a case that has a different body of code.
|
|
252
|
+
distinct_cases = self._count_distinct_cases(real_cases)
|
|
253
|
+
if distinct_cases < self._min_distinct_cases and self._switches_present_in_code == 0:
|
|
254
|
+
_l.debug("Skipping switch-case conversion due to too few distinct cases for %s", real_cases[0])
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
original_nodes = [case.original_node for case in real_cases]
|
|
173
258
|
original_head: Block = original_nodes[0]
|
|
174
259
|
original_nodes = original_nodes[1:]
|
|
175
260
|
existing_nodes_by_addr_and_idx = {(nn.addr, nn.idx): nn for nn in graph_copy}
|
|
@@ -221,7 +306,7 @@ class LoweredSwitchSimplifier(OptimizationPass):
|
|
|
221
306
|
# would result in a successor node no longer being present in the graph
|
|
222
307
|
if any(onode not in graph_copy for onode in original_nodes):
|
|
223
308
|
self.out_graph = None
|
|
224
|
-
return
|
|
309
|
+
return False
|
|
225
310
|
|
|
226
311
|
# add edges between the head and case nodes
|
|
227
312
|
for onode in original_nodes:
|
|
@@ -277,6 +362,8 @@ class LoweredSwitchSimplifier(OptimizationPass):
|
|
|
277
362
|
else:
|
|
278
363
|
graph_copy.add_edge(node_copy, succ)
|
|
279
364
|
|
|
365
|
+
return True
|
|
366
|
+
|
|
280
367
|
def _find_cascading_switch_variable_comparisons(self):
|
|
281
368
|
sorted_nodes = GraphUtils.quasi_topological_sort_nodes(self._graph)
|
|
282
369
|
variable_comparisons = OrderedDict()
|