compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,323 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Sequence, Optional, Tuple
5
+
6
+ from ck.circuit_compiler import CircuitCompiler
7
+ from ck.pgm import RandomVariable
8
+ from ck.pgm_circuit import PGMCircuit
9
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
10
+ from ck.pgm_circuit.support.compile_circuit import compile_results, DEFAULT_CIRCUIT_COMPILER
11
+ from ck.probability.probability_space import ProbabilitySpace, Condition
12
+ from ck.program.program_buffer import ProgramBuffer
13
+ from ck.program.raw_program import RawProgram
14
+ from ck.sampling.sampler import Sampler
15
+ from ck.sampling.sampler_support import SamplerInfo, get_sampler_info
16
+ from ck.sampling.uniform_sampler import UniformSampler
17
+ from ck.sampling.wmc_direct_sampler import WMCDirectSampler
18
+ from ck.sampling.wmc_gibbs_sampler import WMCGibbsSampler
19
+ from ck.sampling.wmc_metropolis_sampler import WMCMetropolisSampler
20
+ from ck.sampling.wmc_rejection_sampler import WMCRejectionSampler
21
+ from ck.utils.np_extras import NDArray
22
+ from ck.utils.random_extras import Random
23
+
24
+
25
+ class WMCProgram(ProgramWithSlotmap, ProbabilitySpace):
26
+ """
27
+ A class for computing Weighted Model Count (WMC).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ pgm_circuit: PGMCircuit,
33
+ const_parameters: bool = True,
34
+ compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
35
+ ):
36
+ """
37
+ Construct a WMCProgram object.
38
+
39
+ Args:
40
+ pgm_circuit: The circuit representing a PGM.
41
+ const_parameters: if True then any circuit variable representing a parameter value will
42
+ be made 'const' in the resulting program.
43
+ """
44
+ raw_program: RawProgram = compile_results(
45
+ pgm_circuit=pgm_circuit,
46
+ results=(pgm_circuit.circuit_top,),
47
+ const_parameters=const_parameters,
48
+ compiler=compiler,
49
+ )
50
+ ProgramWithSlotmap.__init__(
51
+ self,
52
+ ProgramBuffer(raw_program),
53
+ pgm_circuit.slot_map,
54
+ pgm_circuit.rvs,
55
+ pgm_circuit.conditions,
56
+ )
57
+ self._raw_program: RawProgram = raw_program
58
+ self._number_of_indicators: int = pgm_circuit.number_of_indicators
59
+ self._z_cache: Optional[float] = None
60
+
61
+ if not const_parameters:
62
+ # set the parameter slots
63
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
64
+
65
+ def wmc(self, *condition: Condition) -> float:
66
+ self.set_condition(*condition)
67
+ return self.compute().item()
68
+
69
+ @property
70
+ def z(self) -> float:
71
+ if self._z_cache is None:
72
+ number_of_indicators: int = self._number_of_indicators
73
+ slots: NDArray = self.vars
74
+ old_vals: NDArray = slots[:number_of_indicators].copy()
75
+ slots[:number_of_indicators] = 1
76
+ self._z_cache = self.compute().item()
77
+ slots[:number_of_indicators] = old_vals
78
+
79
+ return self._z_cache
80
+
81
+ def sample_uniform(
82
+ self,
83
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
84
+ *,
85
+ condition: Condition = (),
86
+ rand: Random = random,
87
+ ) -> Sampler:
88
+ """
89
+ Create a sampler that performs uniform sampling of
90
+ the state space of the given random variables, rvs.
91
+
92
+ The sampler will yield state lists, where the state
93
+ values are co-indexed with rvs, or self.rvs if rvs is None.
94
+
95
+ This sampler is not affected by and does not affect
96
+ the state of input slots.
97
+
98
+ Args:
99
+ rvs: the list of random variables to sample; the
100
+ yielded state vectors are co-indexed with rvs; if None,
101
+ then the self.rvs are used; if rvs is a single
102
+ random variable, then single samples are yielded.
103
+ condition: is a collection of zero or more conditioning indicators.
104
+ rand: provides the stream of random numbers.
105
+
106
+ Returns:
107
+ a Sampler object (UniformSampler).
108
+ """
109
+ return UniformSampler(
110
+ rvs=(self.rvs if rvs is None else rvs),
111
+ condition=condition,
112
+ rand=rand,
113
+ )
114
+
115
+ def sample_direct(
116
+ self,
117
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
118
+ *,
119
+ condition: Condition = (),
120
+ rand: Random = random,
121
+ chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
122
+ initial_chain_condition: Condition = (),
123
+ ) -> Sampler:
124
+ """
125
+ Create an inverse-transform sampler, which uses the fact that
126
+ probabilities are exactly computable using a WMC.
127
+
128
+ The sampler will yield state lists, where the state
129
+ values are co-indexed with rvs, or self.rvs if rvs is None.
130
+
131
+ Given 'n' random variables, and 'm' number of indicators, for each yielded sample, this method:
132
+ * calls rand.random() once and rand.randrange(...) n times,
133
+ * calls self.program().compute_result() at least once and <= 1 + m.
134
+
135
+ Args:
136
+ rvs: the list of random variables to sample; the
137
+ yielded state vectors are co-indexed with rvs; if None,
138
+ then the WMC rvs are used; if rvs is a single
139
+ random variable, then single samples are yielded.
140
+ condition: is a collection of zero or more conditioning indicators.
141
+ rand: provides the stream of random numbers.
142
+ chain_pairs: is a collection of pairs of random variables, each random variable
143
+ must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
144
+ as a condition for to_rv prior to generating a sample.
145
+ initial_chain_condition: are condition indicators (just like condition)
146
+ for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
147
+
148
+ Returns:
149
+ a Sampler object (WMCDirectSampler).
150
+ """
151
+ sampler_info: SamplerInfo = get_sampler_info(
152
+ program_with_slotmap=self,
153
+ rvs=rvs,
154
+ condition=condition,
155
+ chain_pairs=chain_pairs,
156
+ initial_chain_condition=initial_chain_condition,
157
+ )
158
+
159
+ return WMCDirectSampler(
160
+ sampler_info=sampler_info,
161
+ raw_program=self._raw_program,
162
+ rand=rand,
163
+ )
164
+
165
+ def sample_rejection(
166
+ self,
167
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
168
+ *,
169
+ condition: Condition = (),
170
+ rand: Random = random,
171
+ ) -> Sampler:
172
+ """
173
+ Create a sampler to perform rejection sampling.
174
+
175
+ The sampler will yield state lists, where the state
176
+ values are co-indexed with rvs, or self.rvs if rvs is None.
177
+
178
+ The method uniformly samples states and uses an adaptive 'max weight'
179
+ to reduce unnecessary rejection.
180
+
181
+ After each sample is yielded, the WMC indicator variables will
182
+ be left set as per the yielded states of rvs and conditions.
183
+
184
+ Args:
185
+ rvs: the list of random variables to sample; the
186
+ yielded state vectors are co-indexed with rvs; if None,
187
+ then the WMC rvs are used; if rvs is a single
188
+ random variable, then single samples are yielded.
189
+ condition: is a collection of zero or more conditioning indicators.
190
+ rand: provides the stream of random numbers.
191
+
192
+ Returns:
193
+ a Sampler object (WMCRejectionSampler).
194
+ """
195
+ sampler_info: SamplerInfo = get_sampler_info(
196
+ program_with_slotmap=self,
197
+ rvs=rvs,
198
+ condition=condition,
199
+ )
200
+ z = self.wmc(*condition)
201
+
202
+ return WMCRejectionSampler(
203
+ sampler_info=sampler_info,
204
+ raw_program=self._raw_program,
205
+ rand=rand,
206
+ z=z,
207
+ )
208
+
209
+ def sample_gibbs(
210
+ self,
211
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
212
+ *,
213
+ condition: Condition = (),
214
+ skip: int = 0,
215
+ burn_in: int = 0,
216
+ pr_restart: float = 0,
217
+ rand: Random = random,
218
+ ) -> Sampler:
219
+ """
220
+ Create a sampler to perform Gibbs sampling.
221
+
222
+ The sampler will yield state lists, where the state
223
+ values are co-indexed with rvs, or self.rvs if rvs is None.
224
+
225
+ After each sample is yielded, the WMC indicator vars will
226
+ be left set as per the yielded states of rvs and conditions.
227
+
228
+ Args:
229
+ rvs: the list of random variables to sample; the
230
+ yielded state vectors are co-indexed with rvs; if None,
231
+ then the WMC rvs are used; if rvs is a single
232
+ random variable, then single samples are yielded.
233
+ condition: is a collection of zero or more conditioning indicators.
234
+ skip: is an integer >= 0 specifying how may samples to discard
235
+ for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
236
+ burn_in: how many iterations to perform after
237
+ initialisation before yielding a sample.
238
+ pr_restart: the chance of re-initialising each
239
+ iteration. If restarted then burn-in is performed again.
240
+ rand: provides the stream of random numbers.
241
+
242
+ Returns:
243
+ a Sampler object (WMCGibbsSampler).
244
+ """
245
+ if skip < 0:
246
+ raise RuntimeError('skip must be non-negative')
247
+ if burn_in < 0:
248
+ raise RuntimeError('burn_in must be non-negative')
249
+
250
+ sampler_info: SamplerInfo = get_sampler_info(
251
+ program_with_slotmap=self,
252
+ rvs=rvs,
253
+ condition=condition,
254
+ )
255
+
256
+ return WMCGibbsSampler(
257
+ sampler_info=sampler_info,
258
+ raw_program=self._raw_program,
259
+ rand=rand,
260
+ skip=skip,
261
+ burn_in=burn_in,
262
+ pr_restart=pr_restart,
263
+ )
264
+
265
+ def sample_metropolis(
266
+ self,
267
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
268
+ *,
269
+ condition: Condition = (),
270
+ skip: Optional[int] = None,
271
+ burn_in: int = 0,
272
+ pr_restart: float = 0,
273
+ rand: Random = random,
274
+ ) -> Sampler:
275
+ """
276
+ Create a sampler to perform Metropolis-Hastings sampling.
277
+
278
+ The sampler will yield state lists, where the state
279
+ values are co-indexed with rvs, or self.rvs if rvs is None.
280
+
281
+ After each sample is yielded, the WMC indicator vars will
282
+ be left set as per the yielded states of rvs and conditions.
283
+
284
+ Args:
285
+ rvs: the list of random variables to sample; the
286
+ yielded state vectors are co-indexed with rvs; if None,
287
+ then the WMC rvs are used; if rvs is a single
288
+ random variable, then single samples are yielded.
289
+ condition: is a collection of zero or more conditioning indicators.
290
+ skip: is an optional integer >= 0 specifying how may samples to discard
291
+ for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
292
+ Default value = len(rvs)
293
+ burn_in: how many iterations to perform after initialisation
294
+ before yielding a sample.
295
+ pr_restart: the chance of re-initialising each iteration. If
296
+ restarted then burn-in is performed again.
297
+ rand: provides the stream of random numbers.
298
+
299
+ Returns:
300
+ a Sampler object (WMCMetropolisSampler).
301
+ """
302
+ if skip is not None and skip < 0:
303
+ raise RuntimeError('skip must be non-negative')
304
+ if burn_in < 0:
305
+ raise RuntimeError('burn_in must be non-negative')
306
+
307
+ sampler_info: SamplerInfo = get_sampler_info(
308
+ program_with_slotmap=self,
309
+ rvs=rvs,
310
+ condition=condition,
311
+ )
312
+
313
+ if skip is None:
314
+ skip = len(sampler_info.sample_rvs)
315
+
316
+ return WMCMetropolisSampler(
317
+ sampler_info=sampler_info,
318
+ raw_program=self._raw_program,
319
+ rand=rand,
320
+ skip=skip,
321
+ burn_in=burn_in,
322
+ pr_restart=pr_restart,
323
+ )
@@ -0,0 +1,2 @@
1
+ from .pgm_compiler import PGMCompiler
2
+ from .named_pgm_compilers import NamedPGMCompiler, DEFAULT_PGM_COMPILER
@@ -0,0 +1 @@
1
+ from .ace import compile_pgm, copy_ace_to_default_location, default_ace_location
@@ -0,0 +1,252 @@
1
+ import shutil
2
+ import subprocess
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional, List, Tuple
7
+
8
+ from ck.circuit import CircuitNode
9
+ from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
10
+ from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
11
+ from ck.in_out.render_net import render_bayesian_network
12
+ from ck.pgm import PGM
13
+ from ck.pgm_circuit import PGMCircuit
14
+ from ck.pgm_circuit.slot_map import SlotMap
15
+ from ck.utils.np_extras import NDArrayFloat64
16
+ from ck.utils.tmp_dir import tmp_dir
17
+
18
+
19
+ def compile_pgm(
20
+ pgm: PGM,
21
+ const_parameters: bool = True,
22
+ *,
23
+ ace_dir: Optional[Path | str] = None,
24
+ jar_dir: Optional[Path | str] = None,
25
+ print_output: bool = False,
26
+ m_bytes: int = 1512,
27
+ check_is_bayesian_network: bool = True,
28
+ ) -> PGMCircuit:
29
+ """
30
+ Compile the PGM to an arithmetic circuit, using Ace.
31
+
32
+ This is a wrapper for Ace.
33
+ Ace compiles a Bayesian network into an Arithmetic Circuit.
34
+ Provided by the Automated Reasoning Group, University of California Los Angeles.
35
+ Ace requires the Java Runtime Environment (JRE) version 8 or higher.
36
+ See http://reasoning.cs.ucla.edu/ace/
37
+
38
+ Conforms to the `PGMCompiler` protocol.
39
+
40
+ Args:
41
+ pgm: The PGM to compile.
42
+ const_parameters: If true, the potential function parameters will be circuit
43
+ constants, otherwise they will be circuit variables.
44
+ ace_dir: Directory containing Ace. If not provided then the directory this module is in is used.
45
+ jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
46
+ print_output: if true, the output from Ace is printed.
47
+ m_bytes: requested megabytes for the Java Virtual Machine (using the java "-Xmx" argument).
48
+ check_is_bayesian_network: if true, then the PGM will be checked to confirm it is a Bayesian network.
49
+
50
+ Returns:
51
+ a PGMCircuit object.
52
+
53
+ Raises:
54
+ RuntimeError: if Ace files are not found, including a helpful message.
55
+ ValueError: if `check_is_bayesian_network` is true and the PGM is not a Bayesian network.
56
+ CalledProcessError: if executing Ace failed.
57
+ """
58
+ if check_is_bayesian_network and not pgm.check_is_bayesian_network():
59
+ raise ValueError('the given PGM is not a Bayesian network')
60
+
61
+ java: str
62
+ classpath_separator: str
63
+ java, classpath_separator = _find_java()
64
+ files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
65
+ net_file_name = 'to_compile.net'
66
+ main_class = 'edu.ucla.belief.ace.AceCompile'
67
+ class_path: str = classpath_separator.join(
68
+ str(f) for f in [files.ace_jar, files.inflib_jar, files.jdom_jar]
69
+ )
70
+ ace_cmd: List[str] = [
71
+ java,
72
+ f'-cp',
73
+ class_path,
74
+ f'-DACEC2D={files.c2d}',
75
+ f'-Xmx{int(m_bytes)}m',
76
+ main_class,
77
+ net_file_name,
78
+ ]
79
+
80
+ with tmp_dir():
81
+ # Render the PGM to a .net file to be read by Ace
82
+ with open(net_file_name, 'w') as file:
83
+ node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
84
+
85
+ # Run Ace
86
+ ace_result = subprocess.run(ace_cmd, capture_output=True, text=True)
87
+ if print_output:
88
+ print(ace_result.stdout)
89
+ print(ace_result.stderr)
90
+ if ace_result.returncode != 0:
91
+ raise subprocess.CalledProcessError(
92
+ returncode=ace_result.returncode,
93
+ cmd=' '.join(ace_cmd),
94
+ output=ace_result.stdout,
95
+ stderr=ace_result.stderr,
96
+ )
97
+
98
+ # Parse the literal map output from Ace
99
+ with open(f'{net_file_name}.lmap', 'r') as file:
100
+ literal_map: LiteralMap = read_lmap(file, node_names=node_names)
101
+
102
+ # Parse the arithmetic circuit output from Ace
103
+ with open(f'{net_file_name}.ac', 'r') as file:
104
+ circuit_top: CircuitNode
105
+ slot_map: SlotMap
106
+ parameter_values: NDArrayFloat64
107
+ circuit_top, slot_map, parameter_values = read_nnf_with_literal_map(
108
+ file,
109
+ literal_map=literal_map,
110
+ const_parameters=const_parameters
111
+ )
112
+
113
+ # Consistency checking
114
+ number_of_indicators: int = pgm.number_of_indicators
115
+ number_of_parameters: int = parameter_values.shape[0]
116
+ assert circuit_top.circuit.number_of_vars == number_of_indicators + number_of_parameters, 'consistency check'
117
+
118
+ return PGMCircuit(
119
+ rvs=pgm.rvs,
120
+ conditions=(),
121
+ circuit_top=circuit_top,
122
+ number_of_indicators=number_of_indicators,
123
+ number_of_parameters=number_of_parameters,
124
+ slot_map=slot_map,
125
+ parameter_values=parameter_values,
126
+ )
127
+
128
+
129
+ def copy_ace_to_default_location(
130
+ ace_dir: Path | str,
131
+ jar_dir: Optional[Path | str] = None,
132
+ ) -> None:
133
+ """
134
+ Copy Ace files from the given directories into the default directory.
135
+
136
+ Args:
137
+ ace_dir: Directory containing Ace.
138
+ jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
139
+
140
+ Raises:
141
+ RuntimeError: if Ace files are not found, including a helpful message .
142
+ IOError: if the copy fails.
143
+
144
+ Assumes:
145
+ ace_dir exists and is not the same as the installation directory.
146
+ """
147
+ install_location: Path = default_ace_location()
148
+
149
+ if ace_dir is None or ace_dir == install_location:
150
+ raise RuntimeError(f'Ace directory cannot be the default directory')
151
+
152
+ files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
153
+
154
+ to_copy = [files.ace_jar, files.inflib_jar, files.jdom_jar] + files.c2d_options
155
+
156
+ for file in to_copy:
157
+ shutil.copyfile(file, install_location / file.name)
158
+
159
+
160
+ def default_ace_location() -> Path:
161
+ """
162
+ Get the default location for Ace files.
163
+ """
164
+ return Path(__file__).parent
165
+
166
+
167
+ @dataclass
168
+ class _AceFiles:
169
+ ace_jar: Path
170
+ inflib_jar: Path
171
+ jdom_jar: Path
172
+ c2d: Path
173
+ c2d_options: List[Path]
174
+
175
+
176
+ def _find_java() -> Tuple[str, str]:
177
+ """
178
+ What to call the Java executable and classpath separator.
179
+
180
+ Returns:
181
+ (java, classpath_separator)
182
+ """
183
+ if sys.platform == 'win32':
184
+ return 'java.exe', ';'
185
+ elif sys.platform == 'darwin':
186
+ return 'java', ':'
187
+ elif sys.platform.startswith('linux'):
188
+ return 'java', ':'
189
+ else:
190
+ raise RuntimeError(f'cannot infer java for platform {sys.platform!r}')
191
+
192
+
193
+ def _find_ace_files(
194
+ ace_dir: Optional[Path | str],
195
+ jar_dir: Optional[Path | str],
196
+ ) -> _AceFiles:
197
+ """
198
+ Look for the needed Ace files.
199
+
200
+ Raises:
201
+ RuntimeError: if not found, including a helpful message .
202
+ """
203
+ ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
204
+ jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
205
+
206
+ if not ace_dir.is_dir():
207
+ raise RuntimeError(f'Ace directory does not exist: {ace_dir}')
208
+ if not jar_dir.is_dir():
209
+ raise RuntimeError(f'Ace jar directory does not exist: {jar_dir}')
210
+
211
+ ace_jar = jar_dir / 'ace.jar'
212
+ inflib_jar = jar_dir / 'inflib.jar'
213
+ jdom_jar = jar_dir / 'jdom.jar'
214
+
215
+ missing: List[str] = [
216
+ jar.name
217
+ for jar in [ace_jar, inflib_jar, jdom_jar]
218
+ if not jar.is_file()
219
+ ]
220
+ if len(missing) > 0:
221
+ raise RuntimeError(f'Ace jars missing (ensure Ace is properly installed): {", ".join(missing)}')
222
+
223
+ c2d_options: List[Path] = [
224
+ file
225
+ for file in ace_dir.iterdir()
226
+ if file.is_file() and file.name.startswith('c2d')
227
+ ]
228
+ c2d: Path
229
+ if len(c2d_options) == 0:
230
+ raise RuntimeError(f'cannot find c2d in the Ace directory: {ace_dir}')
231
+ if len(c2d_options) == 1:
232
+ c2d = next(iter(c2d_options))
233
+ else:
234
+ if sys.platform == 'win32':
235
+ c2d = ace_dir / 'c2d_windows.exe'
236
+ elif sys.platform == 'darwin':
237
+ c2d = ace_dir / 'c2d_osx'
238
+ elif sys.platform.startswith('linux'):
239
+ c2d = ace_dir / 'c2d_linux'
240
+ else:
241
+ raise RuntimeError(f'cannot infer c2d executable name for platform {sys.platform!r}')
242
+
243
+ if not c2d.is_file():
244
+ raise RuntimeError(f'cannot find c2d: {c2d}')
245
+
246
+ return _AceFiles(
247
+ c2d=c2d,
248
+ c2d_options=c2d_options,
249
+ ace_jar=ace_jar,
250
+ inflib_jar=inflib_jar,
251
+ jdom_jar=jdom_jar,
252
+ )