compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
ck/example/truss.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import random as _random
|
|
2
|
+
from ck.pgm import PGM
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Truss(PGM):
|
|
6
|
+
r"""
|
|
7
|
+
This PGM is the 'Truss' factor graph.
|
|
8
|
+
|
|
9
|
+
The Truss is a factor graph with five random variables (a, b, c, d, e).
|
|
10
|
+
Binary factors are between pairs of random variables creating the pattern:
|
|
11
|
+
b ---- d
|
|
12
|
+
/ | / |
|
|
13
|
+
a | / |
|
|
14
|
+
\ | / |
|
|
15
|
+
c ---- e
|
|
16
|
+
If include_unaries then, also includes one unary factor per random variable.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
states_per_var=2,
|
|
22
|
+
include_unaries=True,
|
|
23
|
+
random_seed=123456
|
|
24
|
+
):
|
|
25
|
+
params = (states_per_var, include_unaries)
|
|
26
|
+
super().__init__(f'{self.__class__.__name__}({",".join(str(param) for param in params)})')
|
|
27
|
+
|
|
28
|
+
random_stream = _random.Random(random_seed).random
|
|
29
|
+
|
|
30
|
+
a = self.new_rv('a', states_per_var)
|
|
31
|
+
b = self.new_rv('b', states_per_var)
|
|
32
|
+
c = self.new_rv('c', states_per_var)
|
|
33
|
+
d = self.new_rv('d', states_per_var)
|
|
34
|
+
e = self.new_rv('e', states_per_var)
|
|
35
|
+
|
|
36
|
+
self.new_factor(a, b).set_dense().set_stream(random_stream)
|
|
37
|
+
self.new_factor(a, c).set_dense().set_stream(random_stream)
|
|
38
|
+
self.new_factor(b, c).set_dense().set_stream(random_stream)
|
|
39
|
+
self.new_factor(b, d).set_dense().set_stream(random_stream)
|
|
40
|
+
self.new_factor(c, d).set_dense().set_stream(random_stream)
|
|
41
|
+
self.new_factor(c, e).set_dense().set_stream(random_stream)
|
|
42
|
+
self.new_factor(d, e).set_dense().set_stream(random_stream)
|
|
43
|
+
|
|
44
|
+
if include_unaries:
|
|
45
|
+
self.new_factor(a).set_dense().set_stream(random_stream)
|
|
46
|
+
self.new_factor(b).set_dense().set_stream(random_stream)
|
|
47
|
+
self.new_factor(c).set_dense().set_stream(random_stream)
|
|
48
|
+
self.new_factor(d).set_dense().set_stream(random_stream)
|
|
49
|
+
self.new_factor(e).set_dense().set_stream(random_stream)
|
ck/in_out/__init__.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for parsing literal mapping files (lmap), as produced by the software ACE.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Optional, Dict, Sequence
|
|
9
|
+
|
|
10
|
+
from ck.in_out.parser_utils import ParseError, ParserInput
|
|
11
|
+
from ck.pgm import Indicator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def read_lmap(
|
|
15
|
+
input_stream,
|
|
16
|
+
*,
|
|
17
|
+
check_counts: bool = False,
|
|
18
|
+
node_names: Sequence[str] = (),
|
|
19
|
+
) -> LiteralMap:
|
|
20
|
+
"""
|
|
21
|
+
Parse the input literal-map file or string.
|
|
22
|
+
|
|
23
|
+
An "lmap" file is produced by the software ACE to help interpret an "nnf" or "ac" file.
|
|
24
|
+
See module `parse_ace_nnf` for more details.
|
|
25
|
+
|
|
26
|
+
The returned LiteralMap will provide mapping from literals to indicators
|
|
27
|
+
and mapping from literals to parameter values.
|
|
28
|
+
|
|
29
|
+
If a PGM is passed in, then its random variables will be used (for indicators)
|
|
30
|
+
otherwise a PGM will be created and random variables will be added to it as
|
|
31
|
+
needed.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
input_stream: an input that can be passed to `ParserInput`.
|
|
35
|
+
node_names: optional ordering of node names (for indicators in returned literal map).
|
|
36
|
+
check_counts: if true then literal and variable counts are checked.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
a LiteralMap object
|
|
40
|
+
"""
|
|
41
|
+
parser = _LiteralMapParser(check_counts, node_names)
|
|
42
|
+
parser.parse(input_stream)
|
|
43
|
+
return parser.lmap
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class LiteralMapRV:
|
|
48
|
+
name: str
|
|
49
|
+
rv_idx: int
|
|
50
|
+
number_of_states: int
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class LiteralMap:
|
|
55
|
+
"""
|
|
56
|
+
A data structure to hold a literal-map, i.e., provide mapping
|
|
57
|
+
from literals to indicators and mapping from literals to parameter values.
|
|
58
|
+
|
|
59
|
+
Fields:
|
|
60
|
+
rvs[name] = random_variable, where random_variable.name() == name
|
|
61
|
+
indicators[literal_code] = indicator
|
|
62
|
+
params[literal_code] = parameter_value
|
|
63
|
+
"""
|
|
64
|
+
rvs: Dict[str, LiteralMapRV]
|
|
65
|
+
indicators: Dict[int, Indicator]
|
|
66
|
+
params: Dict[int, float]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Parser(ABC):
|
|
70
|
+
|
|
71
|
+
def parse(self, input_stream):
|
|
72
|
+
input_stream = ParserInput(input_stream)
|
|
73
|
+
raise_f = lambda msg: input_stream.raise_error(msg)
|
|
74
|
+
try:
|
|
75
|
+
line = input_stream.readline()
|
|
76
|
+
while line:
|
|
77
|
+
line = line.strip()
|
|
78
|
+
if len(line) > 0:
|
|
79
|
+
if line[0] == 'c' and (len(line) == 1 or line[1] != 'c'):
|
|
80
|
+
self.comment(raise_f, line)
|
|
81
|
+
else:
|
|
82
|
+
line = line.split('$')
|
|
83
|
+
if line[0] != 'cc':
|
|
84
|
+
input_stream.raise_error(f'unexpected line start: {line[0]}')
|
|
85
|
+
code = line[1]
|
|
86
|
+
if code == 'N':
|
|
87
|
+
self.number_of_literals(raise_f, int(line[2]))
|
|
88
|
+
elif code == 'v':
|
|
89
|
+
self.number_of_rvs(raise_f, int(line[2]))
|
|
90
|
+
elif code == 'V':
|
|
91
|
+
self.rv(raise_f, line[2], int(line[3]))
|
|
92
|
+
elif code == 't':
|
|
93
|
+
self.number_of_tables(raise_f, int(line[2]))
|
|
94
|
+
elif code == 'T':
|
|
95
|
+
self.table(raise_f, line[2], int(line[3]))
|
|
96
|
+
elif code == 'I':
|
|
97
|
+
self.indicator(raise_f, int(line[2]), float(line[3]), line[4], line[5], int(line[6]))
|
|
98
|
+
elif code == 'C':
|
|
99
|
+
self.parameter(raise_f, int(line[2]), float(line[3]), line[4], line[5])
|
|
100
|
+
elif code in ['K', 'S']:
|
|
101
|
+
# ignore
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
input_stream.raise_error(f'unexpected line code: cc${code}')
|
|
105
|
+
|
|
106
|
+
line = input_stream.readline()
|
|
107
|
+
self.done(raise_f)
|
|
108
|
+
except ParseError as e:
|
|
109
|
+
raise e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
input_stream.raise_error(str(e))
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def comment(self, raise_f, message: str) -> None:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def number_of_literals(self, raise_f, num_literals: int) -> None:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
def number_of_rvs(self, raise_f, num_rvs: int) -> None:
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
@abstractmethod
|
|
126
|
+
def rv(self, raise_f, rv_name, num_states: int) -> None:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def number_of_tables(self, raise_f, num_tables: int) -> None:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def table(self, raise_f, child_rv_name: str, num_states: int) -> None:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def indicator(self, raise_f, literal_code: int, weight: float, arithmetic_op: str, rv_name: str,
|
|
139
|
+
state: int) -> None:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
@abstractmethod
|
|
143
|
+
def parameter(self, raise_f, literal_code: int, weight: float, arithmetic_op: str, extra: str) -> None:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def done(self, raise_f) -> None:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class _LiteralMapParser(Parser):
|
|
152
|
+
|
|
153
|
+
def __init__(self, check_counts: bool, node_names: Sequence[str]):
|
|
154
|
+
self.node_names: Dict[str, int] = {
|
|
155
|
+
name: i
|
|
156
|
+
for i, name in enumerate(node_names)
|
|
157
|
+
}
|
|
158
|
+
self.check_counts = check_counts
|
|
159
|
+
self.lmap = LiteralMap({}, {}, {})
|
|
160
|
+
self._number_of_literals = None
|
|
161
|
+
self._number_of_rvs = None
|
|
162
|
+
|
|
163
|
+
def number_of_literals(self, raise_f, num):
|
|
164
|
+
self._number_of_literals = num
|
|
165
|
+
|
|
166
|
+
def number_of_rvs(self, raise_f, num):
|
|
167
|
+
self._number_of_rvs = num
|
|
168
|
+
|
|
169
|
+
def rv(self, raise_f, rv_name, num_states):
|
|
170
|
+
if rv_name in self.lmap.rvs.keys():
|
|
171
|
+
raise_f(f'duplicated random variable: {rv_name}')
|
|
172
|
+
|
|
173
|
+
idx: Optional[int] = self.node_names.get(rv_name)
|
|
174
|
+
if idx is None:
|
|
175
|
+
idx = len(self.node_names)
|
|
176
|
+
self.node_names[rv_name] = idx
|
|
177
|
+
|
|
178
|
+
literal_rv = LiteralMapRV(rv_name, idx, num_states)
|
|
179
|
+
self.lmap.rvs[rv_name] = literal_rv
|
|
180
|
+
|
|
181
|
+
def indicator(self, raise_f, literal_code, weight, arithmetic_op, rv_name, state):
|
|
182
|
+
rv: Optional[LiteralMapRV] = self.lmap.rvs.get(rv_name)
|
|
183
|
+
if rv is None:
|
|
184
|
+
raise_f(f'unknown random variable: {rv_name}')
|
|
185
|
+
if literal_code in self.lmap.indicators.keys() or literal_code in self.lmap.params.keys():
|
|
186
|
+
raise_f(f'duplicated indicator literal: {literal_code}')
|
|
187
|
+
self.lmap.indicators[literal_code] = Indicator(rv.rv_idx, state)
|
|
188
|
+
self.lmap.params[literal_code] = weight
|
|
189
|
+
|
|
190
|
+
def parameter(self, raise_f, literal_code, weight, arithmetic_op, extra):
|
|
191
|
+
if literal_code in self.lmap.indicators.keys() or literal_code in self.lmap.params.keys():
|
|
192
|
+
raise_f(f'duplicated parameter literal: {literal_code}')
|
|
193
|
+
self.lmap.params[literal_code] = weight
|
|
194
|
+
|
|
195
|
+
def comment(self, raise_f, message: str) -> None:
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
def number_of_tables(self, raise_f, num_tables: int) -> None:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
def table(self, raise_f, child_rv_name: str, num_states: int) -> None:
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
def done(self, raise_f) -> None:
|
|
205
|
+
if self.check_counts:
|
|
206
|
+
# Perform consistency checks
|
|
207
|
+
if self._number_of_rvs is not None:
|
|
208
|
+
got_rvs: int = len(self.lmap.rvs)
|
|
209
|
+
if got_rvs != self._number_of_rvs:
|
|
210
|
+
raise_f(f'unexpected number of random variables: expected {self._number_of_rvs}, got {got_rvs}')
|
|
211
|
+
|
|
212
|
+
if self._number_of_literals is not None:
|
|
213
|
+
got_params: int = len(self.lmap.params)
|
|
214
|
+
expect_params: int = self._number_of_literals * 2
|
|
215
|
+
if got_params != expect_params:
|
|
216
|
+
raise_f(f'unexpected number of parameters: expected {expect_params}, got: {got_params}')
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Tuple, Optional, Dict, List, Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ck.circuit import Circuit, CircuitNode, VarNode, ConstValue, ConstNode
|
|
7
|
+
from ck.in_out.parse_ace_lmap import LiteralMap
|
|
8
|
+
from ck.in_out.parser_utils import ParseError, ParserInput
|
|
9
|
+
from ck.pgm import Indicator
|
|
10
|
+
from ck.pgm_circuit.slot_map import SlotKey, SlotMap
|
|
11
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
12
|
+
|
|
13
|
+
VAR_NODE = {"l", "L"}
|
|
14
|
+
ADD_NODE = {"O", "+"}
|
|
15
|
+
MUL_NODE = {"A", "*"}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def read_nnf_with_literal_map(
|
|
19
|
+
input_stream,
|
|
20
|
+
literal_map: LiteralMap,
|
|
21
|
+
*,
|
|
22
|
+
indicators: Sequence[Indicator] = (),
|
|
23
|
+
const_parameters: bool = True,
|
|
24
|
+
optimise_ops: bool = True,
|
|
25
|
+
check_header: bool = False,
|
|
26
|
+
) -> Tuple[CircuitNode, SlotMap, NDArrayFloat64]:
|
|
27
|
+
"""
|
|
28
|
+
Parse an input, as per `read_nnf`, using the given literal map to
|
|
29
|
+
create a slot map with indicator entries.
|
|
30
|
+
|
|
31
|
+
See: `read_nnf` and `read_lmap` for more information.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
input_stream: to parse, as per `ParserInput` argument.
|
|
35
|
+
indicators: any indicators to pre allocate to circuit variables.
|
|
36
|
+
literal_map: mapping from literal code to indicators.
|
|
37
|
+
check_header: if true, an exception is raised if the number of nodes or arcs is not as expected.
|
|
38
|
+
const_parameters: if true, the potential function parameters will be circuit
|
|
39
|
+
constants, otherwise they will be circuit variables.
|
|
40
|
+
optimise_ops: if true then circuit optimised operations will be used.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
(circuit_top, slot_map, params)
|
|
44
|
+
circuit_top: is the resulting top node from parsing the input.
|
|
45
|
+
slot_map: is a map from indicator to a circuit var index (int).
|
|
46
|
+
params: is a numpy array of parameter values, co-indexed with `circuit.vars[num_indicators:]`
|
|
47
|
+
"""
|
|
48
|
+
circuit = Circuit()
|
|
49
|
+
|
|
50
|
+
# Set the `const_literals` parameter for `read_nnf`
|
|
51
|
+
const_literals: Optional[Dict[int, ConstValue]]
|
|
52
|
+
if const_parameters:
|
|
53
|
+
indicator_literal_code = literal_map.indicators.keys()
|
|
54
|
+
const_literals = {
|
|
55
|
+
literal_code: value
|
|
56
|
+
for literal_code, value in literal_map.params.items()
|
|
57
|
+
if literal_code not in indicator_literal_code
|
|
58
|
+
}
|
|
59
|
+
else:
|
|
60
|
+
const_literals = {}
|
|
61
|
+
|
|
62
|
+
# Make a slot map to map from an indicator to a circuit variable index.
|
|
63
|
+
# Preload `var_literals` to map literal codes to circuit vars.
|
|
64
|
+
# We allocate the circuit variables here to ensure that indicators
|
|
65
|
+
# come before parameters in the circuit variables.
|
|
66
|
+
slot_map: Dict[SlotKey, int] = {
|
|
67
|
+
indicator: i
|
|
68
|
+
for i, indicator in enumerate(indicators)
|
|
69
|
+
}
|
|
70
|
+
circuit.new_vars(len(slot_map))
|
|
71
|
+
var_literals: Dict[int, int] = {}
|
|
72
|
+
for literal_code, indicator in literal_map.indicators.items():
|
|
73
|
+
slot = slot_map.get(indicator)
|
|
74
|
+
if slot is None:
|
|
75
|
+
slot = circuit.new_var().idx
|
|
76
|
+
slot_map[indicator] = slot
|
|
77
|
+
var_literals[literal_code] = slot
|
|
78
|
+
num_indicators: int = len(slot_map)
|
|
79
|
+
|
|
80
|
+
# Parse the nnf file
|
|
81
|
+
top_node, vars_literals = read_nnf(
|
|
82
|
+
input_stream,
|
|
83
|
+
var_literals=var_literals,
|
|
84
|
+
const_literals=const_literals,
|
|
85
|
+
circuit=circuit,
|
|
86
|
+
check_header=check_header,
|
|
87
|
+
optimise_ops=optimise_ops,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Get the parameter values.
|
|
91
|
+
# Any new circuit vars added to the circuit are parameters.
|
|
92
|
+
# Parameter IDs are not added to the slot map as we don't know them.
|
|
93
|
+
num_parameters: int = top_node.circuit.number_of_vars - num_indicators
|
|
94
|
+
assert num_parameters == 0 or not const_parameters, 'const_parameters -> num_parameters == 0'
|
|
95
|
+
params: NDArrayFloat64 = np.zeros(num_parameters, dtype=np.float64)
|
|
96
|
+
for literal_code, value in literal_map.params.items():
|
|
97
|
+
literal_slot: Optional[int] = var_literals.get(literal_code)
|
|
98
|
+
if literal_slot is not None and literal_slot >= num_indicators:
|
|
99
|
+
params[literal_slot - num_indicators] = value
|
|
100
|
+
|
|
101
|
+
return top_node, slot_map, params
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def read_nnf(
|
|
105
|
+
input_stream,
|
|
106
|
+
*,
|
|
107
|
+
var_literals: Optional[Dict[int, int]] = None,
|
|
108
|
+
const_literals: Optional[Dict[int, ConstValue]] = None,
|
|
109
|
+
circuit: Optional[Circuit] = None,
|
|
110
|
+
check_header: bool = False,
|
|
111
|
+
optimise_ops: bool = True,
|
|
112
|
+
) -> Tuple[CircuitNode, Dict[int, int]]:
|
|
113
|
+
"""
|
|
114
|
+
Parse the input_stream (file or string) as "nnf" or "ac" file format describing a circuit.
|
|
115
|
+
|
|
116
|
+
The input consists of propositional logical sentences in negative normal form (NNF).
|
|
117
|
+
This covers both ".ac" and ".nnf" files produced by the software ACE.
|
|
118
|
+
|
|
119
|
+
This function returns the last node parsed (or the constant zero node if no nodes passed).
|
|
120
|
+
It also returns a mapping from literal code (int) to circuit variable index.
|
|
121
|
+
|
|
122
|
+
Two optional dictionaries may be supplied. Dictionary `var_literals` maps a literal
|
|
123
|
+
code to a pre-existing circuit variable index. Dictionary `const_literals` maps a literal
|
|
124
|
+
code to a constant value. A literal code should not appear in both dictionaries.
|
|
125
|
+
|
|
126
|
+
Any literal code that is parsed but does not appear in `var_literals` or `const_literals`
|
|
127
|
+
results in a new circuit variable being created and a corresponding entry added to
|
|
128
|
+
`var_literals`.
|
|
129
|
+
|
|
130
|
+
External software may modify an NNF file by removing arcs, but it may not update the header.
|
|
131
|
+
Although the resulting file is not conformant, it is still parsable (by ignoring the header).
|
|
132
|
+
Parameter `check_header` can be set to true, which causes an exception being raised if the
|
|
133
|
+
header disagrees with the rest of the file.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
input_stream: to parse, as per `ParserInput` argument.
|
|
137
|
+
var_literals: an optional mapping from literal code to existing circuit variable index.
|
|
138
|
+
const_literals: an optional mapping from literal code to constant value.
|
|
139
|
+
circuit: an optional empty circuit to reuse.
|
|
140
|
+
check_header: if true, an exception is raised if the number of nodes or arcs is not as expected.
|
|
141
|
+
optimise_ops: if true then circuit optimised operations will be used.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
(circuit_top, var_literals)
|
|
145
|
+
circuit_top: is the resulting top node from parsing the input.
|
|
146
|
+
var_literals: is a mapping from literal code (int) to a circuit variable index (int).
|
|
147
|
+
"""
|
|
148
|
+
if circuit is None:
|
|
149
|
+
circuit: Circuit = Circuit()
|
|
150
|
+
|
|
151
|
+
if var_literals is None:
|
|
152
|
+
var_literals: Dict[int, int] = {}
|
|
153
|
+
|
|
154
|
+
if const_literals is None:
|
|
155
|
+
const_literals: Dict[int, ConstValue] = {}
|
|
156
|
+
|
|
157
|
+
parser = CircuitParser(circuit, check_header, var_literals, const_literals, optimise_ops)
|
|
158
|
+
parser.parse(input_stream)
|
|
159
|
+
|
|
160
|
+
nodes = parser.nodes
|
|
161
|
+
cct_top = circuit.zero if len(nodes) == 0 else nodes[-1]
|
|
162
|
+
|
|
163
|
+
return cct_top, var_literals
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class Parser(ABC):
|
|
167
|
+
|
|
168
|
+
def parse(self, input_stream):
|
|
169
|
+
input_stream = ParserInput(input_stream)
|
|
170
|
+
raise_f = lambda msg: input_stream.raise_error(msg)
|
|
171
|
+
try:
|
|
172
|
+
state = 0
|
|
173
|
+
line = input_stream.readline()
|
|
174
|
+
while line and state < 999:
|
|
175
|
+
line = line.strip()
|
|
176
|
+
if len(line) > 0:
|
|
177
|
+
if state == 0:
|
|
178
|
+
if line[0] == 'c':
|
|
179
|
+
self.comment(raise_f, line)
|
|
180
|
+
else:
|
|
181
|
+
line = line.split()
|
|
182
|
+
if line[0] == 'nnf':
|
|
183
|
+
if len(line) != 4:
|
|
184
|
+
raise_f('expected: nnf <num-nodes> <num-edges> <num-?>')
|
|
185
|
+
self.header(raise_f, int(line[1]), int(line[2]), int(line[3]))
|
|
186
|
+
state = 1
|
|
187
|
+
else:
|
|
188
|
+
raise_f('expected: "nnf"')
|
|
189
|
+
elif state == 1:
|
|
190
|
+
if line[0] == '%':
|
|
191
|
+
state = 999
|
|
192
|
+
else:
|
|
193
|
+
line = line.split()
|
|
194
|
+
code = line[0]
|
|
195
|
+
if code in VAR_NODE:
|
|
196
|
+
if len(line) != 2:
|
|
197
|
+
raise_f(f'expected: {code} <literal-code>')
|
|
198
|
+
self.literal(raise_f, int(line[1]))
|
|
199
|
+
else:
|
|
200
|
+
is_add = code in ADD_NODE
|
|
201
|
+
is_mul = code in MUL_NODE
|
|
202
|
+
if not (is_add or is_mul):
|
|
203
|
+
raise_f(f'unexpected line starting with: {code}')
|
|
204
|
+
if len(line) < 2:
|
|
205
|
+
raise_f(f'expected: {code} <num_args> <arguments>...')
|
|
206
|
+
num_args_idx = 2 if code == 'O' else 1
|
|
207
|
+
num_args = int(line[num_args_idx])
|
|
208
|
+
args = [int(arg) for arg in line[num_args_idx + 1:]]
|
|
209
|
+
if len(args) != num_args:
|
|
210
|
+
raise_f(f'unexpected number of args for: {code}')
|
|
211
|
+
if is_add:
|
|
212
|
+
self.add_node(raise_f, args)
|
|
213
|
+
else:
|
|
214
|
+
self.mul_node(raise_f, args)
|
|
215
|
+
else:
|
|
216
|
+
raise_f(f'unexpected parser state: {state}')
|
|
217
|
+
line = input_stream.readline()
|
|
218
|
+
self.done(raise_f)
|
|
219
|
+
except ParseError as e:
|
|
220
|
+
raise e
|
|
221
|
+
except Exception as e:
|
|
222
|
+
input_stream.raise_error(str(e))
|
|
223
|
+
|
|
224
|
+
@abstractmethod
|
|
225
|
+
def comment(self, raise_f, message: str) -> None:
|
|
226
|
+
...
|
|
227
|
+
|
|
228
|
+
@abstractmethod
|
|
229
|
+
def header(self, raise_f, num_nodes: int, num_edges: int, num_: int):
|
|
230
|
+
...
|
|
231
|
+
|
|
232
|
+
@abstractmethod
|
|
233
|
+
def literal(self, raise_f, literal_code: int) -> None:
|
|
234
|
+
...
|
|
235
|
+
|
|
236
|
+
@abstractmethod
|
|
237
|
+
def add_node(self, raise_f, args: List[int]) -> None:
|
|
238
|
+
...
|
|
239
|
+
|
|
240
|
+
@abstractmethod
|
|
241
|
+
def mul_node(self, raise_f, args: List[int]) -> None:
|
|
242
|
+
...
|
|
243
|
+
|
|
244
|
+
@abstractmethod
|
|
245
|
+
def done(self, raise_f) -> None:
|
|
246
|
+
...
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class CircuitParser(Parser):
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
circuit: Circuit,
|
|
254
|
+
check_header: bool,
|
|
255
|
+
var_literals: Dict[int, int],
|
|
256
|
+
const_literals: Dict[int, ConstValue],
|
|
257
|
+
optimise_ops: bool,
|
|
258
|
+
):
|
|
259
|
+
self.check_header: bool = check_header
|
|
260
|
+
self.var_literals: Dict[int, int] = var_literals
|
|
261
|
+
self.const_literals: Dict[int, ConstValue] = const_literals
|
|
262
|
+
self.optimise_ops: bool = optimise_ops
|
|
263
|
+
self.circuit: Circuit = circuit
|
|
264
|
+
self.nodes: List[CircuitNode] = []
|
|
265
|
+
self.num_nodes = None # read from the file header for checking
|
|
266
|
+
self.num_edges = None # read from the file header for checking
|
|
267
|
+
|
|
268
|
+
def comment(self, raise_f, message: str) -> None:
|
|
269
|
+
pass
|
|
270
|
+
|
|
271
|
+
def literal(self, raise_f, literal_code: int) -> None:
|
|
272
|
+
"""
|
|
273
|
+
Makes either a VarNode or a ConstNode.
|
|
274
|
+
"""
|
|
275
|
+
const_value: Optional[ConstValue] = self.const_literals.get(literal_code)
|
|
276
|
+
if const_value is not None:
|
|
277
|
+
# Literal code maps to a constant value
|
|
278
|
+
if literal_code in self.var_literals:
|
|
279
|
+
raise_f('literal code both constant and variable: {literal_code}')
|
|
280
|
+
node: ConstNode = self.circuit.const(const_value)
|
|
281
|
+
|
|
282
|
+
elif (var_idx := self.var_literals.get(literal_code)) is not None:
|
|
283
|
+
# Literal code maps to an existing circuit variable
|
|
284
|
+
node: VarNode = self.circuit.vars[var_idx]
|
|
285
|
+
|
|
286
|
+
else:
|
|
287
|
+
# Literal code maps to a new circuit variable
|
|
288
|
+
node: VarNode = self.circuit.new_var()
|
|
289
|
+
self.var_literals[literal_code] = node.idx
|
|
290
|
+
|
|
291
|
+
self.nodes.append(node)
|
|
292
|
+
|
|
293
|
+
def add_node(self, raise_f, args: List[int]) -> None:
|
|
294
|
+
"""
|
|
295
|
+
Makes a AddNode (or other if optimised).
|
|
296
|
+
"""
|
|
297
|
+
arg_nodes = [self.nodes[arg] for arg in args]
|
|
298
|
+
if self.optimise_ops:
|
|
299
|
+
self.nodes.append(self.circuit.optimised_add(arg_nodes))
|
|
300
|
+
else:
|
|
301
|
+
self.nodes.append(self.circuit.add(arg_nodes))
|
|
302
|
+
|
|
303
|
+
def mul_node(self, raise_f, args: List[int]) -> None:
|
|
304
|
+
"""
|
|
305
|
+
Makes a MulNode (or other if optimised).
|
|
306
|
+
"""
|
|
307
|
+
arg_nodes = [self.nodes[arg] for arg in args]
|
|
308
|
+
if self.optimise_ops:
|
|
309
|
+
self.nodes.append(self.circuit.optimised_mul(arg_nodes))
|
|
310
|
+
else:
|
|
311
|
+
self.nodes.append(self.circuit.mul(arg_nodes))
|
|
312
|
+
|
|
313
|
+
def header(self, raise_f, num_nodes: int, num_edges: int, num_: int) -> None:
|
|
314
|
+
self.num_nodes = num_nodes
|
|
315
|
+
self.num_edges = num_edges
|
|
316
|
+
|
|
317
|
+
def done(self, raise_f) -> None:
|
|
318
|
+
if self.check_header:
|
|
319
|
+
if len(self.nodes) != self.num_nodes:
|
|
320
|
+
raise_f(f'unexpected number of nodes: {len(self.nodes)} expected: {self.num_nodes}')
|
|
321
|
+
if self.circuit.number_of_arcs != self.num_edges:
|
|
322
|
+
raise_f(f'unexpected number of arcs: {self.circuit.number_of_arcs} expected: {self.num_edges}')
|