compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Sequence, Optional, Dict, List, Tuple, Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import ctypes as ct
|
|
8
|
+
|
|
9
|
+
from ..circuit import Circuit, CircuitNode, VarNode, OpNode, ADD, MUL
|
|
10
|
+
from ..program.raw_program import RawProgram, RawProgramFunction
|
|
11
|
+
from ..utils.iter_extras import multiply, first
|
|
12
|
+
from ..utils.np_extras import NDArrayNumeric, DTypeNumeric
|
|
13
|
+
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
14
|
+
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
15
|
+
|
|
16
|
+
# index to a value array
|
|
17
|
+
_VARS = 0
|
|
18
|
+
_CONSTS = 1
|
|
19
|
+
_TMPS = 2
|
|
20
|
+
_RESULT = 3
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compile_circuit(
|
|
24
|
+
*result: CircuitNode,
|
|
25
|
+
input_vars: InputVars = InferVars.ALL,
|
|
26
|
+
circuit: Optional[Circuit] = None,
|
|
27
|
+
dtype: DTypeNumeric = np.double,
|
|
28
|
+
) -> InterpreterRawProgram:
|
|
29
|
+
"""
|
|
30
|
+
Make a RawProgram that interprets the given circuit.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
*result: result nodes nominating the results of the returned program.
|
|
34
|
+
input_vars: How to determine the input variables.
|
|
35
|
+
circuit: optionally explicitly specify the Circuit.
|
|
36
|
+
dtype: the numpy DType to use for the raw program.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
a raw program.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: if the circuit is unknown, but it is needed.
|
|
43
|
+
ValueError: if not all nodes are from the same circuit.
|
|
44
|
+
"""
|
|
45
|
+
in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
|
|
46
|
+
analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
|
|
47
|
+
instructions: List[_Instruction]
|
|
48
|
+
np_consts: NDArrayNumeric
|
|
49
|
+
instructions, np_consts = _make_instructions(analysis, dtype)
|
|
50
|
+
|
|
51
|
+
return InterpreterRawProgram(
|
|
52
|
+
in_vars=in_vars,
|
|
53
|
+
result=result,
|
|
54
|
+
op_nodes=analysis.op_nodes,
|
|
55
|
+
dtype=dtype,
|
|
56
|
+
instructions=instructions,
|
|
57
|
+
np_consts=np_consts,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class InterpreterRawProgram(RawProgram):
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
in_vars: Sequence[VarNode],
|
|
65
|
+
result: Sequence[CircuitNode],
|
|
66
|
+
op_nodes: Sequence[OpNode],
|
|
67
|
+
dtype: DTypeNumeric,
|
|
68
|
+
instructions: List[_Instruction],
|
|
69
|
+
np_consts: NDArrayNumeric,
|
|
70
|
+
):
|
|
71
|
+
self.instructions = instructions
|
|
72
|
+
self.np_consts = np_consts
|
|
73
|
+
|
|
74
|
+
function = _make_function(
|
|
75
|
+
instructions=instructions,
|
|
76
|
+
np_consts=np_consts,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
super().__init__(
|
|
80
|
+
function=function,
|
|
81
|
+
dtype=dtype,
|
|
82
|
+
number_of_vars=len(in_vars),
|
|
83
|
+
number_of_tmps=len(op_nodes),
|
|
84
|
+
number_of_results=len(result),
|
|
85
|
+
var_indices=tuple(var.idx for var in in_vars),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def __getstate__(self):
|
|
89
|
+
"""
|
|
90
|
+
Support for pickle.
|
|
91
|
+
"""
|
|
92
|
+
return {
|
|
93
|
+
'dtype': self.dtype,
|
|
94
|
+
'number_of_vars': self.number_of_vars,
|
|
95
|
+
'number_of_tmps': self.number_of_tmps,
|
|
96
|
+
'number_of_results': self.number_of_results,
|
|
97
|
+
'var_indices': self.var_indices,
|
|
98
|
+
#
|
|
99
|
+
'instructions': self.instructions,
|
|
100
|
+
'np_consts': self.np_consts,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
def __setstate__(self, state):
|
|
104
|
+
"""
|
|
105
|
+
Support for pickle.
|
|
106
|
+
"""
|
|
107
|
+
self.dtype = state['dtype']
|
|
108
|
+
self.number_of_vars = state['number_of_vars']
|
|
109
|
+
self.number_of_tmps = state['number_of_tmps']
|
|
110
|
+
self.number_of_results = state['number_of_results']
|
|
111
|
+
self.var_indices = state['var_indices']
|
|
112
|
+
#
|
|
113
|
+
self.instructions = state['instructions']
|
|
114
|
+
self.np_consts = state['np_consts']
|
|
115
|
+
|
|
116
|
+
self.function = _make_function(
|
|
117
|
+
instructions=self.instructions,
|
|
118
|
+
np_consts=self.np_consts,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _make_instructions(
|
|
123
|
+
analysis: CircuitAnalysis,
|
|
124
|
+
dtype: DTypeNumeric,
|
|
125
|
+
) -> Tuple[Sequence[_Instruction], NDArrayNumeric]:
|
|
126
|
+
|
|
127
|
+
# Store const values in a numpy array
|
|
128
|
+
node_to_const_idx: Dict[int, int] = {
|
|
129
|
+
id(node): i
|
|
130
|
+
for i, node in enumerate(analysis.const_nodes)
|
|
131
|
+
}
|
|
132
|
+
np_consts: NDArrayNumeric = np.zeros(len(node_to_const_idx), dtype=dtype)
|
|
133
|
+
for i, node in enumerate(analysis.const_nodes):
|
|
134
|
+
np_consts[i] = node.value
|
|
135
|
+
|
|
136
|
+
# Where to get input values for each possible node.
|
|
137
|
+
node_to_element: Dict[int, _ElementID] = {}
|
|
138
|
+
# const nodes
|
|
139
|
+
for node_id, const_idx in node_to_const_idx.items():
|
|
140
|
+
node_to_element[node_id] = _ElementID(_CONSTS, const_idx)
|
|
141
|
+
# var nodes
|
|
142
|
+
for i, var_node in enumerate(analysis.var_nodes):
|
|
143
|
+
if var_node.is_const():
|
|
144
|
+
node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
|
|
145
|
+
else:
|
|
146
|
+
node_to_element[id(var_node)] = _ElementID(_VARS, i)
|
|
147
|
+
# op nodes
|
|
148
|
+
for node_id, tmp_index in analysis.op_to_tmp.items():
|
|
149
|
+
node_to_element[node_id] = _ElementID(_TMPS, tmp_index)
|
|
150
|
+
for node_id, tmp_index in analysis.op_to_result.items():
|
|
151
|
+
node_to_element[node_id] = _ElementID(_RESULT, tmp_index)
|
|
152
|
+
|
|
153
|
+
# Build instructions
|
|
154
|
+
instructions: List[_Instruction] = []
|
|
155
|
+
|
|
156
|
+
op_node: OpNode
|
|
157
|
+
for op_node in analysis.op_nodes:
|
|
158
|
+
dest: _ElementID = node_to_element[id(op_node)]
|
|
159
|
+
args: List[_ElementID] = [
|
|
160
|
+
node_to_element[id(arg)]
|
|
161
|
+
for arg in op_node.args
|
|
162
|
+
]
|
|
163
|
+
if op_node.symbol == MUL:
|
|
164
|
+
operation = multiply
|
|
165
|
+
elif op_node.symbol == ADD:
|
|
166
|
+
operation = sum
|
|
167
|
+
else:
|
|
168
|
+
assert False, 'symbol not understood'
|
|
169
|
+
|
|
170
|
+
instructions.append(_Instruction(operation, args, dest))
|
|
171
|
+
|
|
172
|
+
# Add any copy operations, i.e., result nodes that are not op nodes
|
|
173
|
+
for i, node in enumerate(analysis.result_nodes):
|
|
174
|
+
if not isinstance(node, OpNode):
|
|
175
|
+
source: _ElementID = node_to_element[id(node)]
|
|
176
|
+
instructions.append(_Instruction(first, [source], _ElementID(_RESULT, i)))
|
|
177
|
+
|
|
178
|
+
return instructions, np_consts
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _make_function(
|
|
182
|
+
instructions: List[_Instruction],
|
|
183
|
+
np_consts: NDArrayNumeric,
|
|
184
|
+
) -> RawProgramFunction:
|
|
185
|
+
"""
|
|
186
|
+
Make a RawProgram function that executes the given instructions.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
# RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
|
|
190
|
+
def raw_program_function(vars_in: ct.POINTER, tmps: ct.POINTER, result_out: ct.POINTER) -> None:
|
|
191
|
+
nonlocal np_consts
|
|
192
|
+
nonlocal instructions
|
|
193
|
+
|
|
194
|
+
arrays: List[ct.POINTER] = [None, None, None, None]
|
|
195
|
+
arrays[_VARS] = vars_in
|
|
196
|
+
arrays[_TMPS] = tmps
|
|
197
|
+
arrays[_RESULT] = result_out
|
|
198
|
+
arrays[_CONSTS] = np_consts
|
|
199
|
+
|
|
200
|
+
def get_value(_element: _ElementID):
|
|
201
|
+
return arrays[_element.array][_element.index]
|
|
202
|
+
|
|
203
|
+
instruction: _Instruction
|
|
204
|
+
element: _ElementID
|
|
205
|
+
for instruction in instructions:
|
|
206
|
+
value = instruction.operation(get_value(element) for element in instruction.args)
|
|
207
|
+
dest: _ElementID = instruction.dest
|
|
208
|
+
arrays[dest.array][dest.index] = value
|
|
209
|
+
|
|
210
|
+
return raw_program_function
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class _ElementID:
|
|
215
|
+
array: int # VARS, TMPS, CONSTS, RESULT
|
|
216
|
+
index: int # index into the array
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class _Instruction:
|
|
221
|
+
operation: Callable
|
|
222
|
+
args: Sequence[_ElementID]
|
|
223
|
+
dest: _ElementID
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Sequence, Optional, Tuple, Dict, Protocol
|
|
6
|
+
|
|
7
|
+
import llvmlite.binding as llvm
|
|
8
|
+
import llvmlite.ir as ir
|
|
9
|
+
|
|
10
|
+
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
11
|
+
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
12
|
+
from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram, IrBOp
|
|
13
|
+
from ..circuit import Circuit, VarNode, CircuitNode, OpNode, MUL, ADD, ConstNode
|
|
14
|
+
from ..program.raw_program import RawProgramFunction
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Flavour(Enum):
|
|
18
|
+
STACK = 0 # No working temporary memory requested - all on stack.
|
|
19
|
+
TMPS = 1 # Working temporary memory used for op node calculations.
|
|
20
|
+
FUNCS = 2 # Working temporary memory used for op node calculations, one sub-function per op-node.
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
|
|
24
|
+
DEFAULT_FLAVOUR: Flavour = Flavour.TMPS
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compile_circuit(
|
|
28
|
+
*result: CircuitNode,
|
|
29
|
+
input_vars: InputVars = InferVars.ALL,
|
|
30
|
+
circuit: Optional[Circuit] = None,
|
|
31
|
+
data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
|
|
32
|
+
flavour: Flavour = DEFAULT_FLAVOUR,
|
|
33
|
+
keep_llvm_program: bool = True,
|
|
34
|
+
opt: int = 2,
|
|
35
|
+
) -> LLVMRawProgram:
|
|
36
|
+
"""
|
|
37
|
+
Compile the given circuit using LLVM.
|
|
38
|
+
|
|
39
|
+
This creates an LLVM program where each circuit op node is converted to
|
|
40
|
+
one or more LLVM binary op machine code instructions. For large circuits
|
|
41
|
+
this results in a large LLVM program which can be slow to compile.
|
|
42
|
+
|
|
43
|
+
This compiler produces a RawProgram that _does_ use client managed working memory.
|
|
44
|
+
|
|
45
|
+
Conforms to the CircuitCompiler protocol.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
*result: result nodes nominating the results of the returned program.
|
|
49
|
+
input_vars: How to determine the input variables.
|
|
50
|
+
circuit: optionally explicitly specify the Circuit.
|
|
51
|
+
data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
|
|
52
|
+
keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
|
|
53
|
+
flavour: what flavour of LLVM program to construct.
|
|
54
|
+
opt:The optimization level to use by LLVM MC JIT.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
a raw program.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: if the circuit is unknown, but it is needed.
|
|
61
|
+
ValueError: if not all nodes are from the same circuit.
|
|
62
|
+
ValueError: if the program data type could not be interpreted.
|
|
63
|
+
"""
|
|
64
|
+
in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
|
|
65
|
+
var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
|
|
66
|
+
|
|
67
|
+
# Get the type info
|
|
68
|
+
type_info: TypeInfo
|
|
69
|
+
if isinstance(data_type, DataType):
|
|
70
|
+
type_info = data_type.value
|
|
71
|
+
elif isinstance(data_type, TypeInfo):
|
|
72
|
+
type_info = data_type
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f'could not interpret program data type: {data_type!r}')
|
|
75
|
+
|
|
76
|
+
# Compile the circuit to an LLVM module representing a RawProgramFunction
|
|
77
|
+
llvm_program: str
|
|
78
|
+
number_of_tmps: int
|
|
79
|
+
llvm_program, number_of_tmps = _make_llvm_program(in_vars, result, type_info, flavour)
|
|
80
|
+
|
|
81
|
+
# Compile the LLVM program to a native executable
|
|
82
|
+
engine: llvm.ExecutionEngine
|
|
83
|
+
function: RawProgramFunction
|
|
84
|
+
engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
|
|
85
|
+
|
|
86
|
+
return LLVMRawProgram(
|
|
87
|
+
function=function,
|
|
88
|
+
dtype=type_info.dtype,
|
|
89
|
+
number_of_vars=len(var_indices),
|
|
90
|
+
number_of_tmps=number_of_tmps,
|
|
91
|
+
number_of_results=len(result),
|
|
92
|
+
var_indices=var_indices,
|
|
93
|
+
llvm_program=llvm_program if keep_llvm_program else None,
|
|
94
|
+
engine=engine,
|
|
95
|
+
opt=opt,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _make_llvm_program(
|
|
100
|
+
in_vars: Sequence[VarNode],
|
|
101
|
+
result: Sequence[CircuitNode],
|
|
102
|
+
type_info: TypeInfo,
|
|
103
|
+
flavour: Flavour,
|
|
104
|
+
) -> Tuple[str, int]:
|
|
105
|
+
"""
|
|
106
|
+
Returns:
|
|
107
|
+
(llvm_program, number_of_tmps)
|
|
108
|
+
"""
|
|
109
|
+
llvm_function = IRFunction(type_info)
|
|
110
|
+
|
|
111
|
+
builder = llvm_function.builder
|
|
112
|
+
type_info = llvm_function.type_info
|
|
113
|
+
function = llvm_function.function
|
|
114
|
+
|
|
115
|
+
analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
|
|
116
|
+
|
|
117
|
+
function_builder: _FunctionBuilder
|
|
118
|
+
if flavour == Flavour.STACK:
|
|
119
|
+
function_builder = _FunctionBuilderStack(
|
|
120
|
+
builder=builder,
|
|
121
|
+
analysis=analysis,
|
|
122
|
+
llvm_type=type_info.llvm_type,
|
|
123
|
+
llvm_idx_type=ir.IntType(32),
|
|
124
|
+
in_args=function.args[0],
|
|
125
|
+
out_args=function.args[2],
|
|
126
|
+
ir_cache={},
|
|
127
|
+
)
|
|
128
|
+
elif flavour == Flavour.TMPS:
|
|
129
|
+
function_builder = _FunctionBuilderTmps(
|
|
130
|
+
builder=builder,
|
|
131
|
+
analysis=analysis,
|
|
132
|
+
llvm_type=type_info.llvm_type,
|
|
133
|
+
llvm_idx_type=ir.IntType(32),
|
|
134
|
+
in_args=function.args[0],
|
|
135
|
+
tmp_args=function.args[1],
|
|
136
|
+
out_args=function.args[2],
|
|
137
|
+
)
|
|
138
|
+
elif flavour == Flavour.FUNCS:
|
|
139
|
+
function_builder = _FunctionBuilderFuncs(
|
|
140
|
+
builder=builder,
|
|
141
|
+
analysis=analysis,
|
|
142
|
+
llvm_type=type_info.llvm_type,
|
|
143
|
+
llvm_idx_type=ir.IntType(32),
|
|
144
|
+
in_args=function.args[0],
|
|
145
|
+
tmp_args=function.args[1],
|
|
146
|
+
out_args=function.args[2],
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(f'unknown LLVM program flavour: {flavour!r}')
|
|
150
|
+
|
|
151
|
+
# Add a calculation for each op node
|
|
152
|
+
for op_node in analysis.op_nodes:
|
|
153
|
+
if op_node.symbol == ADD:
|
|
154
|
+
op: IrBOp = type_info.add
|
|
155
|
+
elif op_node.symbol == MUL:
|
|
156
|
+
op: IrBOp = type_info.mul
|
|
157
|
+
else:
|
|
158
|
+
raise RuntimeError(f'unknown op node: {op_node.symbol!r}')
|
|
159
|
+
function_builder.process_op_node(op_node, op)
|
|
160
|
+
|
|
161
|
+
# Copy any non-op node values to the results
|
|
162
|
+
for idx, node in enumerate(result):
|
|
163
|
+
if not isinstance(node, OpNode):
|
|
164
|
+
value: ir.Value = function_builder.value(node)
|
|
165
|
+
function_builder.store_result(value, idx)
|
|
166
|
+
|
|
167
|
+
# Return from the function
|
|
168
|
+
builder.ret_void()
|
|
169
|
+
|
|
170
|
+
return llvm_function.llvm_program(), function_builder.number_of_tmps()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class _FunctionBuilder(Protocol):
|
|
174
|
+
def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
|
|
175
|
+
...
|
|
176
|
+
|
|
177
|
+
def value(self, node: CircuitNode) -> ir.Value:
|
|
178
|
+
...
|
|
179
|
+
|
|
180
|
+
def store_result(self, value: ir.Value, idx: int) -> None:
|
|
181
|
+
...
|
|
182
|
+
|
|
183
|
+
def number_of_tmps(self) -> int:
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class _FunctionBuilderTmps(_FunctionBuilder):
|
|
189
|
+
"""
|
|
190
|
+
A function builder that puts op node calculations into the temporary working memory.
|
|
191
|
+
"""
|
|
192
|
+
builder: ir.IRBuilder
|
|
193
|
+
analysis: CircuitAnalysis
|
|
194
|
+
llvm_type: ir.Type
|
|
195
|
+
llvm_idx_type: ir.Type
|
|
196
|
+
in_args: ir.Value
|
|
197
|
+
tmp_args: ir.Value
|
|
198
|
+
out_args: ir.Value
|
|
199
|
+
|
|
200
|
+
def number_of_tmps(self) -> int:
|
|
201
|
+
return len(self.analysis.op_to_tmp)
|
|
202
|
+
|
|
203
|
+
def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
|
|
204
|
+
value: ir.Value = self.value(op_node.args[0])
|
|
205
|
+
for arg in op_node.args[1:]:
|
|
206
|
+
next_value: ir.Value = self.value(arg)
|
|
207
|
+
value = op(self.builder, value, next_value)
|
|
208
|
+
self.store_calculation(value, op_node)
|
|
209
|
+
|
|
210
|
+
def value(self, node: CircuitNode) -> ir.Value:
|
|
211
|
+
"""
|
|
212
|
+
Return an IR value for the given circuit node.
|
|
213
|
+
"""
|
|
214
|
+
node_id: int = id(node)
|
|
215
|
+
|
|
216
|
+
# If it is a constant...
|
|
217
|
+
if isinstance(node, ConstNode):
|
|
218
|
+
return ir.Constant(self.llvm_type, node.value)
|
|
219
|
+
|
|
220
|
+
builder = self.builder
|
|
221
|
+
|
|
222
|
+
# If it is a var...
|
|
223
|
+
if isinstance(node, VarNode):
|
|
224
|
+
if node.is_const():
|
|
225
|
+
return ir.Constant(self.llvm_type, node.const.value)
|
|
226
|
+
else:
|
|
227
|
+
return builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
|
|
228
|
+
|
|
229
|
+
analysis = self.analysis
|
|
230
|
+
|
|
231
|
+
# If it is an op _not_ in the results...
|
|
232
|
+
idx: Optional[int] = analysis.op_to_tmp.get(node_id)
|
|
233
|
+
if idx is not None:
|
|
234
|
+
return builder.load(builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)]))
|
|
235
|
+
|
|
236
|
+
# If it is an op in the results...
|
|
237
|
+
idx: Optional[int] = analysis.op_to_result.get(node_id)
|
|
238
|
+
if idx is not None:
|
|
239
|
+
return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
|
|
240
|
+
|
|
241
|
+
assert False, 'not reached'
|
|
242
|
+
|
|
243
|
+
def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
|
|
244
|
+
"""
|
|
245
|
+
Store the given IR value as a result for the given op node.
|
|
246
|
+
"""
|
|
247
|
+
builder = self.builder
|
|
248
|
+
analysis = self.analysis
|
|
249
|
+
node_id: int = id(op_node)
|
|
250
|
+
|
|
251
|
+
# If it is an op _not_ in the results...
|
|
252
|
+
idx: Optional[int] = analysis.op_to_tmp.get(node_id)
|
|
253
|
+
if idx is not None:
|
|
254
|
+
ptr: ir.GEPInstr = builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)])
|
|
255
|
+
builder.store(value, ptr)
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
# If it is an op in the results...
|
|
259
|
+
idx: Optional[int] = analysis.op_to_result.get(node_id)
|
|
260
|
+
if idx is not None:
|
|
261
|
+
ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
|
|
262
|
+
builder.store(value, ptr)
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
assert False, 'not reached'
|
|
266
|
+
|
|
267
|
+
def store_result(self, value: ir.Value, idx: int) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Store the given IR value in the indexed result slot.
|
|
270
|
+
"""
|
|
271
|
+
builder = self.builder
|
|
272
|
+
ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
|
|
273
|
+
builder.store(value, ptr)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class _FunctionBuilderFuncs(_FunctionBuilderTmps):
|
|
277
|
+
"""
|
|
278
|
+
A function builder that puts op node calculations into the temporary working memory,
|
|
279
|
+
but each op node becomes its own sub-function.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
|
|
283
|
+
builder: ir.IRBuilder = self.builder
|
|
284
|
+
save_block = builder.block
|
|
285
|
+
|
|
286
|
+
sub_function_name: str = f'sub_{id(op_node)}'
|
|
287
|
+
function_type = builder.function.type.pointee
|
|
288
|
+
sub_function = ir.Function(builder.module, function_type, name=sub_function_name)
|
|
289
|
+
sub_function.attributes.add('noinline') # alwaysinline, noinline
|
|
290
|
+
bb_entry = sub_function.append_basic_block(sub_function_name + '_entry')
|
|
291
|
+
self.builder.position_at_end(bb_entry)
|
|
292
|
+
|
|
293
|
+
value: ir.Value = self.value(op_node.args[0])
|
|
294
|
+
for arg in op_node.args[1:]:
|
|
295
|
+
next_value: ir.Value = self.value(arg)
|
|
296
|
+
value = op(self.builder, value, next_value)
|
|
297
|
+
self.store_calculation(value, op_node)
|
|
298
|
+
|
|
299
|
+
builder.ret_void()
|
|
300
|
+
|
|
301
|
+
# Restore builder to main function
|
|
302
|
+
builder.position_at_end(save_block)
|
|
303
|
+
builder.call(sub_function, [self.in_args, self.tmp_args, self.out_args])
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@dataclass
|
|
307
|
+
class _FunctionBuilderStack(_FunctionBuilder):
|
|
308
|
+
"""
|
|
309
|
+
A function builder that puts op node calculations onto the stack.
|
|
310
|
+
"""
|
|
311
|
+
builder: ir.IRBuilder
|
|
312
|
+
analysis: CircuitAnalysis
|
|
313
|
+
llvm_type: ir.Type
|
|
314
|
+
llvm_idx_type: ir.Type
|
|
315
|
+
in_args: ir.Value
|
|
316
|
+
out_args: ir.Value
|
|
317
|
+
ir_cache: Dict[int, ir.Value]
|
|
318
|
+
|
|
319
|
+
def number_of_tmps(self) -> int:
|
|
320
|
+
return 0
|
|
321
|
+
|
|
322
|
+
def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
|
|
323
|
+
value: ir.Value = self.value(op_node.args[0])
|
|
324
|
+
for arg in op_node.args[1:]:
|
|
325
|
+
next_value: ir.Value = self.value(arg)
|
|
326
|
+
value = op(self.builder, value, next_value)
|
|
327
|
+
self.store_calculation(value, op_node)
|
|
328
|
+
|
|
329
|
+
def value(self, node: CircuitNode) -> ir.Value:
|
|
330
|
+
"""
|
|
331
|
+
Return an IR value for the given circuit node.
|
|
332
|
+
"""
|
|
333
|
+
node_id: int = id(node)
|
|
334
|
+
|
|
335
|
+
# First check if it is in the IR cache
|
|
336
|
+
cached: Optional[ir.Value] = self.ir_cache.get(node_id)
|
|
337
|
+
if cached is not None:
|
|
338
|
+
return cached
|
|
339
|
+
|
|
340
|
+
# If it is a constant...
|
|
341
|
+
if isinstance(node, ConstNode):
|
|
342
|
+
value = ir.Constant(self.llvm_type, node.value)
|
|
343
|
+
self.ir_cache[node_id] = value
|
|
344
|
+
return value
|
|
345
|
+
|
|
346
|
+
builder = self.builder
|
|
347
|
+
|
|
348
|
+
# If it is a var...
|
|
349
|
+
if isinstance(node, VarNode):
|
|
350
|
+
if node.is_const():
|
|
351
|
+
value = ir.Constant(self.llvm_type, node.const.value)
|
|
352
|
+
else:
|
|
353
|
+
value = builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
|
|
354
|
+
self.ir_cache[node_id] = value
|
|
355
|
+
return value
|
|
356
|
+
|
|
357
|
+
# If it is an op in the results...
|
|
358
|
+
idx: Optional[int] = self.analysis.op_to_result.get(node_id)
|
|
359
|
+
if idx is not None:
|
|
360
|
+
return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
|
|
361
|
+
|
|
362
|
+
assert False, 'not reached'
|
|
363
|
+
|
|
364
|
+
def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
|
|
365
|
+
"""
|
|
366
|
+
Store the given IR value as a result for the given op node.
|
|
367
|
+
"""
|
|
368
|
+
node_id: int = id(op_node)
|
|
369
|
+
|
|
370
|
+
# If it is an op in the results...
|
|
371
|
+
idx: Optional[int] = self.analysis.op_to_result.get(node_id)
|
|
372
|
+
if idx is not None:
|
|
373
|
+
builder = self.builder
|
|
374
|
+
ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
|
|
375
|
+
builder.store(value, ptr)
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
# Just put it in the ir_cache.
|
|
379
|
+
# This effectively forces the LLVM compiler to put it on the stack when registers run out.
|
|
380
|
+
self.ir_cache[node_id] = value
|
|
381
|
+
|
|
382
|
+
def store_result(self, value: ir.Value, idx: int) -> None:
|
|
383
|
+
"""
|
|
384
|
+
Store the given IR value in the indexed result slot.
|
|
385
|
+
"""
|
|
386
|
+
builder = self.builder
|
|
387
|
+
ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
|
|
388
|
+
builder.store(value, ptr)
|