angr 9.2.64__py3-none-win_amd64.whl → 9.2.66__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (45) hide show
  1. angr/__init__.py +55 -2
  2. angr/analyses/calling_convention.py +4 -3
  3. angr/analyses/cfg/cfg_base.py +2 -2
  4. angr/analyses/cfg/cfg_fast.py +128 -60
  5. angr/analyses/decompiler/ail_simplifier.py +1 -2
  6. angr/analyses/decompiler/block_simplifier.py +4 -3
  7. angr/analyses/decompiler/callsite_maker.py +1 -1
  8. angr/analyses/decompiler/condition_processor.py +5 -3
  9. angr/analyses/decompiler/optimization_passes/flip_boolean_cmp.py +51 -8
  10. angr/analyses/decompiler/peephole_optimizations/__init__.py +1 -1
  11. angr/analyses/decompiler/peephole_optimizations/const_mull_a_shift.py +92 -0
  12. angr/analyses/decompiler/structured_codegen/c.py +59 -6
  13. angr/analyses/decompiler/utils.py +1 -1
  14. angr/analyses/find_objects_static.py +4 -4
  15. angr/analyses/propagator/engine_ail.py +2 -1
  16. angr/analyses/reaching_definitions/__init__.py +1 -3
  17. angr/analyses/reaching_definitions/dep_graph.py +33 -4
  18. angr/analyses/reaching_definitions/engine_ail.py +5 -6
  19. angr/analyses/reaching_definitions/engine_vex.py +6 -7
  20. angr/analyses/reaching_definitions/external_codeloc.py +0 -27
  21. angr/analyses/reaching_definitions/function_handler.py +145 -23
  22. angr/analyses/reaching_definitions/rd_initializer.py +221 -0
  23. angr/analyses/reaching_definitions/rd_state.py +95 -153
  24. angr/analyses/reaching_definitions/reaching_definitions.py +15 -3
  25. angr/calling_conventions.py +2 -2
  26. angr/code_location.py +24 -0
  27. angr/exploration_techniques/__init__.py +28 -0
  28. angr/knowledge_plugins/cfg/cfg_model.py +1 -1
  29. angr/knowledge_plugins/key_definitions/__init__.py +12 -1
  30. angr/knowledge_plugins/key_definitions/atoms.py +9 -0
  31. angr/knowledge_plugins/key_definitions/definition.py +13 -18
  32. angr/knowledge_plugins/key_definitions/live_definitions.py +350 -106
  33. angr/lib/angr_native.dll +0 -0
  34. angr/project.py +1 -1
  35. angr/sim_manager.py +15 -0
  36. angr/sim_state.py +3 -3
  37. angr/storage/memory_mixins/paged_memory/pages/multi_values.py +56 -8
  38. angr/storage/memory_object.py +3 -1
  39. angr/utils/typing.py +16 -0
  40. {angr-9.2.64.dist-info → angr-9.2.66.dist-info}/METADATA +8 -8
  41. {angr-9.2.64.dist-info → angr-9.2.66.dist-info}/RECORD +44 -42
  42. {angr-9.2.64.dist-info → angr-9.2.66.dist-info}/WHEEL +1 -1
  43. angr/analyses/decompiler/peephole_optimizations/conv_const_mull_a_shift.py +0 -75
  44. {angr-9.2.64.dist-info → angr-9.2.66.dist-info}/LICENSE +0 -0
  45. {angr-9.2.64.dist-info → angr-9.2.66.dist-info}/top_level.txt +0 -0
angr/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # pylint: disable=wildcard-import
2
2
  # pylint: disable=wrong-import-position
3
3
 
4
- __version__ = "9.2.64"
4
+ __version__ = "9.2.66"
5
5
 
6
6
  if bytes is str:
7
7
  raise Exception(
@@ -57,7 +57,7 @@ from .state_plugins.inspect import BP_BEFORE, BP_AFTER, BP_BOTH, BP_IPDB, BP_IPY
57
57
  from .state_plugins.inspect import BP
58
58
  from .state_plugins import SimStatePlugin
59
59
 
60
- from .project import *
60
+ from .project import Project, load_shellcode
61
61
  from .errors import *
62
62
  from .blade import Blade
63
63
  from .simos import SimOS
@@ -87,9 +87,62 @@ from .state_plugins.filesystem import SimMount, SimHostFilesystem
87
87
  from .state_plugins.heap import SimHeapBrk, SimHeapPTMalloc, PTChunk
88
88
  from . import concretization_strategies
89
89
  from .distributed import Server
90
+ from .knowledge_base import KnowledgeBase
90
91
 
91
92
  # for compatibility reasons
92
93
  from . import sim_manager as manager
93
94
 
94
95
  # now that we have everything loaded, re-grab the list of loggers
95
96
  loggers.load_all_loggers()
97
+
98
+ __all__ = (
99
+ "SimProcedure",
100
+ "SIM_PROCEDURES",
101
+ "SIM_LIBRARIES",
102
+ "sim_options",
103
+ "options",
104
+ "BP_BEFORE",
105
+ "BP_AFTER",
106
+ "BP_BOTH",
107
+ "BP_IPDB",
108
+ "BP_IPYTHON",
109
+ "BP",
110
+ "SimStatePlugin",
111
+ "Project",
112
+ "load_shellcode",
113
+ "Blade",
114
+ "SimOS",
115
+ "Block",
116
+ "SimulationManager",
117
+ "Analysis",
118
+ "register_analysis",
119
+ "analyses",
120
+ "knowledge_plugins",
121
+ "exploration_techniques",
122
+ "ExplorationTechnique",
123
+ "types",
124
+ "StateHierarchy",
125
+ "SimState",
126
+ "engines",
127
+ "DEFAULT_CC",
128
+ "SYSCALL_CC",
129
+ "PointerWrapper",
130
+ "SimCC",
131
+ "SimFileBase",
132
+ "SimFile",
133
+ "SimPackets",
134
+ "SimFileStream",
135
+ "SimPacketsStream",
136
+ "SimFileDescriptor",
137
+ "SimFileDescriptorDuplex",
138
+ "SimMount",
139
+ "SimHostFilesystem",
140
+ "SimHeapBrk",
141
+ "SimHeapPTMalloc",
142
+ "PTChunk",
143
+ "concretization_strategies",
144
+ "Server",
145
+ "manager",
146
+ "SimProcedures",
147
+ "KnowledgeBase",
148
+ )
@@ -10,6 +10,8 @@ from pyvex.expr import RdTmp
10
10
  from archinfo.arch_arm import is_arm_arch, ArchARMHF
11
11
  import ailment
12
12
 
13
+ from angr.code_location import ExternalCodeLocation
14
+
13
15
  from ..calling_conventions import SimFunctionArgument, SimRegArg, SimStackArg, SimCC, default_cc
14
16
  from ..sim_type import (
15
17
  SimTypeInt,
@@ -32,7 +34,6 @@ from ..knowledge_plugins.functions import Function
32
34
  from ..utils.constants import DEFAULT_STATEMENT
33
35
  from .. import SIM_PROCEDURES
34
36
  from .reaching_definitions import get_all_definitions
35
- from .reaching_definitions.external_codeloc import ExternalCodeLocation
36
37
  from . import Analysis, register_analysis, ReachingDefinitionsAnalysis
37
38
 
38
39
  if TYPE_CHECKING:
@@ -526,8 +527,8 @@ class CallingConventionAnalysis(Analysis):
526
527
  # determine if potential register and stack arguments are set
527
528
  state = rda.observed_results[("insn", call_insn_addr, OP_BEFORE)]
528
529
  defs_by_reg_offset: Dict[int, List["Definition"]] = defaultdict(list)
529
- all_reg_defs: Set["Definition"] = get_all_definitions(state.register_definitions)
530
- all_stack_defs: Set["Definition"] = get_all_definitions(state.stack_definitions)
530
+ all_reg_defs: Set["Definition"] = get_all_definitions(state.registers)
531
+ all_stack_defs: Set["Definition"] = get_all_definitions(state.stack)
531
532
  for d in all_reg_defs:
532
533
  if (
533
534
  isinstance(d.atom, Register)
@@ -8,7 +8,7 @@ from sortedcontainers import SortedDict
8
8
 
9
9
  import pyvex
10
10
  from claripy.utils.orderedset import OrderedSet
11
- from cle import ELF, PE, Blob, TLSObject, MachO, ExternObject, KernelObject, FunctionHintSource, Hex
11
+ from cle import ELF, PE, Blob, TLSObject, MachO, ExternObject, KernelObject, FunctionHintSource, Hex, Coff
12
12
  from cle.backends import NamedRegion
13
13
  import archinfo
14
14
  from archinfo.arch_soot import SootAddressDescriptor
@@ -759,7 +759,7 @@ class CFGBase(Analysis):
759
759
  tpl = (segment.min_addr, segment.max_addr + 1)
760
760
  memory_regions.append(tpl)
761
761
 
762
- elif isinstance(b, PE):
762
+ elif isinstance(b, (Coff, PE)):
763
763
  for section in b.sections:
764
764
  if section.is_executable:
765
765
  tpl = (section.min_addr, section.max_addr + 1)
@@ -4,12 +4,13 @@ import logging
4
4
  import math
5
5
  import re
6
6
  import string
7
- from typing import DefaultDict, List, Set, Dict, Optional
7
+ from typing import DefaultDict, List, Set, Dict, Optional, Tuple
8
8
  from collections import defaultdict, OrderedDict
9
9
  from enum import Enum, unique
10
10
 
11
11
  import networkx
12
12
  from sortedcontainers import SortedDict
13
+ import capstone
13
14
 
14
15
  import claripy
15
16
  import cle
@@ -3845,6 +3846,84 @@ class CFGFast(ForwardAnalysis[CFGNode, CFGNode, CFGJob, int], CFGBase): # pylin
3845
3846
  pyvex.pvc.deregister_all_readonly_regions()
3846
3847
  self._ro_region_cdata_cache = None
3847
3848
 
3849
+ #
3850
+ # Initial registers
3851
+ #
3852
+
3853
+ def _get_initial_registers(self, addr, cfg_job, current_function_addr) -> Optional[List[Tuple[int, int, int]]]:
3854
+ initial_regs = None
3855
+ if self.project.arch.name in {"MIPS64", "MIPS32"}:
3856
+ initial_regs = [
3857
+ (
3858
+ self.project.arch.registers["t9"][0],
3859
+ self.project.arch.registers["t9"][1],
3860
+ current_function_addr,
3861
+ )
3862
+ ]
3863
+ if self.kb.functions.contains_addr(current_function_addr):
3864
+ func = self.kb.functions.get_by_addr(current_function_addr)
3865
+ if "gp" in func.info:
3866
+ initial_regs.append(
3867
+ (
3868
+ self.project.arch.registers["gp"][0],
3869
+ self.project.arch.registers["gp"][1],
3870
+ func.info["gp"],
3871
+ )
3872
+ )
3873
+ elif self.project.arch.name == "X86":
3874
+ # for x86 GCC-generated PIE binaries, detect calls to __x86.get_pc_thunk
3875
+ if (
3876
+ cfg_job.jumpkind == "Ijk_FakeRet"
3877
+ and cfg_job.returning_source is not None
3878
+ and self.kb.functions.contains_addr(cfg_job.returning_source)
3879
+ ):
3880
+ return_from_func = self.kb.functions.get_by_addr(cfg_job.returning_source)
3881
+ if "get_pc" in return_from_func.info:
3882
+ func = self.kb.functions.get_by_addr(current_function_addr)
3883
+ pc_reg = return_from_func.info["get_pc"]
3884
+ # the crazy thing is that GCC-generated code may adjust the register value accordingly after
3885
+ # returning! we must take into account the added offset (in the followin example, 0x8d36)
3886
+ #
3887
+ # e.g.
3888
+ # 000011A1 call __x86_get_pc_thunk_bx
3889
+ # 000011A6 add ebx, 8D36h
3890
+ #
3891
+ # this means, for the current block, the initial value of ebx is whatever __x86_get_pc_thunk_bx
3892
+ # returns. for future blocks in this function, the initial value of ebx must be the returning
3893
+ # value plus 0x8d36.
3894
+ pc_reg_offset, pc_reg_size = self.project.arch.registers[pc_reg]
3895
+ initial_regs = [(pc_reg_offset, pc_reg_size, addr)]
3896
+ # find adjustment
3897
+ adjustment = self._x86_gcc_pie_find_pc_register_adjustment(addr, pc_reg_offset)
3898
+ if adjustment is not None:
3899
+ func.info["pc_reg"] = (pc_reg, addr + adjustment)
3900
+ else:
3901
+ func.info["pc_reg"] = (pc_reg, addr)
3902
+ if self.kb.functions.contains_addr(current_function_addr):
3903
+ func = self.kb.functions.get_by_addr(current_function_addr)
3904
+ if not initial_regs and "pc_reg" in func.info:
3905
+ pc_reg, pc_reg_value = func.info["pc_reg"]
3906
+ initial_regs = [
3907
+ (
3908
+ self.project.arch.registers[pc_reg][0],
3909
+ self.project.arch.registers[pc_reg][1],
3910
+ pc_reg_value,
3911
+ )
3912
+ ]
3913
+ elif is_arm_arch(self.project.arch):
3914
+ if addr != current_function_addr and self.kb.functions.contains_addr(current_function_addr):
3915
+ func = self.kb.functions.get_by_addr(current_function_addr)
3916
+ if "constant_r4" in func.info:
3917
+ initial_regs = [
3918
+ (
3919
+ self.project.arch.registers["r4"][0],
3920
+ self.project.arch.registers["r4"][1],
3921
+ func.info["constant_r4"],
3922
+ )
3923
+ ]
3924
+
3925
+ return initial_regs
3926
+
3848
3927
  #
3849
3928
  # Other methods
3850
3929
  #
@@ -4001,65 +4080,7 @@ class CFGFast(ForwardAnalysis[CFGNode, CFGNode, CFGJob, int], CFGBase): # pylin
4001
4080
  self._cascading_remove_lifted_blocks(cfg_job.src_node.addr & 0xFFFF_FFFE)
4002
4081
  return None, None, None, None
4003
4082
 
4004
- initial_regs = None
4005
- if self.project.arch.name in {"MIPS64", "MIPS32"}:
4006
- initial_regs = [
4007
- (
4008
- self.project.arch.registers["t9"][0],
4009
- self.project.arch.registers["t9"][1],
4010
- current_function_addr,
4011
- )
4012
- ]
4013
- if self.kb.functions.contains_addr(current_function_addr):
4014
- func = self.kb.functions.get_by_addr(current_function_addr)
4015
- if "gp" in func.info:
4016
- initial_regs.append(
4017
- (
4018
- self.project.arch.registers["gp"][0],
4019
- self.project.arch.registers["gp"][1],
4020
- func.info["gp"],
4021
- )
4022
- )
4023
- elif self.project.arch.name == "X86":
4024
- # for x86 GCC-generated PIE binaries, detect calls to __x86.get_pc_thunk
4025
- if (
4026
- cfg_job.jumpkind == "Ijk_FakeRet"
4027
- and cfg_job.returning_source is not None
4028
- and self.kb.functions.contains_addr(cfg_job.returning_source)
4029
- ):
4030
- return_from_func = self.kb.functions.get_by_addr(cfg_job.returning_source)
4031
- if "get_pc" in return_from_func.info:
4032
- func = self.kb.functions.get_by_addr(current_function_addr)
4033
- pc_reg = return_from_func.info["get_pc"]
4034
- # the crazy thing is that GCC-generated code may adjust the register value accordingly after
4035
- # returning! we must take into account the added offset (in the followin example, 0x8d36)
4036
- #
4037
- # e.g.
4038
- # 000011A1 call __x86_get_pc_thunk_bx
4039
- # 000011A6 add ebx, 8D36h
4040
- #
4041
- # this means, for the current block, the initial value of ebx is whatever __x86_get_pc_thunk_bx
4042
- # returns. for future blocks in this function, the initial value of ebx must be the returning
4043
- # value plus 0x8d36.
4044
- pc_reg_offset, pc_reg_size = self.project.arch.registers[pc_reg]
4045
- initial_regs = [(pc_reg_offset, pc_reg_size, addr)]
4046
- # find adjustment
4047
- adjustment = self._x86_gcc_pie_find_pc_register_adjustment(addr, pc_reg_offset)
4048
- if adjustment is not None:
4049
- func.info["pc_reg"] = (pc_reg, addr + adjustment)
4050
- else:
4051
- func.info["pc_reg"] = (pc_reg, addr)
4052
- if self.kb.functions.contains_addr(current_function_addr):
4053
- func = self.kb.functions.get_by_addr(current_function_addr)
4054
- if not initial_regs and "pc_reg" in func.info:
4055
- pc_reg, pc_reg_value = func.info["pc_reg"]
4056
- initial_regs = [
4057
- (
4058
- self.project.arch.registers[pc_reg][0],
4059
- self.project.arch.registers[pc_reg][1],
4060
- pc_reg_value,
4061
- )
4062
- ]
4083
+ initial_regs = self._get_initial_registers(addr, cfg_job, current_function_addr)
4063
4084
 
4064
4085
  # Let's try to create the pyvex IRSB directly, since it's much faster
4065
4086
  nodecode = False
@@ -4367,6 +4388,53 @@ class CFGFast(ForwardAnalysis[CFGNode, CFGNode, CFGJob, int], CFGBase): # pylin
4367
4388
  self._insert_job(job)
4368
4389
  added_addrs.add(ref.data_addr)
4369
4390
 
4391
+ # detect if there are instructions that set r4 as a constant value
4392
+ if (addr & 1) == 0 and addr == func_addr and irsb.size > 0:
4393
+ # re-lift the block to get capstone access
4394
+ lifted_block = self._lift(irsb.addr, size=irsb.size, collect_data_refs=False, strict_block_end=True)
4395
+ for i in range(len(lifted_block.capstone.insns) - 1):
4396
+ insn0 = lifted_block.capstone.insns[i]
4397
+ insn1 = lifted_block.capstone.insns[i + 1]
4398
+ matched_0 = False
4399
+ matched_1 = False
4400
+ reg_dst = None
4401
+ pc_offset = None
4402
+ if insn0.mnemonic == "ldr" and len(insn0.operands) == 2:
4403
+ op0, op1 = insn0.operands
4404
+ if (
4405
+ op0.type == capstone.arm.ARM_OP_REG
4406
+ and op0.value.reg == capstone.arm.ARM_REG_R4
4407
+ and op1.type == capstone.arm.ARM_OP_MEM
4408
+ and op1.mem.base == capstone.arm.ARM_REG_PC
4409
+ and op1.mem.disp > 0
4410
+ and op1.mem.index == 0
4411
+ ):
4412
+ # ldr r4, [pc, #N]
4413
+ matched_0 = True
4414
+ reg_dst = op0.value.reg
4415
+ pc_offset = op1.value.mem.disp
4416
+ if matched_0 and insn1.mnemonic == "add" and len(insn1.operands) == 3:
4417
+ op0, op1, op2 = insn1.operands
4418
+ if (
4419
+ op0.type == capstone.arm.ARM_OP_REG
4420
+ and op0.value.reg == reg_dst
4421
+ and op1.type == capstone.arm.ARM_OP_REG
4422
+ and op1.value.reg == capstone.arm.ARM_REG_PC
4423
+ and op2.type == capstone.arm.ARM_OP_REG
4424
+ and op2.value.reg == reg_dst
4425
+ ):
4426
+ # add r4, pc, r4
4427
+ matched_1 = True
4428
+
4429
+ if matched_1:
4430
+ r4 = self.project.loader.fast_memory_load_pointer(insn0.address + 4 * 2 + pc_offset, 4)
4431
+ if r4 is not None:
4432
+ r4 += insn1.address + 4 * 2
4433
+ r4 &= 0xFFFF_FFFF
4434
+ func = self.kb.functions.get_by_addr(func_addr)
4435
+ func.info["constant_r4"] = r4
4436
+ break
4437
+
4370
4438
  elif self.project.arch.name in {"MIPS32", "MIPS64"}:
4371
4439
  func = self.kb.functions.get_by_addr(func_addr)
4372
4440
  if "gp" not in func.info and addr >= func_addr and addr - func_addr < 15 * 4:
@@ -19,8 +19,7 @@ from ailment.expression import (
19
19
  )
20
20
 
21
21
  from ...engines.light import SpOffset
22
- from ...code_location import CodeLocation
23
- from ...analyses.reaching_definitions.external_codeloc import ExternalCodeLocation
22
+ from ...code_location import CodeLocation, ExternalCodeLocation
24
23
  from ...sim_variable import SimStackVariable, SimMemoryVariable
25
24
  from ...knowledge_plugins.propagations.states import Equivalence
26
25
  from ...knowledge_plugins.key_definitions import atoms
@@ -6,10 +6,11 @@ from ailment.statement import Statement, Assignment, Call, Store, Jump
6
6
  from ailment.expression import Expression, Tmp, Load, Const, Register, Convert
7
7
  from ailment import AILBlockWalker
8
8
 
9
+ from angr.code_location import ExternalCodeLocation
10
+
9
11
  from ...engines.light.data import SpOffset
10
12
  from ...knowledge_plugins.key_definitions.constants import OP_AFTER
11
13
  from ...knowledge_plugins.key_definitions import atoms
12
- from ...analyses.reaching_definitions.external_codeloc import ExternalCodeLocation
13
14
  from ...analyses.propagator import PropagatorAnalysis
14
15
  from ...analyses.reaching_definitions import ReachingDefinitionsAnalysis
15
16
  from ...errors import SimMemoryMissingError
@@ -300,13 +301,13 @@ class BlockSimplifier(Analysis):
300
301
  defs_ = set()
301
302
  if isinstance(d.atom, atoms.Register):
302
303
  try:
303
- vs: "MultiValues" = live_defs.register_definitions.load(d.atom.reg_offset, size=d.atom.size)
304
+ vs: "MultiValues" = live_defs.registers.load(d.atom.reg_offset, size=d.atom.size)
304
305
  except SimMemoryMissingError:
305
306
  vs = None
306
307
  elif isinstance(d.atom, atoms.MemoryLocation) and isinstance(d.atom.addr, SpOffset):
307
308
  stack_addr = live_defs.stack_offset_to_stack_addr(d.atom.addr.offset)
308
309
  try:
309
- vs: "MultiValues" = live_defs.stack_definitions.load(
310
+ vs: "MultiValues" = live_defs.stack.load(
310
311
  stack_addr, size=d.atom.size, endness=d.atom.endness
311
312
  )
312
313
  except SimMemoryMissingError:
@@ -243,7 +243,7 @@ class CallSiteMaker(Analysis):
243
243
  return set()
244
244
 
245
245
  try:
246
- vs: "MultiValues" = rd.register_definitions.load(offset, size=size)
246
+ vs: "MultiValues" = rd.registers.load(offset, size=size)
247
247
  except SimMemoryMissingError:
248
248
  return set()
249
249
  values_and_defs_ = set()
@@ -68,8 +68,8 @@ def _op_with_unified_size(op, conv, operand0, operand1):
68
68
  return op(conv(operand0, nobool=True), conv(operand1, nobool=True))
69
69
 
70
70
 
71
- def _dummy_bvs(condition, condition_mapping):
72
- var = claripy.BVS("ailexpr_%s" % repr(condition), condition.bits, explicit_name=True)
71
+ def _dummy_bvs(condition, condition_mapping, name_suffix=""):
72
+ var = claripy.BVS(f"ailexpr_{repr(condition)}{name_suffix}", condition.bits, explicit_name=True)
73
73
  condition_mapping[var.args[0]] = condition
74
74
  return var
75
75
 
@@ -698,9 +698,11 @@ class ConditionProcessor:
698
698
 
699
699
  if isinstance(
700
700
  condition,
701
- (ailment.Expr.DirtyExpression, ailment.Expr.BasePointerOffset, ailment.Expr.ITE, ailment.Stmt.Call),
701
+ (ailment.Expr.DirtyExpression, ailment.Expr.BasePointerOffset, ailment.Expr.ITE),
702
702
  ):
703
703
  return _dummy_bvs(condition, self._condition_mapping)
704
+ elif isinstance(condition, ailment.Stmt.Call):
705
+ return _dummy_bvs(condition, self._condition_mapping, name_suffix=hex(condition.tags.get("ins_addr", 0)))
704
706
  elif isinstance(condition, (ailment.Expr.Load, ailment.Expr.Register)):
705
707
  # does it have a variable associated?
706
708
  if condition.variable is not None:
@@ -1,13 +1,59 @@
1
- from typing import List
1
+ # pylint:disable=arguments-renamed,too-many-boolean-expressions
2
+ from typing import List, Tuple, Any
2
3
 
3
4
  import ailment
4
5
  from ailment.expression import Op
5
6
 
6
7
  from ..structuring.structurer_nodes import ConditionNode
7
8
  from ..utils import structured_node_is_simple_return
9
+ from ..sequence_walker import SequenceWalker
8
10
  from .optimization_pass import SequenceOptimizationPass, OptimizationPassStage
9
11
 
10
12
 
13
+ class FlipBooleanWalker(SequenceWalker):
14
+ """
15
+ Walks a SequenceNode and handles every sequence.
16
+ """
17
+
18
+ def __init__(self, graph):
19
+ super().__init__()
20
+ self._graph = graph
21
+
22
+ def _handle_Sequence(self, seq_node, **kwargs):
23
+ # Type 1:
24
+ # if (cond) { ... } else { return; } --> if (!cond) { return; } else { ... }
25
+ #
26
+ # Type 2:
27
+ # if (cond) { ... } return; --> if (!cond) return; ...
28
+ type1_condition_nodes = [node for node in seq_node.nodes if isinstance(node, ConditionNode) and node.false_node]
29
+ type2_condition_nodes: List[Tuple[int, ConditionNode, Any]] = []
30
+
31
+ if len(seq_node.nodes) >= 2:
32
+ idx = len(seq_node.nodes) - 2
33
+ node = seq_node.nodes[idx]
34
+ if (
35
+ isinstance(node, ConditionNode)
36
+ and node.true_node is not None
37
+ and node.false_node is None
38
+ and idx < len(seq_node.nodes) - 1
39
+ and structured_node_is_simple_return(seq_node.nodes[idx + 1], self._graph)
40
+ and node not in type1_condition_nodes
41
+ ):
42
+ type2_condition_nodes.append((idx, node, seq_node.nodes[idx + 1]))
43
+
44
+ for node in type1_condition_nodes:
45
+ if isinstance(node.condition, Op) and structured_node_is_simple_return(node.false_node, self._graph):
46
+ node.condition = ailment.expression.negate(node.condition)
47
+ node.true_node, node.false_node = node.false_node, node.true_node
48
+
49
+ for idx, cond_node, successor in type2_condition_nodes:
50
+ cond_node.condition = ailment.expression.negate(cond_node.condition)
51
+ seq_node.nodes[idx + 1] = cond_node.true_node
52
+ cond_node.true_node = successor
53
+
54
+ return super()._handle_Sequence(seq_node, **kwargs)
55
+
56
+
11
57
  class FlipBooleanCmp(SequenceOptimizationPass):
12
58
  """
13
59
  In the scenario in which a false node has no apparent successors, flip the condition on that if-stmt.
@@ -27,12 +73,9 @@ class FlipBooleanCmp(SequenceOptimizationPass):
27
73
  self.analyze()
28
74
 
29
75
  def _check(self):
30
- condition_nodes = [node for node in self.seq.nodes if isinstance(node, ConditionNode) and node.false_node]
31
- return len(condition_nodes) > 0, condition_nodes
76
+ return bool(self.seq.nodes), None
32
77
 
33
78
  def _analyze(self, cache=None):
34
- condition_nodes: List[ConditionNode] = cache or []
35
- for node in condition_nodes:
36
- if isinstance(node.condition, Op) and structured_node_is_simple_return(node.false_node, self._graph):
37
- node.condition = ailment.expression.negate(node.condition)
38
- node.true_node, node.false_node = node.false_node, node.true_node
79
+ walker = FlipBooleanWalker(self._graph)
80
+ walker.walk(self.seq)
81
+ self.out_seq = self.seq
@@ -9,7 +9,7 @@ from .arm_cmpf import ARMCmpF
9
9
  from .bswap import Bswap
10
10
  from .coalesce_same_cascading_ifs import CoalesceSameCascadingIfs
11
11
  from .constant_derefs import ConstantDereferences
12
- from .conv_const_mull_a_shift import ConvConstMullAShift
12
+ from .const_mull_a_shift import ConstMullAShift
13
13
  from .extended_byte_and_mask import ExtendedByteAndMask
14
14
  from .remove_empty_if_body import RemoveEmptyIfBody
15
15
  from .remove_redundant_ite_branch import RemoveRedundantITEBranches
@@ -0,0 +1,92 @@
1
+ # pylint:disable=too-many-boolean-expressions
2
+ from typing import Union
3
+
4
+ from ailment.expression import Convert, BinaryOp, Const
5
+
6
+ from .base import PeepholeOptimizationExprBase
7
+
8
+
9
+ class ConstMullAShift(PeepholeOptimizationExprBase):
10
+ """
11
+ Convert expressions with right shifts into expressions with divisions.
12
+ """
13
+
14
+ __slots__ = ()
15
+
16
+ NAME = "Conv(64->32, (N * a) >> M) => a / N1"
17
+ expr_classes = (Convert, BinaryOp)
18
+
19
+ def optimize(self, expr: Union[Convert, BinaryOp]):
20
+ r = None
21
+
22
+ if isinstance(expr, Convert):
23
+ if expr.from_bits == 64 and expr.to_bits == 32:
24
+ r = self.optimize_binaryop(expr)
25
+
26
+ elif isinstance(expr, BinaryOp):
27
+ r = self.optimize_binaryop(expr)
28
+
29
+ # keep size
30
+ if r is not None and r.bits < expr.bits:
31
+ r = Convert(expr.idx, r.bits, expr.bits, False, r, **expr.tags)
32
+
33
+ return r
34
+
35
+ def optimize_binaryop(self, expr: BinaryOp):
36
+ if isinstance(expr, BinaryOp) and expr.op == "Shr" and isinstance(expr.operands[1], Const):
37
+ # (N * a) >> M ==> a / N1
38
+ inner = expr.operands[0]
39
+ if isinstance(inner, BinaryOp) and inner.op == "Mull" and isinstance(inner.operands[0], Const):
40
+ C = inner.operands[0].value
41
+ X = inner.operands[1]
42
+ V = expr.operands[1].value
43
+ ndigits = 5 if V == 32 else 6
44
+ divisor = self._check_divisor(pow(2, V), C, ndigits)
45
+ if divisor is not None:
46
+ new_const = Const(None, None, divisor, X.bits)
47
+ new_div = BinaryOp(inner.idx, "Div", [X, new_const], inner.signed, **inner.tags)
48
+ return new_div
49
+
50
+ elif isinstance(expr, BinaryOp) and expr.op in {"Add", "Sub"}:
51
+ expr0, expr1 = expr.operands
52
+ if (
53
+ isinstance(expr0, BinaryOp)
54
+ and expr0.op in {"Shr", "Sar"}
55
+ and isinstance(expr0.operands[1], Const)
56
+ and isinstance(expr1, BinaryOp)
57
+ and expr1.op in {"Shr", "Sar"}
58
+ and isinstance(expr1.operands[1], Const)
59
+ ):
60
+ if (
61
+ isinstance(expr0.operands[0], BinaryOp)
62
+ and expr0.operands[0].op in {"Mull", "Mul"}
63
+ and isinstance(expr0.operands[0].operands[1], Const)
64
+ ):
65
+ a0 = expr0.operands[0].operands[0]
66
+ a1 = expr1.operands[0]
67
+ if a0 == a1:
68
+ # (a * x >> M1) +/- (a >> M2) ==> a / N
69
+ C = expr0.operands[0].operands[1].value
70
+ X = a0
71
+ V = expr0.operands[1].value
72
+ ndigits = 5 if V == 32 else 6
73
+ divisor = self._check_divisor(pow(2, V), C, ndigits)
74
+ if divisor is not None:
75
+ new_const = Const(None, None, divisor, X.bits)
76
+ new_div = BinaryOp(
77
+ expr0.operands[0].idx,
78
+ "Div",
79
+ [X, new_const],
80
+ expr0.operands[0].signed,
81
+ **expr0.operands[0].tags,
82
+ )
83
+ # we cannot drop the convert in this case
84
+ return new_div
85
+
86
+ return None
87
+
88
+ @staticmethod
89
+ def _check_divisor(a, b, ndigits=6):
90
+ divisor_1 = 1 + (a // b)
91
+ divisor_2 = int(round(a / float(b), ndigits))
92
+ return divisor_1 if divisor_1 == divisor_2 else None