compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +13 -0
- ck/circuit/circuit.c +38749 -0
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit_py.py +807 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser.py +81 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53674 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +288 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3494 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +75 -0
- ck/pgm_circuit/program_with_slotmap.py +234 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +252 -0
- ck/pgm_compiler/factor_elimination.py +383 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +226 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
- ck/pgm_compiler/support/clusters.py +556 -0
- ck/pgm_compiler/support/factor_tables.py +398 -0
- ck/pgm_compiler/support/join_tree.py +275 -0
- ck/pgm_compiler/support/named_compiler_maker.py +33 -0
- ck/pgm_compiler/variable_elimination.py +89 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +47 -0
- ck/probability/probability_space.py +568 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +129 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +61 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +66 -0
- ck/sampling/wmc_direct_sampler.py +169 -0
- ck/sampling/wmc_gibbs_sampler.py +147 -0
- ck/sampling/wmc_metropolis_sampler.py +159 -0
- ck/sampling/wmc_rejection_sampler.py +113 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +153 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +44 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +88 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
- compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
ck/example/truss.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import random as _random
|
|
2
|
+
from ck.pgm import PGM
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Truss(PGM):
|
|
6
|
+
r"""
|
|
7
|
+
This PGM is the 'Truss' factor graph.
|
|
8
|
+
|
|
9
|
+
The Truss is a factor graph with five random variables (a, b, c, d, e).
|
|
10
|
+
Binary factors are between pairs of random variables creating the pattern:
|
|
11
|
+
b ---- d
|
|
12
|
+
/ | / |
|
|
13
|
+
a | / |
|
|
14
|
+
\ | / |
|
|
15
|
+
c ---- e
|
|
16
|
+
If include_unaries then, also includes one unary factor per random variable.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
states_per_var=2,
|
|
22
|
+
include_unaries=True,
|
|
23
|
+
random_seed=123456
|
|
24
|
+
):
|
|
25
|
+
params = (states_per_var, include_unaries)
|
|
26
|
+
super().__init__(f'{self.__class__.__name__}({",".join(str(param) for param in params)})')
|
|
27
|
+
|
|
28
|
+
random_stream = _random.Random(random_seed).random
|
|
29
|
+
|
|
30
|
+
a = self.new_rv('a', states_per_var)
|
|
31
|
+
b = self.new_rv('b', states_per_var)
|
|
32
|
+
c = self.new_rv('c', states_per_var)
|
|
33
|
+
d = self.new_rv('d', states_per_var)
|
|
34
|
+
e = self.new_rv('e', states_per_var)
|
|
35
|
+
|
|
36
|
+
self.new_factor(a, b).set_dense().set_stream(random_stream)
|
|
37
|
+
self.new_factor(a, c).set_dense().set_stream(random_stream)
|
|
38
|
+
self.new_factor(b, c).set_dense().set_stream(random_stream)
|
|
39
|
+
self.new_factor(b, d).set_dense().set_stream(random_stream)
|
|
40
|
+
self.new_factor(c, d).set_dense().set_stream(random_stream)
|
|
41
|
+
self.new_factor(c, e).set_dense().set_stream(random_stream)
|
|
42
|
+
self.new_factor(d, e).set_dense().set_stream(random_stream)
|
|
43
|
+
|
|
44
|
+
if include_unaries:
|
|
45
|
+
self.new_factor(a).set_dense().set_stream(random_stream)
|
|
46
|
+
self.new_factor(b).set_dense().set_stream(random_stream)
|
|
47
|
+
self.new_factor(c).set_dense().set_stream(random_stream)
|
|
48
|
+
self.new_factor(d).set_dense().set_stream(random_stream)
|
|
49
|
+
self.new_factor(e).set_dense().set_stream(random_stream)
|
ck/in_out/__init__.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for parsing literal mapping files (lmap), as produced by the software ACE.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Optional, Dict, Sequence
|
|
9
|
+
|
|
10
|
+
from ck.in_out.parser_utils import ParseError, ParserInput
|
|
11
|
+
from ck.pgm import Indicator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def read_lmap(
|
|
15
|
+
input_stream,
|
|
16
|
+
*,
|
|
17
|
+
check_counts: bool = False,
|
|
18
|
+
node_names: Sequence[str] = (),
|
|
19
|
+
) -> LiteralMap:
|
|
20
|
+
"""
|
|
21
|
+
Parse the input literal-map file or string.
|
|
22
|
+
|
|
23
|
+
An "lmap" file is produced by the software ACE to help interpret an "nnf" or "ac" file.
|
|
24
|
+
See module `parse_ace_nnf` for more details.
|
|
25
|
+
|
|
26
|
+
The returned LiteralMap will provide mapping from literals to indicators
|
|
27
|
+
and mapping from literals to parameter values.
|
|
28
|
+
|
|
29
|
+
If a PGM is passed in, then its random variables will be used (for indicators)
|
|
30
|
+
otherwise a PGM will be created and random variables will be added to it as
|
|
31
|
+
needed.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
input_stream: an input that can be passed to `ParserInput`.
|
|
35
|
+
node_names: optional ordering of node names (for indicators in returned literal map).
|
|
36
|
+
check_counts: if true then literal and variable counts are checked.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
a LiteralMap object
|
|
40
|
+
"""
|
|
41
|
+
parser = _LiteralMapParser(check_counts, node_names)
|
|
42
|
+
parser.parse(input_stream)
|
|
43
|
+
return parser.lmap
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class LiteralMapRV:
|
|
48
|
+
name: str
|
|
49
|
+
rv_idx: int
|
|
50
|
+
number_of_states: int
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class LiteralMap:
|
|
55
|
+
"""
|
|
56
|
+
A data structure to hold a literal-map, i.e., provide mapping
|
|
57
|
+
from literals to indicators and mapping from literals to parameter values.
|
|
58
|
+
|
|
59
|
+
Fields:
|
|
60
|
+
rvs[name] = random_variable, where random_variable.name() == name
|
|
61
|
+
indicators[literal_code] = indicator
|
|
62
|
+
params[literal_code] = parameter_value
|
|
63
|
+
"""
|
|
64
|
+
rvs: Dict[str, LiteralMapRV]
|
|
65
|
+
indicators: Dict[int, Indicator]
|
|
66
|
+
params: Dict[int, float]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Parser(ABC):
|
|
70
|
+
|
|
71
|
+
def parse(self, input_stream):
|
|
72
|
+
input_stream = ParserInput(input_stream)
|
|
73
|
+
raise_f = lambda msg: input_stream.raise_error(msg)
|
|
74
|
+
try:
|
|
75
|
+
line = input_stream.readline()
|
|
76
|
+
while line:
|
|
77
|
+
line = line.strip()
|
|
78
|
+
if len(line) > 0:
|
|
79
|
+
if line[0] == 'c' and (len(line) == 1 or line[1] != 'c'):
|
|
80
|
+
self.comment(raise_f, line)
|
|
81
|
+
else:
|
|
82
|
+
line = line.split('$')
|
|
83
|
+
if line[0] != 'cc':
|
|
84
|
+
input_stream.raise_error(f'unexpected line start: {line[0]}')
|
|
85
|
+
code = line[1]
|
|
86
|
+
if code == 'N':
|
|
87
|
+
self.number_of_literals(raise_f, int(line[2]))
|
|
88
|
+
elif code == 'v':
|
|
89
|
+
self.number_of_rvs(raise_f, int(line[2]))
|
|
90
|
+
elif code == 'V':
|
|
91
|
+
self.rv(raise_f, line[2], int(line[3]))
|
|
92
|
+
elif code == 't':
|
|
93
|
+
self.number_of_tables(raise_f, int(line[2]))
|
|
94
|
+
elif code == 'T':
|
|
95
|
+
self.table(raise_f, line[2], int(line[3]))
|
|
96
|
+
elif code == 'I':
|
|
97
|
+
self.indicator(raise_f, int(line[2]), float(line[3]), line[4], line[5], int(line[6]))
|
|
98
|
+
elif code == 'C':
|
|
99
|
+
self.parameter(raise_f, int(line[2]), float(line[3]), line[4], line[5])
|
|
100
|
+
elif code in ['K', 'S']:
|
|
101
|
+
# ignore
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
input_stream.raise_error(f'unexpected line code: cc${code}')
|
|
105
|
+
|
|
106
|
+
line = input_stream.readline()
|
|
107
|
+
self.done(raise_f)
|
|
108
|
+
except ParseError as e:
|
|
109
|
+
raise e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
input_stream.raise_error(e)
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def comment(self, raise_f, message: str) -> None:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def number_of_literals(self, raise_f, num_literals: int) -> None:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
def number_of_rvs(self, raise_f, num_rvs: int) -> None:
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
@abstractmethod
|
|
126
|
+
def rv(self, raise_f, rv_name, num_states: int) -> None:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def number_of_tables(self, raise_f, num_tables: int) -> None:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def table(self, raise_f, child_rv_name: str, num_states: int) -> None:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def indicator(self, raise_f, literal_code: int, weight: float, arithmetic_op: str, rv_name: str,
|
|
139
|
+
state: int) -> None:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
@abstractmethod
|
|
143
|
+
def parameter(self, raise_f, literal_code: int, weight: float, arithmetic_op: str, extra: str) -> None:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def done(self, raise_f) -> None:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class _LiteralMapParser(Parser):
|
|
152
|
+
|
|
153
|
+
def __init__(self, check_counts: bool, node_names: Sequence[str]):
|
|
154
|
+
self.node_names: Dict[str, int] = {
|
|
155
|
+
name: i
|
|
156
|
+
for i, name in enumerate(node_names)
|
|
157
|
+
}
|
|
158
|
+
self.check_counts = check_counts
|
|
159
|
+
self.lmap = LiteralMap({}, {}, {})
|
|
160
|
+
self._number_of_literals = None
|
|
161
|
+
self._number_of_rvs = None
|
|
162
|
+
|
|
163
|
+
def number_of_literals(self, raise_f, num):
|
|
164
|
+
self._number_of_literals = num
|
|
165
|
+
|
|
166
|
+
def number_of_rvs(self, raise_f, num):
|
|
167
|
+
self._number_of_rvs = num
|
|
168
|
+
|
|
169
|
+
def rv(self, raise_f, rv_name, num_states):
|
|
170
|
+
if rv_name in self.lmap.rvs.keys():
|
|
171
|
+
raise_f(f'duplicated random variable: {rv_name}')
|
|
172
|
+
|
|
173
|
+
idx: Optional[int] = self.node_names.get(rv_name)
|
|
174
|
+
if idx is None:
|
|
175
|
+
idx = len(self.node_names)
|
|
176
|
+
self.node_names[rv_name] = idx
|
|
177
|
+
|
|
178
|
+
literal_rv = LiteralMapRV(rv_name, idx, num_states)
|
|
179
|
+
self.lmap.rvs[rv_name] = literal_rv
|
|
180
|
+
|
|
181
|
+
def indicator(self, raise_f, literal_code, weight, arithmetic_op, rv_name, state):
|
|
182
|
+
rv: Optional[LiteralMapRV] = self.lmap.rvs.get(rv_name)
|
|
183
|
+
if rv is None:
|
|
184
|
+
raise_f(f'unknown random variable: {rv_name}')
|
|
185
|
+
if literal_code in self.lmap.indicators.keys() or literal_code in self.lmap.params.keys():
|
|
186
|
+
raise_f(f'duplicated indicator literal: {literal_code}')
|
|
187
|
+
self.lmap.indicators[literal_code] = Indicator(rv.rv_idx, state)
|
|
188
|
+
self.lmap.params[literal_code] = weight
|
|
189
|
+
|
|
190
|
+
def parameter(self, raise_f, literal_code, weight, arithmetic_op, extra):
|
|
191
|
+
if literal_code in self.lmap.indicators.keys() or literal_code in self.lmap.params.keys():
|
|
192
|
+
raise_f(f'duplicated parameter literal: {literal_code}')
|
|
193
|
+
self.lmap.params[literal_code] = weight
|
|
194
|
+
|
|
195
|
+
def comment(self, raise_f, message: str) -> None:
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
def number_of_tables(self, raise_f, num_tables: int) -> None:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
def table(self, raise_f, child_rv_name: str, num_states: int) -> None:
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
def done(self, raise_f) -> None:
|
|
205
|
+
if self.check_counts:
|
|
206
|
+
# Perform consistency checks
|
|
207
|
+
if self._number_of_rvs is not None:
|
|
208
|
+
got_rvs: int = len(self.lmap.rvs)
|
|
209
|
+
if got_rvs != self._number_of_rvs:
|
|
210
|
+
raise_f(f'unexpected number of random variables: expected {self._number_of_rvs}, got {got_rvs}')
|
|
211
|
+
|
|
212
|
+
if self._number_of_literals is not None:
|
|
213
|
+
got_params: int = len(self.lmap.params)
|
|
214
|
+
expect_params: int = self._number_of_literals * 2
|
|
215
|
+
if got_params != expect_params:
|
|
216
|
+
raise_f(f'unexpected number of parameters: expected {expect_params}, got: {got_params}')
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Tuple, Optional, Dict, List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ck.circuit import Circuit, CircuitNode, VarNode, ConstValue, ConstNode
|
|
7
|
+
from ck.in_out.parse_ace_lmap import LiteralMap
|
|
8
|
+
from ck.in_out.parser_utils import ParseError, ParserInput
|
|
9
|
+
from ck.pgm_circuit.slot_map import SlotKey, SlotMap
|
|
10
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
11
|
+
|
|
12
|
+
VAR_NODE = {"l", "L"}
|
|
13
|
+
ADD_NODE = {"O", "+"}
|
|
14
|
+
MUL_NODE = {"A", "*"}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def read_nnf_with_literal_map(
|
|
18
|
+
input_stream,
|
|
19
|
+
literal_map: LiteralMap,
|
|
20
|
+
*,
|
|
21
|
+
const_parameters: bool = True,
|
|
22
|
+
optimise_ops: bool = True,
|
|
23
|
+
circuit: Optional[Circuit] = None,
|
|
24
|
+
check_header: bool = False,
|
|
25
|
+
) -> Tuple[CircuitNode, SlotMap, NDArrayFloat64]:
|
|
26
|
+
"""
|
|
27
|
+
Parse an input, as per `read_nnf`, using the given literal map to
|
|
28
|
+
create a slot map with indicator entries.
|
|
29
|
+
|
|
30
|
+
See: `read_nnf` and `read_lmap` for more information.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
input_stream: to parse, as per `ParserInput` argument.
|
|
34
|
+
literal_map: mapping from literal code to indicators.
|
|
35
|
+
circuit: an optional circuit to reuse.
|
|
36
|
+
check_header: if true, an exception is raised if the number of nodes or arcs is not as expected.
|
|
37
|
+
const_parameters: if true, the potential function parameters will be circuit
|
|
38
|
+
constants, otherwise they will be circuit variables.
|
|
39
|
+
optimise_ops: if true then circuit optimised operations will be used.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
(circuit_top, slot_map, params)
|
|
43
|
+
circuit_top: is the resulting top node from parsing the input.
|
|
44
|
+
slot_map: is a map from indicator to a circuit var index (int).
|
|
45
|
+
params: is a numpy array of parameter values, co-indexed with `circuit.vars[num_indicators:]`
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
# Set the `const_literals` parameter for `read_nnf`
|
|
49
|
+
const_literals: Optional[Dict[int, ConstValue]]
|
|
50
|
+
if const_parameters:
|
|
51
|
+
indicator_literal_code = literal_map.indicators.keys()
|
|
52
|
+
const_literals = {
|
|
53
|
+
literal_code: value
|
|
54
|
+
for literal_code, value in literal_map.params.items()
|
|
55
|
+
if literal_code not in indicator_literal_code
|
|
56
|
+
}
|
|
57
|
+
else:
|
|
58
|
+
const_literals = {}
|
|
59
|
+
|
|
60
|
+
top_node, literal_slot_map = read_nnf(
|
|
61
|
+
input_stream,
|
|
62
|
+
const_literals=const_literals,
|
|
63
|
+
circuit=circuit,
|
|
64
|
+
check_header=check_header,
|
|
65
|
+
optimise_ops=optimise_ops,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Build the slot map from indicator to slot
|
|
69
|
+
slot_map: Dict[SlotKey, int] = {
|
|
70
|
+
indicator: literal_slot_map[literal_code]
|
|
71
|
+
for literal_code, indicator in literal_map.indicators.items()
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Get the parameter values
|
|
75
|
+
num_indicators: int = len(literal_map.indicators)
|
|
76
|
+
num_parameters: int = top_node.circuit.number_of_vars - num_indicators
|
|
77
|
+
assert num_parameters == 0 or not const_parameters, 'const_parameters -> num_parameters == 0'
|
|
78
|
+
|
|
79
|
+
params: NDArrayFloat64 = np.zeros(num_parameters, dtype=np.float64)
|
|
80
|
+
for literal_code, value in literal_map.params.items():
|
|
81
|
+
literal_slot: Optional[int] = literal_slot_map.get(literal_code)
|
|
82
|
+
if literal_slot is not None and literal_slot >= num_indicators:
|
|
83
|
+
params[literal_slot - num_indicators] = value
|
|
84
|
+
|
|
85
|
+
return top_node, slot_map, params
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def read_nnf(
|
|
89
|
+
input_stream,
|
|
90
|
+
*,
|
|
91
|
+
const_literals: Optional[Dict[int, ConstValue]] = None,
|
|
92
|
+
circuit: Optional[Circuit] = None,
|
|
93
|
+
check_header: bool = False,
|
|
94
|
+
optimise_ops: bool = True,
|
|
95
|
+
) -> Tuple[CircuitNode, Dict[int, int]]:
|
|
96
|
+
"""
|
|
97
|
+
Parse the input_stream (file or string) as "nnf" or "ac" file format describing a circuit.
|
|
98
|
+
|
|
99
|
+
The input consists of propositional logical sentences in negative normal form (NNF).
|
|
100
|
+
This covers both ".ac" and ".nnf" files produced by the software ACE.
|
|
101
|
+
|
|
102
|
+
The returned slot map will have entries to get circuit vars for literal codes.
|
|
103
|
+
E.g., given a literal code (int), use `circuit.vars[slot_map[literal_code]]` to get its var node.
|
|
104
|
+
|
|
105
|
+
Software may simplify an NNF file by removing arcs, but it may not update the header.
|
|
106
|
+
Although the resulting file is not conformant, it is still parsable.
|
|
107
|
+
Parameter `check_header` can be set to false, which prevents an exception being raised.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
input_stream: to parse, as per `ParserInput` argument.
|
|
111
|
+
const_literals: an optional mapping from literal code to constant value.
|
|
112
|
+
circuit: an optional empty circuit to reuse.
|
|
113
|
+
check_header: if true, an exception is raised if the number of nodes or arcs is not as expected.
|
|
114
|
+
optimise_ops: if true then circuit optimised operations will be used.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
(circuit_top, literal_slot_map)
|
|
118
|
+
circuit_top: is the resulting top node from parsing the input.
|
|
119
|
+
literal_slot_map: is a mapping from literal code (int) to a circuit var index (int).
|
|
120
|
+
|
|
121
|
+
Assumes:
|
|
122
|
+
If a circuit is provided, it is empty.
|
|
123
|
+
"""
|
|
124
|
+
if circuit is None:
|
|
125
|
+
circuit: Circuit = Circuit()
|
|
126
|
+
|
|
127
|
+
if circuit.number_of_vars != 0:
|
|
128
|
+
raise ValueError('the given circuit must be empty')
|
|
129
|
+
|
|
130
|
+
if const_literals is None:
|
|
131
|
+
const_literals: Dict[int, ConstValue] = {}
|
|
132
|
+
|
|
133
|
+
parser = CircuitParser(circuit, check_header, const_literals, optimise_ops)
|
|
134
|
+
parser.parse(input_stream)
|
|
135
|
+
|
|
136
|
+
nodes = parser.nodes
|
|
137
|
+
cct_top = circuit.zero if len(nodes) == 0 else nodes[-1]
|
|
138
|
+
|
|
139
|
+
return cct_top, parser.literal_slot_map
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class Parser(ABC):
|
|
143
|
+
|
|
144
|
+
def parse(self, input_stream):
|
|
145
|
+
input_stream = ParserInput(input_stream)
|
|
146
|
+
raise_f = lambda msg: input_stream.raise_error(msg)
|
|
147
|
+
try:
|
|
148
|
+
state = 0
|
|
149
|
+
line = input_stream.readline()
|
|
150
|
+
while line and state < 999:
|
|
151
|
+
line = line.strip()
|
|
152
|
+
if len(line) > 0:
|
|
153
|
+
if state == 0:
|
|
154
|
+
if line[0] == 'c':
|
|
155
|
+
self.comment(raise_f, line)
|
|
156
|
+
else:
|
|
157
|
+
line = line.split()
|
|
158
|
+
if line[0] == 'nnf':
|
|
159
|
+
if len(line) != 4:
|
|
160
|
+
raise_f('expected: nnf <num-nodes> <num-edges> <num-?>')
|
|
161
|
+
self.header(raise_f, int(line[1]), int(line[2]), int(line[3]))
|
|
162
|
+
state = 1
|
|
163
|
+
else:
|
|
164
|
+
raise_f('expected: "nnf"')
|
|
165
|
+
elif state == 1:
|
|
166
|
+
if line[0] == '%':
|
|
167
|
+
state = 999
|
|
168
|
+
else:
|
|
169
|
+
line = line.split()
|
|
170
|
+
code = line[0]
|
|
171
|
+
if code in VAR_NODE:
|
|
172
|
+
if len(line) != 2:
|
|
173
|
+
raise_f(f'expected: {code} <literal-code>')
|
|
174
|
+
self.literal(raise_f, int(line[1]))
|
|
175
|
+
else:
|
|
176
|
+
is_add = code in ADD_NODE
|
|
177
|
+
is_mul = code in MUL_NODE
|
|
178
|
+
if not (is_add or is_mul):
|
|
179
|
+
raise_f(f'unexpected line starting with: {code}')
|
|
180
|
+
if len(line) < 2:
|
|
181
|
+
raise_f(f'expected: {code} <num_args> <arguments>...')
|
|
182
|
+
num_args_idx = 2 if code == 'O' else 1
|
|
183
|
+
num_args = int(line[num_args_idx])
|
|
184
|
+
args = [int(arg) for arg in line[num_args_idx + 1:]]
|
|
185
|
+
if len(args) != num_args:
|
|
186
|
+
raise_f(f'unexpected number of args for: {code}')
|
|
187
|
+
if is_add:
|
|
188
|
+
self.add_node(raise_f, args)
|
|
189
|
+
else:
|
|
190
|
+
self.mul_node(raise_f, args)
|
|
191
|
+
line = input_stream.readline()
|
|
192
|
+
self.done(raise_f)
|
|
193
|
+
except ParseError as e:
|
|
194
|
+
raise e
|
|
195
|
+
except Exception as e:
|
|
196
|
+
input_stream.raise_error(e)
|
|
197
|
+
|
|
198
|
+
@abstractmethod
|
|
199
|
+
def comment(self, raise_f, message: str) -> None:
|
|
200
|
+
...
|
|
201
|
+
|
|
202
|
+
@abstractmethod
|
|
203
|
+
def header(self, raise_f, num_nodes: int, num_edges: int, num_: int):
|
|
204
|
+
...
|
|
205
|
+
|
|
206
|
+
@abstractmethod
|
|
207
|
+
def literal(self, raise_f, literal_code: int) -> None:
|
|
208
|
+
...
|
|
209
|
+
|
|
210
|
+
@abstractmethod
|
|
211
|
+
def add_node(self, raise_f, args: List[int]) -> None:
|
|
212
|
+
...
|
|
213
|
+
|
|
214
|
+
@abstractmethod
|
|
215
|
+
def mul_node(self, raise_f, args: List[int]) -> None:
|
|
216
|
+
...
|
|
217
|
+
|
|
218
|
+
@abstractmethod
|
|
219
|
+
def done(self, raise_f) -> None:
|
|
220
|
+
...
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class CircuitParser(Parser):
|
|
224
|
+
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
circuit: Circuit,
|
|
228
|
+
check_header: bool,
|
|
229
|
+
const_literals: Dict[int, ConstValue],
|
|
230
|
+
optimise_ops: bool,
|
|
231
|
+
):
|
|
232
|
+
self.check_header: bool = check_header
|
|
233
|
+
self.literal_slot_map: Dict[SlotKey, int] = {}
|
|
234
|
+
self.optimise_ops = optimise_ops
|
|
235
|
+
self.circuit = circuit
|
|
236
|
+
self.const_literals: Dict[int, ConstValue] = const_literals
|
|
237
|
+
self.nodes: List[CircuitNode] = []
|
|
238
|
+
self.num_nodes = None
|
|
239
|
+
self.num_edges = None
|
|
240
|
+
|
|
241
|
+
def comment(self, raise_f, message: str) -> None:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
def literal(self, raise_f, literal_code: int) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Makes either a VarNode or a ConstNode.
|
|
247
|
+
"""
|
|
248
|
+
const_value: Optional[ConstValue] = self.const_literals.get(literal_code)
|
|
249
|
+
if const_value is not None:
|
|
250
|
+
node: ConstNode = self.circuit.const(const_value)
|
|
251
|
+
elif literal_code in self.literal_slot_map.keys():
|
|
252
|
+
raise_f(f'duplicated literal code: {literal_code}')
|
|
253
|
+
return
|
|
254
|
+
else:
|
|
255
|
+
node: VarNode = self.circuit.new_var()
|
|
256
|
+
self.literal_slot_map[literal_code] = node.idx
|
|
257
|
+
self.nodes.append(node)
|
|
258
|
+
|
|
259
|
+
def add_node(self, raise_f, args: List[int]) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Makes a AddNode (or other if optimised).
|
|
262
|
+
"""
|
|
263
|
+
arg_nodes = [self.nodes[arg] for arg in args]
|
|
264
|
+
if self.optimise_ops:
|
|
265
|
+
self.nodes.append(self.circuit.optimised_add(arg_nodes))
|
|
266
|
+
else:
|
|
267
|
+
self.nodes.append(self.circuit.add(arg_nodes))
|
|
268
|
+
|
|
269
|
+
def mul_node(self, raise_f, args: List[int]) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Makes a MulNode (or other if optimised).
|
|
272
|
+
"""
|
|
273
|
+
arg_nodes = [self.nodes[arg] for arg in args]
|
|
274
|
+
if self.optimise_ops:
|
|
275
|
+
self.nodes.append(self.circuit.optimised_mul(arg_nodes))
|
|
276
|
+
else:
|
|
277
|
+
self.nodes.append(self.circuit.mul(arg_nodes))
|
|
278
|
+
|
|
279
|
+
def header(self, raise_f, num_nodes: int, num_edges: int, num_: int) -> None:
|
|
280
|
+
self.num_nodes = num_nodes
|
|
281
|
+
self.num_edges = num_edges
|
|
282
|
+
|
|
283
|
+
def done(self, raise_f) -> None:
|
|
284
|
+
if self.check_header:
|
|
285
|
+
if len(self.nodes) != self.num_nodes:
|
|
286
|
+
raise_f(f'unexpected number of nodes: {len(self.nodes)} expected: {self.num_nodes}')
|
|
287
|
+
if self.circuit.number_of_arcs != self.num_edges:
|
|
288
|
+
raise_f(f'unexpected number of arcs: {self.circuit.number_of_arcs} expected: {self.num_edges}')
|