compiled-knowledge 4.0.0a20__cp313-cp313-musllinux_1_2_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37520 -0
- ck/circuit/_circuit_cy.cpython-313-x86_64-linux-musl.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19821 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-x86_64-linux-musl.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10615 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-x86_64-linux-musl.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16393 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-x86_64-linux-musl.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Sequence, Optional, Tuple, List, Dict
|
|
5
|
+
|
|
6
|
+
import llvmlite.binding as llvm
|
|
7
|
+
import llvmlite.ir as ir
|
|
8
|
+
import numpy as np
|
|
9
|
+
import ctypes as ct
|
|
10
|
+
|
|
11
|
+
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
12
|
+
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
13
|
+
from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram
|
|
14
|
+
from ..circuit import ADD as _ADD, MUL as _MUL, ConstValue
|
|
15
|
+
from ..circuit import Circuit, VarNode, CircuitNode, OpNode
|
|
16
|
+
from ..program.raw_program import RawProgramFunction
|
|
17
|
+
|
|
18
|
+
DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
|
|
19
|
+
|
|
20
|
+
# Byte code operations
|
|
21
|
+
# _ADD: int = circuit.ADD
|
|
22
|
+
# _MUL: int = circuit.MUL
|
|
23
|
+
_END: int = max(_ADD, _MUL) + 1
|
|
24
|
+
|
|
25
|
+
# arrays
|
|
26
|
+
_VARS: int = 0
|
|
27
|
+
_TMPS: int = 1
|
|
28
|
+
_RESULT: int = 2
|
|
29
|
+
_CONSTS: int = 3
|
|
30
|
+
|
|
31
|
+
_SET_CONSTS_FUNCTION_NAME: str = 'set_consts'
|
|
32
|
+
_SET_INSTRUCTIONS_FUNCTION_NAME: str = 'set_instructions'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def compile_circuit(
|
|
36
|
+
*result: CircuitNode,
|
|
37
|
+
input_vars: InputVars = InferVars.ALL,
|
|
38
|
+
circuit: Optional[Circuit] = None,
|
|
39
|
+
data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
|
|
40
|
+
keep_llvm_program: bool = True,
|
|
41
|
+
compile_arrays: bool = False,
|
|
42
|
+
opt: int = 2,
|
|
43
|
+
) -> LLVMRawProgram:
|
|
44
|
+
"""
|
|
45
|
+
Compile the given circuit using LLVM.
|
|
46
|
+
|
|
47
|
+
This creates an LLVM program where each circuit op node is converted to
|
|
48
|
+
one or more LLVM binary op machine code instructions. For large circuits
|
|
49
|
+
this results in a large LLVM program which can be slow to compile.
|
|
50
|
+
|
|
51
|
+
This compiler produces a RawProgram that _does_ use client managed working memory.
|
|
52
|
+
|
|
53
|
+
Conforms to the CircuitCompiler protocol.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
*result: result nodes nominating the results of the returned program.
|
|
57
|
+
input_vars: How to determine the input variables.
|
|
58
|
+
circuit: optionally explicitly specify the Circuit.
|
|
59
|
+
data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
|
|
60
|
+
keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
|
|
61
|
+
compile_arrays: if true, the global array values are included in the LLVM program.
|
|
62
|
+
opt:The optimization level to use by LLVM MC JIT.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
a raw program.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: if the circuit is unknown, but it is needed.
|
|
69
|
+
ValueError: if not all nodes are from the same circuit.
|
|
70
|
+
ValueError: if the program data type could not be interpreted.
|
|
71
|
+
"""
|
|
72
|
+
in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
|
|
73
|
+
var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
|
|
74
|
+
|
|
75
|
+
# Get the type info
|
|
76
|
+
type_info: TypeInfo
|
|
77
|
+
if isinstance(data_type, DataType):
|
|
78
|
+
type_info = data_type.value
|
|
79
|
+
elif isinstance(data_type, TypeInfo):
|
|
80
|
+
type_info = data_type
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f'could not interpret program data type: {data_type!r}')
|
|
83
|
+
|
|
84
|
+
# Compile the circuit to an LLVM module representing a RawProgramFunction
|
|
85
|
+
llvm_program: str
|
|
86
|
+
number_of_tmps: int
|
|
87
|
+
llvm_program, number_of_tmps, consts, byte_code = _make_llvm_program(in_vars, result, type_info, compile_arrays)
|
|
88
|
+
|
|
89
|
+
# Compile the LLVM program to a native executable
|
|
90
|
+
engine: llvm.ExecutionEngine
|
|
91
|
+
function: RawProgramFunction
|
|
92
|
+
engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
|
|
93
|
+
|
|
94
|
+
if compile_arrays:
|
|
95
|
+
return LLVMRawProgram(
|
|
96
|
+
function=function,
|
|
97
|
+
dtype=type_info.dtype,
|
|
98
|
+
number_of_vars=len(var_indices),
|
|
99
|
+
number_of_tmps=number_of_tmps,
|
|
100
|
+
number_of_results=len(result),
|
|
101
|
+
var_indices=var_indices,
|
|
102
|
+
llvm_program=llvm_program if keep_llvm_program else None,
|
|
103
|
+
engine=engine,
|
|
104
|
+
opt=opt,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
# Arrays `consts` and `byte_code` are not compiled into the LLVM program
|
|
108
|
+
# so they need to be stored explicitly.
|
|
109
|
+
return LLVMRawProgramWithArrays(
|
|
110
|
+
function=function,
|
|
111
|
+
dtype=type_info.dtype,
|
|
112
|
+
number_of_vars=len(var_indices),
|
|
113
|
+
number_of_tmps=number_of_tmps,
|
|
114
|
+
number_of_results=len(result),
|
|
115
|
+
var_indices=var_indices,
|
|
116
|
+
llvm_program=llvm_program if keep_llvm_program else None,
|
|
117
|
+
engine=engine,
|
|
118
|
+
opt=opt,
|
|
119
|
+
instructions=np.array(byte_code, dtype=np.uint8),
|
|
120
|
+
consts=np.array(consts, dtype=type_info.dtype),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class LLVMRawProgramWithArrays(LLVMRawProgram):
|
|
126
|
+
instructions: np.ndarray
|
|
127
|
+
consts: np.ndarray
|
|
128
|
+
|
|
129
|
+
def __post_init__(self):
|
|
130
|
+
self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
|
|
131
|
+
self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
|
|
132
|
+
|
|
133
|
+
def __getstate__(self):
|
|
134
|
+
state = super().__getstate__()
|
|
135
|
+
state['instructions'] = self.instructions
|
|
136
|
+
state['consts'] = self.consts
|
|
137
|
+
return state
|
|
138
|
+
|
|
139
|
+
def __setstate__(self, state):
|
|
140
|
+
super().__setstate__(state)
|
|
141
|
+
self.instructions = state['instructions']
|
|
142
|
+
self.consts = state['consts']
|
|
143
|
+
self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
|
|
144
|
+
self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
|
|
145
|
+
|
|
146
|
+
def _set_globals(self, data: np.ndarray, func_name: str) -> None:
|
|
147
|
+
ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(data.dtype))
|
|
148
|
+
c_np_data = data.ctypes.data_as(ptr_type)
|
|
149
|
+
|
|
150
|
+
function_ptr = self.engine.get_function_address(func_name)
|
|
151
|
+
function = ct.CFUNCTYPE(None, ptr_type)(function_ptr)
|
|
152
|
+
function(c_np_data)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _make_llvm_program(
|
|
156
|
+
in_vars: Sequence[VarNode],
|
|
157
|
+
result: Sequence[CircuitNode],
|
|
158
|
+
type_info: TypeInfo,
|
|
159
|
+
compile_arrays: bool,
|
|
160
|
+
) -> Tuple[str, int, List[ConstValue], List[int]]:
|
|
161
|
+
"""
|
|
162
|
+
Construct the LLVM program (i.e., LLVM module).
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
(llvm_program, number_of_tmps, consts, byte_code)
|
|
166
|
+
"""
|
|
167
|
+
llvm_function = IRFunction(type_info)
|
|
168
|
+
|
|
169
|
+
builder = llvm_function.builder
|
|
170
|
+
type_info = llvm_function.type_info
|
|
171
|
+
module = llvm_function.module
|
|
172
|
+
|
|
173
|
+
analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
|
|
174
|
+
const_values: List[ConstValue] = [const_node.value for const_node in analysis.const_nodes]
|
|
175
|
+
|
|
176
|
+
max_index_size: int = max(
|
|
177
|
+
len(analysis.var_nodes), # number of inputs
|
|
178
|
+
len(analysis.result_nodes), # number of outputs
|
|
179
|
+
len(analysis.op_to_tmp), # number of tmps
|
|
180
|
+
len(analysis.const_nodes), # number of constants
|
|
181
|
+
)
|
|
182
|
+
data_idx_bytes: int = _get_bytes_needed(max_index_size)
|
|
183
|
+
|
|
184
|
+
max_num_args: int = max((len(op_node.args) for op_node in analysis.op_nodes), default=0)
|
|
185
|
+
num_args_bytes: int = _get_bytes_needed(max_num_args)
|
|
186
|
+
|
|
187
|
+
data_type: ir.Type = type_info.llvm_type
|
|
188
|
+
byte_type: ir.Type = ir.IntType(8)
|
|
189
|
+
data_idx_type: ir.Type = ir.IntType(data_idx_bytes * 8)
|
|
190
|
+
|
|
191
|
+
byte_code: List[int] = _make_byte_code(analysis, data_idx_bytes, num_args_bytes)
|
|
192
|
+
|
|
193
|
+
inst_idx_bytes: int = _get_bytes_needed(len(byte_code))
|
|
194
|
+
inst_idx_type: ir.Type = ir.IntType(inst_idx_bytes * 8)
|
|
195
|
+
|
|
196
|
+
if compile_arrays:
|
|
197
|
+
# Put constants into the LLVM module
|
|
198
|
+
consts_array_type = ir.ArrayType(data_type, len(analysis.const_nodes))
|
|
199
|
+
consts_global = ir.GlobalVariable(module, consts_array_type, name='consts')
|
|
200
|
+
consts_global.global_constant = True
|
|
201
|
+
consts_global.initializer = ir.Constant(consts_array_type, const_values)
|
|
202
|
+
data_idx_0 = ir.Constant(data_idx_type, 0)
|
|
203
|
+
consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
|
|
204
|
+
|
|
205
|
+
# Put bytecode into the LLVM module
|
|
206
|
+
instructions_array_type = ir.ArrayType(byte_type, len(byte_code))
|
|
207
|
+
instructions_global = ir.GlobalVariable(module, instructions_array_type, name='instructions')
|
|
208
|
+
instructions_global.global_constant = True
|
|
209
|
+
instructions_global.initializer = ir.Constant(instructions_array_type, byte_code)
|
|
210
|
+
inst_idx_0 = ir.Constant(inst_idx_type, 0)
|
|
211
|
+
instructions: ir.Value = builder.gep(instructions_global, [inst_idx_0, inst_idx_0])
|
|
212
|
+
else:
|
|
213
|
+
# Just create two global variables that will be set externally.
|
|
214
|
+
const_ptr_type = data_type.as_pointer()
|
|
215
|
+
consts_global = ir.GlobalVariable(module, const_ptr_type, name='consts')
|
|
216
|
+
consts_global.initializer = ir.Constant(const_ptr_type, None)
|
|
217
|
+
consts: ir.Value = builder.load(consts_global)
|
|
218
|
+
|
|
219
|
+
instructions_ptr_type = byte_type.as_pointer()
|
|
220
|
+
instructions_global = ir.GlobalVariable(module, instructions_ptr_type, name='instructions')
|
|
221
|
+
instructions_global.initializer =ir.Constant(instructions_ptr_type, None)
|
|
222
|
+
instructions: ir.Value = builder.load(instructions_global)
|
|
223
|
+
|
|
224
|
+
interp = _InterpBuilder(builder, type_info, inst_idx_type, data_idx_bytes, num_args_bytes, consts, instructions)
|
|
225
|
+
interp.make_interpreter()
|
|
226
|
+
|
|
227
|
+
if not compile_arrays:
|
|
228
|
+
# add functions to set global arrays
|
|
229
|
+
interp.make_set_consts_function(consts_global)
|
|
230
|
+
interp.make_set_instructions_function(instructions_global)
|
|
231
|
+
|
|
232
|
+
# print(llvm_function.llvm_program())
|
|
233
|
+
# exit(99)
|
|
234
|
+
|
|
235
|
+
return llvm_function.llvm_program(), len(analysis.op_to_tmp), const_values, byte_code
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class _InterpBuilder:
|
|
239
|
+
"""
|
|
240
|
+
Helper to write the LLVM function for the byte code interpreter.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
builder: ir.IRBuilder,
|
|
246
|
+
type_info: TypeInfo,
|
|
247
|
+
inst_idx_type: ir.Type,
|
|
248
|
+
index_bytes: int,
|
|
249
|
+
num_args_bytes: int,
|
|
250
|
+
consts: ir.Value,
|
|
251
|
+
instructions: ir.Value,
|
|
252
|
+
):
|
|
253
|
+
self.builder: ir.IRBuilder = builder
|
|
254
|
+
self.index_bytes: int = index_bytes
|
|
255
|
+
self.num_args_bytes: int = num_args_bytes
|
|
256
|
+
self.type_info: TypeInfo = type_info
|
|
257
|
+
|
|
258
|
+
self.data_type: ir.Type = type_info.llvm_type
|
|
259
|
+
self.byte_type: ir.Type = ir.IntType(8)
|
|
260
|
+
self.inst_idx_type: ir.Type = inst_idx_type
|
|
261
|
+
self.data_idx_type: ir.Type = ir.IntType(index_bytes * 8)
|
|
262
|
+
self.num_args_type: ir.Type = ir.IntType(num_args_bytes * 8)
|
|
263
|
+
|
|
264
|
+
self.data_idx_0 = ir.Constant(self.data_idx_type, 0)
|
|
265
|
+
self.data_idx_1 = ir.Constant(self.data_idx_type, 1)
|
|
266
|
+
self.inst_idx_0 = ir.Constant(self.inst_idx_type, 0)
|
|
267
|
+
self.inst_idx_1 = ir.Constant(self.inst_idx_type, 1)
|
|
268
|
+
self.num_args_0 = ir.Constant(self.num_args_type, 0)
|
|
269
|
+
self.num_args_1 = ir.Constant(self.num_args_type, 1)
|
|
270
|
+
|
|
271
|
+
self.consts: ir.Value = consts
|
|
272
|
+
self.instructions: ir.Value = instructions
|
|
273
|
+
|
|
274
|
+
# allocate locals
|
|
275
|
+
self.local_idx = builder.alloca(self.inst_idx_type, name='idx')
|
|
276
|
+
self.local_num_args = builder.alloca(self.num_args_type, name='num_args')
|
|
277
|
+
self.local_accumulator = builder.alloca(self.data_type, name='accumulator')
|
|
278
|
+
self.local_arrays = builder.alloca(self.data_type.as_pointer(), size=4, name='arrays')
|
|
279
|
+
|
|
280
|
+
# local_arrays = [vars, tmps, result, consts]
|
|
281
|
+
ir_vars_idx = ir.Constant(self.byte_type, _VARS)
|
|
282
|
+
ir_tmps_idx = ir.Constant(self.byte_type, _TMPS)
|
|
283
|
+
ir_result_idx = ir.Constant(self.byte_type, _RESULT)
|
|
284
|
+
ir_consts_idx = ir.Constant(self.byte_type, _CONSTS)
|
|
285
|
+
function: ir.Function = builder.function
|
|
286
|
+
local_arrays = self.local_arrays
|
|
287
|
+
builder.store(function.args[0], builder.gep(local_arrays, [ir_vars_idx]))
|
|
288
|
+
builder.store(function.args[1], builder.gep(local_arrays, [ir_tmps_idx]))
|
|
289
|
+
builder.store(function.args[2], builder.gep(local_arrays, [ir_result_idx]))
|
|
290
|
+
builder.store(consts, builder.gep(local_arrays, [ir_consts_idx]))
|
|
291
|
+
|
|
292
|
+
# local_idx = 0
|
|
293
|
+
builder.store(self.inst_idx_0, self.local_idx)
|
|
294
|
+
|
|
295
|
+
def make_set_consts_function(self, consts_ptr: ir.GlobalVariable):
|
|
296
|
+
builder = self.builder
|
|
297
|
+
module = builder.module
|
|
298
|
+
function_type = ir.FunctionType(ir.VoidType(), (self.data_type.as_pointer(),))
|
|
299
|
+
function = ir.Function(module, function_type, name=_SET_CONSTS_FUNCTION_NAME)
|
|
300
|
+
bb_entry = function.append_basic_block('entry')
|
|
301
|
+
builder.position_at_end(bb_entry)
|
|
302
|
+
arg = function.args[0]
|
|
303
|
+
builder.store(arg, consts_ptr)
|
|
304
|
+
builder.ret_void()
|
|
305
|
+
|
|
306
|
+
def make_set_instructions_function(self, instructions_ptr: ir.GlobalVariable):
|
|
307
|
+
builder = self.builder
|
|
308
|
+
module = builder.module
|
|
309
|
+
function_type = ir.FunctionType(ir.VoidType(), (self.byte_type.as_pointer(),))
|
|
310
|
+
function = ir.Function(module, function_type, name=_SET_INSTRUCTIONS_FUNCTION_NAME)
|
|
311
|
+
bb_entry = function.append_basic_block('entry')
|
|
312
|
+
builder.position_at_end(bb_entry)
|
|
313
|
+
arg = function.args[0]
|
|
314
|
+
builder.store(arg, instructions_ptr)
|
|
315
|
+
builder.ret_void()
|
|
316
|
+
|
|
317
|
+
def add(self, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
318
|
+
return self.type_info.add(self.builder, x, y)
|
|
319
|
+
|
|
320
|
+
def mul(self, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
321
|
+
return self.type_info.mul(self.builder, x, y)
|
|
322
|
+
|
|
323
|
+
def make_interpreter(self):
|
|
324
|
+
"""
|
|
325
|
+
Write the bytecode interpreter
|
|
326
|
+
"""
|
|
327
|
+
builder: ir.IRBuilder = self.builder
|
|
328
|
+
function: ir.Function = builder.function
|
|
329
|
+
|
|
330
|
+
bb_while = function.append_basic_block('while')
|
|
331
|
+
bb_body = function.append_basic_block('body')
|
|
332
|
+
bb_mul = function.append_basic_block('mul')
|
|
333
|
+
bb_mul_op = function.append_basic_block('mul_op')
|
|
334
|
+
bb_add = function.append_basic_block('add')
|
|
335
|
+
bb_add_op = function.append_basic_block('add_op')
|
|
336
|
+
bb_op_continue = function.append_basic_block('op_continue')
|
|
337
|
+
bb_finish = function.append_basic_block('finish')
|
|
338
|
+
|
|
339
|
+
# block: entry
|
|
340
|
+
# (locals already set up in the constructor)
|
|
341
|
+
builder.branch(bb_while)
|
|
342
|
+
|
|
343
|
+
# block: while
|
|
344
|
+
builder.position_at_end(bb_while)
|
|
345
|
+
# load current instruction
|
|
346
|
+
idx = builder.load(self.local_idx)
|
|
347
|
+
inst = builder.load(builder.gep(self.instructions, [idx]))
|
|
348
|
+
idx = builder.add(idx, self.inst_idx_1)
|
|
349
|
+
#
|
|
350
|
+
cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _END))
|
|
351
|
+
builder.cbranch(cmp_end, bb_finish, bb_body)
|
|
352
|
+
|
|
353
|
+
# block: body
|
|
354
|
+
builder.position_at_end(bb_body)
|
|
355
|
+
# load number of args
|
|
356
|
+
idx, num_args = self._read_number(idx, self.num_args_bytes)
|
|
357
|
+
builder.store(num_args, self.local_num_args)
|
|
358
|
+
# load first arg value into the accumulator
|
|
359
|
+
idx, arg0 = self._load_value(idx)
|
|
360
|
+
builder.store(arg0, self.local_accumulator)
|
|
361
|
+
# save the current bytecode index
|
|
362
|
+
builder.store(idx, self.local_idx)
|
|
363
|
+
#
|
|
364
|
+
cmp_end = builder.icmp_unsigned('==', inst, ir.Constant(self.byte_type, _MUL))
|
|
365
|
+
builder.cbranch(cmp_end, bb_mul, bb_add)
|
|
366
|
+
|
|
367
|
+
# block: mul
|
|
368
|
+
builder.position_at_end(bb_mul)
|
|
369
|
+
num_args = builder.load(self.local_num_args)
|
|
370
|
+
num_args = builder.sub(num_args, self.num_args_1)
|
|
371
|
+
builder.store(num_args, self.local_num_args)
|
|
372
|
+
more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
|
|
373
|
+
builder.cbranch(more_args, bb_mul_op, bb_op_continue)
|
|
374
|
+
|
|
375
|
+
# block: mul_op
|
|
376
|
+
builder.position_at_end(bb_mul_op)
|
|
377
|
+
idx = builder.load(self.local_idx)
|
|
378
|
+
idx, value = self._load_value(idx)
|
|
379
|
+
acc = builder.load(self.local_accumulator)
|
|
380
|
+
acc = self.mul(acc, value)
|
|
381
|
+
builder.store(acc, self.local_accumulator)
|
|
382
|
+
builder.store(idx, self.local_idx)
|
|
383
|
+
builder.branch(bb_mul)
|
|
384
|
+
|
|
385
|
+
# block: add
|
|
386
|
+
builder.position_at_end(bb_add)
|
|
387
|
+
num_args = builder.load(self.local_num_args)
|
|
388
|
+
num_args = builder.sub(num_args, self.num_args_1)
|
|
389
|
+
builder.store(num_args, self.local_num_args)
|
|
390
|
+
more_args = builder.icmp_unsigned('>', num_args, self.num_args_0)
|
|
391
|
+
builder.cbranch(more_args, bb_add_op, bb_op_continue)
|
|
392
|
+
|
|
393
|
+
# block: add_op
|
|
394
|
+
builder.position_at_end(bb_add_op)
|
|
395
|
+
idx = builder.load(self.local_idx)
|
|
396
|
+
idx, value = self._load_value(idx)
|
|
397
|
+
acc = builder.load(self.local_accumulator)
|
|
398
|
+
acc = self.add(acc, value)
|
|
399
|
+
builder.store(acc, self.local_accumulator)
|
|
400
|
+
builder.store(idx, self.local_idx)
|
|
401
|
+
builder.branch(bb_add)
|
|
402
|
+
|
|
403
|
+
# block: op_continue
|
|
404
|
+
builder.position_at_end(bb_op_continue)
|
|
405
|
+
# get where we store the result
|
|
406
|
+
idx = builder.load(self.local_idx)
|
|
407
|
+
idx, ptr = self._load_value_ptr(idx)
|
|
408
|
+
builder.store(idx, self.local_idx)
|
|
409
|
+
# get and store the result
|
|
410
|
+
acc = builder.load(self.local_accumulator)
|
|
411
|
+
builder.store(acc, ptr)
|
|
412
|
+
builder.branch(bb_while)
|
|
413
|
+
|
|
414
|
+
# block: finish
|
|
415
|
+
builder.position_at_end(bb_finish)
|
|
416
|
+
builder.ret_void()
|
|
417
|
+
|
|
418
|
+
def _read_number(self, idx: ir.Value, num_bytes: int) -> Tuple[ir.Value, ir.Value]:
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
idx: current instruction index
|
|
423
|
+
num_bytes: how many bytes to read from the instruction stream to form the number
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
(idx, number)
|
|
427
|
+
idx: is the updated instruction index
|
|
428
|
+
number: is the read number
|
|
429
|
+
"""
|
|
430
|
+
builder = self.builder
|
|
431
|
+
|
|
432
|
+
llvm_type: ir.Type = ir.IntType(num_bytes * 8)
|
|
433
|
+
|
|
434
|
+
number: ir.Value = builder.load(builder.gep(self.instructions, [idx]))
|
|
435
|
+
idx = builder.add(idx, self.inst_idx_1)
|
|
436
|
+
|
|
437
|
+
if num_bytes > 1:
|
|
438
|
+
eight = ir.Constant(llvm_type, 8)
|
|
439
|
+
number = builder.zext(number, llvm_type)
|
|
440
|
+
for _ in range(num_bytes - 1):
|
|
441
|
+
next_byte = builder.load(builder.gep(self.instructions, [idx]))
|
|
442
|
+
number = builder.add(builder.shl(number, eight), builder.zext(next_byte, llvm_type))
|
|
443
|
+
idx = builder.add(idx, self.inst_idx_1)
|
|
444
|
+
|
|
445
|
+
return idx, number
|
|
446
|
+
|
|
447
|
+
def _load_value_ptr(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
|
|
448
|
+
builder = self.builder
|
|
449
|
+
|
|
450
|
+
# load array first index
|
|
451
|
+
index_0 = builder.load(builder.gep(self.instructions, [idx]))
|
|
452
|
+
idx = builder.add(idx, self.inst_idx_1)
|
|
453
|
+
|
|
454
|
+
# load array second index
|
|
455
|
+
idx, index_1 = self._read_number(idx, self.index_bytes)
|
|
456
|
+
|
|
457
|
+
# get the pointer
|
|
458
|
+
array = builder.load(builder.gep(self.local_arrays, [index_0]))
|
|
459
|
+
ptr = builder.gep(array, [index_1])
|
|
460
|
+
|
|
461
|
+
return idx, ptr
|
|
462
|
+
|
|
463
|
+
def _load_value(self, idx: ir.Value) -> Tuple[ir.Value, ir.Value]:
|
|
464
|
+
idx, ptr = self._load_value_ptr(idx)
|
|
465
|
+
value = self.builder.load(ptr)
|
|
466
|
+
return idx, value
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@dataclass
|
|
470
|
+
class _ElementID:
|
|
471
|
+
"""
|
|
472
|
+
A 2D index into the function's `arrays`.
|
|
473
|
+
"""
|
|
474
|
+
array: int # which array: VARS, TMPS, CONSTS, RESULT
|
|
475
|
+
index: int # index into the array
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _make_byte_code(analysis: CircuitAnalysis, data_idx_bytes: int, num_args_bytes: int) -> List[int]:
|
|
479
|
+
# Index input value elements for each possible input node.
|
|
480
|
+
node_to_element: Dict[int, _ElementID] = {}
|
|
481
|
+
# const nodes
|
|
482
|
+
for i, node in enumerate(analysis.const_nodes):
|
|
483
|
+
node_to_element[id(node)] = _ElementID(_CONSTS, i)
|
|
484
|
+
# var nodes
|
|
485
|
+
for i, var_node in enumerate(analysis.var_nodes):
|
|
486
|
+
if var_node.is_const():
|
|
487
|
+
node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
|
|
488
|
+
else:
|
|
489
|
+
node_to_element[id(var_node)] = _ElementID(_VARS, i)
|
|
490
|
+
# op nodes
|
|
491
|
+
for node_id, tmp_idx in analysis.op_to_tmp.items():
|
|
492
|
+
node_to_element[node_id] = _ElementID(_TMPS, tmp_idx)
|
|
493
|
+
for node_id, result_idx in analysis.op_to_result.items():
|
|
494
|
+
node_to_element[node_id] = _ElementID(_RESULT, result_idx)
|
|
495
|
+
|
|
496
|
+
# Make byte code
|
|
497
|
+
byte_code: List[int] = []
|
|
498
|
+
for op_node in analysis.op_nodes:
|
|
499
|
+
# write the op code
|
|
500
|
+
byte_code.append(op_node.symbol) # _ADD or _MUL
|
|
501
|
+
# write the number of args
|
|
502
|
+
byte_code.extend(_to_bytes(len(op_node.args), num_args_bytes))
|
|
503
|
+
# write the element id for each arg
|
|
504
|
+
for arg_node in op_node.args:
|
|
505
|
+
element_id: _ElementID = node_to_element[id(arg_node)]
|
|
506
|
+
byte_code.append(element_id.array)
|
|
507
|
+
byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
|
|
508
|
+
# write the element id for the result
|
|
509
|
+
element_id: _ElementID = node_to_element[id(op_node)]
|
|
510
|
+
byte_code.append(element_id.array)
|
|
511
|
+
byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
|
|
512
|
+
# ...any final copy instructions
|
|
513
|
+
for idx, node in enumerate(analysis.result_nodes):
|
|
514
|
+
if not isinstance(node, OpNode):
|
|
515
|
+
byte_code.append(_ADD)
|
|
516
|
+
byte_code.extend(_to_bytes(1, num_args_bytes))
|
|
517
|
+
|
|
518
|
+
element_id: _ElementID = node_to_element[id(node)]
|
|
519
|
+
byte_code.append(element_id.array)
|
|
520
|
+
byte_code.extend(_to_bytes(element_id.index, data_idx_bytes))
|
|
521
|
+
|
|
522
|
+
byte_code.append(_RESULT)
|
|
523
|
+
byte_code.extend(_to_bytes(idx, data_idx_bytes))
|
|
524
|
+
|
|
525
|
+
# write the sentinel - 'end' op code
|
|
526
|
+
byte_code.append(_END)
|
|
527
|
+
|
|
528
|
+
return byte_code
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _to_bytes(value: int, num_bytes: int) -> List[int]:
|
|
532
|
+
buffer: List[int] = []
|
|
533
|
+
for _ in range(num_bytes):
|
|
534
|
+
buffer.append(value % 256)
|
|
535
|
+
value //= 256
|
|
536
|
+
assert value == 0
|
|
537
|
+
buffer.reverse()
|
|
538
|
+
return buffer
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _get_bytes_needed(size: int) -> int:
|
|
542
|
+
index_bytes: int
|
|
543
|
+
for index_bytes in [1, 2, 4, 8]:
|
|
544
|
+
if size < 2 ** (index_bytes * 8 - 1):
|
|
545
|
+
return index_bytes
|
|
546
|
+
raise ValueError(f'size are too large to represent: {size}')
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from .llvm_compiler import Flavour
|
|
6
|
+
from ..circuit import CircuitNode, Circuit
|
|
7
|
+
from ..circuit_compiler import interpret_compiler, cython_vm_compiler, llvm_compiler, llvm_vm_compiler, CircuitCompiler
|
|
8
|
+
from ..circuit_compiler.support.input_vars import InputVars, InferVars
|
|
9
|
+
from ..program import RawProgram
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NamedCircuitCompiler(Enum):
|
|
13
|
+
"""
|
|
14
|
+
A standard collection of named circuit compiler functions.
|
|
15
|
+
|
|
16
|
+
The `value` of each enum member is tuple containing a compiler function.
|
|
17
|
+
Wrapping in a tuple is needed otherwise Python erases the type of the member, which can cause problems.
|
|
18
|
+
Each member itself is callable, conforming to the CircuitCompiler protocol, delegating to the compiler function.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
LLVM_STACK = (partial(llvm_compiler.compile_circuit, flavour=Flavour.STACK),)
|
|
22
|
+
LLVM_TMPS = (partial(llvm_compiler.compile_circuit, flavour=Flavour.TMPS, opt=0),)
|
|
23
|
+
LLVM_VM = (llvm_vm_compiler.compile_circuit,)
|
|
24
|
+
CYTHON_VM = (cython_vm_compiler.compile_circuit,)
|
|
25
|
+
INTERPRET = (interpret_compiler.compile_circuit,)
|
|
26
|
+
|
|
27
|
+
# The following circuit compilers were experimental but are not really useful.
|
|
28
|
+
#
|
|
29
|
+
# Slow compile and execution:
|
|
30
|
+
# LLVM_FUNCS = (partial(llvm_compiler.compile_circuit, flavour=Flavour.FUNCS, opt=0),)
|
|
31
|
+
#
|
|
32
|
+
# Slow compile and same execution as LLVM_VM:
|
|
33
|
+
# LLVM_VM_COMPILED_ARRAYS = (partial(llvm_vm_compiler.compile_circuit, compile_arrays=True),)
|
|
34
|
+
|
|
35
|
+
def __call__(
|
|
36
|
+
self,
|
|
37
|
+
*result: CircuitNode,
|
|
38
|
+
input_vars: InputVars = InferVars.ALL,
|
|
39
|
+
circuit: Optional[Circuit] = None,
|
|
40
|
+
) -> RawProgram:
|
|
41
|
+
"""
|
|
42
|
+
Each member of the enum is a CircuitCompiler function.
|
|
43
|
+
|
|
44
|
+
This implements the `CircuitCompiler` protocol for each member of the enum.
|
|
45
|
+
"""
|
|
46
|
+
return self.compiler(*result, input_vars=input_vars, circuit=circuit)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def compiler(self) -> CircuitCompiler:
|
|
50
|
+
"""
|
|
51
|
+
Returns:
|
|
52
|
+
The compiler function, conforming to the CircuitCompiler protocol.
|
|
53
|
+
"""
|
|
54
|
+
return self.value[0]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.CYTHON_VM
|
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# There are two implementations of the `circuit_analyser` module are provided
|
|
2
|
+
# for developer R&D purposes. One is pure Python and the other is Cython.
|
|
3
|
+
# Which implementation is used can be selected here.
|
|
4
|
+
#
|
|
5
|
+
# A similar selection can be made for the `circuit` module.
|
|
6
|
+
# Note that if the Cython implementation is chosen for `circuit_analyser` then
|
|
7
|
+
# the Cython implementation must be chosen for `circuit`.
|
|
8
|
+
|
|
9
|
+
# from ._circuit_analyser_py import (
|
|
10
|
+
from ._circuit_analyser_cy import (
|
|
11
|
+
CircuitAnalysis,
|
|
12
|
+
analyze_circuit,
|
|
13
|
+
)
|