compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,546 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Sequence, Optional, Tuple, List, Dict
5
+
6
+ import llvmlite.binding as llvm
7
+ import llvmlite.ir as ir
8
+ import numpy as np
9
+ import ctypes as ct
10
+
11
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
13
+ from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram
14
+ from ..circuit import ADD as _ADD, MUL as _MUL, ConstValue
15
+ from ..circuit import Circuit, VarNode, CircuitNode, OpNode
16
+ from ..program.raw_program import RawProgramFunction
17
+
18
+ DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
19
+
20
+ # Byte code operations
21
+ # _ADD: int = circuit.ADD
22
+ # _MUL: int = circuit.MUL
23
+ _END: int = max(_ADD, _MUL) + 1
24
+
25
+ # arrays
26
+ _VARS: int = 0
27
+ _TMPS: int = 1
28
+ _RESULT: int = 2
29
+ _CONSTS: int = 3
30
+
31
+ _SET_CONSTS_FUNCTION_NAME: str = 'set_consts'
32
+ _SET_INSTRUCTIONS_FUNCTION_NAME: str = 'set_instructions'
33
+
34
+
35
+ def compile_circuit(
36
+ *result: CircuitNode,
37
+ input_vars: InputVars = InferVars.ALL,
38
+ circuit: Optional[Circuit] = None,
39
+ data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
40
+ keep_llvm_program: bool = True,
41
+ compile_arrays: bool = False,
42
+ opt: int = 2,
43
+ ) -> LLVMRawProgram:
44
+ """
45
+ Compile the given circuit using LLVM.
46
+
47
+ This creates an LLVM program where each circuit op node is converted to
48
+ one or more LLVM binary op machine code instructions. For large circuits
49
+ this results in a large LLVM program which can be slow to compile.
50
+
51
+ This compiler produces a RawProgram that _does_ use client managed working memory.
52
+
53
+ Conforms to the CircuitCompiler protocol.
54
+
55
+ Args:
56
+ *result: result nodes nominating the results of the returned program.
57
+ input_vars: How to determine the input variables.
58
+ circuit: optionally explicitly specify the Circuit.
59
+ data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
60
+ keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
61
+ compile_arrays: if true, the global array values are included in the LLVM program.
62
+ opt:The optimization level to use by LLVM MC JIT.
63
+
64
+ Returns:
65
+ a raw program.
66
+
67
+ Raises:
68
+ ValueError: if the circuit is unknown, but it is needed.
69
+ ValueError: if not all nodes are from the same circuit.
70
+ ValueError: if the program data type could not be interpreted.
71
+ """
72
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
73
+ var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
74
+
75
+ # Get the type info
76
+ type_info: TypeInfo
77
+ if isinstance(data_type, DataType):
78
+ type_info = data_type.value
79
+ elif isinstance(data_type, TypeInfo):
80
+ type_info = data_type
81
+ else:
82
+ raise ValueError(f'could not interpret program data type: {data_type!r}')
83
+
84
+ # Compile the circuit to an LLVM module representing a RawProgramFunction
85
+ llvm_program: str
86
+ number_of_tmps: int
87
+ llvm_program, number_of_tmps, consts, byte_code = _make_llvm_program(in_vars, result, type_info, compile_arrays)
88
+
89
+ # Compile the LLVM program to a native executable
90
+ engine: llvm.ExecutionEngine
91
+ function: RawProgramFunction
92
+ engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
93
+
94
+ if compile_arrays:
95
+ return LLVMRawProgram(
96
+ function=function,
97
+ dtype=type_info.dtype,
98
+ number_of_vars=len(var_indices),
99
+ number_of_tmps=number_of_tmps,
100
+ number_of_results=len(result),
101
+ var_indices=var_indices,
102
+ llvm_program=llvm_program if keep_llvm_program else None,
103
+ engine=engine,
104
+ opt=opt,
105
+ )
106
+ else:
107
+ # Arrays `consts` and `byte_code` are not compiled into the LLVM program
108
+ # so they need to be stored explicitly.
109
+ return LLVMRawProgramWithArrays(
110
+ function=function,
111
+ dtype=type_info.dtype,
112
+ number_of_vars=len(var_indices),
113
+ number_of_tmps=number_of_tmps,
114
+ number_of_results=len(result),
115
+ var_indices=var_indices,
116
+ llvm_program=llvm_program if keep_llvm_program else None,
117
+ engine=engine,
118
+ opt=opt,
119
+ instructions=np.array(byte_code, dtype=np.uint8),
120
+ consts=np.array(consts, dtype=type_info.dtype),
121
+ )
122
+
123
+
124
+ @dataclass
125
+ class LLVMRawProgramWithArrays(LLVMRawProgram):
126
+ instructions: np.ndarray
127
+ consts: np.ndarray
128
+
129
+ def __post_init__(self):
130
+ self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
131
+ self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
132
+
133
+ def __getstate__(self):
134
+ state = super().__getstate__()
135
+ state['instructions'] = self.instructions
136
+ state['consts'] = self.consts
137
+ return state
138
+
139
+ def __setstate__(self, state):
140
+ super().__setstate__(state)
141
+ self.instructions = state['instructions']
142
+ self.consts = state['consts']
143
+ self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
144
+ self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
145
+
146
+ def _set_globals(self, data: np.ndarray, func_name: str) -> None:
147
+ ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(data.dtype))
148
+ c_np_data = data.ctypes.data_as(ptr_type)
149
+
150
+ function_ptr = self.engine.get_function_address(func_name)
151
+ function = ct.CFUNCTYPE(None, ptr_type)(function_ptr)
152
+ function(c_np_data)
153
+
154
+
155
+ def _make_llvm_program(
156
+ in_vars: Sequence[VarNode],
157
+ result: Sequence[CircuitNode],
158
+ type_info: TypeInfo,
159
+ compile_arrays: bool,
160
+ ) -> Tuple[str, int, List[ConstValue], List[int]]:
161
+ """
162
+ Construct the LLVM program (i.e., LLVM module).
163
+
164
+ Returns:
165
+ (llvm_program, number_of_tmps, consts, byte_code)
166
+ """
167
+ llvm_function = IRFunction(type_info)
168
+
169
+ builder = llvm_function.builder
170
+ type_info = llvm_function.type_info
171
+ module = llvm_function.module
172
+
173
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
174
+ const_values: List[ConstValue] = [const_node.value for const_node in analysis.const_nodes]
175
+
176
+ max_index_size: int = max(
177
+ len(analysis.var_nodes), # number of inputs
178
+ len(analysis.result_nodes), # number of outputs
179
+ len(analysis.op_to_tmp), # number of tmps
180
+ len(analysis.const_nodes), # number of constants
181
+ )
182
+ data_idx_bytes: int = _get_bytes_needed(max_index_size)
183
+
184
+ max_num_args: int = max((len(op_node.args) for op_node in analysis.op_nodes), default=0)
185
+ num_args_bytes: int = _get_bytes_needed(max_num_args)
186
+
187
+ data_type: ir.Type = type_info.llvm_type
188
+ byte_type: ir.Type = ir.IntType(8)
189
+ data_idx_type: ir.Type = ir.IntType(data_idx_bytes * 8)
190
+
191
+ byte_code: List[int] = _make_byte_code(analysis, data_idx_bytes, num_args_bytes)
192
+
193
+ inst_idx_bytes: int = _get_bytes_needed(len(byte_code))
194
+ inst_idx_type: ir.Type = ir.IntType(inst_idx_bytes * 8)
195
+
196
+ if compile_arrays:
197
+ # Put constants into the LLVM module
198
+ consts_array_type = ir.ArrayType(data_type, len(analysis.const_nodes))
199
+ consts_global = ir.GlobalVariable(module, consts_array_type, name='consts')
200
+ consts_global.global_constant = True
201
+ consts_global.initializer = ir.Constant(consts_array_type, const_values)
202
+ data_idx_0 = ir.Constant(data_idx_type, 0)
203
+ consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
204
+
205
+ # Put bytecode into the LLVM module
206
+ instructions_array_type = ir.ArrayType(byte_type, len(byte_code))
207
+ instructions_global = ir.GlobalVariable(module, instructions_array_type, name='instructions')
208
+ instructions_global.global_constant = True
209
+ instructions_global.initializer = ir.Constant(instructions_array_type, byte_code)
210
+ inst_idx_0 = ir.Constant(inst_idx_type, 0)
211
+ instructions: ir.Value = builder.gep(instructions_global, [inst_idx_0, inst_idx_0])
212
+ else:
213
+ # Just create two global variables that will be set externally.
214
+ const_ptr_type = data_type.as_pointer()
215
+ consts_global = ir.GlobalVariable(module, const_ptr_type, name='consts')
216
+ consts_global.initializer = ir.Constant(const_ptr_type, None)
217
+ consts: ir.Value = builder.load(consts_global)
218
+
219
+ instructions_ptr_type = byte_type.as_pointer()
220
+ instructions_global = ir.GlobalVariable(module, instructions_ptr_type, name='instructions')
221
+ instructions_global.initializer =ir.Constant(instructions_ptr_type, None)
222
+ instructions: ir.Value = builder.load(instructions_global)
223
+
224
+ interp = _InterpBuilder(builder, type_info, inst_idx_type, data_idx_bytes, num_args_bytes, consts, instructions)
225
+ interp.make_interpreter()
226
+
227
+ if not compile_arrays:
228
+ # add functions to set global arrays
229
+ interp.make_set_consts_function(consts_global)
230
+ interp.make_set_instructions_function(instructions_global)
231
+
232
+ # print(llvm_function.llvm_program())
233
+ # exit(99)
234
+
235
+ return llvm_function.llvm_program(), len(analysis.op_to_tmp), const_values, byte_code
236
+
237
+
238
+ class _InterpBuilder:
239
+ """
240
+ Helper to write the LLVM function for the byte code interpreter.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ builder: ir.IRBuilder,
246
+ type_info: TypeInfo,
247
+ inst_idx_type: ir.Type,
248
+ index_bytes: int,
249
+ num_args_bytes: int,
250
+ consts: ir.Value,
251
+ instructions: ir.Value,
252
+ ):
253
+ self.builder: ir.IRBuilder = builder
254
+ self.index_bytes: int = index_bytes
255
+ self.num_args_bytes: int = num_args_bytes
256
+ self.type_info: TypeInfo = type_info
257
+
258
+ self.data_type: ir.Type = type_info.llvm_type
259
+ self.byte_type: ir.Type = ir.IntType(8)
260
+ self.inst_idx_type: ir.Type = inst_idx_type
261
+ self.data_idx_type: ir.Type = ir.IntType(index_bytes * 8)
262
+ self.num_args_type: ir.Type = ir.IntType(num_args_bytes * 8)
263
+
264
+ self.data_idx_0 = ir.Constant(self.data_idx_type, 0)
265
+ self.data_idx_1 = ir.Constant(self.data_idx_type, 1)
266
+ self.inst_idx_0 = ir.Constant(self.inst_idx_type, 0)
267
+ self.inst_idx_1 = ir.Constant(self.inst_idx_type, 1)
268
+ self.num_args_0 = ir.Constant(self.num_args_type, 0)
269
+ self.num_args_1 = ir.Constant(self.num_args_type, 1)
270
+
271
+ self.consts: ir.Value = consts
272
+ self.instructions: ir.Value = instructions
273
+
274
+ # allocate locals
275
+ self.local_idx = builder.alloca(self.inst_idx_type, name='idx')
276
+ self.local_num_args = builder.alloca(self.num_args_type, name='num_args')
277
+ self.local_accumulator = builder.alloca(self.data_type, name='accumulator')
278
+ self.local_arrays = builder.alloca(self.data_type.as_pointer(), size=4, name='arrays')
279
+
280
+ # local_arrays = [vars, tmps, result, consts]
281
+ ir_vars_idx = ir.Constant(self.byte_type, _VARS)
282
+ ir_tmps_idx = ir.Constant(self.byte_type, _TMPS)
283
+ ir_result_idx = ir.Constant(self.byte_type, _RESULT)
284
+ ir_consts_idx = ir.Constant(self.byte_type, _CONSTS)
285
+ function: ir.Function = builder.function
286
+ local_arrays = self.local_arrays
287
+ builder.store(function.args[0], builder.gep(local_arrays, [ir_vars_idx]))
288
+ builder.store(function.args[1], builder.gep(local_arrays, [ir_tmps_idx]))
289
+ builder.store(function.args[2], builder.gep(local_arrays, [ir_result_idx]))
290
+ builder.store(consts, builder.gep(local_arrays, [ir_consts_idx]))
291
+
292
+ # local_idx = 0
293
+ builder.store(self.inst_idx_0, self.local_idx)
294
+
295
+ def make_set_consts_function(self, consts_ptr: ir.GlobalVariable):
296
+ builder = self.builder
297
+ module = builder.module
298
+ function_type = ir.FunctionType(ir.VoidType(), (self.data_type.as_pointer(),))
299
+ function = ir.Function(module, function_type, name=_SET_CONSTS_FUNCTION_NAME)
300
+ bb_entry = function.append_basic_block('entry')
301
+ builder.position_at_end(bb_entry)
302
+ arg = function.args[0]
303
+ builder.store(arg, consts_ptr)
304
+ builder.ret_void()
305
+
306
+ def make_set_instructions_function(self, instructions_ptr: ir.GlobalVariable):
307
+ builder = self.builder
308
+ module = builder.module
309
+ function_type = ir.FunctionType(ir.VoidType(), (self.byte_type.as_pointer(),))
310
+ function = ir.Function(module, function_type, name=_SET_INSTRUCTIONS_FUNCTION_NAME)
311
+ bb_entry = function.append_basic_block('entry')
312
+ builder.position_at_end(bb_entry)
313
+ arg = function.args[0]
314
+ builder.store(arg, instructions_ptr)
315
+ builder.ret_void()
316
+
317
+ def add(self, x: ir.Value, y: ir.Value) -> ir.Value:
318
+ return self.type_info.add(self.builder, x, y)
319
+
320
+ def mul(self, x: ir.Value, y: ir.Value) -> ir.Value:
321
+ return self.type_info.mul(self.builder, x, y)
322
+
323
+ def make_interpreter(self):
324
+ """
325
+ Write the bytecode interpreter
326
+ """
327
+ builder: ir.IRBuilder = self.builder
328
+ function: ir.Function = builder.function
329
+
330
+ bb_while = function.append_basic_block('while')
331
+ bb_body = function.append_basic_block('body')
332
+ bb_mul = function.append_basic_block('mul')
333
+ bb_mul_op = function.append_basic_block('mul_op')
334
+ bb_add = function.append_basic_block('add')
335
+ bb_add_op = function.append_basic_block('add_op')
336
+ bb_op_continue = function.append_basic_block('op_continue')
337
+ bb_finish = function.append_basic_block('finish')
338
+
339
+ # block: entry
340
+ # (locals already set up in the constructor)
341
+ builder.branch(bb_while)
342
+
343
+ # block: while
344
+ builder.position_at_end(bb_while)
345
+ # load current instruction
346
+ idx = builder.load(self.local_idx)
347
+ inst = builder.load(builder.gep(self.instructions, [idx]))
348
+ idx = builder.add(idx, self.inst_idx_1)
349
+ #
350
+ cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _END))
351
+ builder.cbranch(cmp_end, bb_finish, bb_body)
352
+
353
+ # block: body
354
+ builder.position_at_end(bb_body)
355
+ # load number of args
356
+ idx, num_args = self._read_number(idx, self.num_args_bytes)
357
+ builder.store(num_args, self.local_num_args)
358
+ # load first arg value into the accumulator
359
+ idx, arg0 = self._load_value(idx)
360
+ builder.store(arg0, self.local_accumulator)
361
+ # save the current bytecode index
362
+ builder.store(idx, self.local_idx)
363
+ #
364
+ cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _MUL))
365
+ builder.cbranch(cmp_end, bb_mul, bb_add)
366
+
367
+ # block: mul
368
+ builder.position_at_end(bb_mul)
369
+ num_args = builder.load(self.local_num_args)
370
+ num_args = builder.sub(num_args, self.num_args_1)
371
+ builder.store(num_args, self.local_num_args)
372
+ more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
373
+ builder.cbranch(more_args, bb_mul_op, bb_op_continue)
374
+
375
+ # block: mul_op
376
+ builder.position_at_end(bb_mul_op)
377
+ idx = builder.load(self.local_idx)
378
+ idx, value = self._load_value(idx)
379
+ acc = builder.load(self.local_accumulator)
380
+ acc = self.mul(acc, value)
381
+ builder.store(acc, self.local_accumulator)
382
+ builder.store(idx, self.local_idx)
383
+ builder.branch(bb_mul)
384
+
385
+ # block: add
386
+ builder.position_at_end(bb_add)
387
+ num_args = builder.load(self.local_num_args)
388
+ num_args = builder.sub(num_args, self.num_args_1)
389
+ builder.store(num_args, self.local_num_args)
390
+ more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
391
+ builder.cbranch(more_args, bb_add_op, bb_op_continue)
392
+
393
+ # block: add_op
394
+ builder.position_at_end(bb_add_op)
395
+ idx = builder.load(self.local_idx)
396
+ idx, value = self._load_value(idx)
397
+ acc = builder.load(self.local_accumulator)
398
+ acc = self.add(acc, value)
399
+ builder.store(acc, self.local_accumulator)
400
+ builder.store(idx, self.local_idx)
401
+ builder.branch(bb_add)
402
+
403
+ # block: op_continue
404
+ builder.position_at_end(bb_op_continue)
405
+ # get where we store the result
406
+ idx = builder.load(self.local_idx)
407
+ idx, ptr = self._load_value_ptr(idx)
408
+ builder.store(idx, self.local_idx)
409
+ # get and store the result
410
+ acc = builder.load(self.local_accumulator)
411
+ builder.store(acc, ptr)
412
+ builder.branch(bb_while)
413
+
414
+ # block: finish
415
+ builder.position_at_end(bb_finish)
416
+ builder.ret_void()
417
+
418
+ def _read_number(self, idx: ir.Value, num_bytes: int) -> Tuple[ir.Value, ir.Value]:
419
+ """
420
+
421
+ Args:
422
+ idx: current instruction index
423
+ num_bytes: how many bytes to read from the instruction stream to form the number
424
+
425
+ Returns:
426
+ (idx, number)
427
+ idx: is the updated instruction index
428
+ number: is the read number
429
+ """
430
+ builder = self.builder
431
+
432
+ llvm_type: ir.Type = ir.IntType(num_bytes * 8)
433
+
434
+ number: ir.Value = builder.load(builder.gep(self.instructions, [idx]))
435
+ idx = builder.add(idx, self.inst_idx_1)
436
+
437
+ if num_bytes > 1:
438
+ eight = ir.Constant(llvm_type, 8)
439
+ number = builder.zext(number, llvm_type)
440
+ for _ in range(num_bytes - 1):
441
+ next_byte = builder.load(builder.gep(self.instructions, [idx]))
442
+ number = builder.add(builder.shl(number, eight), builder.zext(next_byte, llvm_type))
443
+ idx = builder.add(idx, self.inst_idx_1)
444
+
445
+ return idx, number
446
+
447
+ def _load_value_ptr(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
448
+ builder = self.builder
449
+
450
+ # load array first index
451
+ index_0 = builder.load(builder.gep(self.instructions, [idx]))
452
+ idx = builder.add(idx, self.inst_idx_1)
453
+
454
+ # load array second index
455
+ idx, index_1 = self._read_number(idx, self.index_bytes)
456
+
457
+ # get the pointer
458
+ array = builder.load(builder.gep(self.local_arrays, [index_0]))
459
+ ptr = builder.gep(array, [index_1])
460
+
461
+ return idx, ptr
462
+
463
+ def _load_value(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
464
+ idx, ptr = self._load_value_ptr(idx)
465
+ value = self.builder.load(ptr)
466
+ return idx, value
467
+
468
+
469
+ @dataclass
470
+ class _ElementID:
471
+ """
472
+ A 2D index into the function's `arrays`.
473
+ """
474
+ array: int # which array: VARS, TMPS, CONSTS, RESULT
475
+ index: int # index into the array
476
+
477
+
478
+ def _make_byte_code(analysis: CircuitAnalysis, data_idx_bytes: int, num_args_bytes: int) -> List[int]:
479
+ # Index input value elements for each possible input node.
480
+ node_to_element: Dict[int, _ElementID] = {}
481
+ # const nodes
482
+ for i, node in enumerate(analysis.const_nodes):
483
+ node_to_element[id(node)] = _ElementID(_CONSTS, i)
484
+ # var nodes
485
+ for i, var_node in enumerate(analysis.var_nodes):
486
+ if var_node.is_const():
487
+ node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
488
+ else:
489
+ node_to_element[id(var_node)] = _ElementID(_VARS, i)
490
+ # op nodes
491
+ for node_id, tmp_idx in analysis.op_to_tmp.items():
492
+ node_to_element[node_id] = _ElementID(_TMPS, tmp_idx)
493
+ for node_id, result_idx in analysis.op_to_result.items():
494
+ node_to_element[node_id] = _ElementID(_RESULT, result_idx)
495
+
496
+ # Make byte code
497
+ byte_code: List[int] = []
498
+ for op_node in analysis.op_nodes:
499
+ # write the op code
500
+ byte_code.append(op_node.symbol) # _ADD or _MUL
501
+ # write the number of args
502
+ byte_code.extend(_to_bytes(len(op_node.args), num_args_bytes))
503
+ # write the element id for each arg
504
+ for arg_node in op_node.args:
505
+ element_id: _ElementID = node_to_element[id(arg_node)]
506
+ byte_code.append(element_id.array)
507
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
508
+ # write the element id for the result
509
+ element_id: _ElementID = node_to_element[id(op_node)]
510
+ byte_code.append(element_id.array)
511
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
512
+ # ...any final copy instructions
513
+ for idx, node in enumerate(analysis.result_nodes):
514
+ if not isinstance(node, OpNode):
515
+ byte_code.append(_ADD)
516
+ byte_code.extend(_to_bytes(1, num_args_bytes))
517
+
518
+ element_id: _ElementID = node_to_element[id(node)]
519
+ byte_code.append(element_id.array)
520
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
521
+
522
+ byte_code.append(_RESULT)
523
+ byte_code.extend(_to_bytes(idx, data_idx_bytes))
524
+
525
+ # write the sentinel - 'end' op code
526
+ byte_code.append(_END)
527
+
528
+ return byte_code
529
+
530
+
531
+ def _to_bytes(value: int, num_bytes: int) -> List[int]:
532
+ buffer: List[int] = []
533
+ for _ in range(num_bytes):
534
+ buffer.append(value % 256)
535
+ value //= 256
536
+ assert value == 0
537
+ buffer.reverse()
538
+ return buffer
539
+
540
+
541
+ def _get_bytes_needed(size: int) -> int:
542
+ index_bytes: int
543
+ for index_bytes in [1, 2, 4, 8]:
544
+ if size < 2 ** (index_bytes * 8 - 1):
545
+ return index_bytes
546
+ raise ValueError(f'size are too large to represent: {size}')
@@ -0,0 +1,57 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Optional
4
+
5
+ from .llvm_compiler import Flavour
6
+ from ..circuit import CircuitNode, Circuit
7
+ from ..circuit_compiler import interpret_compiler, cython_vm_compiler, llvm_compiler, llvm_vm_compiler, CircuitCompiler
8
+ from ..circuit_compiler.support.input_vars import InputVars, InferVars
9
+ from ..program import RawProgram
10
+
11
+
12
+ class NamedCircuitCompiler(Enum):
13
+ """
14
+ A standard collection of named circuit compiler functions.
15
+
16
+ The `value` of each enum member is tuple containing a compiler function.
17
+ Wrapping in a tuple is needed otherwise Python erases the type of the member, which can cause problems.
18
+ Each member itself is callable, conforming to the CircuitCompiler protocol, delegating to the compiler function.
19
+ """
20
+
21
+ LLVM_STACK: CircuitCompiler = (partial(llvm_compiler.compile_circuit, flavour=Flavour.STACK),)
22
+ LLVM_TMPS: CircuitCompiler = (partial(llvm_compiler.compile_circuit, flavour=Flavour.TMPS, opt=0),)
23
+ LLVM_VM: CircuitCompiler = (llvm_vm_compiler.compile_circuit,)
24
+ CYTHON_VM: CircuitCompiler = (cython_vm_compiler.compile_circuit,)
25
+ INTERPRET: CircuitCompiler = (interpret_compiler.compile_circuit,)
26
+
27
+ # The following circuit compilers were experimental but are not really useful
28
+ #
29
+ # Slow compile and execution:
30
+ # LLVM_FUNCS: CircuitCompiler = (partial(llvm_compiler.compile_circuit, flavour=Flavour.FUNCS, opt=0),)
31
+ #
32
+ # Slow compile and same execution as LLVM_VM:
33
+ # LLVM_VM_COMPILED_ARRAYS: CircuitCompiler = (partial(llvm_vm_compiler.compile_circuit, compile_arrays=True),)
34
+
35
+ def __call__(
36
+ self,
37
+ *result: CircuitNode,
38
+ input_vars: InputVars = InferVars.ALL,
39
+ circuit: Optional[Circuit] = None,
40
+ ) -> RawProgram:
41
+ """
42
+ Each member of the enum is a CircuitCompiler function.
43
+
44
+ This implements the `CircuitCompiler` protocol for each member of the enum.
45
+ """
46
+ return self.compiler(*result, input_vars=input_vars, circuit=circuit)
47
+
48
+ @property
49
+ def compiler(self) -> CircuitCompiler:
50
+ """
51
+ Returns:
52
+ The compiler function, conforming to the CircuitCompiler protocol.
53
+ """
54
+ return self.value[0]
55
+
56
+
57
+ DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.LLVM_VM
File without changes
@@ -0,0 +1,81 @@
1
+ from dataclasses import dataclass
2
+ from itertools import count
3
+ from typing import List, Dict, Sequence, Set
4
+
5
+ from ck.circuit import OpNode, VarNode, CircuitNode, ConstNode
6
+
7
+
8
+ @dataclass
9
+ class CircuitAnalysis:
10
+ var_nodes: Sequence[VarNode] # input var nodes, in VarNode idx order
11
+ result_nodes: Sequence[CircuitNode] # result nodes
12
+ op_nodes: Sequence[OpNode] # in-use op nodes, in computation order
13
+ const_nodes: Sequence[ConstNode] # in_use const nodes, in arbitrary order
14
+ op_to_result: Dict[int, int] # op nodes in the result, op_node = result[idx]: id(op_node) -> idx
15
+ op_to_tmp: Dict[int, int] # op nodes needing tmp memory, using tmp[idx]: id(op_node) -> idx
16
+
17
+
18
+ def analyze_circuit(
19
+ var_nodes: Sequence[VarNode],
20
+ result_nodes: Sequence[CircuitNode],
21
+ ) -> CircuitAnalysis:
22
+ """
23
+ Analyzes a circuit as a function from var_nodes to result_nodes,
24
+ returning a CircuitAnalysis object.
25
+ """
26
+ # What op nodes are in use
27
+ op_nodes: List[OpNode] = (
28
+ [] if len(result_nodes) == 0
29
+ else result_nodes[0].circuit.reachable_op_nodes(*result_nodes)
30
+ )
31
+
32
+ # What constant values are in use
33
+ seen_const_nodes: Set[int] = set()
34
+ const_nodes: List[ConstNode] = []
35
+
36
+ def _register_const(_node: ConstNode) -> None:
37
+ nonlocal seen_const_nodes
38
+ nonlocal const_nodes
39
+ _node_id: int = id(_node)
40
+ if _node_id not in seen_const_nodes:
41
+ const_nodes.append(_node)
42
+ seen_const_nodes.add(_node_id)
43
+
44
+ # Register all the used constants
45
+ for op_node in op_nodes:
46
+ for node in op_node.args:
47
+ if isinstance(node, ConstNode):
48
+ _register_const(node)
49
+ for node in result_nodes:
50
+ if isinstance(node, ConstNode):
51
+ _register_const(node)
52
+ for node in var_nodes:
53
+ if node.is_const():
54
+ _register_const(node.const)
55
+
56
+ # What op nodes are in the result.
57
+ # Dict op_to_result maps id(OpNode) to result index.
58
+ op_to_result: Dict[int, int] = {
59
+ id(node): i
60
+ for i, node in enumerate(result_nodes)
61
+ if isinstance(node, OpNode)
62
+ }
63
+
64
+ # Assign all other op nodes to a tmp slot.
65
+ # Dict op_to_tmp maps id(OpNode) to tmp index.
66
+ tmp_idx = count()
67
+ op_to_tmp: Dict[int, int] = {
68
+ id(op_node): next(tmp_idx)
69
+ for op_node in op_nodes
70
+ if id(op_node) not in op_to_result
71
+ }
72
+ del tmp_idx
73
+
74
+ return CircuitAnalysis(
75
+ var_nodes=var_nodes,
76
+ result_nodes=result_nodes,
77
+ op_nodes=op_nodes,
78
+ const_nodes=const_nodes,
79
+ op_to_result=op_to_result,
80
+ op_to_tmp=op_to_tmp,
81
+ )