angr 9.2.87__py3-none-manylinux2014_x86_64.whl → 9.2.89__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of angr might be problematic. Click here for more details.

Files changed (248) hide show
  1. angr/__init__.py +4 -1
  2. angr/analyses/decompiler/clinic.py +16 -0
  3. angr/analyses/decompiler/decompiler.py +3 -0
  4. angr/analyses/decompiler/optimization_passes/__init__.py +5 -0
  5. angr/analyses/decompiler/optimization_passes/cross_jump_reverter.py +108 -0
  6. angr/analyses/decompiler/optimization_passes/optimization_pass.py +17 -4
  7. angr/analyses/decompiler/optimization_passes/return_duplicator.py +4 -32
  8. angr/analyses/decompiler/structured_codegen/c.py +12 -2
  9. angr/analyses/decompiler/utils.py +13 -0
  10. angr/analyses/typehoon/dfa.py +108 -0
  11. angr/analyses/typehoon/lifter.py +34 -2
  12. angr/analyses/typehoon/simple_solver.py +1043 -503
  13. angr/analyses/typehoon/translator.py +13 -4
  14. angr/analyses/typehoon/typeconsts.py +117 -36
  15. angr/analyses/typehoon/typehoon.py +31 -11
  16. angr/analyses/typehoon/typevars.py +88 -21
  17. angr/analyses/typehoon/variance.py +10 -0
  18. angr/analyses/variable_recovery/engine_ail.py +28 -9
  19. angr/analyses/variable_recovery/engine_base.py +50 -43
  20. angr/analyses/variable_recovery/variable_recovery_base.py +16 -3
  21. angr/analyses/variable_recovery/variable_recovery_fast.py +14 -5
  22. angr/exploration_techniques/tracer.py +2 -0
  23. angr/misc/autoimport.py +26 -0
  24. angr/procedures/definitions/__init__.py +32 -3
  25. angr/utils/constants.py +1 -0
  26. angr/utils/graph.py +20 -1
  27. {angr-9.2.87.dist-info → angr-9.2.89.dist-info}/METADATA +7 -6
  28. {angr-9.2.87.dist-info → angr-9.2.89.dist-info}/RECORD +32 -244
  29. angr-9.2.89.dist-info/top_level.txt +1 -0
  30. angr/procedures/definitions/ntdll.py +0 -12
  31. angr-9.2.87.dist-info/top_level.txt +0 -2
  32. tests/__init__.py +0 -0
  33. tests/analyses/__init__.py +0 -0
  34. tests/analyses/cfg/__init__.py +0 -0
  35. tests/analyses/cfg/test_cfg_clflush.py +0 -43
  36. tests/analyses/cfg/test_cfg_get_any_node.py +0 -34
  37. tests/analyses/cfg/test_cfg_manager.py +0 -32
  38. tests/analyses/cfg/test_cfg_model.py +0 -55
  39. tests/analyses/cfg/test_cfg_patching.py +0 -378
  40. tests/analyses/cfg/test_cfg_rust_got_resolution.py +0 -36
  41. tests/analyses/cfg/test_cfg_thumb_firmware.py +0 -50
  42. tests/analyses/cfg/test_cfg_vex_postprocessor.py +0 -27
  43. tests/analyses/cfg/test_cfgemulated.py +0 -634
  44. tests/analyses/cfg/test_cfgfast.py +0 -1123
  45. tests/analyses/cfg/test_cfgfast_soot.py +0 -38
  46. tests/analyses/cfg/test_const_resolver.py +0 -38
  47. tests/analyses/cfg/test_iat_resolver.py +0 -37
  48. tests/analyses/cfg/test_jumptables.py +0 -3008
  49. tests/analyses/cfg/test_noop_blocks.py +0 -54
  50. tests/analyses/cfg_slice_to_sink/__init__.py +0 -0
  51. tests/analyses/cfg_slice_to_sink/test_cfg_slice_to_sink.py +0 -93
  52. tests/analyses/cfg_slice_to_sink/test_graph.py +0 -114
  53. tests/analyses/cfg_slice_to_sink/test_transitions.py +0 -28
  54. tests/analyses/decompiler/__init__.py +0 -0
  55. tests/analyses/decompiler/test_baseptr_save_simplifier.py +0 -80
  56. tests/analyses/decompiler/test_decompiler.py +0 -3336
  57. tests/analyses/decompiler/test_peephole_optimizations.py +0 -48
  58. tests/analyses/decompiler/test_propagator_loops.py +0 -101
  59. tests/analyses/decompiler/test_structurer.py +0 -275
  60. tests/analyses/reaching_definitions/__init__.py +0 -0
  61. tests/analyses/reaching_definitions/test_dep_graph.py +0 -432
  62. tests/analyses/reaching_definitions/test_function_handler.py +0 -131
  63. tests/analyses/reaching_definitions/test_heap_allocator.py +0 -46
  64. tests/analyses/reaching_definitions/test_rd_state.py +0 -78
  65. tests/analyses/reaching_definitions/test_reachingdefinitions.py +0 -463
  66. tests/analyses/reaching_definitions/test_subject.py +0 -76
  67. tests/analyses/test_bindiff.py +0 -52
  68. tests/analyses/test_block_simplifier.py +0 -112
  69. tests/analyses/test_boyscout.py +0 -104
  70. tests/analyses/test_calling_convention_analysis.py +0 -352
  71. tests/analyses/test_callsite_maker.py +0 -60
  72. tests/analyses/test_cdg.py +0 -165
  73. tests/analyses/test_cfb.py +0 -37
  74. tests/analyses/test_class_identifier.py +0 -46
  75. tests/analyses/test_clinic.py +0 -30
  76. tests/analyses/test_codetagging.py +0 -32
  77. tests/analyses/test_constantpropagation.py +0 -88
  78. tests/analyses/test_ddg.py +0 -95
  79. tests/analyses/test_ddg_global_var_dependencies.py +0 -83
  80. tests/analyses/test_ddg_memvar_addresses.py +0 -40
  81. tests/analyses/test_disassembly.py +0 -121
  82. tests/analyses/test_find_objects_static.py +0 -35
  83. tests/analyses/test_flirt.py +0 -49
  84. tests/analyses/test_identifier.py +0 -33
  85. tests/analyses/test_init_finder.py +0 -38
  86. tests/analyses/test_proximitygraph.py +0 -31
  87. tests/analyses/test_reassembler.py +0 -295
  88. tests/analyses/test_regionidentifier.py +0 -27
  89. tests/analyses/test_slicing.py +0 -164
  90. tests/analyses/test_stack_pointer_tracker.py +0 -74
  91. tests/analyses/test_static_hooker.py +0 -28
  92. tests/analyses/test_typehoon.py +0 -55
  93. tests/analyses/test_variablerecovery.py +0 -464
  94. tests/analyses/test_vfg.py +0 -221
  95. tests/analyses/test_vtable.py +0 -31
  96. tests/analyses/test_xrefs.py +0 -77
  97. tests/common.py +0 -128
  98. tests/engines/__init__.py +0 -0
  99. tests/engines/light/__init__.py +0 -0
  100. tests/engines/light/test_data.py +0 -17
  101. tests/engines/pcode/__init__.py +0 -0
  102. tests/engines/pcode/test_emulate.py +0 -607
  103. tests/engines/pcode/test_pcode.py +0 -84
  104. tests/engines/test_actions.py +0 -27
  105. tests/engines/test_hook.py +0 -112
  106. tests/engines/test_java.py +0 -697
  107. tests/engines/test_unicorn.py +0 -518
  108. tests/engines/vex/__init__.py +0 -0
  109. tests/engines/vex/test_lifter.py +0 -124
  110. tests/engines/vex/test_vex.py +0 -574
  111. tests/exploration_techniques/__init__.py +0 -0
  112. tests/exploration_techniques/test_cacher.py +0 -45
  113. tests/exploration_techniques/test_director.py +0 -67
  114. tests/exploration_techniques/test_driller_core.py +0 -48
  115. tests/exploration_techniques/test_loop_seer.py +0 -158
  116. tests/exploration_techniques/test_memory_watcher.py +0 -46
  117. tests/exploration_techniques/test_oppologist.py +0 -65
  118. tests/exploration_techniques/test_spiller.py +0 -82
  119. tests/exploration_techniques/test_stochastic.py +0 -40
  120. tests/exploration_techniques/test_tech_builder.py +0 -61
  121. tests/exploration_techniques/test_tracer.py +0 -856
  122. tests/exploration_techniques/test_unique.py +0 -40
  123. tests/exploration_techniques/test_veritesting.py +0 -120
  124. tests/factory/__init__.py +0 -0
  125. tests/factory/block/__init__.py +0 -0
  126. tests/factory/block/test_block_cache.py +0 -33
  127. tests/factory/block/test_keystone.py +0 -106
  128. tests/factory/test_argc.py +0 -101
  129. tests/factory/test_argc_sym.py +0 -110
  130. tests/factory/test_argv.py +0 -158
  131. tests/factory/test_callable.py +0 -266
  132. tests/factory/test_windows_args.py +0 -36
  133. tests/knowledge_plugins/__init__.py +0 -0
  134. tests/knowledge_plugins/cfg/__init__.py +0 -0
  135. tests/knowledge_plugins/cfg/test_cfg_manager.py +0 -36
  136. tests/knowledge_plugins/functions/__init__.py +0 -0
  137. tests/knowledge_plugins/functions/test_function.py +0 -91
  138. tests/knowledge_plugins/functions/test_function2.py +0 -79
  139. tests/knowledge_plugins/functions/test_function_manager.py +0 -139
  140. tests/knowledge_plugins/functions/test_prototypes.py +0 -53
  141. tests/knowledge_plugins/key_definitions/__init__.py +0 -0
  142. tests/knowledge_plugins/key_definitions/test_atoms.py +0 -24
  143. tests/knowledge_plugins/key_definitions/test_environment.py +0 -126
  144. tests/knowledge_plugins/key_definitions/test_heap_address.py +0 -27
  145. tests/knowledge_plugins/key_definitions/test_live_definitions.py +0 -72
  146. tests/knowledge_plugins/test_dwarf_variables.py +0 -240
  147. tests/knowledge_plugins/test_kb_plugins.py +0 -91
  148. tests/knowledge_plugins/test_kb_plugins_dwarf.py +0 -36
  149. tests/knowledge_plugins/test_patches.py +0 -48
  150. tests/misc/__init__.py +0 -0
  151. tests/misc/test_hookset.py +0 -57
  152. tests/perf/__init__.py +0 -0
  153. tests/perf/perf_cfgemulated.py +0 -19
  154. tests/perf/perf_cfgfast.py +0 -18
  155. tests/perf/perf_concrete_execution.py +0 -41
  156. tests/perf/perf_siminspect_nop.py +0 -36
  157. tests/perf/perf_state_copy.py +0 -33
  158. tests/perf/perf_unicorn_0.py +0 -27
  159. tests/perf/perf_unicorn_1.py +0 -23
  160. tests/procedures/__init__.py +0 -0
  161. tests/procedures/glibc/__init__.py +0 -0
  162. tests/procedures/glibc/test_ctype_locale.py +0 -164
  163. tests/procedures/libc/__init__.py +0 -0
  164. tests/procedures/libc/test_fgets.py +0 -53
  165. tests/procedures/libc/test_scanf.py +0 -205
  166. tests/procedures/libc/test_sprintf.py +0 -44
  167. tests/procedures/libc/test_sscanf.py +0 -63
  168. tests/procedures/libc/test_strcasecmp.py +0 -37
  169. tests/procedures/libc/test_string.py +0 -1102
  170. tests/procedures/libc/test_strtol.py +0 -78
  171. tests/procedures/linux_kernel/__init__.py +0 -0
  172. tests/procedures/linux_kernel/test_lseek.py +0 -174
  173. tests/procedures/posix/__init__.py +0 -0
  174. tests/procedures/posix/test_chroot.py +0 -33
  175. tests/procedures/posix/test_getenv.py +0 -78
  176. tests/procedures/posix/test_pwrite_pread.py +0 -57
  177. tests/procedures/posix/test_sim_time.py +0 -46
  178. tests/procedures/posix/test_unlink.py +0 -46
  179. tests/procedures/test_project_resolve_simproc.py +0 -43
  180. tests/procedures/test_sim_procedure.py +0 -117
  181. tests/procedures/test_stub_procedure_args.py +0 -53
  182. tests/serialization/__init__.py +0 -0
  183. tests/serialization/test_db.py +0 -197
  184. tests/serialization/test_pickle.py +0 -95
  185. tests/serialization/test_serialization.py +0 -132
  186. tests/serialization/test_vault.py +0 -169
  187. tests/sim/__init__.py +0 -3
  188. tests/sim/exec_func/__init__.py +0 -0
  189. tests/sim/exec_func/test_mem_funcs.py +0 -55
  190. tests/sim/exec_func/test_str_funcs.py +0 -93
  191. tests/sim/exec_func/test_syscall_result.py +0 -39
  192. tests/sim/exec_insn/__init__.py +0 -0
  193. tests/sim/exec_insn/test_adc.py +0 -44
  194. tests/sim/exec_insn/test_ops.py +0 -83
  195. tests/sim/exec_insn/test_rcr.py +0 -26
  196. tests/sim/exec_insn/test_rol.py +0 -51
  197. tests/sim/exec_insn/test_signed_div.py +0 -34
  198. tests/sim/exec_insn/test_sqrt.py +0 -56
  199. tests/sim/options/__init__.py +0 -0
  200. tests/sim/options/test_0div.py +0 -54
  201. tests/sim/options/test_symbolic_fd.py +0 -59
  202. tests/sim/options/test_unsupported.py +0 -34
  203. tests/sim/test_accuracy.py +0 -137
  204. tests/sim/test_checkbyte.py +0 -53
  205. tests/sim/test_echo.py +0 -36
  206. tests/sim/test_fauxware.py +0 -202
  207. tests/sim/test_self_modifying_code.py +0 -65
  208. tests/sim/test_simple_api.py +0 -36
  209. tests/sim/test_simulation_manager.py +0 -147
  210. tests/sim/test_stack_alignment.py +0 -65
  211. tests/sim/test_state.py +0 -303
  212. tests/sim/test_state_customization.py +0 -54
  213. tests/sim/test_symbol_hooked_by.py +0 -49
  214. tests/simos/__init__.py +0 -0
  215. tests/simos/windows/__init__.py +0 -0
  216. tests/simos/windows/test_windows_stack_cookie.py +0 -58
  217. tests/state_plugins/__init__.py +0 -0
  218. tests/state_plugins/inspect/__init__.py +0 -0
  219. tests/state_plugins/inspect/test_inspect.py +0 -310
  220. tests/state_plugins/inspect/test_syscall_override.py +0 -90
  221. tests/state_plugins/posix/__init__.py +0 -0
  222. tests/state_plugins/posix/test_file_struct_funcs.py +0 -56
  223. tests/state_plugins/posix/test_files.py +0 -69
  224. tests/state_plugins/posix/test_posix.py +0 -72
  225. tests/state_plugins/solver/__init__.py +0 -0
  226. tests/state_plugins/solver/test_simsolver.py +0 -58
  227. tests/state_plugins/solver/test_symbolic.py +0 -153
  228. tests/state_plugins/solver/test_variable_registration.py +0 -46
  229. tests/state_plugins/test_callstack.py +0 -54
  230. tests/state_plugins/test_gdb_plugin.py +0 -35
  231. tests/state_plugins/test_multi_open_file.py +0 -47
  232. tests/state_plugins/test_symbolization.py +0 -38
  233. tests/storage/__init__.py +0 -0
  234. tests/storage/test_memory.py +0 -960
  235. tests/storage/test_memory_merge.py +0 -114
  236. tests/storage/test_memview.py +0 -205
  237. tests/storage/test_mmap.py +0 -26
  238. tests/storage/test_multivalues.py +0 -44
  239. tests/storage/test_permissions.py +0 -32
  240. tests/storage/test_ptmalloc.py +0 -291
  241. tests/storage/test_relro_perm.py +0 -49
  242. tests/test_calling_conventions.py +0 -86
  243. tests/test_types.py +0 -329
  244. tests/utils/__init__.py +0 -0
  245. tests/utils/test_graph.py +0 -41
  246. {angr-9.2.87.dist-info → angr-9.2.89.dist-info}/LICENSE +0 -0
  247. {angr-9.2.87.dist-info → angr-9.2.89.dist-info}/WHEEL +0 -0
  248. {angr-9.2.87.dist-info → angr-9.2.89.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,28 @@
1
1
  # pylint:disable=missing-class-docstring
2
- import itertools
2
+ from typing import Union, Type, Set, Dict, Optional, Tuple, List, DefaultDict
3
+ import enum
3
4
  from collections import defaultdict
4
- from typing import Union, Type, Callable
5
+ import logging
5
6
 
6
7
  import networkx
7
8
 
9
+ from angr.utils.constants import MAX_POINTSTO_BITS
8
10
  from .typevars import (
9
11
  Existence,
10
- Equivalence,
11
12
  Subtype,
13
+ Equivalence,
14
+ Add,
12
15
  TypeVariable,
13
16
  DerivedTypeVariable,
14
17
  HasField,
15
- Add,
16
- ConvertTo,
17
18
  IsArray,
19
+ TypeConstraint,
20
+ Load,
21
+ Store,
22
+ BaseLabel,
23
+ FuncIn,
24
+ FuncOut,
25
+ ConvertTo,
18
26
  )
19
27
  from .typeconsts import (
20
28
  BottomType,
@@ -29,35 +37,67 @@ from .typeconsts import (
29
37
  Pointer32,
30
38
  Pointer64,
31
39
  Struct,
40
+ Array,
41
+ Function,
32
42
  int_type,
33
- TypeVariableReference,
34
43
  )
44
+ from .variance import Variance
45
+ from .dfa import DFAConstraintSolver, EmptyEpsilonNFAError
46
+
47
+ _l = logging.getLogger(__name__)
48
+
49
+
50
+ PRIMITIVE_TYPES = {
51
+ TopType(),
52
+ Int(),
53
+ Int8(),
54
+ Int16(),
55
+ Int32(),
56
+ Int64(),
57
+ Pointer32(),
58
+ Pointer64(),
59
+ BottomType(),
60
+ Struct(),
61
+ Array(),
62
+ }
63
+
64
+ Top_ = TopType()
65
+ Int_ = Int()
66
+ Int64_ = Int64()
67
+ Int32_ = Int32()
68
+ Int16_ = Int16()
69
+ Int8_ = Int8()
70
+ Bottom_ = BottomType()
71
+ Pointer64_ = Pointer64()
72
+ Pointer32_ = Pointer32()
73
+ Struct_ = Struct()
74
+ Array_ = Array()
35
75
 
36
76
  # lattice for 64-bit binaries
37
77
  BASE_LATTICE_64 = networkx.DiGraph()
38
- BASE_LATTICE_64.add_edge(TopType, Int)
39
- BASE_LATTICE_64.add_edge(Int, Int64)
40
- BASE_LATTICE_64.add_edge(Int, Int32)
41
- BASE_LATTICE_64.add_edge(Int, Int16)
42
- BASE_LATTICE_64.add_edge(Int, Int8)
43
- BASE_LATTICE_64.add_edge(Int32, BottomType)
44
- BASE_LATTICE_64.add_edge(Int16, BottomType)
45
- BASE_LATTICE_64.add_edge(Int8, BottomType)
46
- BASE_LATTICE_64.add_edge(Int64, Pointer64)
47
- BASE_LATTICE_64.add_edge(Pointer64, BottomType)
78
+ BASE_LATTICE_64.add_edge(Top_, Int_)
79
+ BASE_LATTICE_64.add_edge(Int_, Int64_)
80
+ BASE_LATTICE_64.add_edge(Int_, Int32_)
81
+ BASE_LATTICE_64.add_edge(Int_, Int16_)
82
+ BASE_LATTICE_64.add_edge(Int_, Int8_)
83
+ BASE_LATTICE_64.add_edge(Int32_, Bottom_)
84
+ BASE_LATTICE_64.add_edge(Int16_, Bottom_)
85
+ BASE_LATTICE_64.add_edge(Int8_, Bottom_)
86
+ BASE_LATTICE_64.add_edge(Int64_, Pointer64_)
87
+ BASE_LATTICE_64.add_edge(Pointer64_, Bottom_)
48
88
 
49
89
  # lattice for 32-bit binaries
50
90
  BASE_LATTICE_32 = networkx.DiGraph()
51
- BASE_LATTICE_32.add_edge(TopType, Int)
52
- BASE_LATTICE_32.add_edge(Int, Int64)
53
- BASE_LATTICE_32.add_edge(Int, Int32)
54
- BASE_LATTICE_32.add_edge(Int, Int16)
55
- BASE_LATTICE_32.add_edge(Int, Int8)
56
- BASE_LATTICE_32.add_edge(Int32, Pointer32)
57
- BASE_LATTICE_32.add_edge(Int64, BottomType)
58
- BASE_LATTICE_32.add_edge(Pointer32, BottomType)
59
- BASE_LATTICE_32.add_edge(Int16, BottomType)
60
- BASE_LATTICE_32.add_edge(Int8, BottomType)
91
+ BASE_LATTICE_32.add_edge(Top_, Int_)
92
+ BASE_LATTICE_32.add_edge(Int_, Int64_)
93
+ BASE_LATTICE_32.add_edge(Int_, Int32_)
94
+ BASE_LATTICE_32.add_edge(Int_, Int16_)
95
+ BASE_LATTICE_32.add_edge(Int_, Int8_)
96
+ BASE_LATTICE_32.add_edge(Int32_, Pointer32_)
97
+ BASE_LATTICE_32.add_edge(Int64_, Bottom_)
98
+ BASE_LATTICE_32.add_edge(Pointer32_, Bottom_)
99
+ BASE_LATTICE_32.add_edge(Int16_, Bottom_)
100
+ BASE_LATTICE_32.add_edge(Int8_, Bottom_)
61
101
 
62
102
  BASE_LATTICES = {
63
103
  32: BASE_LATTICE_32,
@@ -65,93 +105,535 @@ BASE_LATTICES = {
65
105
  }
66
106
 
67
107
 
68
- class RecursiveType:
69
- def __init__(self, typevar, offset):
108
+ #
109
+ # Sketch
110
+ #
111
+
112
+
113
+ class SketchNodeBase:
114
+ """
115
+ The base class for nodes in a sketch.
116
+ """
117
+
118
+ __slots__ = ()
119
+
120
+
121
+ class SketchNode(SketchNodeBase):
122
+ """
123
+ Represents a node in a sketch graph.
124
+ """
125
+
126
+ __slots__ = ("typevar", "upper_bound", "lower_bound")
127
+
128
+ def __init__(self, typevar: Union[TypeVariable, DerivedTypeVariable]):
129
+ self.typevar: Union[TypeVariable, DerivedTypeVariable] = typevar
130
+ self.upper_bound = TopType()
131
+ self.lower_bound = BottomType()
132
+
133
+ def __repr__(self):
134
+ return f"{self.lower_bound} <: {self.typevar} <: {self.upper_bound}"
135
+
136
+ def __eq__(self, other):
137
+ return isinstance(other, SketchNode) and self.typevar == other.typevar
138
+
139
+ def __hash__(self):
140
+ return hash((SketchNode, self.typevar))
141
+
142
+
143
+ class RecursiveRefNode(SketchNodeBase):
144
+ """
145
+ Represents a cycle in a sketch graph.
146
+
147
+ This is equivalent to sketches.LabelNode in the reference implementation of retypd.
148
+ """
149
+
150
+ def __init__(self, target: DerivedTypeVariable):
151
+ self.target: DerivedTypeVariable = target
152
+
153
+ def __hash__(self):
154
+ return hash((RecursiveRefNode, self.target))
155
+
156
+ def __eq__(self, other):
157
+ return type(other) is RecursiveRefNode and other.target == self.target
158
+
159
+
160
+ class Sketch:
161
+ """
162
+ Describes the sketch of a type variable.
163
+ """
164
+
165
+ __slots__ = (
166
+ "graph",
167
+ "root",
168
+ "node_mapping",
169
+ "solver",
170
+ )
171
+
172
+ def __init__(self, solver: "SimpleSolver", root: TypeVariable):
173
+ self.root: SketchNode = SketchNode(root)
174
+ self.graph = networkx.DiGraph()
175
+ self.node_mapping: Dict[Union[TypeVariable, DerivedTypeVariable], SketchNodeBase] = {}
176
+ self.solver = solver
177
+
178
+ # add the root node
179
+ self.graph.add_node(self.root)
180
+ self.node_mapping[root] = self.root
181
+
182
+ def lookup(self, typevar: Union[TypeVariable, DerivedTypeVariable]) -> Optional[SketchNodeBase]:
183
+ if typevar in self.node_mapping:
184
+ return self.node_mapping[typevar]
185
+ node: Optional[SketchNodeBase] = None
186
+ if isinstance(typevar, DerivedTypeVariable):
187
+ node = self.node_mapping[SimpleSolver._to_typevar_or_typeconst(typevar.type_var)]
188
+ for label in typevar.labels:
189
+ succs = []
190
+ for _, dst, data in self.graph.out_edges(node, data=True):
191
+ if "label" in data and data["label"] == label:
192
+ succs.append(dst)
193
+ assert len(succs) <= 1
194
+ if not succs:
195
+ return None
196
+ node = succs[0]
197
+ if isinstance(node, RecursiveRefNode):
198
+ node = self.lookup(node.target)
199
+ return node
200
+
201
+ def add_edge(self, src: SketchNodeBase, dst: SketchNodeBase, label):
202
+ self.graph.add_edge(src, dst, label=label)
203
+
204
+ def add_constraint(self, constraint: TypeConstraint) -> None:
205
+ # sub <: super
206
+ if not isinstance(constraint, Subtype):
207
+ return
208
+ subtype = self.flatten_typevar(constraint.sub_type)
209
+ supertype = self.flatten_typevar(constraint.super_type)
210
+ if SimpleSolver._typevar_inside_set(subtype, PRIMITIVE_TYPES) and not SimpleSolver._typevar_inside_set(
211
+ supertype, PRIMITIVE_TYPES
212
+ ):
213
+ super_node: Optional[SketchNode] = self.lookup(supertype)
214
+ if super_node is not None:
215
+ super_node.lower_bound = self.solver.join(super_node.lower_bound, subtype)
216
+ elif SimpleSolver._typevar_inside_set(supertype, PRIMITIVE_TYPES) and not SimpleSolver._typevar_inside_set(
217
+ subtype, PRIMITIVE_TYPES
218
+ ):
219
+ sub_node: Optional[SketchNode] = self.lookup(subtype)
220
+ # assert sub_node is not None
221
+ if sub_node is not None:
222
+ sub_node.upper_bound = self.solver.meet(sub_node.upper_bound, supertype)
223
+
224
+ @staticmethod
225
+ def flatten_typevar(
226
+ derived_typevar: Union[TypeVariable, TypeConstant, DerivedTypeVariable]
227
+ ) -> Union[DerivedTypeVariable, TypeVariable, TypeConstant]: # pylint:disable=too-many-boolean-expressions
228
+ if (
229
+ isinstance(derived_typevar, DerivedTypeVariable)
230
+ and isinstance(derived_typevar.type_var, Pointer)
231
+ and SimpleSolver._typevar_inside_set(derived_typevar.type_var.basetype, PRIMITIVE_TYPES)
232
+ and len(derived_typevar.labels) == 2
233
+ and isinstance(derived_typevar.labels[0], Load)
234
+ and isinstance(derived_typevar.labels[1], HasField)
235
+ and derived_typevar.labels[1].offset == 0
236
+ and derived_typevar.labels[1].bits == MAX_POINTSTO_BITS
237
+ ):
238
+ return derived_typevar.type_var.basetype
239
+ return derived_typevar
240
+
241
+
242
+ #
243
+ # Constraint graph
244
+ #
245
+
246
+
247
+ class ConstraintGraphTag(enum.Enum):
248
+ LEFT = 0
249
+ RIGHT = 1
250
+ UNKNOWN = 2
251
+
252
+
253
+ class FORGOTTEN(enum.Enum):
254
+ PRE_FORGOTTEN = 0
255
+ POST_FORGOTTEN = 1
256
+
257
+
258
+ class ConstraintGraphNode:
259
+ __slots__ = ("typevar", "variance", "tag", "forgotten")
260
+
261
+ def __init__(
262
+ self,
263
+ typevar: Union[TypeVariable, DerivedTypeVariable],
264
+ variance: Variance,
265
+ tag: ConstraintGraphTag,
266
+ forgotten: FORGOTTEN,
267
+ ):
70
268
  self.typevar = typevar
71
- self.offset = offset
269
+ self.variance = variance
270
+ self.tag = tag
271
+ self.forgotten = forgotten
272
+
273
+ def __repr__(self):
274
+ variance_str = "CO" if self.variance == Variance.COVARIANT else "CONTRA"
275
+ if self.tag == ConstraintGraphTag.LEFT:
276
+ tag_str = "L"
277
+ elif self.tag == ConstraintGraphTag.RIGHT:
278
+ tag_str = "R"
279
+ else:
280
+ tag_str = "U"
281
+ forgotten_str = "PRE" if FORGOTTEN.PRE_FORGOTTEN else "POST"
282
+ s = f"{self.typevar}#{variance_str}.{tag_str}.{forgotten_str}"
283
+ if ":" in s:
284
+ return '"' + s + '"'
285
+ return s
286
+
287
+ def __eq__(self, other):
288
+ if not isinstance(other, ConstraintGraphNode):
289
+ return False
290
+ return (
291
+ self.typevar == other.typevar
292
+ and self.variance == other.variance
293
+ and self.tag == other.tag
294
+ and self.forgotten == other.forgotten
295
+ )
296
+
297
+ def __hash__(self):
298
+ return hash((ConstraintGraphNode, self.typevar, self.variance, self.tag, self.forgotten))
299
+
300
+ def forget_last_label(self) -> Optional[Tuple["ConstraintGraphNode", BaseLabel]]:
301
+ if isinstance(self.typevar, DerivedTypeVariable) and self.typevar.labels:
302
+ last_label = self.typevar.labels[-1]
303
+ if len(self.typevar.labels) == 1:
304
+ prefix = self.typevar.type_var
305
+ else:
306
+ prefix = DerivedTypeVariable(self.typevar.type_var, None, labels=self.typevar.labels[:-1])
307
+ if self.variance == last_label.variance:
308
+ variance = Variance.COVARIANT
309
+ else:
310
+ variance = Variance.CONTRAVARIANT
311
+ return (
312
+ ConstraintGraphNode(prefix, variance, self.tag, FORGOTTEN.PRE_FORGOTTEN),
313
+ self.typevar.labels[-1],
314
+ )
315
+ return None
316
+
317
+ def recall(self, label: BaseLabel) -> "ConstraintGraphNode":
318
+ if isinstance(self.typevar, DerivedTypeVariable):
319
+ labels = self.typevar.labels + (label,)
320
+ typevar = self.typevar.type_var
321
+ elif isinstance(self.typevar, TypeVariable):
322
+ labels = (label,)
323
+ typevar = self.typevar
324
+ elif isinstance(self.typevar, TypeConstant):
325
+ labels = (label,)
326
+ typevar = self.typevar
327
+ else:
328
+ raise TypeError(f"Unsupported type {type(self.typevar)}")
329
+ if self.variance == label.variance:
330
+ variance = Variance.COVARIANT
331
+ else:
332
+ variance = Variance.CONTRAVARIANT
333
+ if not labels:
334
+ var = typevar
335
+ else:
336
+ var = DerivedTypeVariable(typevar, None, labels=labels)
337
+ return ConstraintGraphNode(var, variance, self.tag, FORGOTTEN.PRE_FORGOTTEN)
338
+
339
+ def inverse(self) -> "ConstraintGraphNode":
340
+ if self.tag == ConstraintGraphTag.LEFT:
341
+ tag = ConstraintGraphTag.RIGHT
342
+ elif self.tag == ConstraintGraphTag.RIGHT:
343
+ tag = ConstraintGraphTag.LEFT
344
+ else:
345
+ tag = ConstraintGraphTag.UNKNOWN
346
+
347
+ if self.variance == Variance.COVARIANT:
348
+ variance = Variance.CONTRAVARIANT
349
+ else:
350
+ variance = Variance.COVARIANT
351
+
352
+ return ConstraintGraphNode(self.typevar, variance, tag, self.forgotten)
353
+
354
+ def inverse_wo_tag(self) -> "ConstraintGraphNode":
355
+ """
356
+ Invert the variance only.
357
+ """
358
+ if self.variance == Variance.COVARIANT:
359
+ variance = Variance.CONTRAVARIANT
360
+ else:
361
+ variance = Variance.COVARIANT
362
+
363
+ return ConstraintGraphNode(self.typevar, variance, self.tag, self.forgotten)
364
+
365
+
366
+ #
367
+ # The solver
368
+ #
72
369
 
73
370
 
74
371
  class SimpleSolver:
75
372
  """
76
- SimpleSolver is, literally, a simple, unification-based type constraint solver.
373
+ SimpleSolver is, by its name, a simple solver. Most of this solver is based on the (complex) simplification logic
374
+ that the retypd paper describes and the retypd re-implementation (https://github.com/GrammaTech/retypd) implements.
375
+ Additionally, we add some improvements to allow type propagation of known struct names, among a few other
376
+ improvements.
77
377
  """
78
378
 
79
- def __init__(self, bits: int, constraints):
379
+ def __init__(self, bits: int, constraints, typevars):
80
380
  if bits not in (32, 64):
81
381
  raise ValueError("Pointer size %d is not supported. Expect 32 or 64." % bits)
82
382
 
83
383
  self.bits = bits
84
- self._constraints = constraints
384
+ self._constraints: Dict[TypeVariable, Set[TypeConstraint]] = constraints
385
+ self._typevars: Set[TypeVariable] = typevars
85
386
  self._base_lattice = BASE_LATTICES[bits]
387
+ self._base_lattice_inverted = networkx.DiGraph()
388
+ for src, dst in self._base_lattice.edges:
389
+ self._base_lattice_inverted.add_edge(dst, src)
86
390
 
87
391
  #
88
392
  # Solving state
89
393
  #
90
- self._equivalence = {}
91
- self._lower_bounds = defaultdict(BottomType)
92
- self._upper_bounds = defaultdict(TopType)
93
- self._recursive_types = defaultdict(set)
94
-
95
- self.solve()
96
- self.solution = self.determine()
394
+ self._equivalence = defaultdict(dict)
395
+ for typevar in list(self._constraints):
396
+ if self._constraints[typevar]:
397
+ self._constraints[typevar] |= self._eq_constraints_from_add(typevar)
398
+ self._constraints[typevar] = self._handle_equivalence(typevar)
399
+ equ_classes, sketches, _ = self.solve()
400
+ self.solution = {}
401
+ self._solution_cache = {}
402
+ self.determine(equ_classes, sketches, self.solution)
403
+ for typevar in list(self._constraints):
404
+ self._convert_arrays(self._constraints[typevar])
97
405
 
98
406
  def solve(self):
99
- # import pprint
100
- # pprint.pprint(self._constraints)
101
-
102
- eq_constraints = self._eq_constraints_from_add()
103
- self._constraints |= eq_constraints
104
- constraints = self._handle_equivalence()
105
- subtypevars, supertypevars = self._calculate_closure(constraints)
106
- self._find_recursive_types(subtypevars)
107
- self._compute_lower_upper_bounds(subtypevars, supertypevars)
108
- self._lower_struct_fields()
109
- self._convert_arrays(constraints)
110
- # import pprint
111
- # print("Lower bounds")
112
- # pprint.pprint(self._lower_bounds)
113
- # print("Upper bounds")
114
- # pprint.pprint(self._upper_bounds)
115
-
116
- def determine(self):
117
- solution = {}
118
-
119
- for v in self._lower_bounds:
120
- if isinstance(v, TypeVariable) and not isinstance(v, DerivedTypeVariable):
121
- lb = self._lower_bounds[v]
122
- if isinstance(lb, BottomType):
123
- # use its upper bound instead
124
- solution[v] = self._upper_bounds[v]
125
- else:
126
- solution[v] = lb
407
+ """
408
+ Steps:
127
409
 
128
- for v in self._upper_bounds:
129
- if v not in solution:
130
- ub = self._upper_bounds[v]
131
- if not isinstance(ub, TopType):
132
- solution[v] = ub
410
+ For each type variable,
411
+ - Infer the shape in its sketch
412
+ - Build the constraint graph
413
+ - Collect all constraints
414
+ - Apply constraints to derive the lower and upper bounds
415
+ """
133
416
 
134
- for v, e in self._equivalence.items():
135
- if v not in solution:
136
- solution[v] = solution.get(e, None)
417
+ typevars = set(self._constraints) | self._typevars
418
+ constraints = set()
419
+ for tv in typevars:
420
+ if tv in self._constraints:
421
+ constraints |= self._constraints[tv]
422
+ equivalence_classes, sketches = self.infer_shapes(typevars, constraints)
423
+ # TODO: Handle global variables
137
424
 
138
- # import pprint
139
- # print("Lower bounds")
140
- # pprint.pprint(self._lower_bounds)
141
- # print("Upper bounds")
142
- # pprint.pprint(self._upper_bounds)
143
- # print("Solution")
144
- # pprint.pprint(solution)
145
- return solution
425
+ type_schemes = constraints
426
+
427
+ for tv in typevars:
428
+ primitive_constraints = self._generate_primitive_constraints(type_schemes, {tv})
429
+ for primitive_constraint in primitive_constraints:
430
+ sketches[tv].add_constraint(primitive_constraint)
431
+
432
+ return equivalence_classes, sketches, type_schemes
433
+
434
+ def infer_shapes(
435
+ self, typevars: Set[TypeVariable], constraints: Set[TypeConstraint]
436
+ ) -> Tuple[Dict, Dict[TypeVariable, Sketch]]:
437
+ """
438
+ Computing sketches from constraint sets. Implements Algorithm E.1 in the retypd paper.
439
+ """
146
440
 
147
- def _handle_equivalence(self):
441
+ equivalence_classes, quotient_graph = self.compute_quotient_graph(constraints)
442
+
443
+ sketches: Dict[TypeVariable, Sketch] = {}
444
+ for tv in typevars:
445
+ sketches[tv] = Sketch(self, tv)
446
+
447
+ for tv, sketch in sketches.items():
448
+ sketch_node = sketch.lookup(tv)
449
+ graph_node = equivalence_classes.get(tv, None)
450
+ # assert graph_node is not None
451
+ if graph_node is None:
452
+ continue
453
+ visited = {graph_node: sketch_node}
454
+ self._get_all_paths(quotient_graph, sketch, graph_node, visited)
455
+ return equivalence_classes, sketches
456
+
457
+ def compute_quotient_graph(self, constraints: Set[TypeConstraint]):
458
+ """
459
+ Compute the quotient graph (the constraint graph modulo ~ in Algorithm E.1 in the retypd paper) with respect to
460
+ a given set of type constraints.
461
+ """
462
+
463
+ g = networkx.DiGraph()
464
+ # collect all derived type variables
465
+ typevars = self._typevars_from_constraints(constraints)
466
+ g.add_nodes_from(typevars)
467
+ # add paths for each derived type variable into the graph
468
+ for tv in typevars:
469
+ last_node = tv
470
+ prefix = tv
471
+ while isinstance(prefix, DerivedTypeVariable) and prefix.labels:
472
+ prefix = prefix.longest_prefix()
473
+ if prefix is None:
474
+ continue
475
+ g.add_edge(prefix, last_node, label=last_node.labels[-1])
476
+ last_node = prefix
477
+
478
+ # compute the constraint graph modulo ~
479
+ equivalence_classes = {node: node for node in g}
480
+
481
+ load = Load()
482
+ store = Store()
483
+ for node in g.nodes:
484
+ lbl_to_node = {}
485
+ for succ in g.successors(node):
486
+ lbl_to_node[succ.labels[-1]] = succ
487
+ if load in lbl_to_node and store in lbl_to_node:
488
+ self._unify(equivalence_classes, lbl_to_node[load], lbl_to_node[store], g)
489
+
490
+ for constraint in constraints:
491
+ if isinstance(constraint, Subtype):
492
+ if self._typevar_inside_set(constraint.super_type, PRIMITIVE_TYPES) or self._typevar_inside_set(
493
+ constraint.sub_type, PRIMITIVE_TYPES
494
+ ):
495
+ continue
496
+ self._unify(equivalence_classes, constraint.super_type, constraint.sub_type, g)
497
+
498
+ out_graph = networkx.MultiDiGraph() # there can be multiple edges between two nodes, each edge is associated
499
+ # with a different label
500
+ for src, dst, data in g.edges(data=True):
501
+ src_cls = equivalence_classes[src]
502
+ dst_cls = equivalence_classes[dst]
503
+ label = None if not data else data["label"]
504
+ if label is not None and out_graph.has_edge(src_cls, dst_cls):
505
+ # do not add the same edge twice
506
+ existing_labels = {
507
+ data_["label"]
508
+ for _, dst_cls_, data_ in out_graph.out_edges(src_cls, data=True)
509
+ if dst_cls_ == dst_cls and data
510
+ }
511
+ if label in existing_labels:
512
+ continue
513
+ out_graph.add_edge(src_cls, dst_cls, label=label)
514
+
515
+ return equivalence_classes, out_graph
516
+
517
+ def _generate_primitive_constraints(
518
+ self, constraints: Set[TypeConstraint], non_primitive_endpoints: Set[Union[TypeVariable, DerivedTypeVariable]]
519
+ ) -> Set[TypeConstraint]:
520
+ # FIXME: Extract interesting variables
521
+ constraint_graph = self._generate_constraint_graph(constraints, non_primitive_endpoints | PRIMITIVE_TYPES)
522
+ constraints_0 = self._solve_constraints_between(constraint_graph, non_primitive_endpoints, PRIMITIVE_TYPES)
523
+ constraints_1 = self._solve_constraints_between(constraint_graph, PRIMITIVE_TYPES, non_primitive_endpoints)
524
+ return constraints_0 | constraints_1
525
+
526
+ @staticmethod
527
+ def _typevars_from_constraints(constraints: Set[TypeConstraint]) -> Set[Union[TypeVariable, DerivedTypeVariable]]:
528
+ """
529
+ Collect derived type variables from a set of constraints.
530
+ """
531
+
532
+ typevars: Set[Union[TypeVariable, DerivedTypeVariable]] = set()
533
+ for constraint in constraints:
534
+ if isinstance(constraint, Subtype):
535
+ typevars.add(constraint.sub_type)
536
+ typevars.add(constraint.super_type)
537
+ # TODO: Other types of constraints?
538
+ return typevars
539
+
540
+ @staticmethod
541
+ def _get_all_paths(
542
+ graph: networkx.DiGraph,
543
+ sketch: Sketch,
544
+ node: DerivedTypeVariable,
545
+ visited: Dict[Union[TypeVariable, DerivedTypeVariable], SketchNode],
546
+ ):
547
+ if node not in graph:
548
+ return
549
+ curr_node = visited[node]
550
+ for _, succ, data in graph.out_edges(node, data=True):
551
+ label = data["label"]
552
+ if succ not in visited:
553
+ if isinstance(curr_node.typevar, DerivedTypeVariable):
554
+ base_typevar = curr_node.typevar.type_var
555
+ labels = curr_node.typevar.labels
556
+ elif isinstance(curr_node.typevar, TypeVariable):
557
+ base_typevar = curr_node.typevar
558
+ labels = ()
559
+ else:
560
+ raise TypeError("Unexpected")
561
+ labels += (label,)
562
+ succ_derived_typevar = DerivedTypeVariable(
563
+ base_typevar,
564
+ None,
565
+ labels=labels,
566
+ )
567
+ succ_node = SketchNode(succ_derived_typevar)
568
+ sketch.add_edge(curr_node, succ_node, label)
569
+ visited[succ] = succ_node
570
+ SimpleSolver._get_all_paths(graph, sketch, succ, visited)
571
+ del visited[succ]
572
+ else:
573
+ # a cycle exists
574
+ ref_node = RecursiveRefNode(visited[succ].typevar)
575
+ sketch.add_edge(curr_node, ref_node, label)
576
+
577
+ @staticmethod
578
+ def _unify(
579
+ equivalence_classes: Dict, cls0: DerivedTypeVariable, cls1: DerivedTypeVariable, graph: networkx.DiGraph
580
+ ) -> None:
581
+ # first convert cls0 and cls1 to their equivalence classes
582
+ cls0 = equivalence_classes[cls0]
583
+ cls1 = equivalence_classes[cls1]
584
+
585
+ # unify if needed
586
+ if cls0 != cls1:
587
+ # MakeEquiv
588
+ existing_elements = {key for key, item in equivalence_classes.items() if item in {cls0, cls1}}
589
+ rep_cls = cls0
590
+ for elem in existing_elements:
591
+ equivalence_classes[elem] = rep_cls
592
+ # the logic below refers to the retypd reference implementation. it is different from Algorithm E.1
593
+ # note that graph is used read-only in this method, so we do not need to make copy of edges
594
+ for _, dst0, data0 in graph.out_edges(cls0, data=True):
595
+ if "label" in data0 and data0["label"] is not None:
596
+ for _, dst1, data1 in graph.out_edges(cls1, data=True):
597
+ if (
598
+ data0["label"] == data1["label"]
599
+ or isinstance(data0["label"], Load)
600
+ and isinstance(data1["label"], Store)
601
+ ):
602
+ SimpleSolver._unify(
603
+ equivalence_classes, equivalence_classes[dst0], equivalence_classes[dst1], graph
604
+ )
605
+
606
+ def _eq_constraints_from_add(self, typevar: TypeVariable):
607
+ """
608
+ Handle Add constraints.
609
+ """
610
+ new_constraints = set()
611
+ for constraint in self._constraints[typevar]:
612
+ if isinstance(constraint, Add):
613
+ if (
614
+ isinstance(constraint.type_0, TypeVariable)
615
+ and not isinstance(constraint.type_0, DerivedTypeVariable)
616
+ and isinstance(constraint.type_r, TypeVariable)
617
+ and not isinstance(constraint.type_r, DerivedTypeVariable)
618
+ ):
619
+ new_constraints.add(Equivalence(constraint.type_0, constraint.type_r))
620
+ if (
621
+ isinstance(constraint.type_1, TypeVariable)
622
+ and not isinstance(constraint.type_1, DerivedTypeVariable)
623
+ and isinstance(constraint.type_r, TypeVariable)
624
+ and not isinstance(constraint.type_r, DerivedTypeVariable)
625
+ ):
626
+ new_constraints.add(Equivalence(constraint.type_1, constraint.type_r))
627
+ return new_constraints
628
+
629
+ def _handle_equivalence(self, typevar: TypeVariable):
148
630
  graph = networkx.Graph()
149
631
 
150
632
  replacements = {}
151
633
  constraints = set()
152
634
 
153
635
  # collect equivalence relations
154
- for constraint in self._constraints:
636
+ for constraint in self._constraints[typevar]:
155
637
  if isinstance(constraint, Equivalence):
156
638
  # | type_a == type_b
157
639
  # we apply unification and removes one of them
@@ -173,7 +655,7 @@ class SimpleSolver:
173
655
  replacements[tv] = representative
174
656
 
175
657
  # replace
176
- for constraint in self._constraints:
658
+ for constraint in self._constraints[typevar]:
177
659
  if isinstance(constraint, Existence):
178
660
  replaced, new_constraint = constraint.replace(replacements)
179
661
 
@@ -201,248 +683,14 @@ class SimpleSolver:
201
683
  self._equivalence = replacements
202
684
  return constraints
203
685
 
204
- def _eq_constraints_from_add(self):
205
- """
206
- Handle Add constraints.
207
- """
208
- new_constraints = set()
209
- for constraint in self._constraints:
210
- if isinstance(constraint, Add):
211
- if (
212
- isinstance(constraint.type_0, TypeVariable)
213
- and not isinstance(constraint.type_0, DerivedTypeVariable)
214
- and isinstance(constraint.type_r, TypeVariable)
215
- and not isinstance(constraint.type_r, DerivedTypeVariable)
216
- ):
217
- new_constraints.add(Equivalence(constraint.type_0, constraint.type_r))
218
- if (
219
- isinstance(constraint.type_1, TypeVariable)
220
- and not isinstance(constraint.type_1, DerivedTypeVariable)
221
- and isinstance(constraint.type_r, TypeVariable)
222
- and not isinstance(constraint.type_r, DerivedTypeVariable)
223
- ):
224
- new_constraints.add(Equivalence(constraint.type_1, constraint.type_r))
225
- return new_constraints
226
-
227
- def _pointer_class(self) -> Union[Type[Pointer32], Type[Pointer64]]:
228
- if self.bits == 32:
229
- return Pointer32
230
- elif self.bits == 64:
231
- return Pointer64
232
- raise NotImplementedError("Unsupported bits %d" % self.bits)
233
-
234
- def _calculate_closure(self, constraints):
235
- ptr_class = self._pointer_class()
236
-
237
- # a mapping from type variables to all the variables which are {super,sub}types of them
238
- subtypevars = defaultdict(set) # {k: {v}}: v <: k
239
- supertypevars = defaultdict(set) # {k: {v}}: k <: v
240
-
241
- constraints = set(constraints) # make a copy
242
-
243
- while constraints:
244
- constraint = constraints.pop()
245
-
246
- if isinstance(constraint, Existence):
247
- # has a derived type
248
- if isinstance(constraint.type_, DerivedTypeVariable):
249
- # handle label
250
- if isinstance(constraint.type_.label, HasField):
251
- # the original variable is a pointer
252
- v = constraint.type_.type_var.type_var
253
- if isinstance(v, TypeVariable):
254
- subtypevars[v].add(
255
- ptr_class(
256
- Struct(
257
- fields={
258
- constraint.type_.label.offset: int_type(constraint.type_.label.bits),
259
- }
260
- )
261
- )
262
- )
263
-
264
- elif isinstance(constraint, Subtype):
265
- # subtype <: supertype
266
-
267
- subtype, supertype = constraint.sub_type, constraint.super_type
268
-
269
- if isinstance(supertype, TypeVariable):
270
- if subtype not in subtypevars[supertype]:
271
- if supertype is not subtype:
272
- subtypevars[supertype].add(subtype)
273
- for s in supertypevars[subtype]:
274
- # re-add impacted constraints
275
- constraints.add(Subtype(subtype, s))
276
-
277
- if subtype in subtypevars:
278
- for v in subtypevars[subtype]:
279
- if v not in subtypevars[supertype]:
280
- if supertype is not v:
281
- subtypevars[supertype].add(v)
282
- for sup in supertypevars[v]:
283
- constraints.add(Subtype(subtype, sup))
284
-
285
- if isinstance(subtype, TypeVariable):
286
- if supertype not in supertypevars[subtype]:
287
- if subtype is not supertype:
288
- supertypevars[subtype].add(supertype)
289
- for s in subtypevars[supertype]:
290
- # re-add impacted constraints
291
- constraints.add(Subtype(s, supertype))
292
-
293
- if supertype in supertypevars:
294
- for v in supertypevars[supertype]:
295
- if v not in supertypevars[subtype]:
296
- if v is not subtype:
297
- supertypevars[subtype].add(v)
298
- for sup in supertypevars[v]:
299
- constraints.add(Subtype(subtype, sup))
300
-
301
- elif isinstance(constraint, Equivalence):
302
- raise Exception("Shouldn't exist anymore.")
303
-
304
- else:
305
- raise NotImplementedError("Unsupported instance type %s." % type(constraint))
306
-
307
- # import pprint
308
- # print("Subtype vars")
309
- # pprint.pprint(subtypevars)
310
- # print("Supertype vars")
311
- # pprint.pprint(supertypevars)
312
-
313
- return subtypevars, supertypevars
314
-
315
- def _find_recursive_types(self, subtypevars):
316
- ptr_class = self._pointer_class()
317
-
318
- for var in list(subtypevars.keys()):
319
- sts = subtypevars[var].copy()
320
- if isinstance(var, DerivedTypeVariable) and isinstance(var.label, HasField):
321
- for subtype_var in sts:
322
- if var.type_var.type_var == subtype_var:
323
- subtypevars[subtype_var].add(
324
- ptr_class(Struct({var.label.offset: TypeVariableReference(subtype_var)}))
325
- )
326
- self._recursive_types[subtype_var].add(var.label.offset)
327
-
328
- def _get_lower_bound(self, v):
329
- if isinstance(v, TypeConstant):
330
- return v
331
- return self._lower_bounds[v]
332
-
333
- def _get_upper_bound(self, v):
334
- if isinstance(v, TypeConstant):
335
- return v
336
- if v in self._upper_bounds:
337
- return self._upper_bounds[v]
338
-
339
- # try to compute it
340
- if isinstance(v, DerivedTypeVariable):
341
- if isinstance(v.label, ConvertTo):
342
- # after integer conversion,
343
- ub = int_type(v.label.to_bits)
344
- if ub is not None:
345
- self._upper_bounds[v] = ub
346
- elif isinstance(v.label, HasField):
347
- ub = int_type(v.label.bits)
348
- if ub is not None:
349
- self._upper_bounds[v] = ub
350
-
351
- # if all that failed, let the defaultdict generate a Top
352
- return self._upper_bounds[v]
353
-
354
- def _compute_lower_upper_bounds(self, subtypevars, supertypevars):
355
- # compute the least upper bound for each type variable
356
- for typevar, upper_bounds in supertypevars.items():
357
- if typevar is None:
358
- continue
359
- if isinstance(typevar, TypeConstant):
360
- continue
361
- self._upper_bounds[typevar] = self._meet(typevar, *upper_bounds, translate=self._get_upper_bound)
362
-
363
- # compute the greatest lower bound for each type variable
364
- seen = set() # loop avoidance
365
- queue = list(subtypevars)
366
- while queue:
367
- typevar = queue.pop(0)
368
- lower_bounds = subtypevars[typevar]
369
-
370
- if typevar not in seen:
371
- # we detect if it depends on any other typevar upon the first encounter
372
- seen.add(typevar)
373
-
374
- abort = False
375
- for subtypevar in lower_bounds:
376
- if isinstance(subtypevar, TypeVariable) and subtypevar not in self._lower_bounds:
377
- # oops - we should analyze the subtypevar first
378
- queue.append(typevar)
379
- # to avoid loops, make sure typevar does not rely on
380
- abort = True
381
- break
382
- if abort:
383
- continue
384
- else:
385
- # avoid loop and continue no matter what
386
- pass
387
-
388
- self._lower_bounds[typevar] = self._join(typevar, *lower_bounds, translate=self._get_lower_bound)
389
-
390
- # because of T-InheritR, fields are propagated *both ways* in a subtype relation
391
- for subtypevar in lower_bounds:
392
- if not isinstance(subtypevar, TypeVariable):
393
- continue
394
- subtype_infimum = self._lower_bounds[subtypevar]
395
- if isinstance(subtype_infimum, Pointer) and isinstance(subtype_infimum.basetype, Struct):
396
- subtype_infimum = self._join(subtypevar, typevar, translate=self._get_lower_bound)
397
- self._lower_bounds[subtypevar] = subtype_infimum
398
-
399
- def _lower_struct_fields(self):
400
- # tv_680: ptr32(struct{0: int32})
401
- # tv_680.load.<32>@0: ptr32(struct{5: int8})
402
- # becomes
403
- # tv_680: ptr32(struct{0: ptr32(struct{5: int8})})
404
-
405
- for outer, outer_lb in self._lower_bounds.items():
406
- if (
407
- isinstance(outer, DerivedTypeVariable)
408
- and isinstance(outer.label, HasField)
409
- and not isinstance(outer_lb, BottomType)
410
- ):
411
- # unpack v
412
- base = outer.type_var.type_var
413
-
414
- if base in self._lower_bounds:
415
- base_lb = self._lower_bounds[base]
416
-
417
- # make sure it's a pointer at the offset that v.label specifies
418
- if isinstance(base_lb, Pointer):
419
- if isinstance(base_lb.basetype, Struct):
420
- the_field = base_lb.basetype.fields[outer.label.offset]
421
- # replace this field
422
- new_field = self._meet(the_field, outer_lb, translate=self._get_upper_bound)
423
- if new_field != the_field:
424
- new_fields = base_lb.basetype.fields.copy()
425
- new_fields.update(
426
- {
427
- outer.label.offset: new_field,
428
- }
429
- )
430
- base_lb = base_lb.__class__(Struct(new_fields))
431
- self._lower_bounds[base] = base_lb
432
-
433
- # another attempt: if a pointer to a struct has only one field, remove the struct
434
- if len(base_lb.basetype.fields) == 1 and 0 in base_lb.basetype.fields:
435
- base_lb = base_lb.__class__(base_lb.basetype.fields[0])
436
- self._lower_bounds[base] = base_lb
437
-
438
686
  def _convert_arrays(self, constraints):
439
687
  for constraint in constraints:
440
688
  if not isinstance(constraint, Existence):
441
689
  continue
442
690
  inner = constraint.type_
443
- if isinstance(inner, DerivedTypeVariable) and isinstance(inner.label, IsArray):
444
- if inner.type_var in self._lower_bounds:
445
- curr_type = self._lower_bounds[inner.type_var]
691
+ if isinstance(inner, DerivedTypeVariable) and isinstance(inner.one_label(), IsArray):
692
+ if inner.type_var in self.solution:
693
+ curr_type = self.solution[inner.type_var]
446
694
  if isinstance(curr_type, Pointer) and isinstance(curr_type.basetype, Struct):
447
695
  # replace all fields with the first field
448
696
  if 0 in curr_type.basetype.fields:
@@ -450,202 +698,494 @@ class SimpleSolver:
450
698
  for offset in curr_type.basetype.fields.keys():
451
699
  curr_type.basetype.fields[offset] = first_field
452
700
 
453
- def _abstract(self, t): # pylint:disable=no-self-use
454
- return t.__class__
701
+ #
702
+ # Constraint graph
703
+ #
455
704
 
456
- def _concretize(self, n_cls, t1, t2, join_or_meet, translate):
457
- ptr_class = self._pointer_class()
705
+ def _generate_constraint_graph(
706
+ self, constraints: Set[TypeConstraint], interesting_variables: Set[DerivedTypeVariable]
707
+ ) -> networkx.DiGraph:
708
+ """
709
+ A constraint graph is the same as the finite state transducer that is presented in Appendix D in the retypd
710
+ paper.
711
+ """
458
712
 
459
- if n_cls is ptr_class:
460
- if isinstance(t1, ptr_class) and isinstance(t2, ptr_class):
461
- # we need to merge them
462
- return ptr_class(join_or_meet(t1.basetype, t2.basetype, translate=translate))
463
- if isinstance(t1, ptr_class):
713
+ graph = networkx.DiGraph()
714
+ for constraint in constraints:
715
+ if isinstance(constraint, Subtype):
716
+ self._constraint_graph_add_edges(
717
+ graph, constraint.sub_type, constraint.super_type, interesting_variables
718
+ )
719
+ self._constraint_graph_saturate(graph)
720
+ self._constraint_graph_remove_self_loops(graph)
721
+ self._constraint_graph_recall_forget_split(graph)
722
+ return graph
723
+
724
+ @staticmethod
725
+ def _constraint_graph_add_recall_edges(graph: networkx.DiGraph, node: ConstraintGraphNode) -> None:
726
+ while True:
727
+ r = node.forget_last_label()
728
+ if r is None:
729
+ break
730
+ prefix, last_label = r
731
+ graph.add_edge(prefix, node, label=(last_label, "recall"))
732
+ node = prefix
733
+
734
+ @staticmethod
735
+ def _constraint_graph_add_forget_edges(graph: networkx.DiGraph, node: ConstraintGraphNode) -> None:
736
+ while True:
737
+ r = node.forget_last_label()
738
+ if r is None:
739
+ break
740
+ prefix, last_label = r
741
+ graph.add_edge(node, prefix, label=(last_label, "forget"))
742
+ node = prefix
743
+
744
+ def _constraint_graph_add_edges(
745
+ self,
746
+ graph: networkx.DiGraph,
747
+ subtype: Union[TypeVariable, DerivedTypeVariable],
748
+ supertype: Union[TypeVariable, DerivedTypeVariable],
749
+ interesting_variables: Set[DerivedTypeVariable],
750
+ ):
751
+ # left and right tags
752
+ if self._typevar_inside_set(self._to_typevar_or_typeconst(subtype), interesting_variables):
753
+ left_tag = ConstraintGraphTag.LEFT
754
+ else:
755
+ left_tag = ConstraintGraphTag.UNKNOWN
756
+ if self._typevar_inside_set(self._to_typevar_or_typeconst(supertype), interesting_variables):
757
+ right_tag = ConstraintGraphTag.RIGHT
758
+ else:
759
+ right_tag = ConstraintGraphTag.UNKNOWN
760
+ # nodes
761
+ forward_src = ConstraintGraphNode(subtype, Variance.COVARIANT, left_tag, FORGOTTEN.PRE_FORGOTTEN)
762
+ forward_dst = ConstraintGraphNode(supertype, Variance.COVARIANT, right_tag, FORGOTTEN.PRE_FORGOTTEN)
763
+ graph.add_edge(forward_src, forward_dst)
764
+ # add recall edges and forget edges
765
+ self._constraint_graph_add_recall_edges(graph, forward_src)
766
+ self._constraint_graph_add_forget_edges(graph, forward_dst)
767
+
768
+ # backward edges
769
+ backward_src = forward_dst.inverse()
770
+ backward_dst = forward_src.inverse()
771
+ graph.add_edge(backward_src, backward_dst)
772
+ self._constraint_graph_add_recall_edges(graph, backward_src)
773
+ self._constraint_graph_add_forget_edges(graph, backward_dst)
774
+
775
+ @staticmethod
776
+ def _constraint_graph_saturate(graph: networkx.DiGraph) -> None:
777
+ """
778
+ The saturation algorithm D.2 as described in Appendix of the retypd paper.
779
+ """
780
+ R: DefaultDict[ConstraintGraphNode, Set[Tuple[BaseLabel, ConstraintGraphNode]]] = defaultdict(set)
781
+
782
+ # initialize the reaching-push sets R(x)
783
+ for x, y, data in graph.edges(data=True):
784
+ if "label" in data and data.get("label")[1] == "forget":
785
+ d = data["label"][0], x
786
+ R[y].add(d)
787
+
788
+ # repeat ... until fixed point
789
+ changed = True
790
+ while changed:
791
+ changed = False
792
+ for x, y, data in graph.edges(data=True):
793
+ if "label" not in data:
794
+ if R[y].issuperset(R[x]):
795
+ continue
796
+ changed = True
797
+ R[y] |= R[x]
798
+ for x, y, data in graph.edges(data=True):
799
+ lbl = data.get("label")
800
+ if lbl and lbl[1] == "recall":
801
+ for label, z in R[x]:
802
+ if not graph.has_edge(z, y):
803
+ changed = True
804
+ graph.add_edge(z, y)
805
+ v_contravariant = []
806
+ for node in graph.nodes:
807
+ node: ConstraintGraphNode
808
+ if node.variance == Variance.CONTRAVARIANT:
809
+ v_contravariant.append(node)
810
+ # lazily apply saturation rules corresponding to S-Pointer
811
+ for x in v_contravariant:
812
+ for z_label, z in R[x]:
813
+ label = None
814
+ if isinstance(z_label, Store):
815
+ label = Load()
816
+ elif isinstance(z_label, Load):
817
+ label = Store()
818
+ if label is not None:
819
+ x_inverse = x.inverse_wo_tag()
820
+ d = label, z
821
+ if d not in R[x_inverse]:
822
+ changed = True
823
+ R[x_inverse].add(d)
824
+
825
+ @staticmethod
826
+ def _constraint_graph_remove_self_loops(graph: networkx.DiGraph):
827
+ for node in list(graph.nodes):
828
+ if graph.has_edge(node, node):
829
+ graph.remove_edge(node, node)
830
+
831
+ @staticmethod
832
+ def _constraint_graph_recall_forget_split(graph: networkx.DiGraph):
833
+ """
834
+ Ensure that recall edges are not reachable after traversing a forget node.
835
+ """
836
+ for src, dst, data in list(graph.edges(data=True)):
837
+ src: ConstraintGraphNode
838
+ dst: ConstraintGraphNode
839
+ if "label" in data and data["label"][1] == "recall":
840
+ continue
841
+ forget_src = ConstraintGraphNode(src.typevar, src.variance, src.tag, FORGOTTEN.POST_FORGOTTEN)
842
+ forget_dst = ConstraintGraphNode(dst.typevar, dst.variance, dst.tag, FORGOTTEN.POST_FORGOTTEN)
843
+ if "label" in data and data["label"][1] == "forget":
844
+ graph.remove_edge(src, dst)
845
+ graph.add_edge(src, forget_dst, **data)
846
+ graph.add_edge(forget_src, forget_dst, **data)
847
+
848
+ @staticmethod
849
+ def _to_typevar_or_typeconst(
850
+ obj: Union[TypeVariable, DerivedTypeVariable, TypeConstant]
851
+ ) -> Union[TypeVariable, TypeConstant]:
852
+ if isinstance(obj, DerivedTypeVariable):
853
+ return SimpleSolver._to_typevar_or_typeconst(obj.type_var)
854
+ elif isinstance(obj, TypeVariable):
855
+ return obj
856
+ elif isinstance(obj, TypeConstant):
857
+ return obj
858
+ raise TypeError(f"Unsupported type {type(obj)}")
859
+
860
+ #
861
+ # Graph solver
862
+ #
863
+
864
+ @staticmethod
865
+ def _typevar_inside_set(typevar, typevar_set: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]]) -> bool:
866
+ if typevar in typevar_set:
867
+ return True
868
+ if isinstance(typevar, Struct) and Struct_ in typevar_set:
869
+ if not typevar.fields:
870
+ return True
871
+ return all(
872
+ SimpleSolver._typevar_inside_set(field_typevar, typevar_set)
873
+ for field_typevar in typevar.fields.values()
874
+ )
875
+ if isinstance(typevar, Array) and Array_ in typevar_set:
876
+ return SimpleSolver._typevar_inside_set(typevar.element, typevar_set)
877
+ if isinstance(typevar, Pointer) and (Pointer32_ in typevar_set or Pointer64_ in typevar_set):
878
+ return SimpleSolver._typevar_inside_set(typevar.basetype, typevar_set)
879
+ return False
880
+
881
+ def _solve_constraints_between(
882
+ self,
883
+ graph: networkx.DiGraph,
884
+ starts: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]],
885
+ ends: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]],
886
+ ) -> Set[TypeConstraint]:
887
+ start_nodes = set()
888
+ end_nodes = set()
889
+ for node in graph.nodes:
890
+ node: ConstraintGraphNode
891
+ if (
892
+ self._typevar_inside_set(self._to_typevar_or_typeconst(node.typevar), starts)
893
+ and node.tag == ConstraintGraphTag.LEFT
894
+ ):
895
+ start_nodes.add(node)
896
+ if (
897
+ self._typevar_inside_set(self._to_typevar_or_typeconst(node.typevar), ends)
898
+ and node.tag == ConstraintGraphTag.RIGHT
899
+ ):
900
+ end_nodes.add(node)
901
+
902
+ if not start_nodes or not end_nodes:
903
+ return set()
904
+
905
+ dfa_solver = DFAConstraintSolver()
906
+ try:
907
+ return dfa_solver.generate_constraints_between(graph, start_nodes, end_nodes)
908
+ except EmptyEpsilonNFAError:
909
+ return set()
910
+
911
+ #
912
+ # Type lattice
913
+ #
914
+
915
+ def join(self, t1: Union[TypeConstant, TypeVariable], t2: Union[TypeConstant, TypeVariable]) -> TypeConstant:
916
+ abstract_t1 = self.abstract(t1)
917
+ abstract_t2 = self.abstract(t2)
918
+ if abstract_t1 in self._base_lattice and abstract_t2 in self._base_lattice:
919
+ ancestor = networkx.lowest_common_ancestor(self._base_lattice, abstract_t1, abstract_t2)
920
+ if ancestor == abstract_t1:
464
921
  return t1
465
- elif isinstance(t2, ptr_class):
922
+ elif ancestor == abstract_t2:
466
923
  return t2
467
924
  else:
468
- # huh?
469
- return ptr_class(BottomType())
470
-
471
- return n_cls()
472
-
473
- def _join(self, *args, translate: Callable):
474
- """
475
- Get the least upper bound (V, maximum) of the arguments.
476
- """
477
-
478
- if len(args) == 0:
479
- return BottomType()
480
- if len(args) == 1:
481
- return translate(args[0])
482
- if len(args) > 2:
483
- split = len(args) // 2
484
- first = self._join(*args[:split], translate=translate)
485
- second = self._join(*args[split:], translate=translate)
486
- return self._join(first, second, translate=translate)
487
-
488
- t1 = translate(args[0])
489
- t2 = translate(args[1])
490
-
491
- # Trivial cases
492
- if t1 == t2:
493
- return t1
494
- if isinstance(t1, TopType):
495
- return t1
496
- elif isinstance(t2, TopType):
925
+ return ancestor
926
+ if t1 == Bottom_:
497
927
  return t2
498
- if isinstance(t1, BottomType):
499
- return t2
500
- elif isinstance(t2, BottomType):
501
- return t1
502
- if isinstance(t1, TypeVariableReference) and not isinstance(t2, TypeVariableReference):
928
+ if t2 == Bottom_:
503
929
  return t1
504
- elif isinstance(t2, TypeVariableReference) and not isinstance(t1, TypeVariableReference):
505
- return t2
506
-
507
- # consult the graph
508
- t1_cls = self._abstract(t1)
509
- t2_cls = self._abstract(t2)
510
-
511
- if t1_cls in self._base_lattice and t2_cls in self._base_lattice:
512
- queue = [t1_cls]
513
- while queue:
514
- n = queue[0]
515
- queue = queue[1:]
516
-
517
- if networkx.has_path(self._base_lattice, n, t2_cls):
518
- return self._concretize(n, t1, t2, self._join, translate)
519
- # go up
520
- queue.extend(self._base_lattice.predecessors(n))
521
-
522
- # handling Struct
523
- if t1_cls is Struct and t2_cls is Struct:
524
- fields = {}
525
- for offset in sorted(set(itertools.chain(t1.fields.keys(), t2.fields.keys()))):
526
- if offset in t1.fields and offset in t2.fields:
527
- v = self._join(t1.fields[offset], t2.fields[offset], translate=translate)
528
- elif offset in t1.fields:
529
- v = t1.fields[offset]
530
- elif offset in t2.fields:
531
- v = t2.fields[offset]
532
- else:
533
- raise Exception("Impossible")
534
- fields[offset] = v
535
- return Struct(fields=fields)
536
-
537
- # single element and single-element struct
538
- if issubclass(t2_cls, Int) and t1_cls is Struct:
539
- # swap them
540
- t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls
541
- if issubclass(t1_cls, Int) and t2_cls is Struct and len(t2.fields) == 1 and 0 in t2.fields:
542
- # e.g., char & struct {0: char}
543
- return Struct(fields={0: self._join(t1, t2.fields[0], translate=translate)})
544
-
545
- ptr_class = self._pointer_class()
546
-
547
- # Struct and Pointers
548
- if t1_cls is ptr_class and t2_cls is Struct:
549
- # swap them
550
- t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls
551
- if t1_cls is Struct and len(t1.fields) == 1 and 0 in t1.fields:
552
- if t1.fields[0].size == 8 and t2_cls is Pointer64:
553
- # they are equivalent
554
- # e.g., struct{0: int64} ptr64(int8)
555
- # return t2 since t2 is more specific
556
- return t2
557
- elif t1.fields[0].size == 4 and t2_cls is Pointer32:
930
+ return Bottom_
931
+
932
+ def meet(self, t1: Union[TypeConstant, TypeVariable], t2: Union[TypeConstant, TypeVariable]) -> TypeConstant:
933
+ abstract_t1 = self.abstract(t1)
934
+ abstract_t2 = self.abstract(t2)
935
+ if abstract_t1 in self._base_lattice_inverted and abstract_t2 in self._base_lattice_inverted:
936
+ ancestor = networkx.lowest_common_ancestor(self._base_lattice_inverted, abstract_t1, abstract_t2)
937
+ if ancestor == abstract_t1:
938
+ return t1
939
+ elif ancestor == abstract_t2:
558
940
  return t2
941
+ else:
942
+ return ancestor
943
+ if t1 == Top_:
944
+ return t2
945
+ if t2 == Top_:
946
+ return t1
947
+ return Top_
948
+
949
+ def abstract(self, t: Union[TypeConstant, TypeVariable]) -> Union[TypeConstant, TypeVariable]:
950
+ if isinstance(t, Pointer32):
951
+ return Pointer32()
952
+ elif isinstance(t, Pointer64):
953
+ return Pointer64()
954
+ return t
955
+
956
+ def determine(
957
+ self,
958
+ equivalent_classes: Dict[TypeVariable, TypeVariable],
959
+ sketches,
960
+ solution: Dict,
961
+ nodes: Optional[Set[SketchNode]] = None,
962
+ ) -> None:
963
+ """
964
+ Determine C-like types from sketches.
965
+
966
+ :param equivalent_classes: A dictionary mapping each type variable from its representative in the equivalence
967
+ class over ~.
968
+ :param sketches: A dictionary storing sketches for each type variable.
969
+ :param solution: The dictionary storing C-like types for each type variable. Output.
970
+ :param nodes: Optional. Nodes that should be considered in the sketch.
971
+ :return: None
972
+ """
973
+ for typevar, sketch in sketches.items():
974
+ self._determine(equivalent_classes, typevar, sketch, solution, nodes=nodes)
559
975
 
560
- # import ipdb; ipdb.set_trace()
561
- return TopType()
976
+ for v, e in self._equivalence.items():
977
+ if v not in solution and e in solution:
978
+ solution[v] = solution[e]
562
979
 
563
- def _meet(self, *args, translate: Callable):
980
+ def _determine(
981
+ self, equivalent_classes, the_typevar, sketch, solution: Dict, nodes: Optional[Set[SketchNode]] = None
982
+ ):
564
983
  """
565
- Get the greatest lower bound (^, minimum) of the arguments.
984
+ Return the solution from sketches
566
985
  """
567
986
 
568
- if len(args) == 0:
569
- return TopType()
570
- if len(args) == 1:
571
- return translate(args[0])
572
- if len(args) > 2:
573
- split = len(args) // 2
574
- first = self._meet(*args[:split], translate=translate)
575
- second = self._meet(*args[split:], translate=translate)
576
- return self._meet(first, second, translate=translate)
577
-
578
- t1 = translate(args[0])
579
- t2 = translate(args[1])
580
-
581
- # Trivial cases
582
- if t1 == t2:
583
- return t1
584
- elif isinstance(t1, BottomType):
585
- return t1
586
- elif isinstance(t2, BottomType):
587
- return t2
588
- if isinstance(t1, TopType):
589
- return t2
590
- elif isinstance(t2, TopType):
591
- return t1
592
- if isinstance(t1, TypeVariableReference) and not isinstance(t2, TypeVariableReference):
593
- return t1
594
- elif isinstance(t2, TypeVariableReference) and not isinstance(t1, TypeVariableReference):
595
- return t2
987
+ if not nodes:
988
+ # TODO: resolve references
989
+ node = sketch.lookup(the_typevar)
990
+ assert node is not None
991
+ nodes = {node}
992
+
993
+ # consult the cache
994
+ cached_results = set()
995
+ for node in nodes:
996
+ if node.typevar in self._solution_cache:
997
+ cached_results.add(self._solution_cache[node.typevar])
998
+ if len(cached_results) == 1:
999
+ return next(iter(cached_results))
1000
+ elif len(cached_results) > 1:
1001
+ # we get nodes for multiple type variables?
1002
+ raise RuntimeError("Getting nodes for multiple type variables. Unexpected.")
1003
+
1004
+ # collect all successors and the paths (labels) of this type variable
1005
+ path_and_successors = []
1006
+ last_labels = []
1007
+ for node in nodes:
1008
+ path_and_successors += self._collect_sketch_paths(node, sketch)
1009
+ for labels, _ in path_and_successors:
1010
+ if labels:
1011
+ last_labels.append(labels[-1])
1012
+
1013
+ # now, what is this variable?
1014
+ if last_labels and all(isinstance(label, (FuncIn, FuncOut)) for label in last_labels):
1015
+ # create a dummy result and dump it to the cache
1016
+ func_type = Function([], [])
1017
+ result = self._pointer_class()(basetype=func_type)
1018
+ for node in nodes:
1019
+ self._solution_cache[node.typevar] = result
1020
+
1021
+ # this is a function variable
1022
+ func_inputs = defaultdict(set)
1023
+ func_outputs = defaultdict(set)
1024
+
1025
+ for labels, succ in path_and_successors:
1026
+ last_label = labels[-1] if labels else None
1027
+
1028
+ if isinstance(last_label, FuncIn):
1029
+ func_inputs[last_label.loc].add(succ)
1030
+ elif isinstance(last_label, FuncOut):
1031
+ func_outputs[last_label.loc].add(succ)
1032
+ else:
1033
+ raise RuntimeError("Unreachable")
1034
+
1035
+ input_args = []
1036
+ output_values = []
1037
+ for vals, out in [(func_inputs, input_args), (func_outputs, output_values)]:
1038
+ for idx in range(0, max(vals) + 1):
1039
+ if idx in vals:
1040
+ sol = self._determine(equivalent_classes, the_typevar, sketch, solution, nodes=vals[idx])
1041
+ out.append(sol)
1042
+ else:
1043
+ out.append(None)
1044
+
1045
+ # back patch
1046
+ func_type.params = input_args
1047
+ func_type.outputs = output_values
1048
+
1049
+ for node in nodes:
1050
+ solution[node.typevar] = result
1051
+
1052
+ elif not path_and_successors:
1053
+ # this is a primitive variable
1054
+ lower_bound = Bottom_
1055
+ upper_bound = Top_
1056
+
1057
+ for node in nodes:
1058
+ lower_bound = self.join(lower_bound, node.lower_bound)
1059
+ upper_bound = self.meet(upper_bound, node.upper_bound)
1060
+ # TODO: Support variables that are accessed via differently sized pointers
1061
+
1062
+ result = lower_bound if not isinstance(lower_bound, BottomType) else upper_bound
1063
+ for node in nodes:
1064
+ solution[node.typevar] = result
1065
+ self._solution_cache[node.typevar] = result
1066
+
1067
+ else:
1068
+ if len(nodes) == 1:
1069
+ the_node = next(iter(nodes))
1070
+ if (
1071
+ isinstance(the_node.upper_bound, self._pointer_class())
1072
+ and isinstance(the_node.upper_bound.basetype, Struct)
1073
+ and the_node.upper_bound.basetype.name
1074
+ ):
1075
+ # handle pointers to known struct types
1076
+ result = (
1077
+ the_node.lower_bound
1078
+ if not isinstance(the_node.lower_bound, BottomType)
1079
+ else the_node.upper_bound
1080
+ )
1081
+ for node in nodes:
1082
+ solution[node.typevar] = result
1083
+ self._solution_cache[node.typevar] = result
1084
+ return result
1085
+
1086
+ # create a dummy result and shove it into the cache
1087
+ struct_type = Struct(fields={})
1088
+ result = self._pointer_class()(struct_type)
1089
+ for node in nodes:
1090
+ self._solution_cache[node.typevar] = result
1091
+
1092
+ # this might be a struct
1093
+ fields = {}
596
1094
 
597
- # consult the graph
598
- t1_cls = self._abstract(t1)
599
- t2_cls = self._abstract(t2)
1095
+ candidate_bases = defaultdict(set)
1096
+
1097
+ for labels, succ in path_and_successors:
1098
+ last_label = labels[-1] if labels else None
1099
+ if isinstance(last_label, HasField):
1100
+ candidate_bases[last_label.offset].add(last_label.bits // 8)
1101
+
1102
+ node_to_base = {}
1103
+
1104
+ for labels, succ in path_and_successors:
1105
+ last_label = labels[-1] if labels else None
1106
+ if isinstance(last_label, HasField):
1107
+ for start_offset, sizes in candidate_bases.items():
1108
+ for size in sizes:
1109
+ if last_label.offset > start_offset:
1110
+ if last_label.offset < start_offset + size: # ???
1111
+ node_to_base[succ] = start_offset
1112
+
1113
+ node_by_offset = defaultdict(set)
1114
+
1115
+ for labels, succ in path_and_successors:
1116
+ last_label = labels[-1] if labels else None
1117
+ if isinstance(last_label, HasField):
1118
+ if succ in node_to_base:
1119
+ node_by_offset[node_to_base[succ]].add(succ)
1120
+ else:
1121
+ node_by_offset[last_label.offset].add(succ)
1122
+
1123
+ for offset, child_nodes in node_by_offset.items():
1124
+ sol = self._determine(equivalent_classes, the_typevar, sketch, solution, nodes=child_nodes)
1125
+ if isinstance(sol, TopType):
1126
+ sol = int_type(min(candidate_bases[offset]) * 8)
1127
+ fields[offset] = sol
1128
+
1129
+ if not fields:
1130
+ result = Top_
1131
+ for node in nodes:
1132
+ self._solution_cache[node.typevar] = result
1133
+ else:
1134
+ # back-patch
1135
+ struct_type.fields = fields
1136
+ for node in nodes:
1137
+ solution[node.typevar] = result
600
1138
 
601
- if t1_cls in self._base_lattice and t2_cls in self._base_lattice:
602
- queue = [t1_cls]
603
- while queue:
604
- n = queue[0]
605
- queue = queue[1:]
1139
+ # import pprint
606
1140
 
607
- if networkx.has_path(self._base_lattice, t2_cls, n):
608
- return self._concretize(n, t1, t2, self._meet, translate)
609
- # go down
610
- queue.extend(self._base_lattice.successors(n))
1141
+ # print("Solution")
1142
+ # pprint.pprint(result)
1143
+ return result
611
1144
 
612
- # handling Struct
613
- if t1_cls is Struct and t2_cls is Struct:
614
- fields = {}
615
- for offset in sorted(set(itertools.chain(t1.fields.keys(), t2.fields.keys()))):
616
- if offset in t1.fields and offset in t2.fields:
617
- v = self._meet(t1.fields[offset], t2.fields[offset], translate=translate)
618
- elif offset in t1.fields:
619
- v = t1.fields[offset]
620
- elif offset in t2.fields:
621
- v = t2.fields[offset]
1145
+ @staticmethod
1146
+ def _collect_sketch_paths(node: SketchNodeBase, sketch: Sketch) -> List[Tuple[List[BaseLabel], SketchNodeBase]]:
1147
+ """
1148
+ Collect all paths that go from `typevar` to its leaves.
1149
+ """
1150
+ paths = []
1151
+ visited: Set[SketchNodeBase] = set()
1152
+ queue: List[Tuple[List[BaseLabel], SketchNodeBase]] = [([], node)]
1153
+
1154
+ while queue:
1155
+ curr_labels, curr_node = queue.pop(0)
1156
+ if curr_node in visited:
1157
+ continue
1158
+ visited.add(curr_node)
1159
+
1160
+ out_edges = sketch.graph.out_edges(curr_node, data=True)
1161
+ for _, succ, data in out_edges:
1162
+ if isinstance(succ, RecursiveRefNode):
1163
+ ref = succ
1164
+ succ: Optional[SketchNode] = sketch.lookup(succ.target)
1165
+ if succ is None:
1166
+ # failed to resolve...
1167
+ _l.warning(
1168
+ "Failed to resolve reference node to a real sketch node for type variable %s", ref.target
1169
+ )
1170
+ continue
1171
+ label = data["label"]
1172
+ if isinstance(label, ConvertTo):
1173
+ # drop conv labels for now
1174
+ continue
1175
+ if isinstance(label, IsArray):
1176
+ continue
1177
+ new_labels = curr_labels + [label]
1178
+ succ: SketchNode
1179
+ if isinstance(succ.typevar, DerivedTypeVariable) and isinstance(succ.typevar.labels[-1], (Load, Store)):
1180
+ queue.append((new_labels, succ))
622
1181
  else:
623
- raise Exception("Impossible")
624
- fields[offset] = v
625
- return Struct(fields=fields)
626
-
627
- # single element and single-element struct
628
- if issubclass(t2_cls, Int) and t1_cls is Struct:
629
- # swap them
630
- t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls
631
- if issubclass(t1_cls, Int) and t2_cls is Struct and len(t2.fields) == 1 and 0 in t2.fields:
632
- # e.g., char & struct {0: char}
633
- return Struct(fields={0: self._meet(t1, t2.fields[0], translate=translate)})
634
-
635
- ptr_class = self._pointer_class()
636
-
637
- # Struct and Pointers
638
- if t1_cls is ptr_class and t2_cls is Struct:
639
- # swap them
640
- t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls
641
- if t1_cls is Struct and len(t1.fields) == 1 and 0 in t1.fields:
642
- if t1.fields[0].size == 8 and t2_cls is Pointer64:
643
- # they are equivalent
644
- # e.g., struct{0: int64} ptr64(int8)
645
- # return t2 since t2 is more specific
646
- return t2
647
- elif t1.fields[0].size == 4 and t2_cls is Pointer32:
648
- return t2
1182
+ paths.append((new_labels, succ))
649
1183
 
650
- # import ipdb; ipdb.set_trace()
651
- return BottomType()
1184
+ return paths
1185
+
1186
+ def _pointer_class(self) -> Union[Type[Pointer32], Type[Pointer64]]:
1187
+ if self.bits == 32:
1188
+ return Pointer32
1189
+ elif self.bits == 64:
1190
+ return Pointer64
1191
+ raise NotImplementedError("Unsupported bits %d" % self.bits)