compiled-knowledge 4.0.0a20__cp312-cp312-musllinux_1_2_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37520 -0
- ck/circuit/_circuit_cy.cpython-312-x86_64-linux-musl.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19821 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-x86_64-linux-musl.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10615 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-x86_64-linux-musl.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16393 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-x86_64-linux-musl.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-x86_64-linux-musl.so
ADDED
|
Binary file
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Dict, Sequence, Set
|
|
3
|
+
|
|
4
|
+
from ck.circuit._circuit_cy cimport Circuit, OpNode, VarNode, CircuitNode, ConstNode
|
|
5
|
+
from cython.operator cimport postincrement
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CircuitAnalysis:
|
|
10
|
+
"""
|
|
11
|
+
A data structure representing the analysis of a function defined by
|
|
12
|
+
a circuit which chosen input variables and output result nodes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
var_nodes: Sequence[VarNode] # specified input var nodes
|
|
16
|
+
result_nodes: Sequence[CircuitNode] # specified result nodes
|
|
17
|
+
op_nodes: Sequence[OpNode] # in-use op nodes, in computation order
|
|
18
|
+
const_nodes: Sequence[ConstNode] # in_use const nodes, in arbitrary order
|
|
19
|
+
op_to_result: Dict[int, int] # op nodes in the result, op_node = result[idx]: id(op_node) -> idx
|
|
20
|
+
op_to_tmp: Dict[int, int] # op nodes needing tmp memory, using tmp[idx]: id(op_node) -> idx
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def analyze_circuit(
|
|
24
|
+
var_nodes: Sequence[VarNode],
|
|
25
|
+
result_nodes: Sequence[CircuitNode],
|
|
26
|
+
) -> CircuitAnalysis:
|
|
27
|
+
"""
|
|
28
|
+
Analyzes a circuit as a function from var_nodes to result_nodes,
|
|
29
|
+
returning a CircuitAnalysis object.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
var_nodes: The chosen input variable nodes of the circuit.
|
|
33
|
+
result_nodes: The chosen output result nodes of the circuit.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A CircuitAnalysis object.
|
|
37
|
+
"""
|
|
38
|
+
cdef list[CircuitNode] results_list = list(result_nodes)
|
|
39
|
+
|
|
40
|
+
# What op nodes are in use
|
|
41
|
+
cdef list[OpNode] op_nodes = _reachable_op_nodes(results_list)
|
|
42
|
+
|
|
43
|
+
# What constant values are in use
|
|
44
|
+
cdef set[int] seen_const_nodes = set()
|
|
45
|
+
cdef list[ConstNode] const_nodes = []
|
|
46
|
+
|
|
47
|
+
def _register_const(_node: ConstNode) -> None:
|
|
48
|
+
nonlocal seen_const_nodes
|
|
49
|
+
nonlocal const_nodes
|
|
50
|
+
_node_id: int = id(_node)
|
|
51
|
+
if _node_id not in seen_const_nodes:
|
|
52
|
+
const_nodes.append(_node)
|
|
53
|
+
seen_const_nodes.add(_node_id)
|
|
54
|
+
|
|
55
|
+
# Register all the used constants
|
|
56
|
+
for op_node in op_nodes:
|
|
57
|
+
for node in op_node.args:
|
|
58
|
+
if isinstance(node, ConstNode):
|
|
59
|
+
_register_const(node)
|
|
60
|
+
for node in results_list:
|
|
61
|
+
if isinstance(node, ConstNode):
|
|
62
|
+
_register_const(node)
|
|
63
|
+
for node in var_nodes:
|
|
64
|
+
if node.is_const():
|
|
65
|
+
_register_const(node.const)
|
|
66
|
+
|
|
67
|
+
# What op nodes are in the result.
|
|
68
|
+
# Dict op_to_result maps id(OpNode) to result index.
|
|
69
|
+
cdef dict[int, int] op_to_result = {
|
|
70
|
+
id(node): i
|
|
71
|
+
for i, node in enumerate(result_nodes)
|
|
72
|
+
if isinstance(node, OpNode)
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Assign all other op nodes to a tmp slot.
|
|
76
|
+
# Dict op_to_tmp maps id(OpNode) to tmp index.
|
|
77
|
+
cdef int tmp_idx = 0
|
|
78
|
+
op_to_tmp: Dict[int, int] = {
|
|
79
|
+
id(op_node): postincrement(tmp_idx)
|
|
80
|
+
for op_node in op_nodes
|
|
81
|
+
if id(op_node) not in op_to_result
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return CircuitAnalysis(
|
|
85
|
+
var_nodes=var_nodes,
|
|
86
|
+
result_nodes=result_nodes,
|
|
87
|
+
op_nodes=op_nodes,
|
|
88
|
+
const_nodes=const_nodes,
|
|
89
|
+
op_to_result=op_to_result,
|
|
90
|
+
op_to_tmp=op_to_tmp,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
cdef list[OpNode] _reachable_op_nodes(list[CircuitNode] results):
|
|
95
|
+
if len(results) == 0:
|
|
96
|
+
return []
|
|
97
|
+
cdef Circuit circuit = results[0].circuit
|
|
98
|
+
return circuit.find_reachable_op_nodes(results)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from itertools import count
|
|
3
|
+
from typing import List, Dict, Sequence, Set
|
|
4
|
+
|
|
5
|
+
from ck.circuit import OpNode, VarNode, CircuitNode, ConstNode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CircuitAnalysis:
|
|
10
|
+
"""
|
|
11
|
+
A data structure representing the analysis of a function defined by
|
|
12
|
+
a circuit which chosen input variables and output result nodes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
var_nodes: Sequence[VarNode] # specified input var nodes
|
|
16
|
+
result_nodes: Sequence[CircuitNode] # specified result nodes
|
|
17
|
+
op_nodes: Sequence[OpNode] # in-use op nodes, in computation order
|
|
18
|
+
const_nodes: Sequence[ConstNode] # in_use const nodes, in arbitrary order
|
|
19
|
+
op_to_result: Dict[int, int] # op nodes in the result, op_node = result[idx]: id(op_node) -> idx
|
|
20
|
+
op_to_tmp: Dict[int, int] # op nodes needing tmp memory, using tmp[idx]: id(op_node) -> idx
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def analyze_circuit(
|
|
24
|
+
var_nodes: Sequence[VarNode],
|
|
25
|
+
result_nodes: Sequence[CircuitNode],
|
|
26
|
+
) -> CircuitAnalysis:
|
|
27
|
+
"""
|
|
28
|
+
Analyzes a circuit as a function from var_nodes to result_nodes,
|
|
29
|
+
returning a CircuitAnalysis object.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
var_nodes: The chosen input variable nodes of the circuit.
|
|
33
|
+
result_nodes: The chosen output result nodes of the circuit.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A CircuitAnalysis object.
|
|
37
|
+
"""
|
|
38
|
+
# What op nodes are in use
|
|
39
|
+
op_nodes: List[OpNode] = (
|
|
40
|
+
[] if len(result_nodes) == 0
|
|
41
|
+
else result_nodes[0].circuit.reachable_op_nodes(*result_nodes)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# What constant values are in use
|
|
45
|
+
seen_const_nodes: Set[int] = set()
|
|
46
|
+
const_nodes: List[ConstNode] = []
|
|
47
|
+
|
|
48
|
+
def _register_const(_node: ConstNode) -> None:
|
|
49
|
+
nonlocal seen_const_nodes
|
|
50
|
+
nonlocal const_nodes
|
|
51
|
+
_node_id: int = id(_node)
|
|
52
|
+
if _node_id not in seen_const_nodes:
|
|
53
|
+
const_nodes.append(_node)
|
|
54
|
+
seen_const_nodes.add(_node_id)
|
|
55
|
+
|
|
56
|
+
# Register all the used constants
|
|
57
|
+
for op_node in op_nodes:
|
|
58
|
+
for node in op_node.args:
|
|
59
|
+
if isinstance(node, ConstNode):
|
|
60
|
+
_register_const(node)
|
|
61
|
+
for node in result_nodes:
|
|
62
|
+
if isinstance(node, ConstNode):
|
|
63
|
+
_register_const(node)
|
|
64
|
+
for node in var_nodes:
|
|
65
|
+
if node.is_const():
|
|
66
|
+
_register_const(node.const)
|
|
67
|
+
|
|
68
|
+
# What op nodes are in the result.
|
|
69
|
+
# Dict op_to_result maps id(OpNode) to result index.
|
|
70
|
+
op_to_result: Dict[int, int] = {
|
|
71
|
+
id(node): i
|
|
72
|
+
for i, node in enumerate(result_nodes)
|
|
73
|
+
if isinstance(node, OpNode)
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
# Assign all other op nodes to a tmp slot.
|
|
77
|
+
# Dict op_to_tmp maps id(OpNode) to tmp index.
|
|
78
|
+
tmp_idx = count()
|
|
79
|
+
op_to_tmp: Dict[int, int] = {
|
|
80
|
+
id(op_node): next(tmp_idx)
|
|
81
|
+
for op_node in op_nodes
|
|
82
|
+
if id(op_node) not in op_to_result
|
|
83
|
+
}
|
|
84
|
+
del tmp_idx
|
|
85
|
+
|
|
86
|
+
return CircuitAnalysis(
|
|
87
|
+
var_nodes=var_nodes,
|
|
88
|
+
result_nodes=result_nodes,
|
|
89
|
+
op_nodes=op_nodes,
|
|
90
|
+
const_nodes=const_nodes,
|
|
91
|
+
op_to_result=op_to_result,
|
|
92
|
+
op_to_tmp=op_to_tmp,
|
|
93
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module supports circuit compilers and interpreters by inferring and checking input variables
|
|
3
|
+
that are explicitly or implicitly referred to by a client.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from itertools import chain
|
|
8
|
+
from typing import Sequence, Optional, Set, Iterable, List
|
|
9
|
+
|
|
10
|
+
from ck.circuit import VarNode, Circuit, CircuitNode, OpNode
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InferVars(Enum):
|
|
14
|
+
"""
|
|
15
|
+
An enum specifying how to automatically infer a program's input variables.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
ALL = 'all' # all circuit vars are input vars
|
|
19
|
+
REF = 'ref' # only referenced vars are input vars
|
|
20
|
+
LOW = 'low' # input vars are circuit vars[0 : max_referenced + 1]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Type for specifying input circuit vars
|
|
24
|
+
InputVars = InferVars | Sequence[VarNode] | VarNode
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def infer_input_vars(
|
|
28
|
+
circuit: Optional[Circuit],
|
|
29
|
+
results: Sequence[CircuitNode],
|
|
30
|
+
input_vars: InputVars,
|
|
31
|
+
) -> Sequence[VarNode]:
|
|
32
|
+
"""
|
|
33
|
+
Infer what circuit is being referred to, based on Program constructor arguments.
|
|
34
|
+
Infer what input variable are being referred to, based on Program constructor arguments.
|
|
35
|
+
Check that all input vars and results nodes are in the circuit.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
The inferred input circuit vars.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: if the circuit is unknown, but it is needed.
|
|
42
|
+
ValueError: if not all nodes are from the same circuit.
|
|
43
|
+
|
|
44
|
+
Ensures:
|
|
45
|
+
circuit is None implies len(input_vars) == 0
|
|
46
|
+
"""
|
|
47
|
+
cct: Optional[Circuit] = _infer_circuit(circuit, results, input_vars)
|
|
48
|
+
input_vars: Sequence[VarNode] = _infer_input(cct, results, input_vars)
|
|
49
|
+
|
|
50
|
+
# Check that all results nodes and input vars are in the circuit.
|
|
51
|
+
if cct is not None:
|
|
52
|
+
for n in chain(results, input_vars):
|
|
53
|
+
if n.circuit is not cct:
|
|
54
|
+
raise ValueError('a var node or result node is not in the inferred circuit')
|
|
55
|
+
|
|
56
|
+
return input_vars
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _infer_circuit(
|
|
60
|
+
cct: Optional[Circuit],
|
|
61
|
+
results: Sequence[CircuitNode],
|
|
62
|
+
input_vars: InputVars,
|
|
63
|
+
) -> Optional[Circuit]:
|
|
64
|
+
"""
|
|
65
|
+
Infer what circuit is being referred to, based on Program constructor arguments.
|
|
66
|
+
"""
|
|
67
|
+
if cct is not None:
|
|
68
|
+
return cct
|
|
69
|
+
if len(results) > 0:
|
|
70
|
+
return results[0].circuit
|
|
71
|
+
if isinstance(input_vars, CircuitNode):
|
|
72
|
+
return input_vars.circuit
|
|
73
|
+
if not isinstance(input_vars, InferVars):
|
|
74
|
+
# input vars is a sequence of CircuitNode
|
|
75
|
+
for input_var in input_vars:
|
|
76
|
+
return input_var.circuit
|
|
77
|
+
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _infer_input(
|
|
82
|
+
cct: Optional[Circuit],
|
|
83
|
+
results: Sequence[CircuitNode],
|
|
84
|
+
input_vars: InputVars,
|
|
85
|
+
) -> Sequence[VarNode]:
|
|
86
|
+
"""
|
|
87
|
+
Infer what input variable are being referred to, based on Program constructor arguments.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
have_results: bool = len(results) > 0
|
|
91
|
+
|
|
92
|
+
if input_vars == InferVars.ALL:
|
|
93
|
+
if have_results:
|
|
94
|
+
return cct.vars
|
|
95
|
+
else:
|
|
96
|
+
return ()
|
|
97
|
+
|
|
98
|
+
elif input_vars == InferVars.LOW:
|
|
99
|
+
if have_results:
|
|
100
|
+
to_index: int = max((var.idx for var in _find_vars(results)), default=-1) + 1
|
|
101
|
+
return cct.vars[:to_index]
|
|
102
|
+
else:
|
|
103
|
+
return ()
|
|
104
|
+
|
|
105
|
+
elif input_vars == InferVars.REF:
|
|
106
|
+
return tuple(sorted(_find_vars(results)))
|
|
107
|
+
|
|
108
|
+
elif isinstance(input_vars, VarNode):
|
|
109
|
+
input_vars = (input_vars,)
|
|
110
|
+
|
|
111
|
+
# Assume input_vars is a Sequence[VarNode]
|
|
112
|
+
|
|
113
|
+
in_vars: Sequence[VarNode] = tuple(input_vars)
|
|
114
|
+
|
|
115
|
+
# check no duplicate in_vars
|
|
116
|
+
input_var_indices: Set[int] = {var.idx for var in in_vars}
|
|
117
|
+
if len(input_var_indices) != len(in_vars):
|
|
118
|
+
raise ValueError('cannot have duplicate circuit variables as inputs')
|
|
119
|
+
|
|
120
|
+
# ensure that the input vars cover what is needed.
|
|
121
|
+
needed_var_indices: Set[int] = {var.idx for var in _find_vars(results)}
|
|
122
|
+
if not input_var_indices.issuperset(needed_var_indices):
|
|
123
|
+
raise ValueError('input var nodes does not cover all need var nodes for result')
|
|
124
|
+
|
|
125
|
+
return in_vars
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _find_vars(nodes: Iterable[CircuitNode]) -> List[VarNode]:
|
|
129
|
+
"""
|
|
130
|
+
Get the set of all VarNode nodes that are not set constant, reachable from the given nodes.
|
|
131
|
+
"""
|
|
132
|
+
seen: Set[int] = set()
|
|
133
|
+
var_nodes: List[VarNode] = []
|
|
134
|
+
__find_vars_r(nodes, seen, var_nodes)
|
|
135
|
+
return var_nodes
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def __find_vars_r(nodes: Iterable[CircuitNode], seen: Set[int], var_nodes: List[VarNode]) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Recursive support for _find_vars.
|
|
141
|
+
"""
|
|
142
|
+
for node in nodes:
|
|
143
|
+
if id(node) not in seen:
|
|
144
|
+
seen.add(id(node))
|
|
145
|
+
if isinstance(node, VarNode) and not node.is_const():
|
|
146
|
+
var_nodes.append(node)
|
|
147
|
+
elif isinstance(node, OpNode):
|
|
148
|
+
__find_vars_r(node.args, seen, var_nodes)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import ctypes as ct
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Callable, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
import llvmlite.binding as llvm
|
|
7
|
+
import llvmlite.ir as ir
|
|
8
|
+
|
|
9
|
+
from ck.program import RawProgram
|
|
10
|
+
from ck.program.raw_program import RawProgramFunction
|
|
11
|
+
from ck.utils.np_extras import DType, DTypeNumeric
|
|
12
|
+
|
|
13
|
+
__LLVM_INITIALISED: bool = False
|
|
14
|
+
|
|
15
|
+
_LVM_FUNCTION_NAME: str = 'main'
|
|
16
|
+
|
|
17
|
+
# Type for an LLVM builder binary Operation
|
|
18
|
+
IrBOp = Callable[[ir.IRBuilder, ir.Value, ir.Value], ir.Value]
|
|
19
|
+
|
|
20
|
+
IrBoolType = ir.IntType(1) # Type for an LLVM Boolean.
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class TypeInfo:
|
|
25
|
+
"""
|
|
26
|
+
Record compiler related information contingent on a given numpy/ctypes `dtype`
|
|
27
|
+
|
|
28
|
+
An instance of this data type defines a mathematical ring, i.e., an atomic machine
|
|
29
|
+
data type and arithmetic operations over them.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
dtype: DTypeNumeric # This is the same as numpy `dtype`.
|
|
33
|
+
llvm_type: ir.Type # Corresponding LLVM IR type.
|
|
34
|
+
add: IrBOp # LLVM IR binary operation for addition.
|
|
35
|
+
mul: IrBOp # LLVM IR binary operation for multiplication.
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# The Boolean constant "One", i.e., "True".
|
|
39
|
+
_IrBoolOne: ir.Value = ir.Constant(IrBoolType, 1)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _bool_and(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
43
|
+
"""
|
|
44
|
+
LLVM IR Boolean "and"
|
|
45
|
+
"""
|
|
46
|
+
tmp: ir.Value = ir.IRBuilder.and_(builder, x, y)
|
|
47
|
+
return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _bool_or(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
51
|
+
"""
|
|
52
|
+
LLVM IR Boolean "or"
|
|
53
|
+
"""
|
|
54
|
+
tmp: ir.Value = ir.IRBuilder.or_(builder, x, y)
|
|
55
|
+
return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _bool_xor(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
59
|
+
"""
|
|
60
|
+
LLVM IR Boolean "xor"
|
|
61
|
+
"""
|
|
62
|
+
tmp: ir.Value = ir.IRBuilder.xor(builder, x, y)
|
|
63
|
+
return ir.IRBuilder.and_(builder, tmp, _IrBoolOne)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _float_max(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
67
|
+
"""
|
|
68
|
+
LLVM IR floating point "max"
|
|
69
|
+
"""
|
|
70
|
+
cond = builder.fcmp_ordered('>', x, y)
|
|
71
|
+
return builder.select(cond, x, y)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _float_min(builder: ir.IRBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
|
|
75
|
+
"""
|
|
76
|
+
LLVM IR floating point "min"
|
|
77
|
+
"""
|
|
78
|
+
cond = builder.fcmp_ordered('<', x, y)
|
|
79
|
+
return builder.select(cond, x, y)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# IR operations for TypeInfo: (add, mul)
|
|
83
|
+
_float_add: IrBOp = ir.IRBuilder.fadd
|
|
84
|
+
_float_mul: IrBOp = ir.IRBuilder.fmul
|
|
85
|
+
_int_add: IrBOp = ir.IRBuilder.add
|
|
86
|
+
_ind_mul: IrBOp = ir.IRBuilder.mul
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class DataType(Enum):
|
|
90
|
+
"""
|
|
91
|
+
Predefined TypeInfo objects.
|
|
92
|
+
|
|
93
|
+
Each member defines a mathematical ring, i.e., a machine data
|
|
94
|
+
type and the "add" and "mul" arithmetic operations over them.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
FLOAT_32 = TypeInfo(ct.c_float, ir.FloatType(), _float_add, _float_mul)
|
|
98
|
+
FLOAT_64 = TypeInfo(ct.c_double, ir.DoubleType(), _float_add, _float_mul)
|
|
99
|
+
INT_8 = TypeInfo(ct.c_int8, ir.IntType(8), _int_add, _ind_mul)
|
|
100
|
+
INT_16 = TypeInfo(ct.c_int16, ir.IntType(16), _int_add, _ind_mul)
|
|
101
|
+
INT_32 = TypeInfo(ct.c_int32, ir.IntType(32), _int_add, _ind_mul)
|
|
102
|
+
INT_64 = TypeInfo(ct.c_int64, ir.IntType(64), _int_add, _ind_mul)
|
|
103
|
+
BOOL = TypeInfo(ct.c_bool, IrBoolType, _bool_or, _bool_and)
|
|
104
|
+
XBOOL = TypeInfo(ct.c_bool, IrBoolType, _bool_xor, _bool_and)
|
|
105
|
+
MAX_MIN = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_min)
|
|
106
|
+
MAX_MUL = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_mul)
|
|
107
|
+
MAX_SUM = TypeInfo(ct.c_double, ir.DoubleType(), _float_max, _float_add)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class IRFunction:
|
|
111
|
+
"""
|
|
112
|
+
Data structure to hold information while building an LLVM IR program function.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, type_info: TypeInfo):
|
|
116
|
+
"""
|
|
117
|
+
Create an LLVM IR program function.
|
|
118
|
+
|
|
119
|
+
Actions performed:
|
|
120
|
+
1. LLVM will be initialized.
|
|
121
|
+
2. A IRBuilder will be constructed (field `builder`).
|
|
122
|
+
3. A module will be created (field `module`).
|
|
123
|
+
4. A function will be added to the module (field `function`), the function will
|
|
124
|
+
have the signature (T* in, T* tmp, T* out) -> Void, where T is `type_info.llvm_type`.
|
|
125
|
+
5. A basic block will be added to the function (named "entry").
|
|
126
|
+
"""
|
|
127
|
+
_init_llvm()
|
|
128
|
+
|
|
129
|
+
# Get important IR types
|
|
130
|
+
self.type_info: TypeInfo = type_info
|
|
131
|
+
self.ret_type: ir.Type = ir.VoidType()
|
|
132
|
+
self.ptr_type: ir.Type = self.type_info.llvm_type.as_pointer()
|
|
133
|
+
function_type = ir.FunctionType(self.ret_type, (self.ptr_type, self.ptr_type, self.ptr_type))
|
|
134
|
+
|
|
135
|
+
self.module = ir.Module()
|
|
136
|
+
self.function = ir.Function(self.module, function_type, name=_LVM_FUNCTION_NAME)
|
|
137
|
+
self.builder = ir.IRBuilder()
|
|
138
|
+
|
|
139
|
+
# Create a block of code in the function
|
|
140
|
+
bb_entry = self.function.append_basic_block('entry')
|
|
141
|
+
self.builder.position_at_end(bb_entry)
|
|
142
|
+
|
|
143
|
+
def llvm_program(self) -> str:
|
|
144
|
+
"""
|
|
145
|
+
Get the LLVM source code (i.e., the module as an LLVM string).
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
an LLVM program string that can be passed to `compile_llvm_program`.
|
|
149
|
+
"""
|
|
150
|
+
return str(self.module)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass
|
|
154
|
+
class LLVMRawProgram(RawProgram):
|
|
155
|
+
llvm_program: Optional[str]
|
|
156
|
+
engine: llvm.ExecutionEngine
|
|
157
|
+
opt: int
|
|
158
|
+
|
|
159
|
+
def __getstate__(self):
|
|
160
|
+
"""
|
|
161
|
+
Support for pickle.
|
|
162
|
+
"""
|
|
163
|
+
if self.llvm_program is None:
|
|
164
|
+
raise ValueError('need to have the LLVM program to pickle a Program object')
|
|
165
|
+
|
|
166
|
+
return {
|
|
167
|
+
'dtype': self.dtype,
|
|
168
|
+
'number_of_vars': self.number_of_vars,
|
|
169
|
+
'number_of_tmps': self.number_of_tmps,
|
|
170
|
+
'number_of_results': self.number_of_results,
|
|
171
|
+
'llvm_program': self.llvm_program,
|
|
172
|
+
'opt': self.opt,
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
def __setstate__(self, state):
|
|
176
|
+
"""
|
|
177
|
+
Support for pickle.
|
|
178
|
+
"""
|
|
179
|
+
self.dtype = state['dtype']
|
|
180
|
+
self.number_of_vars = state['number_of_vars']
|
|
181
|
+
self.number_of_tmps = state['number_of_tmps']
|
|
182
|
+
self.number_of_results = state['number_of_results']
|
|
183
|
+
self.llvm_program = state['llvm_program']
|
|
184
|
+
self.opt = state['opt']
|
|
185
|
+
|
|
186
|
+
# Compile the LLVM program
|
|
187
|
+
self.engine, self.function = compile_llvm_program(self.llvm_program, self.dtype, self.opt)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def compile_llvm_program(
|
|
191
|
+
llvm_program: str,
|
|
192
|
+
dtype: DType,
|
|
193
|
+
opt: int,
|
|
194
|
+
) -> Tuple[llvm.ExecutionEngine, RawProgramFunction]:
|
|
195
|
+
"""
|
|
196
|
+
Compile the given LLVM program.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
(engine, function) where
|
|
200
|
+
engine: is an LLVM execution engine, which must remain
|
|
201
|
+
in memory for the returned function to be valid.
|
|
202
|
+
function: is the raw Python callable for the compiled function.
|
|
203
|
+
"""
|
|
204
|
+
_init_llvm()
|
|
205
|
+
|
|
206
|
+
llvm_module = llvm.parse_assembly(llvm_program)
|
|
207
|
+
llvm_module.verify()
|
|
208
|
+
|
|
209
|
+
target = llvm.Target.from_default_triple().create_target_machine(opt=opt)
|
|
210
|
+
engine = llvm.create_mcjit_compiler(llvm_module, target)
|
|
211
|
+
|
|
212
|
+
# Calling finalize_object will create native code and make it executable.
|
|
213
|
+
engine.finalize_object()
|
|
214
|
+
|
|
215
|
+
engine.run_static_constructors()
|
|
216
|
+
|
|
217
|
+
# Get the function entry point
|
|
218
|
+
function_ptr = engine.get_function_address(_LVM_FUNCTION_NAME)
|
|
219
|
+
ctypes_ptr_type = ct.POINTER(dtype)
|
|
220
|
+
function = ct.CFUNCTYPE(None, ctypes_ptr_type, ctypes_ptr_type, ctypes_ptr_type)(function_ptr)
|
|
221
|
+
|
|
222
|
+
return engine, function
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _init_llvm() -> None:
|
|
226
|
+
"""
|
|
227
|
+
Ensure that LLVM is initialised.
|
|
228
|
+
"""
|
|
229
|
+
global __LLVM_INITIALISED
|
|
230
|
+
if not __LLVM_INITIALISED:
|
|
231
|
+
llvm.initialize()
|
|
232
|
+
llvm.initialize_native_target()
|
|
233
|
+
llvm.initialize_native_asmprinter()
|
|
234
|
+
__LLVM_INITIALISED = True
|
ck/example/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A package of standard probabilistic graphical models.
|
|
3
|
+
"""
|
|
4
|
+
from ck.pgm import PGM
|
|
5
|
+
from ck.example.alarm import Alarm
|
|
6
|
+
from ck.example.binary_clique import BinaryClique
|
|
7
|
+
from ck.example.bow_tie import BowTie
|
|
8
|
+
from ck.example.cancer import Cancer
|
|
9
|
+
from ck.example.asia import Asia
|
|
10
|
+
from ck.example.chain import Chain
|
|
11
|
+
from ck.example.child import Child
|
|
12
|
+
from ck.example.clique import Clique
|
|
13
|
+
from ck.example.cnf_pgm import CNF_PGM
|
|
14
|
+
from ck.example.diamond_square import DiamondSquare
|
|
15
|
+
from ck.example.earthquake import Earthquake
|
|
16
|
+
from ck.example.empty import Empty
|
|
17
|
+
from ck.example.hailfinder import Hailfinder
|
|
18
|
+
from ck.example.hepar2 import Hepar2
|
|
19
|
+
from ck.example.insurance import Insurance
|
|
20
|
+
from ck.example.loop import Loop
|
|
21
|
+
from ck.example.mildew import Mildew
|
|
22
|
+
from ck.example.munin import Munin
|
|
23
|
+
from ck.example.pathfinder import Pathfinder
|
|
24
|
+
from ck.example.rectangle import Rectangle
|
|
25
|
+
from ck.example.rain import Rain
|
|
26
|
+
from ck.example.run import Run
|
|
27
|
+
from ck.example.sachs import Sachs
|
|
28
|
+
from ck.example.sprinkler import Sprinkler
|
|
29
|
+
from ck.example.survey import Survey
|
|
30
|
+
from ck.example.star import Star
|
|
31
|
+
from ck.example.stress import Stress
|
|
32
|
+
from ck.example.student import Student
|
|
33
|
+
from ck.example.triangle_square import TriangleSquare
|
|
34
|
+
from ck.example.truss import Truss
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# A dictionary with entries, `name: class`, for all example PGM classes.
|
|
38
|
+
#
|
|
39
|
+
# Example usage:
|
|
40
|
+
# from ck.example import ALL_EXAMPLES
|
|
41
|
+
#
|
|
42
|
+
# my_pgm: PGM = ALL_EXAMPLES['Alarm']()
|
|
43
|
+
#
|
|
44
|
+
ALL_EXAMPLES = {
|
|
45
|
+
name: pgm_class
|
|
46
|
+
for name, pgm_class in globals().items()
|
|
47
|
+
if (
|
|
48
|
+
not name.startswith('_')
|
|
49
|
+
and name != PGM.__name__
|
|
50
|
+
and isinstance(pgm_class, type)
|
|
51
|
+
and issubclass(pgm_class, PGM)
|
|
52
|
+
)
|
|
53
|
+
}
|