compiled-knowledge 4.0.0a20__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,223 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Sequence, Optional, Dict, List, Tuple, Callable
5
+
6
+ import numpy as np
7
+ import ctypes as ct
8
+
9
+ from ..circuit import Circuit, CircuitNode, VarNode, OpNode, ADD, MUL
10
+ from ..program.raw_program import RawProgram, RawProgramFunction
11
+ from ..utils.iter_extras import multiply, first
12
+ from ..utils.np_extras import NDArrayNumeric, DTypeNumeric
13
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
14
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
15
+
16
+ # index to a value array
17
+ _VARS = 0
18
+ _CONSTS = 1
19
+ _TMPS = 2
20
+ _RESULT = 3
21
+
22
+
23
+ def compile_circuit(
24
+ *result: CircuitNode,
25
+ input_vars: InputVars = InferVars.ALL,
26
+ circuit: Optional[Circuit] = None,
27
+ dtype: DTypeNumeric = np.double,
28
+ ) -> InterpreterRawProgram:
29
+ """
30
+ Make a RawProgram that interprets the given circuit.
31
+
32
+ Args:
33
+ *result: result nodes nominating the results of the returned program.
34
+ input_vars: How to determine the input variables.
35
+ circuit: optionally explicitly specify the Circuit.
36
+ dtype: the numpy DType to use for the raw program.
37
+
38
+ Returns:
39
+ a raw program.
40
+
41
+ Raises:
42
+ ValueError: if the circuit is unknown, but it is needed.
43
+ ValueError: if not all nodes are from the same circuit.
44
+ """
45
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
46
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
47
+ instructions: List[_Instruction]
48
+ np_consts: NDArrayNumeric
49
+ instructions, np_consts = _make_instructions(analysis, dtype)
50
+
51
+ return InterpreterRawProgram(
52
+ in_vars=in_vars,
53
+ result=result,
54
+ op_nodes=analysis.op_nodes,
55
+ dtype=dtype,
56
+ instructions=instructions,
57
+ np_consts=np_consts,
58
+ )
59
+
60
+
61
+ class InterpreterRawProgram(RawProgram):
62
+ def __init__(
63
+ self,
64
+ in_vars: Sequence[VarNode],
65
+ result: Sequence[CircuitNode],
66
+ op_nodes: Sequence[OpNode],
67
+ dtype: DTypeNumeric,
68
+ instructions: List[_Instruction],
69
+ np_consts: NDArrayNumeric,
70
+ ):
71
+ self.instructions = instructions
72
+ self.np_consts = np_consts
73
+
74
+ function = _make_function(
75
+ instructions=instructions,
76
+ np_consts=np_consts,
77
+ )
78
+
79
+ super().__init__(
80
+ function=function,
81
+ dtype=dtype,
82
+ number_of_vars=len(in_vars),
83
+ number_of_tmps=len(op_nodes),
84
+ number_of_results=len(result),
85
+ var_indices=tuple(var.idx for var in in_vars),
86
+ )
87
+
88
+ def __getstate__(self):
89
+ """
90
+ Support for pickle.
91
+ """
92
+ return {
93
+ 'dtype': self.dtype,
94
+ 'number_of_vars': self.number_of_vars,
95
+ 'number_of_tmps': self.number_of_tmps,
96
+ 'number_of_results': self.number_of_results,
97
+ 'var_indices': self.var_indices,
98
+ #
99
+ 'instructions': self.instructions,
100
+ 'np_consts': self.np_consts,
101
+ }
102
+
103
+ def __setstate__(self, state):
104
+ """
105
+ Support for pickle.
106
+ """
107
+ self.dtype = state['dtype']
108
+ self.number_of_vars = state['number_of_vars']
109
+ self.number_of_tmps = state['number_of_tmps']
110
+ self.number_of_results = state['number_of_results']
111
+ self.var_indices = state['var_indices']
112
+ #
113
+ self.instructions = state['instructions']
114
+ self.np_consts = state['np_consts']
115
+
116
+ self.function = _make_function(
117
+ instructions=self.instructions,
118
+ np_consts=self.np_consts,
119
+ )
120
+
121
+
122
+ def _make_instructions(
123
+ analysis: CircuitAnalysis,
124
+ dtype: DTypeNumeric,
125
+ ) -> Tuple[Sequence[_Instruction], NDArrayNumeric]:
126
+
127
+ # Store const values in a numpy array
128
+ node_to_const_idx: Dict[int, int] = {
129
+ id(node): i
130
+ for i, node in enumerate(analysis.const_nodes)
131
+ }
132
+ np_consts: NDArrayNumeric = np.zeros(len(node_to_const_idx), dtype=dtype)
133
+ for i, node in enumerate(analysis.const_nodes):
134
+ np_consts[i] = node.value
135
+
136
+ # Where to get input values for each possible node.
137
+ node_to_element: Dict[int, _ElementID] = {}
138
+ # const nodes
139
+ for node_id, const_idx in node_to_const_idx.items():
140
+ node_to_element[node_id] = _ElementID(_CONSTS, const_idx)
141
+ # var nodes
142
+ for i, var_node in enumerate(analysis.var_nodes):
143
+ if var_node.is_const():
144
+ node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
145
+ else:
146
+ node_to_element[id(var_node)] = _ElementID(_VARS, i)
147
+ # op nodes
148
+ for node_id, tmp_index in analysis.op_to_tmp.items():
149
+ node_to_element[node_id] = _ElementID(_TMPS, tmp_index)
150
+ for node_id, tmp_index in analysis.op_to_result.items():
151
+ node_to_element[node_id] = _ElementID(_RESULT, tmp_index)
152
+
153
+ # Build instructions
154
+ instructions: List[_Instruction] = []
155
+
156
+ op_node: OpNode
157
+ for op_node in analysis.op_nodes:
158
+ dest: _ElementID = node_to_element[id(op_node)]
159
+ args: List[_ElementID] = [
160
+ node_to_element[id(arg)]
161
+ for arg in op_node.args
162
+ ]
163
+ if op_node.symbol == MUL:
164
+ operation = multiply
165
+ elif op_node.symbol == ADD:
166
+ operation = sum
167
+ else:
168
+ assert False, 'symbol not understood'
169
+
170
+ instructions.append(_Instruction(operation, args, dest))
171
+
172
+ # Add any copy operations, i.e., result nodes that are not op nodes
173
+ for i, node in enumerate(analysis.result_nodes):
174
+ if not isinstance(node, OpNode):
175
+ source: _ElementID = node_to_element[id(node)]
176
+ instructions.append(_Instruction(first, [source], _ElementID(_RESULT, i)))
177
+
178
+ return instructions, np_consts
179
+
180
+
181
+ def _make_function(
182
+ instructions: List[_Instruction],
183
+ np_consts: NDArrayNumeric,
184
+ ) -> RawProgramFunction:
185
+ """
186
+ Make a RawProgram function that executes the given instructions.
187
+ """
188
+
189
+ # RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
190
+ def raw_program_function(vars_in: ct.POINTER, tmps: ct.POINTER, result_out: ct.POINTER) -> None:
191
+ nonlocal np_consts
192
+ nonlocal instructions
193
+
194
+ arrays: List[ct.POINTER] = [None, None, None, None]
195
+ arrays[_VARS] = vars_in
196
+ arrays[_TMPS] = tmps
197
+ arrays[_RESULT] = result_out
198
+ arrays[_CONSTS] = np_consts
199
+
200
+ def get_value(_element: _ElementID):
201
+ return arrays[_element.array][_element.index]
202
+
203
+ instruction: _Instruction
204
+ element: _ElementID
205
+ for instruction in instructions:
206
+ value = instruction.operation(get_value(element) for element in instruction.args)
207
+ dest: _ElementID = instruction.dest
208
+ arrays[dest.array][dest.index] = value
209
+
210
+ return raw_program_function
211
+
212
+
213
+ @dataclass
214
+ class _ElementID:
215
+ array: int # VARS, TMPS, CONSTS, RESULT
216
+ index: int # index into the array
217
+
218
+
219
+ @dataclass
220
+ class _Instruction:
221
+ operation: Callable
222
+ args: Sequence[_ElementID]
223
+ dest: _ElementID
@@ -0,0 +1,388 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Sequence, Optional, Tuple, Dict, Protocol
6
+
7
+ import llvmlite.binding as llvm
8
+ import llvmlite.ir as ir
9
+
10
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
11
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
12
+ from .support.llvm_ir_function import IRFunction, DataType, TypeInfo, compile_llvm_program, LLVMRawProgram, IrBOp
13
+ from ..circuit import Circuit, VarNode, CircuitNode, OpNode, MUL, ADD, ConstNode
14
+ from ..program.raw_program import RawProgramFunction
15
+
16
+
17
+ class Flavour(Enum):
18
+ STACK = 0 # No working temporary memory requested - all on stack.
19
+ TMPS = 1 # Working temporary memory used for op node calculations.
20
+ FUNCS = 2 # Working temporary memory used for op node calculations, one sub-function per op-node.
21
+
22
+
23
+ DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
24
+ DEFAULT_FLAVOUR: Flavour = Flavour.TMPS
25
+
26
+
27
+ def compile_circuit(
28
+ *result: CircuitNode,
29
+ input_vars: InputVars = InferVars.ALL,
30
+ circuit: Optional[Circuit] = None,
31
+ data_type: DataType | TypeInfo = DEFAULT_TYPE_INFO,
32
+ flavour: Flavour = DEFAULT_FLAVOUR,
33
+ keep_llvm_program: bool = True,
34
+ opt: int = 2,
35
+ ) -> LLVMRawProgram:
36
+ """
37
+ Compile the given circuit using LLVM.
38
+
39
+ This creates an LLVM program where each circuit op node is converted to
40
+ one or more LLVM binary op machine code instructions. For large circuits
41
+ this results in a large LLVM program which can be slow to compile.
42
+
43
+ This compiler produces a RawProgram that _does_ use client managed working memory.
44
+
45
+ Conforms to the CircuitCompiler protocol.
46
+
47
+ Args:
48
+ *result: result nodes nominating the results of the returned program.
49
+ input_vars: How to determine the input variables.
50
+ circuit: optionally explicitly specify the Circuit.
51
+ data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
52
+ keep_llvm_program: if true, the LLVM program will be kept. This is required for picking.
53
+ flavour: what flavour of LLVM program to construct.
54
+ opt:The optimization level to use by LLVM MC JIT.
55
+
56
+ Returns:
57
+ a raw program.
58
+
59
+ Raises:
60
+ ValueError: if the circuit is unknown, but it is needed.
61
+ ValueError: if not all nodes are from the same circuit.
62
+ ValueError: if the program data type could not be interpreted.
63
+ """
64
+ in_vars: Sequence[VarNode] = infer_input_vars(circuit, result, input_vars)
65
+ var_indices: Sequence[int] = tuple(var.idx for var in in_vars)
66
+
67
+ # Get the type info
68
+ type_info: TypeInfo
69
+ if isinstance(data_type, DataType):
70
+ type_info = data_type.value
71
+ elif isinstance(data_type, TypeInfo):
72
+ type_info = data_type
73
+ else:
74
+ raise ValueError(f'could not interpret program data type: {data_type!r}')
75
+
76
+ # Compile the circuit to an LLVM module representing a RawProgramFunction
77
+ llvm_program: str
78
+ number_of_tmps: int
79
+ llvm_program, number_of_tmps = _make_llvm_program(in_vars, result, type_info, flavour)
80
+
81
+ # Compile the LLVM program to a native executable
82
+ engine: llvm.ExecutionEngine
83
+ function: RawProgramFunction
84
+ engine, function = compile_llvm_program(llvm_program, dtype=type_info.dtype, opt=opt)
85
+
86
+ return LLVMRawProgram(
87
+ function=function,
88
+ dtype=type_info.dtype,
89
+ number_of_vars=len(var_indices),
90
+ number_of_tmps=number_of_tmps,
91
+ number_of_results=len(result),
92
+ var_indices=var_indices,
93
+ llvm_program=llvm_program if keep_llvm_program else None,
94
+ engine=engine,
95
+ opt=opt,
96
+ )
97
+
98
+
99
+ def _make_llvm_program(
100
+ in_vars: Sequence[VarNode],
101
+ result: Sequence[CircuitNode],
102
+ type_info: TypeInfo,
103
+ flavour: Flavour,
104
+ ) -> Tuple[str, int]:
105
+ """
106
+ Returns:
107
+ (llvm_program, number_of_tmps)
108
+ """
109
+ llvm_function = IRFunction(type_info)
110
+
111
+ builder = llvm_function.builder
112
+ type_info = llvm_function.type_info
113
+ function = llvm_function.function
114
+
115
+ analysis: CircuitAnalysis = analyze_circuit(in_vars, result)
116
+
117
+ function_builder: _FunctionBuilder
118
+ if flavour == Flavour.STACK:
119
+ function_builder = _FunctionBuilderStack(
120
+ builder=builder,
121
+ analysis=analysis,
122
+ llvm_type=type_info.llvm_type,
123
+ llvm_idx_type=ir.IntType(32),
124
+ in_args=function.args[0],
125
+ out_args=function.args[2],
126
+ ir_cache={},
127
+ )
128
+ elif flavour == Flavour.TMPS:
129
+ function_builder = _FunctionBuilderTmps(
130
+ builder=builder,
131
+ analysis=analysis,
132
+ llvm_type=type_info.llvm_type,
133
+ llvm_idx_type=ir.IntType(32),
134
+ in_args=function.args[0],
135
+ tmp_args=function.args[1],
136
+ out_args=function.args[2],
137
+ )
138
+ elif flavour == Flavour.FUNCS:
139
+ function_builder = _FunctionBuilderFuncs(
140
+ builder=builder,
141
+ analysis=analysis,
142
+ llvm_type=type_info.llvm_type,
143
+ llvm_idx_type=ir.IntType(32),
144
+ in_args=function.args[0],
145
+ tmp_args=function.args[1],
146
+ out_args=function.args[2],
147
+ )
148
+ else:
149
+ raise ValueError(f'unknown LLVM program flavour: {flavour!r}')
150
+
151
+ # Add a calculation for each op node
152
+ for op_node in analysis.op_nodes:
153
+ if op_node.symbol == ADD:
154
+ op: IrBOp = type_info.add
155
+ elif op_node.symbol == MUL:
156
+ op: IrBOp = type_info.mul
157
+ else:
158
+ raise RuntimeError(f'unknown op node: {op_node.symbol!r}')
159
+ function_builder.process_op_node(op_node, op)
160
+
161
+ # Copy any non-op node values to the results
162
+ for idx, node in enumerate(result):
163
+ if not isinstance(node, OpNode):
164
+ value: ir.Value = function_builder.value(node)
165
+ function_builder.store_result(value, idx)
166
+
167
+ # Return from the function
168
+ builder.ret_void()
169
+
170
+ return llvm_function.llvm_program(), function_builder.number_of_tmps()
171
+
172
+
173
+ class _FunctionBuilder(Protocol):
174
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
175
+ ...
176
+
177
+ def value(self, node: CircuitNode) -> ir.Value:
178
+ ...
179
+
180
+ def store_result(self, value: ir.Value, idx: int) -> None:
181
+ ...
182
+
183
+ def number_of_tmps(self) -> int:
184
+ ...
185
+
186
+
187
+ @dataclass
188
+ class _FunctionBuilderTmps(_FunctionBuilder):
189
+ """
190
+ A function builder that puts op node calculations into the temporary working memory.
191
+ """
192
+ builder: ir.IRBuilder
193
+ analysis: CircuitAnalysis
194
+ llvm_type: ir.Type
195
+ llvm_idx_type: ir.Type
196
+ in_args: ir.Value
197
+ tmp_args: ir.Value
198
+ out_args: ir.Value
199
+
200
+ def number_of_tmps(self) -> int:
201
+ return len(self.analysis.op_to_tmp)
202
+
203
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
204
+ value: ir.Value = self.value(op_node.args[0])
205
+ for arg in op_node.args[1:]:
206
+ next_value: ir.Value = self.value(arg)
207
+ value = op(self.builder, value, next_value)
208
+ self.store_calculation(value, op_node)
209
+
210
+ def value(self, node: CircuitNode) -> ir.Value:
211
+ """
212
+ Return an IR value for the given circuit node.
213
+ """
214
+ node_id: int = id(node)
215
+
216
+ # If it is a constant...
217
+ if isinstance(node, ConstNode):
218
+ return ir.Constant(self.llvm_type, node.value)
219
+
220
+ builder = self.builder
221
+
222
+ # If it is a var...
223
+ if isinstance(node, VarNode):
224
+ if node.is_const():
225
+ return ir.Constant(self.llvm_type, node.const.value)
226
+ else:
227
+ return builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
228
+
229
+ analysis = self.analysis
230
+
231
+ # If it is an op _not_ in the results...
232
+ idx: Optional[int] = analysis.op_to_tmp.get(node_id)
233
+ if idx is not None:
234
+ return builder.load(builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)]))
235
+
236
+ # If it is an op in the results...
237
+ idx: Optional[int] = analysis.op_to_result.get(node_id)
238
+ if idx is not None:
239
+ return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
240
+
241
+ assert False, 'not reached'
242
+
243
+ def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
244
+ """
245
+ Store the given IR value as a result for the given op node.
246
+ """
247
+ builder = self.builder
248
+ analysis = self.analysis
249
+ node_id: int = id(op_node)
250
+
251
+ # If it is an op _not_ in the results...
252
+ idx: Optional[int] = analysis.op_to_tmp.get(node_id)
253
+ if idx is not None:
254
+ ptr: ir.GEPInstr = builder.gep(self.tmp_args, [ir.Constant(self.llvm_idx_type, idx)])
255
+ builder.store(value, ptr)
256
+ return
257
+
258
+ # If it is an op in the results...
259
+ idx: Optional[int] = analysis.op_to_result.get(node_id)
260
+ if idx is not None:
261
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
262
+ builder.store(value, ptr)
263
+ return
264
+
265
+ assert False, 'not reached'
266
+
267
+ def store_result(self, value: ir.Value, idx: int) -> None:
268
+ """
269
+ Store the given IR value in the indexed result slot.
270
+ """
271
+ builder = self.builder
272
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
273
+ builder.store(value, ptr)
274
+
275
+
276
+ class _FunctionBuilderFuncs(_FunctionBuilderTmps):
277
+ """
278
+ A function builder that puts op node calculations into the temporary working memory,
279
+ but each op node becomes its own sub-function.
280
+ """
281
+
282
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
283
+ builder: ir.IRBuilder = self.builder
284
+ save_block = builder.block
285
+
286
+ sub_function_name: str = f'sub_{id(op_node)}'
287
+ function_type = builder.function.type.pointee
288
+ sub_function = ir.Function(builder.module, function_type, name=sub_function_name)
289
+ sub_function.attributes.add('noinline') # alwaysinline, noinline
290
+ bb_entry = sub_function.append_basic_block(sub_function_name + '_entry')
291
+ self.builder.position_at_end(bb_entry)
292
+
293
+ value: ir.Value = self.value(op_node.args[0])
294
+ for arg in op_node.args[1:]:
295
+ next_value: ir.Value = self.value(arg)
296
+ value = op(self.builder, value, next_value)
297
+ self.store_calculation(value, op_node)
298
+
299
+ builder.ret_void()
300
+
301
+ # Restore builder to main function
302
+ builder.position_at_end(save_block)
303
+ builder.call(sub_function, [self.in_args, self.tmp_args, self.out_args])
304
+
305
+
306
+ @dataclass
307
+ class _FunctionBuilderStack(_FunctionBuilder):
308
+ """
309
+ A function builder that puts op node calculations onto the stack.
310
+ """
311
+ builder: ir.IRBuilder
312
+ analysis: CircuitAnalysis
313
+ llvm_type: ir.Type
314
+ llvm_idx_type: ir.Type
315
+ in_args: ir.Value
316
+ out_args: ir.Value
317
+ ir_cache: Dict[int, ir.Value]
318
+
319
+ def number_of_tmps(self) -> int:
320
+ return 0
321
+
322
+ def process_op_node(self, op_node: OpNode, op: IrBOp) -> None:
323
+ value: ir.Value = self.value(op_node.args[0])
324
+ for arg in op_node.args[1:]:
325
+ next_value: ir.Value = self.value(arg)
326
+ value = op(self.builder, value, next_value)
327
+ self.store_calculation(value, op_node)
328
+
329
+ def value(self, node: CircuitNode) -> ir.Value:
330
+ """
331
+ Return an IR value for the given circuit node.
332
+ """
333
+ node_id: int = id(node)
334
+
335
+ # First check if it is in the IR cache
336
+ cached: Optional[ir.Value] = self.ir_cache.get(node_id)
337
+ if cached is not None:
338
+ return cached
339
+
340
+ # If it is a constant...
341
+ if isinstance(node, ConstNode):
342
+ value = ir.Constant(self.llvm_type, node.value)
343
+ self.ir_cache[node_id] = value
344
+ return value
345
+
346
+ builder = self.builder
347
+
348
+ # If it is a var...
349
+ if isinstance(node, VarNode):
350
+ if node.is_const():
351
+ value = ir.Constant(self.llvm_type, node.const.value)
352
+ else:
353
+ value = builder.load(builder.gep(self.in_args, [ir.Constant(self.llvm_idx_type, node.idx)]))
354
+ self.ir_cache[node_id] = value
355
+ return value
356
+
357
+ # If it is an op in the results...
358
+ idx: Optional[int] = self.analysis.op_to_result.get(node_id)
359
+ if idx is not None:
360
+ return builder.load(builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)]))
361
+
362
+ assert False, 'not reached'
363
+
364
+ def store_calculation(self, value: ir.Value, op_node: OpNode) -> None:
365
+ """
366
+ Store the given IR value as a result for the given op node.
367
+ """
368
+ node_id: int = id(op_node)
369
+
370
+ # If it is an op in the results...
371
+ idx: Optional[int] = self.analysis.op_to_result.get(node_id)
372
+ if idx is not None:
373
+ builder = self.builder
374
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
375
+ builder.store(value, ptr)
376
+ return
377
+
378
+ # Just put it in the ir_cache.
379
+ # This effectively forces the LLVM compiler to put it on the stack when registers run out.
380
+ self.ir_cache[node_id] = value
381
+
382
+ def store_result(self, value: ir.Value, idx: int) -> None:
383
+ """
384
+ Store the given IR value in the indexed result slot.
385
+ """
386
+ builder = self.builder
387
+ ptr: ir.GEPInstr = builder.gep(self.out_args, [ir.Constant(self.llvm_idx_type, idx)])
388
+ builder.store(value, ptr)