compiled-knowledge 4.0.0a20__cp312-cp312-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37520 -0
  4. ck/circuit/_circuit_cy.cpython-312-x86_64-linux-musl.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19821 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-x86_64-linux-musl.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10615 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-x86_64-linux-musl.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16393 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-x86_64-linux-musl.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,98 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict, Sequence, Set
3
+
4
+ from ck.circuit._circuit_cy cimport Circuit, OpNode, VarNode, CircuitNode, ConstNode
5
+ from cython.operator cimport postincrement
6
+
7
+
8
+ @dataclass
9
+ class CircuitAnalysis:
10
+ """
11
+ A data structure representing the analysis of a function defined by
12
+ a circuit which chosen input variables and output result nodes.
13
+ """
14
+
15
+ var_nodes: Sequence[VarNode] # specified input var nodes
16
+ result_nodes: Sequence[CircuitNode] # specified result nodes
17
+ op_nodes: Sequence[OpNode] # in-use op nodes, in computation order
18
+ const_nodes: Sequence[ConstNode] # in_use const nodes, in arbitrary order
19
+ op_to_result: Dict[int, int] # op nodes in the result, op_node = result[idx]: id(op_node) -> idx
20
+ op_to_tmp: Dict[int, int] # op nodes needing tmp memory, using tmp[idx]: id(op_node) -> idx
21
+
22
+
23
+ def analyze_circuit(
24
+ var_nodes: Sequence[VarNode],
25
+ result_nodes: Sequence[CircuitNode],
26
+ ) -> CircuitAnalysis:
27
+ """
28
+ Analyzes a circuit as a function from var_nodes to result_nodes,
29
+ returning a CircuitAnalysis object.
30
+
31
+ Args:
32
+ var_nodes: The chosen input variable nodes of the circuit.
33
+ result_nodes: The chosen output result nodes of the circuit.
34
+
35
+ Returns:
36
+ A CircuitAnalysis object.
37
+ """
38
+ cdef list[CircuitNode] results_list = list(result_nodes)
39
+
40
+ # What op nodes are in use
41
+ cdef list[OpNode] op_nodes = _reachable_op_nodes(results_list)
42
+
43
+ # What constant values are in use
44
+ cdef set[int] seen_const_nodes = set()
45
+ cdef list[ConstNode] const_nodes = []
46
+
47
+ def _register_const(_node: ConstNode) -> None:
48
+ nonlocal seen_const_nodes
49
+ nonlocal const_nodes
50
+ _node_id: int = id(_node)
51
+ if _node_id not in seen_const_nodes:
52
+ const_nodes.append(_node)
53
+ seen_const_nodes.add(_node_id)
54
+
55
+ # Register all the used constants
56
+ for op_node in op_nodes:
57
+ for node in op_node.args:
58
+ if isinstance(node, ConstNode):
59
+ _register_const(node)
60
+ for node in results_list:
61
+ if isinstance(node, ConstNode):
62
+ _register_const(node)
63
+ for node in var_nodes:
64
+ if node.is_const():
65
+ _register_const(node.const)
66
+
67
+ # What op nodes are in the result.
68
+ # Dict op_to_result maps id(OpNode) to result index.
69
+ cdef dict[int, int] op_to_result = {
70
+ id(node): i
71
+ for i, node in enumerate(result_nodes)
72
+ if isinstance(node, OpNode)
73
+ }
74
+
75
+ # Assign all other op nodes to a tmp slot.
76
+ # Dict op_to_tmp maps id(OpNode) to tmp index.
77
+ cdef int tmp_idx = 0
78
+ op_to_tmp: Dict[int, int] = {
79
+ id(op_node): postincrement(tmp_idx)
80
+ for op_node in op_nodes
81
+ if id(op_node) not in op_to_result
82
+ }
83
+
84
+ return CircuitAnalysis(
85
+ var_nodes=var_nodes,
86
+ result_nodes=result_nodes,
87
+ op_nodes=op_nodes,
88
+ const_nodes=const_nodes,
89
+ op_to_result=op_to_result,
90
+ op_to_tmp=op_to_tmp,
91
+ )
92
+
93
+
94
+ cdef list[OpNode] _reachable_op_nodes(list[CircuitNode] results):
95
+ if len(results) == 0:
96
+ return []
97
+ cdef Circuit circuit = results[0].circuit
98
+ return circuit.find_reachable_op_nodes(results)
@@ -0,0 +1,93 @@
1
+ from dataclasses import dataclass
2
+ from itertools import count
3
+ from typing import List, Dict, Sequence, Set
4
+
5
+ from ck.circuit import OpNode, VarNode, CircuitNode, ConstNode
6
+
7
+
8
+ @dataclass
9
+ class CircuitAnalysis:
10
+ """
11
+ A data structure representing the analysis of a function defined by
12
+ a circuit which chosen input variables and output result nodes.
13
+ """
14
+
15
+ var_nodes: Sequence[VarNode] # specified input var nodes
16
+ result_nodes: Sequence[CircuitNode] # specified result nodes
17
+ op_nodes: Sequence[OpNode] # in-use op nodes, in computation order
18
+ const_nodes: Sequence[ConstNode] # in_use const nodes, in arbitrary order
19
+ op_to_result: Dict[int, int] # op nodes in the result, op_node = result[idx]: id(op_node) -> idx
20
+ op_to_tmp: Dict[int, int] # op nodes needing tmp memory, using tmp[idx]: id(op_node) -> idx
21
+
22
+
23
+ def analyze_circuit(
24
+ var_nodes: Sequence[VarNode],
25
+ result_nodes: Sequence[CircuitNode],
26
+ ) -> CircuitAnalysis:
27
+ """
28
+ Analyzes a circuit as a function from var_nodes to result_nodes,
29
+ returning a CircuitAnalysis object.
30
+
31
+ Args:
32
+ var_nodes: The chosen input variable nodes of the circuit.
33
+ result_nodes: The chosen output result nodes of the circuit.
34
+
35
+ Returns:
36
+ A CircuitAnalysis object.
37
+ """
38
+ # What op nodes are in use
39
+ op_nodes: List[OpNode] = (
40
+ [] if len(result_nodes) == 0
41
+ else result_nodes[0].circuit.reachable_op_nodes(*result_nodes)
42
+ )
43
+
44
+ # What constant values are in use
45
+ seen_const_nodes: Set[int] = set()
46
+ const_nodes: List[ConstNode] = []
47
+
48
+ def _register_const(_node: ConstNode) -> None:
49
+ nonlocal seen_const_nodes
50
+ nonlocal const_nodes
51
+ _node_id: int = id(_node)
52
+ if _node_id not in seen_const_nodes:
53
+ const_nodes.append(_node)
54
+ seen_const_nodes.add(_node_id)
55
+
56
+ # Register all the used constants
57
+ for op_node in op_nodes:
58
+ for node in op_node.args:
59
+ if isinstance(node, ConstNode):
60
+ _register_const(node)
61
+ for node in result_nodes:
62
+ if isinstance(node, ConstNode):
63
+ _register_const(node)
64
+ for node in var_nodes:
65
+ if node.is_const():
66
+ _register_const(node.const)
67
+
68
+ # What op nodes are in the result.
69
+ # Dict op_to_result maps id(OpNode) to result index.
70
+ op_to_result: Dict[int, int] = {
71
+ id(node): i
72
+ for i, node in enumerate(result_nodes)
73
+ if isinstance(node, OpNode)
74
+ }
75
+
76
+ # Assign all other op nodes to a tmp slot.
77
+ # Dict op_to_tmp maps id(OpNode) to tmp index.
78
+ tmp_idx = count()
79
+ op_to_tmp: Dict[int, int] = {
80
+ id(op_node): next(tmp_idx)
81
+ for op_node in op_nodes
82
+ if id(op_node) not in op_to_result
83
+ }
84
+ del tmp_idx
85
+
86
+ return CircuitAnalysis(
87
+ var_nodes=var_nodes,
88
+ result_nodes=result_nodes,
89
+ op_nodes=op_nodes,
90
+ const_nodes=const_nodes,
91
+ op_to_result=op_to_result,
92
+ op_to_tmp=op_to_tmp,
93
+ )
@@ -0,0 +1,148 @@
1
+ """
2
+ This module supports circuit compilers and interpreters by inferring and checking input variables
3
+ that are explicitly or implicitly referred to by a client.
4
+ """
5
+
6
+ from enum import Enum
7
+ from itertools import chain
8
+ from typing import Sequence, Optional, Set, Iterable, List
9
+
10
+ from ck.circuit import VarNode, Circuit, CircuitNode, OpNode
11
+
12
+
13
+ class InferVars(Enum):
14
+ """
15
+ An enum specifying how to automatically infer a program's input variables.
16
+ """
17
+
18
+ ALL = 'all' # all circuit vars are input vars
19
+ REF = 'ref' # only referenced vars are input vars
20
+ LOW = 'low' # input vars are circuit vars[0 : max_referenced + 1]
21
+
22
+
23
+ # Type for specifying input circuit vars
24
+ InputVars = InferVars | Sequence[VarNode] | VarNode
25
+
26
+
27
+ def infer_input_vars(
28
+ circuit: Optional[Circuit],
29
+ results: Sequence[CircuitNode],
30
+ input_vars: InputVars,
31
+ ) -> Sequence[VarNode]:
32
+ """
33
+ Infer what circuit is being referred to, based on Program constructor arguments.
34
+ Infer what input variable are being referred to, based on Program constructor arguments.
35
+ Check that all input vars and results nodes are in the circuit.
36
+
37
+ Returns:
38
+ The inferred input circuit vars.
39
+
40
+ Raises:
41
+ ValueError: if the circuit is unknown, but it is needed.
42
+ ValueError: if not all nodes are from the same circuit.
43
+
44
+ Ensures:
45
+ circuit is None implies len(input_vars) == 0
46
+ """
47
+ cct: Optional[Circuit] = _infer_circuit(circuit, results, input_vars)
48
+ input_vars: Sequence[VarNode] = _infer_input(cct, results, input_vars)
49
+
50
+ # Check that all results nodes and input vars are in the circuit.
51
+ if cct is not None:
52
+ for n in chain(results, input_vars):
53
+ if n.circuit is not cct:
54
+ raise ValueError('a var node or result node is not in the inferred circuit')
55
+
56
+ return input_vars
57
+
58
+
59
+ def _infer_circuit(
60
+ cct: Optional[Circuit],
61
+ results: Sequence[CircuitNode],
62
+ input_vars: InputVars,
63
+ ) -> Optional[Circuit]:
64
+ """
65
+ Infer what circuit is being referred to, based on Program constructor arguments.
66
+ """
67
+ if cct is not None:
68
+ return cct
69
+ if len(results) > 0:
70
+ return results[0].circuit
71
+ if isinstance(input_vars, CircuitNode):
72
+ return input_vars.circuit
73
+ if not isinstance(input_vars, InferVars):
74
+ # input vars is a sequence of CircuitNode
75
+ for input_var in input_vars:
76
+ return input_var.circuit
77
+
78
+ return None
79
+
80
+
81
+ def _infer_input(
82
+ cct: Optional[Circuit],
83
+ results: Sequence[CircuitNode],
84
+ input_vars: InputVars,
85
+ ) -> Sequence[VarNode]:
86
+ """
87
+ Infer what input variable are being referred to, based on Program constructor arguments.
88
+ """
89
+
90
+ have_results: bool = len(results) > 0
91
+
92
+ if input_vars == InferVars.ALL:
93
+ if have_results:
94
+ return cct.vars
95
+ else:
96
+ return ()
97
+
98
+ elif input_vars == InferVars.LOW:
99
+ if have_results:
100
+ to_index: int = max((var.idx for var in _find_vars(results)), default=-1) + 1
101
+ return cct.vars[:to_index]
102
+ else:
103
+ return ()
104
+
105
+ elif input_vars == InferVars.REF:
106
+ return tuple(sorted(_find_vars(results)))
107
+
108
+ elif isinstance(input_vars, VarNode):
109
+ input_vars = (input_vars,)
110
+
111
+ # Assume input_vars is a Sequence[VarNode]
112
+
113
+ in_vars: Sequence[VarNode] = tuple(input_vars)
114
+
115
+ # check no duplicate in_vars
116
+ input_var_indices: Set[int] = {var.idx for var in in_vars}
117
+ if len(input_var_indices) != len(in_vars):
118
+ raise ValueError('cannot have duplicate circuit variables as inputs')
119
+
120
+ # ensure that the input vars cover what is needed.
121
+ needed_var_indices: Set[int] = {var.idx for var in _find_vars(results)}
122
+ if not input_var_indices.issuperset(needed_var_indices):
123
+ raise ValueError('input var nodes does not cover all need var nodes for result')
124
+
125
+ return in_vars
126
+
127
+
128
+ def _find_vars(nodes: Iterable[CircuitNode]) -> List[VarNode]:
129
+ """
130
+ Get the set of all VarNode nodes that are not set constant, reachable from the given nodes.
131
+ """
132
+ seen: Set[int] = set()
133
+ var_nodes: List[VarNode] = []
134
+ __find_vars_r(nodes, seen, var_nodes)
135
+ return var_nodes
136
+
137
+
138
+ def __find_vars_r(nodes: Iterable[CircuitNode], seen: Set[int], var_nodes: List[VarNode]) -> None:
139
+ """
140
+ Recursive support for _find_vars.
141
+ """
142
+ for node in nodes:
143
+ if id(node) not in seen:
144
+ seen.add(id(node))
145
+ if isinstance(node, VarNode) and not node.is_const():
146
+ var_nodes.append(node)
147
+ elif isinstance(node, OpNode):
148
+ __find_vars_r(node.args, seen, var_nodes)
@@ -0,0 +1,234 @@
1
+ import ctypes as ct
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Callable, Tuple, Optional
5
+
6
+ import llvmlite.binding as llvm
7
+ import llvmlite.ir as ir
8
+
9
+ from ck.program import RawProgram
10
+ from ck.program.raw_program import RawProgramFunction
11
+ from ck.utils.np_extras import DType, DTypeNumeric
12
+
13
+ __LLVM_INITIALISED: bool = False
14
+
15
+ _LVM_FUNCTION_NAME: str = 'main'
16
+
17
+ # Type for an LLVM builder binary Operation
18
+ IrBOp = Callable[[ir.IRBuilder, ir.Value, ir.Value], ir.Value]
19
+
20
+ IrBoolType = ir.IntType(1) # Type for an LLVM Boolean.
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class TypeInfo:
25
+ """
26
+ Record compiler related information contingent on a given numpy/ctypes `dtype`
27
+
28
+ An instance of this data type defines a mathematical ring, i.e., an atomic machine
29
+ data type and arithmetic operations over them.
30
+ """
31
+
32
+ dtype: DTypeNumeric # This is the same as numpy `dtype`.
33
+ llvm_type: ir.Type # Corresponding LLVM IR type.
34
+ add: IrBOp # LLVM IR binary operation for addition.
35
+ mul: IrBOp # LLVM IR binary operation for multiplication.
36
+
37
+
38
+ # The Boolean constant "One", i.e., "True".
39
+ _IrBoolOne: ir.Value = ir.Constant(IrBoolType, 1)
40
+
41
+
42
+ def _bool_and(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
43
+ """
44
+ LLVM IR Boolean "and"
45
+ """
46
+ tmp: ir.Value = ir.IRBuilder.and_(builder, x, y)
47
+ return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
48
+
49
+
50
+ def _bool_or(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
51
+ """
52
+ LLVM IR Boolean "or"
53
+ """
54
+ tmp: ir.Value = ir.IRBuilder.or_(builder, x, y)
55
+ return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
56
+
57
+
58
+ def _bool_xor(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
59
+ """
60
+ LLVM IR Boolean "xor"
61
+ """
62
+ tmp: ir.Value = ir.IRBuilder.xor(builder, x, y)
63
+ return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
64
+
65
+
66
+ def _float_max(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
67
+ """
68
+ LLVM IR floating point "max"
69
+ """
70
+ cond = builder.fcmp_ordered('>', x, y)
71
+ return builder.select(cond, x, y)
72
+
73
+
74
+ def _float_min(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
75
+ """
76
+ LLVM IR floating point "min"
77
+ """
78
+ cond = builder.fcmp_ordered('<', x, y)
79
+ return builder.select(cond, x, y)
80
+
81
+
82
+ # IR operations for TypeInfo: (add, mul)
83
+ _float_add: IrBOp = ir.IRBuilder.fadd
84
+ _float_mul: IrBOp = ir.IRBuilder.fmul
85
+ _int_add: IrBOp = ir.IRBuilder.add
86
+ _ind_mul: IrBOp = ir.IRBuilder.mul
87
+
88
+
89
+ class DataType(Enum):
90
+ """
91
+ Predefined TypeInfo objects.
92
+
93
+ Each member defines a mathematical ring, i.e., a machine data
94
+ type and the "add" and "mul" arithmetic operations over them.
95
+ """
96
+
97
+ FLOAT_32 = TypeInfo(ct.c_float, ir.FloatType(), _float_add, _float_mul)
98
+ FLOAT_64 = TypeInfo(ct.c_double, ir.DoubleType(), _float_add, _float_mul)
99
+ INT_8 = TypeInfo(ct.c_int8, ir.IntType(8), _int_add, _ind_mul)
100
+ INT_16 = TypeInfo(ct.c_int16, ir.IntType(16), _int_add, _ind_mul)
101
+ INT_32 = TypeInfo(ct.c_int32, ir.IntType(32), _int_add, _ind_mul)
102
+ INT_64 = TypeInfo(ct.c_int64, ir.IntType(64), _int_add, _ind_mul)
103
+ BOOL = TypeInfo(ct.c_bool, IrBoolType, _bool_or, _bool_and)
104
+ XBOOL = TypeInfo(ct.c_bool, IrBoolType, _bool_xor, _bool_and)
105
+ MAX_MIN = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_min)
106
+ MAX_MUL = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_mul)
107
+ MAX_SUM = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_add)
108
+
109
+
110
+ class IRFunction:
111
+ """
112
+ Data structure to hold information while building an LLVM IR program function.
113
+ """
114
+
115
+ def __init__(self, type_info: TypeInfo):
116
+ """
117
+ Create an LLVM IR program function.
118
+
119
+ Actions performed:
120
+ 1. LLVM will be initialized.
121
+ 2. A IRBuilder will be constructed (field `builder`).
122
+ 3. A module will be created (field `module`).
123
+ 4. A function will be added to the module (field `function`), the function will
124
+ have the signature (T* in, T* tmp, T* out) -> Void, where T is `type_info.llvm_type`.
125
+ 5. A basic block will be added to the function (named "entry").
126
+ """
127
+ _init_llvm()
128
+
129
+ # Get important IR types
130
+ self.type_info: TypeInfo = type_info
131
+ self.ret_type: ir.Type = ir.VoidType()
132
+ self.ptr_type: ir.Type = self.type_info.llvm_type.as_pointer()
133
+ function_type = ir.FunctionType(self.ret_type, (self.ptr_type, self.ptr_type, self.ptr_type))
134
+
135
+ self.module = ir.Module()
136
+ self.function = ir.Function(self.module, function_type, name=_LVM_FUNCTION_NAME)
137
+ self.builder = ir.IRBuilder()
138
+
139
+ # Create a block of code in the function
140
+ bb_entry = self.function.append_basic_block('entry')
141
+ self.builder.position_at_end(bb_entry)
142
+
143
+ def llvm_program(self) -> str:
144
+ """
145
+ Get the LLVM source code (i.e., the module as an LLVM string).
146
+
147
+ Returns:
148
+ an LLVM program string that can be passed to `compile_llvm_program`.
149
+ """
150
+ return str(self.module)
151
+
152
+
153
+ @dataclass
154
+ class LLVMRawProgram(RawProgram):
155
+ llvm_program: Optional[str]
156
+ engine: llvm.ExecutionEngine
157
+ opt: int
158
+
159
+ def __getstate__(self):
160
+ """
161
+ Support for pickle.
162
+ """
163
+ if self.llvm_program is None:
164
+ raise ValueError('need to have the LLVM program to pickle a Program object')
165
+
166
+ return {
167
+ 'dtype': self.dtype,
168
+ 'number_of_vars': self.number_of_vars,
169
+ 'number_of_tmps': self.number_of_tmps,
170
+ 'number_of_results': self.number_of_results,
171
+ 'llvm_program': self.llvm_program,
172
+ 'opt': self.opt,
173
+ }
174
+
175
+ def __setstate__(self, state):
176
+ """
177
+ Support for pickle.
178
+ """
179
+ self.dtype = state['dtype']
180
+ self.number_of_vars = state['number_of_vars']
181
+ self.number_of_tmps = state['number_of_tmps']
182
+ self.number_of_results = state['number_of_results']
183
+ self.llvm_program = state['llvm_program']
184
+ self.opt = state['opt']
185
+
186
+ # Compile the LLVM program
187
+ self.engine, self.function = compile_llvm_program(self.llvm_program, self.dtype, self.opt)
188
+
189
+
190
+ def compile_llvm_program(
191
+ llvm_program: str,
192
+ dtype: DType,
193
+ opt: int,
194
+ ) -> Tuple[llvm.ExecutionEngine, RawProgramFunction]:
195
+ """
196
+ Compile the given LLVM program.
197
+
198
+ Returns:
199
+ (engine, function) where
200
+ engine: is an LLVM execution engine, which must remain
201
+ in memory for the returned function to be valid.
202
+ function: is the raw Python callable for the compiled function.
203
+ """
204
+ _init_llvm()
205
+
206
+ llvm_module = llvm.parse_assembly(llvm_program)
207
+ llvm_module.verify()
208
+
209
+ target = llvm.Target.from_default_triple().create_target_machine(opt=opt)
210
+ engine = llvm.create_mcjit_compiler(llvm_module, target)
211
+
212
+ # Calling finalize_object will create native code and make it executable.
213
+ engine.finalize_object()
214
+
215
+ engine.run_static_constructors()
216
+
217
+ # Get the function entry point
218
+ function_ptr = engine.get_function_address(_LVM_FUNCTION_NAME)
219
+ ctypes_ptr_type = ct.POINTER(dtype)
220
+ function = ct.CFUNCTYPE(None, ctypes_ptr_type, ctypes_ptr_type, ctypes_ptr_type)(function_ptr)
221
+
222
+ return engine, function
223
+
224
+
225
+ def _init_llvm() -> None:
226
+ """
227
+ Ensure that LLVM is initialised.
228
+ """
229
+ global __LLVM_INITIALISED
230
+ if not __LLVM_INITIALISED:
231
+ llvm.initialize()
232
+ llvm.initialize_native_target()
233
+ llvm.initialize_native_asmprinter()
234
+ __LLVM_INITIALISED = True
ck/example/__init__.py ADDED
@@ -0,0 +1,53 @@
1
+ """
2
+ A package of standard probabilistic graphical models.
3
+ """
4
+ from ck.pgm import PGM
5
+ from ck.example.alarm import Alarm
6
+ from ck.example.binary_clique import BinaryClique
7
+ from ck.example.bow_tie import BowTie
8
+ from ck.example.cancer import Cancer
9
+ from ck.example.asia import Asia
10
+ from ck.example.chain import Chain
11
+ from ck.example.child import Child
12
+ from ck.example.clique import Clique
13
+ from ck.example.cnf_pgm import CNF_PGM
14
+ from ck.example.diamond_square import DiamondSquare
15
+ from ck.example.earthquake import Earthquake
16
+ from ck.example.empty import Empty
17
+ from ck.example.hailfinder import Hailfinder
18
+ from ck.example.hepar2 import Hepar2
19
+ from ck.example.insurance import Insurance
20
+ from ck.example.loop import Loop
21
+ from ck.example.mildew import Mildew
22
+ from ck.example.munin import Munin
23
+ from ck.example.pathfinder import Pathfinder
24
+ from ck.example.rectangle import Rectangle
25
+ from ck.example.rain import Rain
26
+ from ck.example.run import Run
27
+ from ck.example.sachs import Sachs
28
+ from ck.example.sprinkler import Sprinkler
29
+ from ck.example.survey import Survey
30
+ from ck.example.star import Star
31
+ from ck.example.stress import Stress
32
+ from ck.example.student import Student
33
+ from ck.example.triangle_square import TriangleSquare
34
+ from ck.example.truss import Truss
35
+
36
+
37
+ # A dictionary with entries, `name: class`, for all example PGM classes.
38
+ #
39
+ # Example usage:
40
+ # from ck.example import ALL_EXAMPLES
41
+ #
42
+ # my_pgm: PGM = ALL_EXAMPLES['Alarm']()
43
+ #
44
+ ALL_EXAMPLES = {
45
+ name: pgm_class
46
+ for name, pgm_class in globals().items()
47
+ if (
48
+ not name.startswith('_')
49
+ and name != PGM.__name__
50
+ and isinstance(pgm_class, type)
51
+ and issubclass(pgm_class, PGM)
52
+ )
53
+ }