compiled-knowledge 4.0.0a20__cp313-cp313-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Sequence, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
from ck.circuit_compiler import CircuitCompiler
|
|
7
|
+
from ck.pgm import RandomVariable
|
|
8
|
+
from ck.pgm_circuit import PGMCircuit
|
|
9
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
10
|
+
from ck.pgm_circuit.support.compile_circuit import compile_results, DEFAULT_CIRCUIT_COMPILER
|
|
11
|
+
from ck.probability.probability_space import ProbabilitySpace, Condition
|
|
12
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
13
|
+
from ck.program.raw_program import RawProgram
|
|
14
|
+
from ck.sampling.sampler import Sampler
|
|
15
|
+
from ck.sampling.sampler_support import SamplerInfo, get_sampler_info
|
|
16
|
+
from ck.sampling.uniform_sampler import UniformSampler
|
|
17
|
+
from ck.sampling.wmc_direct_sampler import WMCDirectSampler
|
|
18
|
+
from ck.sampling.wmc_gibbs_sampler import WMCGibbsSampler
|
|
19
|
+
from ck.sampling.wmc_metropolis_sampler import WMCMetropolisSampler
|
|
20
|
+
from ck.sampling.wmc_rejection_sampler import WMCRejectionSampler
|
|
21
|
+
from ck.utils.np_extras import NDArray
|
|
22
|
+
from ck.utils.random_extras import Random
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WMCProgram(ProgramWithSlotmap, ProbabilitySpace):
|
|
26
|
+
"""
|
|
27
|
+
A class for computing Weighted Model Count (WMC).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
pgm_circuit: PGMCircuit,
|
|
33
|
+
const_parameters: bool = True,
|
|
34
|
+
compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Construct a WMCProgram object.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
pgm_circuit: The circuit representing a PGM.
|
|
41
|
+
const_parameters: if True then any circuit variable representing a parameter value will
|
|
42
|
+
be made 'const' in the resulting program.
|
|
43
|
+
"""
|
|
44
|
+
raw_program: RawProgram = compile_results(
|
|
45
|
+
pgm_circuit=pgm_circuit,
|
|
46
|
+
results=(pgm_circuit.circuit_top,),
|
|
47
|
+
const_parameters=const_parameters,
|
|
48
|
+
compiler=compiler,
|
|
49
|
+
)
|
|
50
|
+
ProgramWithSlotmap.__init__(
|
|
51
|
+
self,
|
|
52
|
+
ProgramBuffer(raw_program),
|
|
53
|
+
pgm_circuit.slot_map,
|
|
54
|
+
pgm_circuit.rvs,
|
|
55
|
+
pgm_circuit.conditions,
|
|
56
|
+
)
|
|
57
|
+
self._raw_program: RawProgram = raw_program
|
|
58
|
+
self._number_of_indicators: int = pgm_circuit.number_of_indicators
|
|
59
|
+
self._z_cache: Optional[float] = None
|
|
60
|
+
|
|
61
|
+
if not const_parameters:
|
|
62
|
+
# set the parameter slots
|
|
63
|
+
self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
|
|
64
|
+
|
|
65
|
+
def wmc(self, *condition: Condition) -> float:
|
|
66
|
+
self.set_condition(*condition)
|
|
67
|
+
return self.compute().item()
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def z(self) -> float:
|
|
71
|
+
if self._z_cache is None:
|
|
72
|
+
number_of_indicators: int = self._number_of_indicators
|
|
73
|
+
slots: NDArray = self.vars
|
|
74
|
+
old_vals: NDArray = slots[:number_of_indicators].copy()
|
|
75
|
+
slots[:number_of_indicators] = 1
|
|
76
|
+
self._z_cache = self.compute().item()
|
|
77
|
+
slots[:number_of_indicators] = old_vals
|
|
78
|
+
|
|
79
|
+
return self._z_cache
|
|
80
|
+
|
|
81
|
+
def sample_uniform(
|
|
82
|
+
self,
|
|
83
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
84
|
+
*,
|
|
85
|
+
condition: Condition = (),
|
|
86
|
+
rand: Random = random,
|
|
87
|
+
) -> Sampler:
|
|
88
|
+
"""
|
|
89
|
+
Create a sampler that performs uniform sampling of
|
|
90
|
+
the state space of the given random variables, rvs.
|
|
91
|
+
|
|
92
|
+
The sampler will yield state lists, where the state
|
|
93
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
94
|
+
|
|
95
|
+
This sampler is not affected by and does not affect
|
|
96
|
+
the state of input slots.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
rvs: the list of random variables to sample; the
|
|
100
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
101
|
+
then the self.rvs are used; if rvs is a single
|
|
102
|
+
random variable, then single samples are yielded.
|
|
103
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
104
|
+
rand: provides the stream of random numbers.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
a Sampler object (UniformSampler).
|
|
108
|
+
"""
|
|
109
|
+
return UniformSampler(
|
|
110
|
+
rvs=(self.rvs if rvs is None else rvs),
|
|
111
|
+
condition=condition,
|
|
112
|
+
rand=rand,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def sample_direct(
|
|
116
|
+
self,
|
|
117
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
118
|
+
*,
|
|
119
|
+
condition: Condition = (),
|
|
120
|
+
rand: Random = random,
|
|
121
|
+
chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
|
|
122
|
+
initial_chain_condition: Condition = (),
|
|
123
|
+
) -> Sampler:
|
|
124
|
+
"""
|
|
125
|
+
Create an inverse-transform sampler, which uses the fact that
|
|
126
|
+
probabilities are exactly computable using a WMC.
|
|
127
|
+
|
|
128
|
+
The sampler will yield state lists, where the state
|
|
129
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
130
|
+
|
|
131
|
+
Given 'n' random variables, and 'm' number of indicators, for each yielded sample, this method:
|
|
132
|
+
* calls rand.random() once and rand.randrange(...) n times,
|
|
133
|
+
* calls self.program().compute_result() at least once and <= 1 + m.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
rvs: the list of random variables to sample; the
|
|
137
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
138
|
+
then the WMC rvs are used; if rvs is a single
|
|
139
|
+
random variable, then single samples are yielded.
|
|
140
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
141
|
+
rand: provides the stream of random numbers.
|
|
142
|
+
chain_pairs: is a collection of pairs of random variables, each random variable
|
|
143
|
+
must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
|
|
144
|
+
as a condition for to_rv prior to generating a sample.
|
|
145
|
+
initial_chain_condition: are condition indicators (just like condition)
|
|
146
|
+
for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
a Sampler object (WMCDirectSampler).
|
|
150
|
+
"""
|
|
151
|
+
sampler_info: SamplerInfo = get_sampler_info(
|
|
152
|
+
program_with_slotmap=self,
|
|
153
|
+
rvs=rvs,
|
|
154
|
+
condition=condition,
|
|
155
|
+
chain_pairs=chain_pairs,
|
|
156
|
+
initial_chain_condition=initial_chain_condition,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return WMCDirectSampler(
|
|
160
|
+
sampler_info=sampler_info,
|
|
161
|
+
raw_program=self._raw_program,
|
|
162
|
+
rand=rand,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def sample_rejection(
|
|
166
|
+
self,
|
|
167
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
168
|
+
*,
|
|
169
|
+
condition: Condition = (),
|
|
170
|
+
rand: Random = random,
|
|
171
|
+
) -> Sampler:
|
|
172
|
+
"""
|
|
173
|
+
Create a sampler to perform rejection sampling.
|
|
174
|
+
|
|
175
|
+
The sampler will yield state lists, where the state
|
|
176
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
177
|
+
|
|
178
|
+
The method uniformly samples states and uses an adaptive 'max weight'
|
|
179
|
+
to reduce unnecessary rejection.
|
|
180
|
+
|
|
181
|
+
After each sample is yielded, the WMC indicator variables will
|
|
182
|
+
be left set as per the yielded states of rvs and conditions.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
rvs: the list of random variables to sample; the
|
|
186
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
187
|
+
then the WMC rvs are used; if rvs is a single
|
|
188
|
+
random variable, then single samples are yielded.
|
|
189
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
190
|
+
rand: provides the stream of random numbers.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
a Sampler object (WMCRejectionSampler).
|
|
194
|
+
"""
|
|
195
|
+
sampler_info: SamplerInfo = get_sampler_info(
|
|
196
|
+
program_with_slotmap=self,
|
|
197
|
+
rvs=rvs,
|
|
198
|
+
condition=condition,
|
|
199
|
+
)
|
|
200
|
+
z = self.wmc(*condition)
|
|
201
|
+
|
|
202
|
+
return WMCRejectionSampler(
|
|
203
|
+
sampler_info=sampler_info,
|
|
204
|
+
raw_program=self._raw_program,
|
|
205
|
+
rand=rand,
|
|
206
|
+
z=z,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def sample_gibbs(
|
|
210
|
+
self,
|
|
211
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
212
|
+
*,
|
|
213
|
+
condition: Condition = (),
|
|
214
|
+
skip: int = 0,
|
|
215
|
+
burn_in: int = 0,
|
|
216
|
+
pr_restart: float = 0,
|
|
217
|
+
rand: Random = random,
|
|
218
|
+
) -> Sampler:
|
|
219
|
+
"""
|
|
220
|
+
Create a sampler to perform Gibbs sampling.
|
|
221
|
+
|
|
222
|
+
The sampler will yield state lists, where the state
|
|
223
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
224
|
+
|
|
225
|
+
After each sample is yielded, the WMC indicator vars will
|
|
226
|
+
be left set as per the yielded states of rvs and conditions.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
rvs: the list of random variables to sample; the
|
|
230
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
231
|
+
then the WMC rvs are used; if rvs is a single
|
|
232
|
+
random variable, then single samples are yielded.
|
|
233
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
234
|
+
skip: is an integer >= 0 specifying how may samples to discard
|
|
235
|
+
for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
|
|
236
|
+
burn_in: how many iterations to perform after
|
|
237
|
+
initialisation before yielding a sample.
|
|
238
|
+
pr_restart: the chance of re-initialising each
|
|
239
|
+
iteration. If restarted then burn-in is performed again.
|
|
240
|
+
rand: provides the stream of random numbers.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
a Sampler object (WMCGibbsSampler).
|
|
244
|
+
"""
|
|
245
|
+
if skip < 0:
|
|
246
|
+
raise RuntimeError('skip must be non-negative')
|
|
247
|
+
if burn_in < 0:
|
|
248
|
+
raise RuntimeError('burn_in must be non-negative')
|
|
249
|
+
|
|
250
|
+
sampler_info: SamplerInfo = get_sampler_info(
|
|
251
|
+
program_with_slotmap=self,
|
|
252
|
+
rvs=rvs,
|
|
253
|
+
condition=condition,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return WMCGibbsSampler(
|
|
257
|
+
sampler_info=sampler_info,
|
|
258
|
+
raw_program=self._raw_program,
|
|
259
|
+
rand=rand,
|
|
260
|
+
skip=skip,
|
|
261
|
+
burn_in=burn_in,
|
|
262
|
+
pr_restart=pr_restart,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def sample_metropolis(
|
|
266
|
+
self,
|
|
267
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
268
|
+
*,
|
|
269
|
+
condition: Condition = (),
|
|
270
|
+
skip: Optional[int] = None,
|
|
271
|
+
burn_in: int = 0,
|
|
272
|
+
pr_restart: float = 0,
|
|
273
|
+
rand: Random = random,
|
|
274
|
+
) -> Sampler:
|
|
275
|
+
"""
|
|
276
|
+
Create a sampler to perform Metropolis-Hastings sampling.
|
|
277
|
+
|
|
278
|
+
The sampler will yield state lists, where the state
|
|
279
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
280
|
+
|
|
281
|
+
After each sample is yielded, the WMC indicator vars will
|
|
282
|
+
be left set as per the yielded states of rvs and conditions.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
rvs: the list of random variables to sample; the
|
|
286
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
287
|
+
then the WMC rvs are used; if rvs is a single
|
|
288
|
+
random variable, then single samples are yielded.
|
|
289
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
290
|
+
skip: is an optional integer >= 0 specifying how may samples to discard
|
|
291
|
+
for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
|
|
292
|
+
Default value = len(rvs)
|
|
293
|
+
burn_in: how many iterations to perform after initialisation
|
|
294
|
+
before yielding a sample.
|
|
295
|
+
pr_restart: the chance of re-initialising each iteration. If
|
|
296
|
+
restarted then burn-in is performed again.
|
|
297
|
+
rand: provides the stream of random numbers.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
a Sampler object (WMCMetropolisSampler).
|
|
301
|
+
"""
|
|
302
|
+
if skip is not None and skip < 0:
|
|
303
|
+
raise RuntimeError('skip must be non-negative')
|
|
304
|
+
if burn_in < 0:
|
|
305
|
+
raise RuntimeError('burn_in must be non-negative')
|
|
306
|
+
|
|
307
|
+
sampler_info: SamplerInfo = get_sampler_info(
|
|
308
|
+
program_with_slotmap=self,
|
|
309
|
+
rvs=rvs,
|
|
310
|
+
condition=condition,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if skip is None:
|
|
314
|
+
skip = len(sampler_info.sample_rvs)
|
|
315
|
+
|
|
316
|
+
return WMCMetropolisSampler(
|
|
317
|
+
sampler_info=sampler_info,
|
|
318
|
+
raw_program=self._raw_program,
|
|
319
|
+
rand=rand,
|
|
320
|
+
skip=skip,
|
|
321
|
+
burn_in=burn_in,
|
|
322
|
+
pr_restart=pr_restart,
|
|
323
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .ace import compile_pgm, copy_ace_to_default_location, default_ace_location, ace_available
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from ck.circuit import CircuitNode, Circuit
|
|
11
|
+
from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
|
|
12
|
+
from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
|
|
13
|
+
from ck.in_out.render_net import render_bayesian_network
|
|
14
|
+
from ck.pgm import PGM
|
|
15
|
+
from ck.pgm_circuit import PGMCircuit
|
|
16
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
17
|
+
from ck.utils.local_config import config
|
|
18
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
19
|
+
from ck.utils.tmp_dir import tmp_dir
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def compile_pgm(
|
|
23
|
+
pgm: PGM,
|
|
24
|
+
const_parameters: bool = True,
|
|
25
|
+
*,
|
|
26
|
+
ace_dir: Optional[Path | str] = None,
|
|
27
|
+
jar_dir: Optional[Path | str] = None,
|
|
28
|
+
print_output: bool = False,
|
|
29
|
+
m_bytes: int = 1512,
|
|
30
|
+
check_is_bayesian_network: bool = True,
|
|
31
|
+
) -> PGMCircuit:
|
|
32
|
+
"""
|
|
33
|
+
Compile the PGM to an arithmetic circuit, using Ace.
|
|
34
|
+
|
|
35
|
+
This is a wrapper for Ace.
|
|
36
|
+
Ace compiles a Bayesian network into an Arithmetic Circuit.
|
|
37
|
+
Provided by the Automated Reasoning Group, University of California Los Angeles.
|
|
38
|
+
Ace requires the Java Runtime Environment (JRE) version 8 or higher.
|
|
39
|
+
See http://reasoning.cs.ucla.edu/ace/
|
|
40
|
+
|
|
41
|
+
Conforms to the `PGMCompiler` protocol.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
pgm: The PGM to compile.
|
|
45
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
46
|
+
constants, otherwise they will be circuit variables.
|
|
47
|
+
ace_dir: Directory containing Ace. If not provided then the directory this module is in is used.
|
|
48
|
+
jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
|
|
49
|
+
print_output: if true, the output from Ace is printed.
|
|
50
|
+
m_bytes: requested megabytes for the Java Virtual Machine (using the java "-Xmx" argument).
|
|
51
|
+
check_is_bayesian_network: if true, then the PGM will be checked to confirm it is a Bayesian network.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
a PGMCircuit object.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
RuntimeError: if Ace files are not found, including a helpful message.
|
|
58
|
+
ValueError: if `check_is_bayesian_network` is true and the PGM is not a Bayesian network.
|
|
59
|
+
CalledProcessError: if executing Ace failed.
|
|
60
|
+
"""
|
|
61
|
+
if check_is_bayesian_network and not pgm.check_is_bayesian_network():
|
|
62
|
+
raise ValueError('the given PGM is not a Bayesian network')
|
|
63
|
+
|
|
64
|
+
# ACE cannot deal with the empty PGM even though it is a valid Bayesian network
|
|
65
|
+
if pgm.number_of_factors == 0:
|
|
66
|
+
circuit = Circuit()
|
|
67
|
+
circuit.new_vars(pgm.number_of_indicators)
|
|
68
|
+
parameter_values = np.array([], dtype=np.float64)
|
|
69
|
+
slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
|
|
70
|
+
return PGMCircuit(
|
|
71
|
+
rvs=pgm.rvs,
|
|
72
|
+
conditions=(),
|
|
73
|
+
circuit_top=circuit.const(1),
|
|
74
|
+
number_of_indicators=pgm.number_of_indicators,
|
|
75
|
+
number_of_parameters=0,
|
|
76
|
+
slot_map=slot_map,
|
|
77
|
+
parameter_values=parameter_values,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
java: str
|
|
81
|
+
classpath_separator: str
|
|
82
|
+
java, classpath_separator = _find_java()
|
|
83
|
+
files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
|
|
84
|
+
net_file_name = 'to_compile.net'
|
|
85
|
+
main_class = 'edu.ucla.belief.ace.AceCompile'
|
|
86
|
+
class_path: str = classpath_separator.join(
|
|
87
|
+
str(f) for f in [files.ace_jar, files.inflib_jar, files.jdom_jar]
|
|
88
|
+
)
|
|
89
|
+
ace_cmd: List[str] = [
|
|
90
|
+
java,
|
|
91
|
+
'-cp',
|
|
92
|
+
class_path,
|
|
93
|
+
f'-DACEC2D={files.c2d}',
|
|
94
|
+
f'-Xmx{int(m_bytes)}m',
|
|
95
|
+
main_class,
|
|
96
|
+
net_file_name,
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
with tmp_dir():
|
|
100
|
+
# Render the PGM to a .net file to be read by Ace
|
|
101
|
+
with open(net_file_name, 'w') as file:
|
|
102
|
+
node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
|
|
103
|
+
|
|
104
|
+
# Run Ace
|
|
105
|
+
ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=(not print_output), text=True)
|
|
106
|
+
if ace_result.returncode != 0:
|
|
107
|
+
raise subprocess.CalledProcessError(
|
|
108
|
+
returncode=ace_result.returncode,
|
|
109
|
+
cmd=' '.join(ace_cmd),
|
|
110
|
+
output=None if print_output else ace_result.stdout,
|
|
111
|
+
stderr=None if print_output else ace_result.stderr,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Parse the literal map output from Ace
|
|
115
|
+
with open(f'{net_file_name}.lmap', 'r') as file:
|
|
116
|
+
literal_map: LiteralMap = read_lmap(file, node_names=node_names)
|
|
117
|
+
|
|
118
|
+
# Parse the arithmetic circuit output from Ace
|
|
119
|
+
with open(f'{net_file_name}.ac', 'r') as file:
|
|
120
|
+
circuit_top: CircuitNode
|
|
121
|
+
slot_map: SlotMap
|
|
122
|
+
parameter_values: NDArrayFloat64
|
|
123
|
+
circuit_top, slot_map, parameter_values = read_nnf_with_literal_map(
|
|
124
|
+
file,
|
|
125
|
+
indicators=pgm.indicators,
|
|
126
|
+
literal_map=literal_map,
|
|
127
|
+
const_parameters=const_parameters,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Consistency checking
|
|
131
|
+
number_of_indicators: int = pgm.number_of_indicators
|
|
132
|
+
number_of_parameters: int = parameter_values.shape[0]
|
|
133
|
+
assert circuit_top.circuit.number_of_vars == number_of_indicators + number_of_parameters, 'consistency check'
|
|
134
|
+
|
|
135
|
+
return PGMCircuit(
|
|
136
|
+
rvs=pgm.rvs,
|
|
137
|
+
conditions=(),
|
|
138
|
+
circuit_top=circuit_top,
|
|
139
|
+
number_of_indicators=number_of_indicators,
|
|
140
|
+
number_of_parameters=number_of_parameters,
|
|
141
|
+
slot_map=slot_map,
|
|
142
|
+
parameter_values=parameter_values,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def ace_available(
|
|
147
|
+
ace_dir: Optional[Path | str] = None,
|
|
148
|
+
jar_dir: Optional[Path | str] = None,
|
|
149
|
+
) -> bool:
|
|
150
|
+
"""
|
|
151
|
+
Returns:
|
|
152
|
+
True if it looks like ACE is available, False otherwise.
|
|
153
|
+
ACE is available if ACE files are in the default location and Java is available.
|
|
154
|
+
"""
|
|
155
|
+
try:
|
|
156
|
+
java: str
|
|
157
|
+
java, _ = _find_java()
|
|
158
|
+
_: _AceFiles = _find_ace_files(ace_dir, jar_dir)
|
|
159
|
+
|
|
160
|
+
java_cmd: List[str] = [java, '--version',]
|
|
161
|
+
java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
|
|
162
|
+
|
|
163
|
+
return java_result.returncode == 0
|
|
164
|
+
|
|
165
|
+
except RuntimeError:
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def copy_ace_to_default_location(
|
|
170
|
+
ace_dir: Path | str,
|
|
171
|
+
jar_dir: Optional[Path | str] = None,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Copy Ace files from the given directories into the default directory.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
ace_dir: Directory containing Ace.
|
|
178
|
+
jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
RuntimeError: if Ace files are not found, including a helpful message .
|
|
182
|
+
IOError: if the copy fails.
|
|
183
|
+
|
|
184
|
+
Assumes:
|
|
185
|
+
ace_dir exists and is not the same as the installation directory.
|
|
186
|
+
"""
|
|
187
|
+
install_location: Path = default_ace_location()
|
|
188
|
+
|
|
189
|
+
if ace_dir is None or ace_dir == install_location:
|
|
190
|
+
raise RuntimeError(f'Ace directory cannot be the default directory')
|
|
191
|
+
|
|
192
|
+
files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
|
|
193
|
+
|
|
194
|
+
to_copy = [files.ace_jar, files.inflib_jar, files.jdom_jar] + files.c2d_options
|
|
195
|
+
|
|
196
|
+
for file in to_copy:
|
|
197
|
+
shutil.copyfile(file, install_location / file.name)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def default_ace_location() -> Path:
|
|
201
|
+
"""
|
|
202
|
+
Get the default location for Ace files.
|
|
203
|
+
|
|
204
|
+
This function checks the local config for the variable
|
|
205
|
+
CK_ACE_LOCATION. If that is not available, then the
|
|
206
|
+
directory that this Python module is in will be used.
|
|
207
|
+
"""
|
|
208
|
+
return Path(config.get('CK_ACE_LOCATION', Path(__file__).parent))
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class _AceFiles:
|
|
213
|
+
ace_jar: Path
|
|
214
|
+
inflib_jar: Path
|
|
215
|
+
jdom_jar: Path
|
|
216
|
+
c2d: Path
|
|
217
|
+
c2d_options: List[Path]
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _find_java() -> Tuple[str, str]:
|
|
221
|
+
"""
|
|
222
|
+
What to call the Java executable and classpath separator.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
(java, classpath_separator)
|
|
226
|
+
|
|
227
|
+
Raises:
|
|
228
|
+
RuntimeError: if not found, including a helpful message.
|
|
229
|
+
"""
|
|
230
|
+
if sys.platform == 'win32':
|
|
231
|
+
return 'java.exe', ';'
|
|
232
|
+
elif sys.platform == 'darwin':
|
|
233
|
+
return 'java', ':'
|
|
234
|
+
elif sys.platform.startswith('linux'):
|
|
235
|
+
return 'java', ':'
|
|
236
|
+
else:
|
|
237
|
+
raise RuntimeError(f'cannot infer java for platform {sys.platform!r}')
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _find_ace_files(
|
|
241
|
+
ace_dir: Optional[Path | str],
|
|
242
|
+
jar_dir: Optional[Path | str],
|
|
243
|
+
) -> _AceFiles:
|
|
244
|
+
"""
|
|
245
|
+
Look for the needed Ace files.
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
RuntimeError: if not found, including a helpful message.
|
|
249
|
+
"""
|
|
250
|
+
ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
|
|
251
|
+
jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
|
|
252
|
+
|
|
253
|
+
if not ace_dir.is_dir():
|
|
254
|
+
raise RuntimeError(f'Ace directory does not exist: {ace_dir}')
|
|
255
|
+
if not jar_dir.is_dir():
|
|
256
|
+
raise RuntimeError(f'Ace jar directory does not exist: {jar_dir}')
|
|
257
|
+
|
|
258
|
+
ace_jar = jar_dir / 'ace.jar'
|
|
259
|
+
inflib_jar = jar_dir / 'inflib.jar'
|
|
260
|
+
jdom_jar = jar_dir / 'jdom.jar'
|
|
261
|
+
|
|
262
|
+
missing: List[str] = [
|
|
263
|
+
jar.name
|
|
264
|
+
for jar in [ace_jar, inflib_jar, jdom_jar]
|
|
265
|
+
if not jar.is_file()
|
|
266
|
+
]
|
|
267
|
+
if len(missing) > 0:
|
|
268
|
+
raise RuntimeError(f'Ace jars missing (ensure Ace is properly installed): {", ".join(missing)}')
|
|
269
|
+
|
|
270
|
+
c2d_options: List[Path] = [
|
|
271
|
+
file
|
|
272
|
+
for file in ace_dir.iterdir()
|
|
273
|
+
if file.is_file() and file.name.startswith('c2d')
|
|
274
|
+
]
|
|
275
|
+
c2d: Path
|
|
276
|
+
if len(c2d_options) == 0:
|
|
277
|
+
raise RuntimeError(f'cannot find c2d in the Ace directory: {ace_dir}')
|
|
278
|
+
if len(c2d_options) == 1:
|
|
279
|
+
c2d = next(iter(c2d_options))
|
|
280
|
+
else:
|
|
281
|
+
if sys.platform == 'win32':
|
|
282
|
+
c2d = ace_dir / 'c2d_windows.exe'
|
|
283
|
+
elif sys.platform == 'darwin':
|
|
284
|
+
c2d = ace_dir / 'c2d_osx'
|
|
285
|
+
elif sys.platform.startswith('linux'):
|
|
286
|
+
c2d = ace_dir / 'c2d_linux'
|
|
287
|
+
else:
|
|
288
|
+
raise RuntimeError(f'cannot infer c2d executable name for platform {sys.platform!r}')
|
|
289
|
+
|
|
290
|
+
if not c2d.is_file():
|
|
291
|
+
raise RuntimeError(f'cannot find c2d: {c2d}')
|
|
292
|
+
|
|
293
|
+
return _AceFiles(
|
|
294
|
+
c2d=c2d,
|
|
295
|
+
c2d_options=c2d_options,
|
|
296
|
+
ace_jar=ace_jar,
|
|
297
|
+
inflib_jar=inflib_jar,
|
|
298
|
+
jdom_jar=jdom_jar,
|
|
299
|
+
)
|