compiled-knowledge 4.0.0__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (182) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37513 -0
  4. ck/circuit/_circuit_cy.cp313-win32.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +75 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +27 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19833 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win32.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +128 -0
  16. ck/circuit_compiler/interpret_compiler.py +255 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +552 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win32.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +251 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +70 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +56 -0
  58. ck/example/truss.py +51 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +482 -0
  63. ck/in_out/parser_utils.py +189 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3482 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +236 -0
  73. ck/pgm_circuit/pgm_circuit.py +88 -0
  74. ck/pgm_circuit/program_with_slotmap.py +217 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +78 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +60 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win32.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +572 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +52 -0
  100. ck/probability/pgm_probability_space.py +36 -0
  101. ck/probability/probability_space.py +627 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +106 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +234 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +164 -0
  118. ck/utils/local_config.py +278 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/ace/simple_ace_demo.py +18 -0
  129. ck_demos/all_demos.py +88 -0
  130. ck_demos/circuit/__init__.py +0 -0
  131. ck_demos/circuit/demo_circuit_dump.py +22 -0
  132. ck_demos/circuit/demo_derivatives.py +43 -0
  133. ck_demos/circuit_compiler/__init__.py +0 -0
  134. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  135. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  136. ck_demos/getting_started/__init__.py +0 -0
  137. ck_demos/getting_started/simple_demo.py +18 -0
  138. ck_demos/pgm/__init__.py +0 -0
  139. ck_demos/pgm/demo_pgm_dump.py +18 -0
  140. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  141. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  142. ck_demos/pgm/show_examples.py +25 -0
  143. ck_demos/pgm_compiler/__init__.py +0 -0
  144. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  145. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  146. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  147. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  148. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  149. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  150. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  151. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  152. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  153. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  154. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  155. ck_demos/pgm_inference/__init__.py +0 -0
  156. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  157. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  158. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  159. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  160. ck_demos/programs/__init__.py +0 -0
  161. ck_demos/programs/demo_program_buffer.py +24 -0
  162. ck_demos/programs/demo_program_multi.py +24 -0
  163. ck_demos/programs/demo_program_none.py +19 -0
  164. ck_demos/programs/demo_program_single.py +23 -0
  165. ck_demos/programs/demo_raw_program_dump.py +17 -0
  166. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  167. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  168. ck_demos/sampling/__init__.py +0 -0
  169. ck_demos/sampling/check_sampler.py +71 -0
  170. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  171. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  172. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  173. ck_demos/utils/__init__.py +0 -0
  174. ck_demos/utils/compare.py +120 -0
  175. ck_demos/utils/convert_network.py +45 -0
  176. ck_demos/utils/sample_model.py +216 -0
  177. ck_demos/utils/stop_watch.py +384 -0
  178. compiled_knowledge-4.0.0.dist-info/METADATA +50 -0
  179. compiled_knowledge-4.0.0.dist-info/RECORD +182 -0
  180. compiled_knowledge-4.0.0.dist-info/WHEEL +5 -0
  181. compiled_knowledge-4.0.0.dist-info/licenses/LICENSE.txt +21 -0
  182. compiled_knowledge-4.0.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,552 @@
1
+ from __future__ import annotations
2
+
3
+ import ctypes as ct
4
+ from dataclasses import dataclass
5
+ from typing import Sequence, Optional, Tuple, List, Dict
6
+
7
+ import llvmlite.binding as llvm
8
+ import llvmlite.ir as ir
9
+ import numpy as np
10
+
11
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
13
+ from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram
14
+ from ..circuit import ADD as _ADD, MUL as _MUL, ConstValue
15
+ from ..circuit import Circuit, VarNode, CircuitNode, OpNode
16
+ from ..program.raw_program import RawProgramFunction
17
+
18
+ DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
19
+
20
+ # Byte code operations
21
+ # _ADD: int = circuit.ADD
22
+ # _MUL: int = circuit.MUL
23
+ _END: int = max(_ADD, _MUL) + 1
24
+
25
+ # arrays
26
+ _VARS: int = 0
27
+ _TMPS: int = 1
28
+ _RESULT: int = 2
29
+ _CONSTS: int = 3
30
+
31
+ _SET_CONSTS_FUNCTION_NAME: str = 'set_consts'
32
+ _SET_INSTRUCTIONS_FUNCTION_NAME: str = 'set_instructions'
33
+
34
+
35
+ def compile_circuit(
36
+ *result: CircuitNode,
37
+ input_vars: InputVars = InferVars.ALL,
38
+ circuit: Optional[Circuit] = None,
39
+ data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
40
+ keep_llvm_program: bool = True,
41
+ compile_arrays: bool = False,
42
+ opt: int = 2,
43
+ ) -> LLVMRawProgram:
44
+ """
45
+ Compile the given circuit using LLVM.
46
+
47
+ This creates an LLVM program where each circuit op node is converted to
48
+ one or more LLVM binary op machine code instructions. For large circuits
49
+ this results in a large LLVM program which can be slow to compile.
50
+
51
+ This compiler produces a RawProgram that _does_ use client managed working memory.
52
+
53
+ Conforms to the CircuitCompiler protocol.
54
+
55
+ Args:
56
+ *result: result nodes nominating the results of the returned program.
57
+ input_vars: How to determine the input variables.
58
+ circuit: optionally explicitly specify the Circuit.
59
+ data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
60
+ keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
61
+ compile_arrays: if true, the global array values are included in the LLVM program.
62
+ opt:The optimization level to use by LLVM MC JIT.
63
+
64
+ Returns:
65
+ a raw program.
66
+
67
+ Raises:
68
+ ValueError: if the circuit is unknown, but it is needed.
69
+ ValueError: if not all nodes are from the same circuit.
70
+ ValueError: if the program data type could not be interpreted.
71
+ """
72
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
73
+ var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
74
+
75
+ # Get the type info
76
+ type_info: TypeInfo
77
+ if isinstance(data_type, DataType):
78
+ type_info = data_type.value
79
+ elif isinstance(data_type, TypeInfo):
80
+ type_info = data_type
81
+ else:
82
+ raise ValueError(f'could not interpret program data type: {data_type!r}')
83
+
84
+ # Compile the circuit to an LLVM module representing a RawProgramFunction
85
+ llvm_program: str
86
+ number_of_tmps: int
87
+ llvm_program, number_of_tmps, consts, byte_code = _make_llvm_program(in_vars, result, type_info, compile_arrays)
88
+
89
+ # Compile the LLVM program to a native executable
90
+ engine: llvm.ExecutionEngine
91
+ function: RawProgramFunction
92
+ engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
93
+
94
+ if compile_arrays:
95
+ return LLVMRawProgram(
96
+ function=function,
97
+ dtype=type_info.dtype,
98
+ number_of_vars=len(var_indices),
99
+ number_of_tmps=number_of_tmps,
100
+ number_of_results=len(result),
101
+ var_indices=var_indices,
102
+ llvm_program=llvm_program if keep_llvm_program else None,
103
+ engine=engine,
104
+ opt=opt,
105
+ )
106
+ else:
107
+ # Arrays `consts` and `byte_code` are not compiled into the LLVM program
108
+ # so they need to be stored explicitly.
109
+ return LLVMRawProgramWithArrays(
110
+ function=function,
111
+ dtype=type_info.dtype,
112
+ number_of_vars=len(var_indices),
113
+ number_of_tmps=number_of_tmps,
114
+ number_of_results=len(result),
115
+ var_indices=var_indices,
116
+ llvm_program=llvm_program if keep_llvm_program else None,
117
+ engine=engine,
118
+ opt=opt,
119
+ instructions=np.array(byte_code, dtype=np.uint8),
120
+ consts=np.array(consts, dtype=type_info.dtype),
121
+ )
122
+
123
+
124
+ @dataclass
125
+ class LLVMRawProgramWithArrays(LLVMRawProgram):
126
+ instructions: np.ndarray
127
+ consts: np.ndarray
128
+
129
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
130
+ super().dump(prefix=prefix, indent=indent)
131
+ print(f'{prefix}LLVM byte code size = {len(self.instructions)}')
132
+ if show_instructions:
133
+ self.dump_llvm_program(prefix=prefix, indent=indent)
134
+
135
+ def __post_init__(self):
136
+ self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
137
+ self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
138
+
139
+ def __getstate__(self):
140
+ state = super().__getstate__()
141
+ state['instructions'] = self.instructions
142
+ state['consts'] = self.consts
143
+ return state
144
+
145
+ def __setstate__(self, state):
146
+ super().__setstate__(state)
147
+ self.instructions = state['instructions']
148
+ self.consts = state['consts']
149
+ self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
150
+ self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
151
+
152
+ def _set_globals(self, data: np.ndarray, func_name: str) -> None:
153
+ ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(data.dtype))
154
+ c_np_data = data.ctypes.data_as(ptr_type)
155
+
156
+ function_ptr = self.engine.get_function_address(func_name)
157
+ function = ct.CFUNCTYPE(None, ptr_type)(function_ptr)
158
+ function(c_np_data)
159
+
160
+
161
+ def _make_llvm_program(
162
+ in_vars: Sequence[VarNode],
163
+ result: Sequence[CircuitNode],
164
+ type_info: TypeInfo,
165
+ compile_arrays: bool,
166
+ ) -> Tuple[str, int, List[ConstValue], List[int]]:
167
+ """
168
+ Construct the LLVM program (i.e., LLVM module).
169
+
170
+ Returns:
171
+ (llvm_program, number_of_tmps, consts, byte_code)
172
+ """
173
+ llvm_function = IRFunction(type_info)
174
+
175
+ builder = llvm_function.builder
176
+ type_info = llvm_function.type_info
177
+ module = llvm_function.module
178
+
179
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
180
+ const_values: List[ConstValue] = [const_node.value for const_node in analysis.const_nodes]
181
+
182
+ max_index_size: int = max(
183
+ len(analysis.var_nodes), # number of inputs
184
+ len(analysis.result_nodes), # number of outputs
185
+ len(analysis.op_to_tmp), # number of tmps
186
+ len(analysis.const_nodes), # number of constants
187
+ )
188
+ data_idx_bytes: int = _get_bytes_needed(max_index_size)
189
+
190
+ max_num_args: int = max((len(op_node.args) for op_node in analysis.op_nodes), default=0)
191
+ num_args_bytes: int = _get_bytes_needed(max_num_args)
192
+
193
+ data_type: ir.Type = type_info.llvm_type
194
+ byte_type: ir.Type = ir.IntType(8)
195
+ data_idx_type: ir.Type = ir.IntType(data_idx_bytes * 8)
196
+
197
+ byte_code: List[int] = _make_byte_code(analysis, data_idx_bytes, num_args_bytes)
198
+
199
+ inst_idx_bytes: int = _get_bytes_needed(len(byte_code))
200
+ inst_idx_type: ir.Type = ir.IntType(inst_idx_bytes * 8)
201
+
202
+ if compile_arrays:
203
+ # Put constants into the LLVM module
204
+ consts_array_type = ir.ArrayType(data_type, len(analysis.const_nodes))
205
+ consts_global = ir.GlobalVariable(module, consts_array_type, name='consts')
206
+ consts_global.global_constant = True
207
+ consts_global.initializer = ir.Constant(consts_array_type, const_values)
208
+ data_idx_0 = ir.Constant(data_idx_type, 0)
209
+ consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
210
+
211
+ # Put bytecode into the LLVM module
212
+ instructions_array_type = ir.ArrayType(byte_type, len(byte_code))
213
+ instructions_global = ir.GlobalVariable(module, instructions_array_type, name='instructions')
214
+ instructions_global.global_constant = True
215
+ instructions_global.initializer = ir.Constant(instructions_array_type, byte_code)
216
+ inst_idx_0 = ir.Constant(inst_idx_type, 0)
217
+ instructions: ir.Value = builder.gep(instructions_global, [inst_idx_0, inst_idx_0])
218
+ else:
219
+ # Just create two global variables that will be set externally.
220
+ const_ptr_type = data_type.as_pointer()
221
+ consts_global = ir.GlobalVariable(module, const_ptr_type, name='consts')
222
+ consts_global.initializer = ir.Constant(const_ptr_type, None)
223
+ consts: ir.Value = builder.load(consts_global)
224
+
225
+ instructions_ptr_type = byte_type.as_pointer()
226
+ instructions_global = ir.GlobalVariable(module, instructions_ptr_type, name='instructions')
227
+ instructions_global.initializer = ir.Constant(instructions_ptr_type, None)
228
+ instructions: ir.Value = builder.load(instructions_global)
229
+
230
+ interp = _InterpBuilder(builder, type_info, inst_idx_type, data_idx_bytes, num_args_bytes, consts, instructions)
231
+ interp.make_interpreter()
232
+
233
+ if not compile_arrays:
234
+ # add functions to set global arrays
235
+ interp.make_set_consts_function(consts_global)
236
+ interp.make_set_instructions_function(instructions_global)
237
+
238
+ # print(llvm_function.llvm_program())
239
+ # exit(99)
240
+
241
+ return llvm_function.llvm_program(), len(analysis.op_to_tmp), const_values, byte_code
242
+
243
+
244
+ class _InterpBuilder:
245
+ """
246
+ Helper to write the LLVM function for the byte code interpreter.
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ builder: ir.IRBuilder,
252
+ type_info: TypeInfo,
253
+ inst_idx_type: ir.Type,
254
+ index_bytes: int,
255
+ num_args_bytes: int,
256
+ consts: ir.Value,
257
+ instructions: ir.Value,
258
+ ):
259
+ self.builder: ir.IRBuilder = builder
260
+ self.index_bytes: int = index_bytes
261
+ self.num_args_bytes: int = num_args_bytes
262
+ self.type_info: TypeInfo = type_info
263
+
264
+ self.data_type: ir.Type = type_info.llvm_type
265
+ self.byte_type: ir.Type = ir.IntType(8)
266
+ self.inst_idx_type: ir.Type = inst_idx_type
267
+ self.data_idx_type: ir.Type = ir.IntType(index_bytes * 8)
268
+ self.num_args_type: ir.Type = ir.IntType(num_args_bytes * 8)
269
+
270
+ self.data_idx_0 = ir.Constant(self.data_idx_type, 0)
271
+ self.data_idx_1 = ir.Constant(self.data_idx_type, 1)
272
+ self.inst_idx_0 = ir.Constant(self.inst_idx_type, 0)
273
+ self.inst_idx_1 = ir.Constant(self.inst_idx_type, 1)
274
+ self.num_args_0 = ir.Constant(self.num_args_type, 0)
275
+ self.num_args_1 = ir.Constant(self.num_args_type, 1)
276
+
277
+ self.consts: ir.Value = consts
278
+ self.instructions: ir.Value = instructions
279
+
280
+ # allocate locals
281
+ self.local_idx = builder.alloca(self.inst_idx_type, name='idx')
282
+ self.local_num_args = builder.alloca(self.num_args_type, name='num_args')
283
+ self.local_accumulator = builder.alloca(self.data_type, name='accumulator')
284
+ self.local_arrays = builder.alloca(self.data_type.as_pointer(), size=4, name='arrays')
285
+
286
+ # local_arrays = [vars, tmps, result, consts]
287
+ ir_vars_idx = ir.Constant(self.byte_type, _VARS)
288
+ ir_tmps_idx = ir.Constant(self.byte_type, _TMPS)
289
+ ir_result_idx = ir.Constant(self.byte_type, _RESULT)
290
+ ir_consts_idx = ir.Constant(self.byte_type, _CONSTS)
291
+ function: ir.Function = builder.function
292
+ local_arrays = self.local_arrays
293
+ builder.store(function.args[0], builder.gep(local_arrays, [ir_vars_idx]))
294
+ builder.store(function.args[1], builder.gep(local_arrays, [ir_tmps_idx]))
295
+ builder.store(function.args[2], builder.gep(local_arrays, [ir_result_idx]))
296
+ builder.store(consts, builder.gep(local_arrays, [ir_consts_idx]))
297
+
298
+ # local_idx = 0
299
+ builder.store(self.inst_idx_0, self.local_idx)
300
+
301
+ def make_set_consts_function(self, consts_ptr: ir.GlobalVariable):
302
+ builder = self.builder
303
+ module = builder.module
304
+ function_type = ir.FunctionType(ir.VoidType(), (self.data_type.as_pointer(),))
305
+ function = ir.Function(module, function_type, name=_SET_CONSTS_FUNCTION_NAME)
306
+ bb_entry = function.append_basic_block('entry')
307
+ builder.position_at_end(bb_entry)
308
+ arg = function.args[0]
309
+ builder.store(arg, consts_ptr)
310
+ builder.ret_void()
311
+
312
+ def make_set_instructions_function(self, instructions_ptr: ir.GlobalVariable):
313
+ builder = self.builder
314
+ module = builder.module
315
+ function_type = ir.FunctionType(ir.VoidType(), (self.byte_type.as_pointer(),))
316
+ function = ir.Function(module, function_type, name=_SET_INSTRUCTIONS_FUNCTION_NAME)
317
+ bb_entry = function.append_basic_block('entry')
318
+ builder.position_at_end(bb_entry)
319
+ arg = function.args[0]
320
+ builder.store(arg, instructions_ptr)
321
+ builder.ret_void()
322
+
323
+ def add(self, x: ir.Value, y: ir.Value) -> ir.Value:
324
+ return self.type_info.add(self.builder, x, y)
325
+
326
+ def mul(self, x: ir.Value, y: ir.Value) -> ir.Value:
327
+ return self.type_info.mul(self.builder, x, y)
328
+
329
+ def make_interpreter(self):
330
+ """
331
+ Write the bytecode interpreter
332
+ """
333
+ builder: ir.IRBuilder = self.builder
334
+ function: ir.Function = builder.function
335
+
336
+ bb_while = function.append_basic_block('while')
337
+ bb_body = function.append_basic_block('body')
338
+ bb_mul = function.append_basic_block('mul')
339
+ bb_mul_op = function.append_basic_block('mul_op')
340
+ bb_add = function.append_basic_block('add')
341
+ bb_add_op = function.append_basic_block('add_op')
342
+ bb_op_continue = function.append_basic_block('op_continue')
343
+ bb_finish = function.append_basic_block('finish')
344
+
345
+ # block: entry
346
+ # (locals already set up in the constructor)
347
+ builder.branch(bb_while)
348
+
349
+ # block: while
350
+ builder.position_at_end(bb_while)
351
+ # load current instruction
352
+ idx = builder.load(self.local_idx)
353
+ inst = builder.load(builder.gep(self.instructions, [idx]))
354
+ idx = builder.add(idx, self.inst_idx_1)
355
+ #
356
+ cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _END))
357
+ builder.cbranch(cmp_end, bb_finish, bb_body)
358
+
359
+ # block: body
360
+ builder.position_at_end(bb_body)
361
+ # load number of args
362
+ idx, num_args = self._read_number(idx, self.num_args_bytes)
363
+ builder.store(num_args, self.local_num_args)
364
+ # load first arg value into the accumulator
365
+ idx, arg0 = self._load_value(idx)
366
+ builder.store(arg0, self.local_accumulator)
367
+ # save the current bytecode index
368
+ builder.store(idx, self.local_idx)
369
+ #
370
+ cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _MUL))
371
+ builder.cbranch(cmp_end, bb_mul, bb_add)
372
+
373
+ # block: mul
374
+ builder.position_at_end(bb_mul)
375
+ num_args = builder.load(self.local_num_args)
376
+ num_args = builder.sub(num_args, self.num_args_1)
377
+ builder.store(num_args, self.local_num_args)
378
+ more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
379
+ builder.cbranch(more_args, bb_mul_op, bb_op_continue)
380
+
381
+ # block: mul_op
382
+ builder.position_at_end(bb_mul_op)
383
+ idx = builder.load(self.local_idx)
384
+ idx, value = self._load_value(idx)
385
+ acc = builder.load(self.local_accumulator)
386
+ acc = self.mul(acc, value)
387
+ builder.store(acc, self.local_accumulator)
388
+ builder.store(idx, self.local_idx)
389
+ builder.branch(bb_mul)
390
+
391
+ # block: add
392
+ builder.position_at_end(bb_add)
393
+ num_args = builder.load(self.local_num_args)
394
+ num_args = builder.sub(num_args, self.num_args_1)
395
+ builder.store(num_args, self.local_num_args)
396
+ more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
397
+ builder.cbranch(more_args, bb_add_op, bb_op_continue)
398
+
399
+ # block: add_op
400
+ builder.position_at_end(bb_add_op)
401
+ idx = builder.load(self.local_idx)
402
+ idx, value = self._load_value(idx)
403
+ acc = builder.load(self.local_accumulator)
404
+ acc = self.add(acc, value)
405
+ builder.store(acc, self.local_accumulator)
406
+ builder.store(idx, self.local_idx)
407
+ builder.branch(bb_add)
408
+
409
+ # block: op_continue
410
+ builder.position_at_end(bb_op_continue)
411
+ # get where we store the result
412
+ idx = builder.load(self.local_idx)
413
+ idx, ptr = self._load_value_ptr(idx)
414
+ builder.store(idx, self.local_idx)
415
+ # get and store the result
416
+ acc = builder.load(self.local_accumulator)
417
+ builder.store(acc, ptr)
418
+ builder.branch(bb_while)
419
+
420
+ # block: finish
421
+ builder.position_at_end(bb_finish)
422
+ builder.ret_void()
423
+
424
+ def _read_number(self, idx: ir.Value, num_bytes: int) -> Tuple[ir.Value, ir.Value]:
425
+ """
426
+
427
+ Args:
428
+ idx: current instruction index
429
+ num_bytes: how many bytes to read from the instruction stream to form the number
430
+
431
+ Returns:
432
+ (idx, number)
433
+ idx: is the updated instruction index
434
+ number: is the read number
435
+ """
436
+ builder = self.builder
437
+
438
+ llvm_type: ir.Type = ir.IntType(num_bytes * 8)
439
+
440
+ number: ir.Value = builder.load(builder.gep(self.instructions, [idx]))
441
+ idx = builder.add(idx, self.inst_idx_1)
442
+
443
+ if num_bytes > 1:
444
+ eight = ir.Constant(llvm_type, 8)
445
+ number = builder.zext(number, llvm_type)
446
+ for _ in range(num_bytes - 1):
447
+ next_byte = builder.load(builder.gep(self.instructions, [idx]))
448
+ number = builder.add(builder.shl(number, eight), builder.zext(next_byte, llvm_type))
449
+ idx = builder.add(idx, self.inst_idx_1)
450
+
451
+ return idx, number
452
+
453
+ def _load_value_ptr(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
454
+ builder = self.builder
455
+
456
+ # load array first index
457
+ index_0 = builder.load(builder.gep(self.instructions, [idx]))
458
+ idx = builder.add(idx, self.inst_idx_1)
459
+
460
+ # load array second index
461
+ idx, index_1 = self._read_number(idx, self.index_bytes)
462
+
463
+ # get the pointer
464
+ array = builder.load(builder.gep(self.local_arrays, [index_0]))
465
+ ptr = builder.gep(array, [index_1])
466
+
467
+ return idx, ptr
468
+
469
+ def _load_value(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
470
+ idx, ptr = self._load_value_ptr(idx)
471
+ value = self.builder.load(ptr)
472
+ return idx, value
473
+
474
+
475
+ @dataclass
476
+ class _ElementID:
477
+ """
478
+ A 2D index into the function's `arrays`.
479
+ """
480
+ array: int # which array: VARS, TMPS, CONSTS, RESULT
481
+ index: int # index into the array
482
+
483
+
484
+ def _make_byte_code(analysis: CircuitAnalysis, data_idx_bytes: int, num_args_bytes: int) -> List[int]:
485
+ # Index input value elements for each possible input node.
486
+ node_to_element: Dict[int, _ElementID] = {}
487
+ # const nodes
488
+ for i, node in enumerate(analysis.const_nodes):
489
+ node_to_element[id(node)] = _ElementID(_CONSTS, i)
490
+ # var nodes
491
+ for i, var_node in enumerate(analysis.var_nodes):
492
+ if var_node.is_const():
493
+ node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
494
+ else:
495
+ node_to_element[id(var_node)] = _ElementID(_VARS, i)
496
+ # op nodes
497
+ for node_id, tmp_idx in analysis.op_to_tmp.items():
498
+ node_to_element[node_id] = _ElementID(_TMPS, tmp_idx)
499
+ for node_id, result_idx in analysis.op_to_result.items():
500
+ node_to_element[node_id] = _ElementID(_RESULT, result_idx)
501
+
502
+ # Make byte code
503
+ byte_code: List[int] = []
504
+ for op_node in analysis.op_nodes:
505
+ # write the op code
506
+ byte_code.append(op_node.symbol) # _ADD or _MUL
507
+ # write the number of args
508
+ byte_code.extend(_to_bytes(len(op_node.args), num_args_bytes))
509
+ # write the element id for each arg
510
+ for arg_node in op_node.args:
511
+ element_id: _ElementID = node_to_element[id(arg_node)]
512
+ byte_code.append(element_id.array)
513
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
514
+ # write the element id for the result
515
+ element_id: _ElementID = node_to_element[id(op_node)]
516
+ byte_code.append(element_id.array)
517
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
518
+ # ...any final copy instructions
519
+ for idx, node in enumerate(analysis.result_nodes):
520
+ if not isinstance(node, OpNode):
521
+ byte_code.append(_ADD)
522
+ byte_code.extend(_to_bytes(1, num_args_bytes))
523
+
524
+ element_id: _ElementID = node_to_element[id(node)]
525
+ byte_code.append(element_id.array)
526
+ byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
527
+
528
+ byte_code.append(_RESULT)
529
+ byte_code.extend(_to_bytes(idx, data_idx_bytes))
530
+
531
+ # write the sentinel - 'end' op code
532
+ byte_code.append(_END)
533
+
534
+ return byte_code
535
+
536
+
537
+ def _to_bytes(value: int, num_bytes: int) -> List[int]:
538
+ buffer: List[int] = []
539
+ for _ in range(num_bytes):
540
+ buffer.append(value % 256)
541
+ value //= 256
542
+ assert value == 0
543
+ buffer.reverse()
544
+ return buffer
545
+
546
+
547
+ def _get_bytes_needed(size: int) -> int:
548
+ index_bytes: int
549
+ for index_bytes in [1, 2, 4, 8]:
550
+ if size < 2 ** (index_bytes * 8 - 1):
551
+ return index_bytes
552
+ raise ValueError(f'size are too large to represent: {size}')
@@ -0,0 +1,57 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Optional
4
+
5
+ from .llvm_compiler import Flavour
6
+ from ..circuit import CircuitNode, Circuit
7
+ from ..circuit_compiler import interpret_compiler, cython_vm_compiler, llvm_compiler, llvm_vm_compiler, CircuitCompiler
8
+ from ..circuit_compiler.support.input_vars import InputVars, InferVars
9
+ from ..program import RawProgram
10
+
11
+
12
+ class NamedCircuitCompiler(Enum):
13
+ """
14
+ A standard collection of named circuit compiler functions.
15
+
16
+ The `value` of each enum member is tuple containing a compiler function.
17
+ Wrapping in a tuple is needed otherwise Python erases the type of the member, which can cause problems.
18
+ Each member itself is callable, conforming to the CircuitCompiler protocol, delegating to the compiler function.
19
+ """
20
+
21
+ LLVM_STACK = (partial(llvm_compiler.compile_circuit, flavour=Flavour.STACK),)
22
+ LLVM_TMPS = (partial(llvm_compiler.compile_circuit, flavour=Flavour.TMPS, opt=0),)
23
+ LLVM_VM = (llvm_vm_compiler.compile_circuit,)
24
+ CYTHON_VM = (cython_vm_compiler.compile_circuit,)
25
+ INTERPRET = (interpret_compiler.compile_circuit,)
26
+
27
+ # The following circuit compilers were experimental but are not really useful.
28
+ #
29
+ # Slow compile and execution:
30
+ # LLVM_FUNCS = (partial(llvm_compiler.compile_circuit, flavour=Flavour.FUNCS, opt=0),)
31
+ #
32
+ # Slow compile and same execution as LLVM_VM:
33
+ # LLVM_VM_COMPILED_ARRAYS = (partial(llvm_vm_compiler.compile_circuit, compile_arrays=True),)
34
+
35
+ def __call__(
36
+ self,
37
+ *result: CircuitNode,
38
+ input_vars: InputVars = InferVars.ALL,
39
+ circuit: Optional[Circuit] = None,
40
+ ) -> RawProgram:
41
+ """
42
+ Each member of the enum is a CircuitCompiler function.
43
+
44
+ This implements the `CircuitCompiler` protocol for each member of the enum.
45
+ """
46
+ return self.compiler(*result, input_vars=input_vars, circuit=circuit)
47
+
48
+ @property
49
+ def compiler(self) -> CircuitCompiler:
50
+ """
51
+ Returns:
52
+ The compiler function, conforming to the CircuitCompiler protocol.
53
+ """
54
+ return self.value[0]
55
+
56
+
57
+ DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.CYTHON_VM
File without changes
@@ -0,0 +1,13 @@
1
+ # There are two implementations of the `circuit_analyser` module are provided
2
+ # for developer R&D purposes. One is pure Python and the other is Cython.
3
+ # Which implementation is used can be selected here.
4
+ #
5
+ # A similar selection can be made for the `circuit` module.
6
+ # Note that if the Cython implementation is chosen for `circuit_analyser` then
7
+ # the Cython implementation must be chosen for `circuit`.
8
+
9
+ # from ._circuit_analyser_py import (
10
+ from ._circuit_analyser_cy import (
11
+ CircuitAnalysis,
12
+ analyze_circuit,
13
+ )