compiled-knowledge 4.0.0__cp313-cp313-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (182) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37510 -0
  4. ck/circuit/_circuit_cy.cpython-313-x86_64-linux-musl.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +75 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +27 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19830 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-x86_64-linux-musl.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +128 -0
  16. ck/circuit_compiler/interpret_compiler.py +255 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +552 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10615 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-x86_64-linux-musl.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +251 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +70 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +56 -0
  58. ck/example/truss.py +51 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +482 -0
  63. ck/in_out/parser_utils.py +189 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3482 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +236 -0
  73. ck/pgm_circuit/pgm_circuit.py +88 -0
  74. ck/pgm_circuit/program_with_slotmap.py +217 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +78 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +60 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16393 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-x86_64-linux-musl.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +572 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +52 -0
  100. ck/probability/pgm_probability_space.py +36 -0
  101. ck/probability/probability_space.py +627 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +106 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +234 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +164 -0
  118. ck/utils/local_config.py +278 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/ace/simple_ace_demo.py +18 -0
  129. ck_demos/all_demos.py +88 -0
  130. ck_demos/circuit/__init__.py +0 -0
  131. ck_demos/circuit/demo_circuit_dump.py +22 -0
  132. ck_demos/circuit/demo_derivatives.py +43 -0
  133. ck_demos/circuit_compiler/__init__.py +0 -0
  134. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  135. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  136. ck_demos/getting_started/__init__.py +0 -0
  137. ck_demos/getting_started/simple_demo.py +18 -0
  138. ck_demos/pgm/__init__.py +0 -0
  139. ck_demos/pgm/demo_pgm_dump.py +18 -0
  140. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  141. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  142. ck_demos/pgm/show_examples.py +25 -0
  143. ck_demos/pgm_compiler/__init__.py +0 -0
  144. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  145. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  146. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  147. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  148. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  149. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  150. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  151. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  152. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  153. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  154. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  155. ck_demos/pgm_inference/__init__.py +0 -0
  156. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  157. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  158. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  159. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  160. ck_demos/programs/__init__.py +0 -0
  161. ck_demos/programs/demo_program_buffer.py +24 -0
  162. ck_demos/programs/demo_program_multi.py +24 -0
  163. ck_demos/programs/demo_program_none.py +19 -0
  164. ck_demos/programs/demo_program_single.py +23 -0
  165. ck_demos/programs/demo_raw_program_dump.py +17 -0
  166. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  167. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  168. ck_demos/sampling/__init__.py +0 -0
  169. ck_demos/sampling/check_sampler.py +71 -0
  170. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  171. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  172. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  173. ck_demos/utils/__init__.py +0 -0
  174. ck_demos/utils/compare.py +120 -0
  175. ck_demos/utils/convert_network.py +45 -0
  176. ck_demos/utils/sample_model.py +216 -0
  177. ck_demos/utils/stop_watch.py +384 -0
  178. compiled_knowledge-4.0.0.dist-info/METADATA +50 -0
  179. compiled_knowledge-4.0.0.dist-info/RECORD +182 -0
  180. compiled_knowledge-4.0.0.dist-info/WHEEL +5 -0
  181. compiled_knowledge-4.0.0.dist-info/licenses/LICENSE.txt +21 -0
  182. compiled_knowledge-4.0.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,255 @@
1
+ from __future__ import annotations
2
+
3
+ import ctypes as ct
4
+ from dataclasses import dataclass
5
+ from typing import Sequence, Optional, Dict, List, Tuple, Callable, assert_never
6
+
7
+ import numpy as np
8
+
9
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
10
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
11
+ from ..circuit import Circuit, CircuitNode, VarNode, OpNode, ADD, MUL
12
+ from ..program.raw_program import RawProgram, RawProgramFunction
13
+ from ..utils.iter_extras import multiply, first
14
+ from ..utils.np_extras import NDArrayNumeric, DTypeNumeric
15
+
16
+ # index to a value array
17
+ _VARS = 0
18
+ _CONSTS = 1
19
+ _TMPS = 2
20
+ _RESULT = 3
21
+
22
+
23
+ def compile_circuit(
24
+ *result: CircuitNode,
25
+ input_vars: InputVars = InferVars.ALL,
26
+ circuit: Optional[Circuit] = None,
27
+ dtype: DTypeNumeric = np.double,
28
+ ) -> InterpreterRawProgram:
29
+ """
30
+ Make a RawProgram that interprets the given circuit.
31
+
32
+ Args:
33
+ *result: result nodes nominating the results of the returned program.
34
+ input_vars: How to determine the input variables.
35
+ circuit: optionally explicitly specify the Circuit.
36
+ dtype: the numpy DType to use for the raw program.
37
+
38
+ Returns:
39
+ a raw program.
40
+
41
+ Raises:
42
+ ValueError: if the circuit is unknown, but it is needed.
43
+ ValueError: if not all nodes are from the same circuit.
44
+ """
45
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
46
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
47
+ instructions: List[_Instruction]
48
+ np_consts: NDArrayNumeric
49
+ instructions, np_consts = _make_instructions(analysis, dtype)
50
+
51
+ return InterpreterRawProgram(
52
+ in_vars=in_vars,
53
+ result=result,
54
+ op_nodes=analysis.op_nodes,
55
+ dtype=dtype,
56
+ instructions=instructions,
57
+ np_consts=np_consts,
58
+ )
59
+
60
+
61
+ class InterpreterRawProgram(RawProgram):
62
+ def __init__(
63
+ self,
64
+ in_vars: Sequence[VarNode],
65
+ result: Sequence[CircuitNode],
66
+ op_nodes: Sequence[OpNode],
67
+ dtype: DTypeNumeric,
68
+ instructions: List[_Instruction],
69
+ np_consts: NDArrayNumeric,
70
+ ):
71
+ self.instructions = instructions
72
+ self.np_consts = np_consts
73
+
74
+ function = _make_function(
75
+ instructions=instructions,
76
+ np_consts=np_consts,
77
+ )
78
+
79
+ super().__init__(
80
+ function=function,
81
+ dtype=dtype,
82
+ number_of_vars=len(in_vars),
83
+ number_of_tmps=len(op_nodes),
84
+ number_of_results=len(result),
85
+ var_indices=tuple(var.idx for var in in_vars),
86
+ )
87
+
88
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
89
+ super().dump(prefix=prefix, indent=indent)
90
+ print(f'{prefix}number of instructions = {len(self.instructions)}')
91
+ if show_instructions:
92
+ print(f'{prefix}instructions:')
93
+ next_prefix: str = prefix + indent
94
+ for instruction in self.instructions:
95
+ print(f'{next_prefix}{instruction.to_str(self.var_indices, self.np_consts)}')
96
+
97
+ def __getstate__(self):
98
+ """
99
+ Support for pickle.
100
+ """
101
+ return {
102
+ 'dtype': self.dtype,
103
+ 'number_of_vars': self.number_of_vars,
104
+ 'number_of_tmps': self.number_of_tmps,
105
+ 'number_of_results': self.number_of_results,
106
+ 'var_indices': self.var_indices,
107
+ #
108
+ 'instructions': self.instructions,
109
+ 'np_consts': self.np_consts,
110
+ }
111
+
112
+ def __setstate__(self, state):
113
+ """
114
+ Support for pickle.
115
+ """
116
+ self.dtype = state['dtype']
117
+ self.number_of_vars = state['number_of_vars']
118
+ self.number_of_tmps = state['number_of_tmps']
119
+ self.number_of_results = state['number_of_results']
120
+ self.var_indices = state['var_indices']
121
+ #
122
+ self.instructions = state['instructions']
123
+ self.np_consts = state['np_consts']
124
+
125
+ self.function = _make_function(
126
+ instructions=self.instructions,
127
+ np_consts=self.np_consts,
128
+ )
129
+
130
+
131
+ def _make_instructions(
132
+ analysis: CircuitAnalysis,
133
+ dtype: DTypeNumeric,
134
+ ) -> Tuple[Sequence[_Instruction], NDArrayNumeric]:
135
+ # Store const values in a numpy array
136
+ node_to_const_idx: Dict[int, int] = {
137
+ id(node): i
138
+ for i, node in enumerate(analysis.const_nodes)
139
+ }
140
+ np_consts: NDArrayNumeric = np.zeros(len(node_to_const_idx), dtype=dtype)
141
+ for i, node in enumerate(analysis.const_nodes):
142
+ np_consts[i] = node.value
143
+
144
+ # Where to get input values for each possible node.
145
+ node_to_element: Dict[int, _ElementID] = {}
146
+ # const nodes
147
+ for node_id, const_idx in node_to_const_idx.items():
148
+ node_to_element[node_id] = _ElementID(_CONSTS, const_idx)
149
+ # var nodes
150
+ var_node: VarNode
151
+ for i, var_node in enumerate(analysis.var_nodes):
152
+ if var_node.is_const():
153
+ node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
154
+ else:
155
+ node_to_element[id(var_node)] = _ElementID(_VARS, i)
156
+ # op nodes
157
+ for node_id, tmp_index in analysis.op_to_tmp.items():
158
+ node_to_element[node_id] = _ElementID(_TMPS, tmp_index)
159
+ for node_id, tmp_index in analysis.op_to_result.items():
160
+ node_to_element[node_id] = _ElementID(_RESULT, tmp_index)
161
+
162
+ # Build instructions
163
+ instructions: List[_Instruction] = []
164
+
165
+ op_node: OpNode
166
+ for op_node in analysis.op_nodes:
167
+ dest: _ElementID = node_to_element[id(op_node)]
168
+ args: List[_ElementID] = [
169
+ node_to_element[id(arg)]
170
+ for arg in op_node.args
171
+ ]
172
+ if op_node.symbol == MUL:
173
+ operation = multiply
174
+ elif op_node.symbol == ADD:
175
+ operation = sum
176
+ else:
177
+ assert_never('not reached')
178
+
179
+ instructions.append(_Instruction(operation, args, dest))
180
+
181
+ # Add any copy operations, i.e., result nodes that are not op nodes
182
+ for i, node in enumerate(analysis.result_nodes):
183
+ if not isinstance(node, OpNode):
184
+ source: _ElementID = node_to_element[id(node)]
185
+ instructions.append(_Instruction(first, [source], _ElementID(_RESULT, i)))
186
+
187
+ return instructions, np_consts
188
+
189
+
190
+ def _make_function(
191
+ instructions: List[_Instruction],
192
+ np_consts: NDArrayNumeric,
193
+ ) -> RawProgramFunction:
194
+ """
195
+ Make a RawProgram function that executes the given instructions.
196
+ """
197
+
198
+ # RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
199
+ def raw_program_function(vars_in: ct.POINTER, tmps: ct.POINTER, result_out: ct.POINTER) -> None:
200
+ nonlocal np_consts
201
+ nonlocal instructions
202
+
203
+ arrays: List[ct.POINTER] = [None, None, None, None]
204
+ arrays[_VARS] = vars_in
205
+ arrays[_TMPS] = tmps
206
+ arrays[_RESULT] = result_out
207
+ arrays[_CONSTS] = np_consts
208
+
209
+ def get_value(_element: _ElementID):
210
+ return arrays[_element.array][_element.index]
211
+
212
+ instruction: _Instruction
213
+ element: _ElementID
214
+ for instruction in instructions:
215
+ value = instruction.operation(get_value(element) for element in instruction.args)
216
+ dest: _ElementID = instruction.dest
217
+ arrays[dest.array][dest.index] = value
218
+
219
+ return raw_program_function
220
+
221
+
222
+ @dataclass
223
+ class _ElementID:
224
+ array: int # VARS, TMPS, CONSTS, RESULT
225
+ index: int # index into the array
226
+
227
+ def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
228
+ if self.array == _VARS:
229
+ return f'var[{var_indices[self.index]}]'
230
+ elif self.array == _TMPS:
231
+ return f'tmp[{self.index}]'
232
+ elif self.array == _CONSTS:
233
+ return str(consts.item(self.index))
234
+ elif self.array == _RESULT:
235
+ return f'result[{self.index}]'
236
+ else:
237
+ return f'?[{self.index}]'
238
+
239
+
240
+ @dataclass
241
+ class _Instruction:
242
+ operation: Callable
243
+ args: Sequence[_ElementID]
244
+ dest: _ElementID
245
+
246
+ def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
247
+ symbol: str
248
+ if self.operation is multiply:
249
+ symbol = 'mul'
250
+ elif self.operation == sum:
251
+ symbol = 'sum'
252
+ else:
253
+ symbol = '<?>'
254
+ args: str = ' '.join(elem.to_str(var_indices, consts) for elem in self.args)
255
+ return f'{self.dest.to_str(var_indices, consts)} = {symbol} {args}'
@@ -0,0 +1,388 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Sequence, Optional, Tuple, Dict, Protocol, assert_never
6
+
7
+ import llvmlite.binding as llvm
8
+ import llvmlite.ir as ir
9
+
10
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
11
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
12
+ from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram, IrBOp
13
+ from ..circuit import Circuit, VarNode, CircuitNode, OpNode, MUL, ADD, ConstNode
14
+ from ..program.raw_program import RawProgramFunction
15
+
16
+
17
+ class Flavour(Enum):
18
+ STACK = 0 # No working temporary memory requested - all on stack.
19
+ TMPS = 1 # Working temporary memory used for op node calculations.
20
+ FUNCS = 2 # Working temporary memory used for op node calculations, one sub-function per op-node.
21
+
22
+
23
+ DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
24
+ DEFAULT_FLAVOUR: Flavour = Flavour.TMPS
25
+
26
+
27
+ def compile_circuit(
28
+ *result: CircuitNode,
29
+ input_vars: InputVars = InferVars.ALL,
30
+ circuit: Optional[Circuit] = None,
31
+ data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
32
+ flavour: Flavour = DEFAULT_FLAVOUR,
33
+ keep_llvm_program: bool = True,
34
+ opt: int = 2,
35
+ ) -> LLVMRawProgram:
36
+ """
37
+ Compile the given circuit using LLVM.
38
+
39
+ This creates an LLVM program where each circuit op node is converted to
40
+ one or more LLVM binary op machine code instructions. For large circuits
41
+ this results in a large LLVM program which can be slow to compile.
42
+
43
+ This compiler produces a RawProgram that _does_ use client managed working memory.
44
+
45
+ Conforms to the CircuitCompiler protocol.
46
+
47
+ Args:
48
+ *result: result nodes nominating the results of the returned program.
49
+ input_vars: How to determine the input variables.
50
+ circuit: optionally explicitly specify the Circuit.
51
+ data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
52
+ keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
53
+ flavour: what flavour of LLVM program to construct.
54
+ opt:The optimization level to use by LLVM MC JIT.
55
+
56
+ Returns:
57
+ a raw program.
58
+
59
+ Raises:
60
+ ValueError: if the circuit is unknown, but it is needed.
61
+ ValueError: if not all nodes are from the same circuit.
62
+ ValueError: if the program data type could not be interpreted.
63
+ """
64
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
65
+ var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
66
+
67
+ # Get the type info
68
+ type_info: TypeInfo
69
+ if isinstance(data_type, DataType):
70
+ type_info = data_type.value
71
+ elif isinstance(data_type, TypeInfo):
72
+ type_info = data_type
73
+ else:
74
+ raise ValueError(f'could not interpret program data type: {data_type!r}')
75
+
76
+ # Compile the circuit to an LLVM module representing a RawProgramFunction
77
+ llvm_program: str
78
+ number_of_tmps: int
79
+ llvm_program, number_of_tmps = _make_llvm_program(in_vars, result, type_info, flavour)
80
+
81
+ # Compile the LLVM program to a native executable
82
+ engine: llvm.ExecutionEngine
83
+ function: RawProgramFunction
84
+ engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
85
+
86
+ return LLVMRawProgram(
87
+ function=function,
88
+ dtype=type_info.dtype,
89
+ number_of_vars=len(var_indices),
90
+ number_of_tmps=number_of_tmps,
91
+ number_of_results=len(result),
92
+ var_indices=var_indices,
93
+ llvm_program=llvm_program if keep_llvm_program else None,
94
+ engine=engine,
95
+ opt=opt,
96
+ )
97
+
98
+
99
+ def _make_llvm_program(
100
+ in_vars: Sequence[VarNode],
101
+ result: Sequence[CircuitNode],
102
+ type_info: TypeInfo,
103
+ flavour: Flavour,
104
+ ) -> Tuple[str, int]:
105
+ """
106
+ Returns:
107
+ (llvm_program, number_of_tmps)
108
+ """
109
+ llvm_function = IRFunction(type_info)
110
+
111
+ builder = llvm_function.builder
112
+ type_info = llvm_function.type_info
113
+ function = llvm_function.function
114
+
115
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
116
+
117
+ function_builder: _FunctionBuilder
118
+ if flavour == Flavour.STACK:
119
+ function_builder = _FunctionBuilderStack(
120
+ builder=builder,
121
+ analysis=analysis,
122
+ llvm_type=type_info.llvm_type,
123
+ llvm_idx_type=ir.IntType(32),
124
+ in_args=function.args[0],
125
+ out_args=function.args[2],
126
+ ir_cache={},
127
+ )
128
+ elif flavour == Flavour.TMPS:
129
+ function_builder = _FunctionBuilderTmps(
130
+ builder=builder,
131
+ analysis=analysis,
132
+ llvm_type=type_info.llvm_type,
133
+ llvm_idx_type=ir.IntType(32),
134
+ in_args=function.args[0],
135
+ tmp_args=function.args[1],
136
+ out_args=function.args[2],
137
+ )
138
+ elif flavour == Flavour.FUNCS:
139
+ function_builder = _FunctionBuilderFuncs(
140
+ builder=builder,
141
+ analysis=analysis,
142
+ llvm_type=type_info.llvm_type,
143
+ llvm_idx_type=ir.IntType(32),
144
+ in_args=function.args[0],
145
+ tmp_args=function.args[1],
146
+ out_args=function.args[2],
147
+ )
148
+ else:
149
+ raise ValueError(f'unknown LLVM program flavour: {flavour!r}')
150
+
151
+ # Add a calculation for each op node
152
+ for op_node in analysis.op_nodes:
153
+ if op_node.symbol == ADD:
154
+ op: IrBOp = type_info.add
155
+ elif op_node.symbol == MUL:
156
+ op: IrBOp = type_info.mul
157
+ else:
158
+ raise RuntimeError(f'unknown op node: {op_node.symbol!r}')
159
+ function_builder.process_op_node(op_node, op)
160
+
161
+ # Copy any non-op node values to the results
162
+ for idx, node in enumerate(result):
163
+ if not isinstance(node, OpNode):
164
+ value: ir.Value = function_builder.value(node)
165
+ function_builder.store_result(value, idx)
166
+
167
+ # Return from the function
168
+ builder.ret_void()
169
+
170
+ return llvm_function.llvm_program(), function_builder.number_of_tmps()
171
+
172
+
173
+ class _FunctionBuilder(Protocol):
174
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
175
+ ...
176
+
177
+ def value(self, node: CircuitNode) -> ir.Value:
178
+ ...
179
+
180
+ def store_result(self, value: ir.Value, idx: int) -> None:
181
+ ...
182
+
183
+ def number_of_tmps(self) -> int:
184
+ ...
185
+
186
+
187
+ @dataclass
188
+ class _FunctionBuilderTmps(_FunctionBuilder):
189
+ """
190
+ A function builder that puts op node calculations into the temporary working memory.
191
+ """
192
+ builder: ir.IRBuilder
193
+ analysis: CircuitAnalysis
194
+ llvm_type: ir.Type
195
+ llvm_idx_type: ir.Type
196
+ in_args: ir.Value
197
+ tmp_args: ir.Value
198
+ out_args: ir.Value
199
+
200
+ def number_of_tmps(self) -> int:
201
+ return len(self.analysis.op_to_tmp)
202
+
203
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
204
+ value: ir.Value = self.value(op_node.args[0])
205
+ for arg in op_node.args[1:]:
206
+ next_value: ir.Value = self.value(arg)
207
+ value = op(self.builder, value, next_value)
208
+ self.store_calculation(value, op_node)
209
+
210
+ def value(self, node: CircuitNode) -> ir.Value:
211
+ """
212
+ Return an IR value for the given circuit node.
213
+ """
214
+ node_id: int = id(node)
215
+
216
+ # If it is a constant...
217
+ if isinstance(node, ConstNode):
218
+ return ir.Constant(self.llvm_type, node.value)
219
+
220
+ builder = self.builder
221
+
222
+ # If it is a var...
223
+ if isinstance(node, VarNode):
224
+ if node.is_const():
225
+ return ir.Constant(self.llvm_type, node.const.value)
226
+ else:
227
+ return builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
228
+
229
+ analysis = self.analysis
230
+
231
+ # If it is an op _not_ in the results...
232
+ idx: Optional[int] = analysis.op_to_tmp.get(node_id)
233
+ if idx is not None:
234
+ return builder.load(builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)]))
235
+
236
+ # If it is an op in the results...
237
+ idx: Optional[int] = analysis.op_to_result.get(node_id)
238
+ if idx is not None:
239
+ return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
240
+
241
+ assert_never('not reached')
242
+
243
+ def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
244
+ """
245
+ Store the given IR value as a result for the given op node.
246
+ """
247
+ builder = self.builder
248
+ analysis = self.analysis
249
+ node_id: int = id(op_node)
250
+
251
+ # If it is an op _not_ in the results...
252
+ idx: Optional[int] = analysis.op_to_tmp.get(node_id)
253
+ if idx is not None:
254
+ ptr: ir.GEPInstr = builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)])
255
+ builder.store(value, ptr)
256
+ return
257
+
258
+ # If it is an op in the results...
259
+ idx: Optional[int] = analysis.op_to_result.get(node_id)
260
+ if idx is not None:
261
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
262
+ builder.store(value, ptr)
263
+ return
264
+
265
+ assert_never('not reached')
266
+
267
+ def store_result(self, value: ir.Value, idx: int) -> None:
268
+ """
269
+ Store the given IR value in the indexed result slot.
270
+ """
271
+ builder = self.builder
272
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
273
+ builder.store(value, ptr)
274
+
275
+
276
+ class _FunctionBuilderFuncs(_FunctionBuilderTmps):
277
+ """
278
+ A function builder that puts op node calculations into the temporary working memory,
279
+ but each op node becomes its own sub-function.
280
+ """
281
+
282
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
283
+ builder: ir.IRBuilder = self.builder
284
+ save_block = builder.block
285
+
286
+ sub_function_name: str = f'sub_{id(op_node)}'
287
+ function_type = builder.function.type.pointee
288
+ sub_function = ir.Function(builder.module, function_type, name=sub_function_name)
289
+ sub_function.attributes.add('noinline') # alwaysinline, noinline
290
+ bb_entry = sub_function.append_basic_block(sub_function_name + '_entry')
291
+ self.builder.position_at_end(bb_entry)
292
+
293
+ value: ir.Value = self.value(op_node.args[0])
294
+ for arg in op_node.args[1:]:
295
+ next_value: ir.Value = self.value(arg)
296
+ value = op(self.builder, value, next_value)
297
+ self.store_calculation(value, op_node)
298
+
299
+ builder.ret_void()
300
+
301
+ # Restore builder to main function
302
+ builder.position_at_end(save_block)
303
+ builder.call(sub_function, [self.in_args, self.tmp_args, self.out_args])
304
+
305
+
306
+ @dataclass
307
+ class _FunctionBuilderStack(_FunctionBuilder):
308
+ """
309
+ A function builder that puts op node calculations onto the stack.
310
+ """
311
+ builder: ir.IRBuilder
312
+ analysis: CircuitAnalysis
313
+ llvm_type: ir.Type
314
+ llvm_idx_type: ir.Type
315
+ in_args: ir.Value
316
+ out_args: ir.Value
317
+ ir_cache: Dict[int, ir.Value]
318
+
319
+ def number_of_tmps(self) -> int:
320
+ return 0
321
+
322
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
323
+ value: ir.Value = self.value(op_node.args[0])
324
+ for arg in op_node.args[1:]:
325
+ next_value: ir.Value = self.value(arg)
326
+ value = op(self.builder, value, next_value)
327
+ self.store_calculation(value, op_node)
328
+
329
+ def value(self, node: CircuitNode) -> ir.Value:
330
+ """
331
+ Return an IR value for the given circuit node.
332
+ """
333
+ node_id: int = id(node)
334
+
335
+ # First check if it is in the IR cache
336
+ cached: Optional[ir.Value] = self.ir_cache.get(node_id)
337
+ if cached is not None:
338
+ return cached
339
+
340
+ # If it is a constant...
341
+ if isinstance(node, ConstNode):
342
+ value = ir.Constant(self.llvm_type, node.value)
343
+ self.ir_cache[node_id] = value
344
+ return value
345
+
346
+ builder = self.builder
347
+
348
+ # If it is a var...
349
+ if isinstance(node, VarNode):
350
+ if node.is_const():
351
+ value = ir.Constant(self.llvm_type, node.const.value)
352
+ else:
353
+ value = builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
354
+ self.ir_cache[node_id] = value
355
+ return value
356
+
357
+ # If it is an op in the results...
358
+ idx: Optional[int] = self.analysis.op_to_result.get(node_id)
359
+ if idx is not None:
360
+ return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
361
+
362
+ assert_never('not reached')
363
+
364
+ def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
365
+ """
366
+ Store the given IR value as a result for the given op node.
367
+ """
368
+ node_id: int = id(op_node)
369
+
370
+ # If it is an op in the results...
371
+ idx: Optional[int] = self.analysis.op_to_result.get(node_id)
372
+ if idx is not None:
373
+ builder = self.builder
374
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
375
+ builder.store(value, ptr)
376
+ return
377
+
378
+ # Just put it in the ir_cache.
379
+ # This effectively forces the LLVM compiler to put it on the stack when registers run out.
380
+ self.ir_cache[node_id] = value
381
+
382
+ def store_result(self, value: ir.Value, idx: int) -> None:
383
+ """
384
+ Store the given IR value in the indexed result slot.
385
+ """
386
+ builder = self.builder
387
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
388
+ builder.store(value, ptr)