compiled-knowledge 4.0.0a20__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,323 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Sequence, Optional, Tuple
5
+
6
+ from ck.circuit_compiler import CircuitCompiler
7
+ from ck.pgm import RandomVariable
8
+ from ck.pgm_circuit import PGMCircuit
9
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
10
+ from ck.pgm_circuit.support.compile_circuit import compile_results, DEFAULT_CIRCUIT_COMPILER
11
+ from ck.probability.probability_space import ProbabilitySpace, Condition
12
+ from ck.program.program_buffer import ProgramBuffer
13
+ from ck.program.raw_program import RawProgram
14
+ from ck.sampling.sampler import Sampler
15
+ from ck.sampling.sampler_support import SamplerInfo, get_sampler_info
16
+ from ck.sampling.uniform_sampler import UniformSampler
17
+ from ck.sampling.wmc_direct_sampler import WMCDirectSampler
18
+ from ck.sampling.wmc_gibbs_sampler import WMCGibbsSampler
19
+ from ck.sampling.wmc_metropolis_sampler import WMCMetropolisSampler
20
+ from ck.sampling.wmc_rejection_sampler import WMCRejectionSampler
21
+ from ck.utils.np_extras import NDArray
22
+ from ck.utils.random_extras import Random
23
+
24
+
25
+ class WMCProgram(ProgramWithSlotmap, ProbabilitySpace):
26
+ """
27
+ A class for computing Weighted Model Count (WMC).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ pgm_circuit: PGMCircuit,
33
+ const_parameters: bool = True,
34
+ compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
35
+ ):
36
+ """
37
+ Construct a WMCProgram object.
38
+
39
+ Args:
40
+ pgm_circuit: The circuit representing a PGM.
41
+ const_parameters: if True then any circuit variable representing a parameter value will
42
+ be made 'const' in the resulting program.
43
+ """
44
+ raw_program: RawProgram = compile_results(
45
+ pgm_circuit=pgm_circuit,
46
+ results=(pgm_circuit.circuit_top,),
47
+ const_parameters=const_parameters,
48
+ compiler=compiler,
49
+ )
50
+ ProgramWithSlotmap.__init__(
51
+ self,
52
+ ProgramBuffer(raw_program),
53
+ pgm_circuit.slot_map,
54
+ pgm_circuit.rvs,
55
+ pgm_circuit.conditions,
56
+ )
57
+ self._raw_program: RawProgram = raw_program
58
+ self._number_of_indicators: int = pgm_circuit.number_of_indicators
59
+ self._z_cache: Optional[float] = None
60
+
61
+ if not const_parameters:
62
+ # set the parameter slots
63
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
64
+
65
+ def wmc(self, *condition: Condition) -> float:
66
+ self.set_condition(*condition)
67
+ return self.compute().item()
68
+
69
+ @property
70
+ def z(self) -> float:
71
+ if self._z_cache is None:
72
+ number_of_indicators: int = self._number_of_indicators
73
+ slots: NDArray = self.vars
74
+ old_vals: NDArray = slots[:number_of_indicators].copy()
75
+ slots[:number_of_indicators] = 1
76
+ self._z_cache = self.compute().item()
77
+ slots[:number_of_indicators] = old_vals
78
+
79
+ return self._z_cache
80
+
81
+ def sample_uniform(
82
+ self,
83
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
84
+ *,
85
+ condition: Condition = (),
86
+ rand: Random = random,
87
+ ) -> Sampler:
88
+ """
89
+ Create a sampler that performs uniform sampling of
90
+ the state space of the given random variables, rvs.
91
+
92
+ The sampler will yield state lists, where the state
93
+ values are co-indexed with rvs, or self.rvs if rvs is None.
94
+
95
+ This sampler is not affected by and does not affect
96
+ the state of input slots.
97
+
98
+ Args:
99
+ rvs: the list of random variables to sample; the
100
+ yielded state vectors are co-indexed with rvs; if None,
101
+ then the self.rvs are used; if rvs is a single
102
+ random variable, then single samples are yielded.
103
+ condition: is a collection of zero or more conditioning indicators.
104
+ rand: provides the stream of random numbers.
105
+
106
+ Returns:
107
+ a Sampler object (UniformSampler).
108
+ """
109
+ return UniformSampler(
110
+ rvs=(self.rvs if rvs is None else rvs),
111
+ condition=condition,
112
+ rand=rand,
113
+ )
114
+
115
+ def sample_direct(
116
+ self,
117
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
118
+ *,
119
+ condition: Condition = (),
120
+ rand: Random = random,
121
+ chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
122
+ initial_chain_condition: Condition = (),
123
+ ) -> Sampler:
124
+ """
125
+ Create an inverse-transform sampler, which uses the fact that
126
+ probabilities are exactly computable using a WMC.
127
+
128
+ The sampler will yield state lists, where the state
129
+ values are co-indexed with rvs, or self.rvs if rvs is None.
130
+
131
+ Given 'n' random variables, and 'm' number of indicators, for each yielded sample, this method:
132
+ * calls rand.random() once and rand.randrange(...) n times,
133
+ * calls self.program().compute_result() at least once and <= 1 + m.
134
+
135
+ Args:
136
+ rvs: the list of random variables to sample; the
137
+ yielded state vectors are co-indexed with rvs; if None,
138
+ then the WMC rvs are used; if rvs is a single
139
+ random variable, then single samples are yielded.
140
+ condition: is a collection of zero or more conditioning indicators.
141
+ rand: provides the stream of random numbers.
142
+ chain_pairs: is a collection of pairs of random variables, each random variable
143
+ must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
144
+ as a condition for to_rv prior to generating a sample.
145
+ initial_chain_condition: are condition indicators (just like condition)
146
+ for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
147
+
148
+ Returns:
149
+ a Sampler object (WMCDirectSampler).
150
+ """
151
+ sampler_info: SamplerInfo = get_sampler_info(
152
+ program_with_slotmap=self,
153
+ rvs=rvs,
154
+ condition=condition,
155
+ chain_pairs=chain_pairs,
156
+ initial_chain_condition=initial_chain_condition,
157
+ )
158
+
159
+ return WMCDirectSampler(
160
+ sampler_info=sampler_info,
161
+ raw_program=self._raw_program,
162
+ rand=rand,
163
+ )
164
+
165
+ def sample_rejection(
166
+ self,
167
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
168
+ *,
169
+ condition: Condition = (),
170
+ rand: Random = random,
171
+ ) -> Sampler:
172
+ """
173
+ Create a sampler to perform rejection sampling.
174
+
175
+ The sampler will yield state lists, where the state
176
+ values are co-indexed with rvs, or self.rvs if rvs is None.
177
+
178
+ The method uniformly samples states and uses an adaptive 'max weight'
179
+ to reduce unnecessary rejection.
180
+
181
+ After each sample is yielded, the WMC indicator variables will
182
+ be left set as per the yielded states of rvs and conditions.
183
+
184
+ Args:
185
+ rvs: the list of random variables to sample; the
186
+ yielded state vectors are co-indexed with rvs; if None,
187
+ then the WMC rvs are used; if rvs is a single
188
+ random variable, then single samples are yielded.
189
+ condition: is a collection of zero or more conditioning indicators.
190
+ rand: provides the stream of random numbers.
191
+
192
+ Returns:
193
+ a Sampler object (WMCRejectionSampler).
194
+ """
195
+ sampler_info: SamplerInfo = get_sampler_info(
196
+ program_with_slotmap=self,
197
+ rvs=rvs,
198
+ condition=condition,
199
+ )
200
+ z = self.wmc(*condition)
201
+
202
+ return WMCRejectionSampler(
203
+ sampler_info=sampler_info,
204
+ raw_program=self._raw_program,
205
+ rand=rand,
206
+ z=z,
207
+ )
208
+
209
+ def sample_gibbs(
210
+ self,
211
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
212
+ *,
213
+ condition: Condition = (),
214
+ skip: int = 0,
215
+ burn_in: int = 0,
216
+ pr_restart: float = 0,
217
+ rand: Random = random,
218
+ ) -> Sampler:
219
+ """
220
+ Create a sampler to perform Gibbs sampling.
221
+
222
+ The sampler will yield state lists, where the state
223
+ values are co-indexed with rvs, or self.rvs if rvs is None.
224
+
225
+ After each sample is yielded, the WMC indicator vars will
226
+ be left set as per the yielded states of rvs and conditions.
227
+
228
+ Args:
229
+ rvs: the list of random variables to sample; the
230
+ yielded state vectors are co-indexed with rvs; if None,
231
+ then the WMC rvs are used; if rvs is a single
232
+ random variable, then single samples are yielded.
233
+ condition: is a collection of zero or more conditioning indicators.
234
+ skip: is an integer >= 0 specifying how may samples to discard
235
+ for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
236
+ burn_in: how many iterations to perform after
237
+ initialisation before yielding a sample.
238
+ pr_restart: the chance of re-initialising each
239
+ iteration. If restarted then burn-in is performed again.
240
+ rand: provides the stream of random numbers.
241
+
242
+ Returns:
243
+ a Sampler object (WMCGibbsSampler).
244
+ """
245
+ if skip < 0:
246
+ raise RuntimeError('skip must be non-negative')
247
+ if burn_in < 0:
248
+ raise RuntimeError('burn_in must be non-negative')
249
+
250
+ sampler_info: SamplerInfo = get_sampler_info(
251
+ program_with_slotmap=self,
252
+ rvs=rvs,
253
+ condition=condition,
254
+ )
255
+
256
+ return WMCGibbsSampler(
257
+ sampler_info=sampler_info,
258
+ raw_program=self._raw_program,
259
+ rand=rand,
260
+ skip=skip,
261
+ burn_in=burn_in,
262
+ pr_restart=pr_restart,
263
+ )
264
+
265
+ def sample_metropolis(
266
+ self,
267
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
268
+ *,
269
+ condition: Condition = (),
270
+ skip: Optional[int] = None,
271
+ burn_in: int = 0,
272
+ pr_restart: float = 0,
273
+ rand: Random = random,
274
+ ) -> Sampler:
275
+ """
276
+ Create a sampler to perform Metropolis-Hastings sampling.
277
+
278
+ The sampler will yield state lists, where the state
279
+ values are co-indexed with rvs, or self.rvs if rvs is None.
280
+
281
+ After each sample is yielded, the WMC indicator vars will
282
+ be left set as per the yielded states of rvs and conditions.
283
+
284
+ Args:
285
+ rvs: the list of random variables to sample; the
286
+ yielded state vectors are co-indexed with rvs; if None,
287
+ then the WMC rvs are used; if rvs is a single
288
+ random variable, then single samples are yielded.
289
+ condition: is a collection of zero or more conditioning indicators.
290
+ skip: is an optional integer >= 0 specifying how may samples to discard
291
+ for each sample provided. Values > 0 can be used to de-correlate adjacent samples.
292
+ Default value = len(rvs)
293
+ burn_in: how many iterations to perform after initialisation
294
+ before yielding a sample.
295
+ pr_restart: the chance of re-initialising each iteration. If
296
+ restarted then burn-in is performed again.
297
+ rand: provides the stream of random numbers.
298
+
299
+ Returns:
300
+ a Sampler object (WMCMetropolisSampler).
301
+ """
302
+ if skip is not None and skip < 0:
303
+ raise RuntimeError('skip must be non-negative')
304
+ if burn_in < 0:
305
+ raise RuntimeError('burn_in must be non-negative')
306
+
307
+ sampler_info: SamplerInfo = get_sampler_info(
308
+ program_with_slotmap=self,
309
+ rvs=rvs,
310
+ condition=condition,
311
+ )
312
+
313
+ if skip is None:
314
+ skip = len(sampler_info.sample_rvs)
315
+
316
+ return WMCMetropolisSampler(
317
+ sampler_info=sampler_info,
318
+ raw_program=self._raw_program,
319
+ rand=rand,
320
+ skip=skip,
321
+ burn_in=burn_in,
322
+ pr_restart=pr_restart,
323
+ )
@@ -0,0 +1,2 @@
1
+ from .pgm_compiler import PGMCompiler
2
+ from .named_pgm_compilers import NamedPGMCompiler, DEFAULT_PGM_COMPILER
@@ -0,0 +1 @@
1
+ from .ace import compile_pgm, copy_ace_to_default_location, default_ace_location, ace_available
@@ -0,0 +1,299 @@
1
+ import shutil
2
+ import subprocess
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional, List, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from ck.circuit import CircuitNode, Circuit
11
+ from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
12
+ from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
13
+ from ck.in_out.render_net import render_bayesian_network
14
+ from ck.pgm import PGM
15
+ from ck.pgm_circuit import PGMCircuit
16
+ from ck.pgm_circuit.slot_map import SlotMap
17
+ from ck.utils.local_config import config
18
+ from ck.utils.np_extras import NDArrayFloat64
19
+ from ck.utils.tmp_dir import tmp_dir
20
+
21
+
22
+ def compile_pgm(
23
+ pgm: PGM,
24
+ const_parameters: bool = True,
25
+ *,
26
+ ace_dir: Optional[Path | str] = None,
27
+ jar_dir: Optional[Path | str] = None,
28
+ print_output: bool = False,
29
+ m_bytes: int = 1512,
30
+ check_is_bayesian_network: bool = True,
31
+ ) -> PGMCircuit:
32
+ """
33
+ Compile the PGM to an arithmetic circuit, using Ace.
34
+
35
+ This is a wrapper for Ace.
36
+ Ace compiles a Bayesian network into an Arithmetic Circuit.
37
+ Provided by the Automated Reasoning Group, University of California Los Angeles.
38
+ Ace requires the Java Runtime Environment (JRE) version 8 or higher.
39
+ See http://reasoning.cs.ucla.edu/ace/
40
+
41
+ Conforms to the `PGMCompiler` protocol.
42
+
43
+ Args:
44
+ pgm: The PGM to compile.
45
+ const_parameters: If true, the potential function parameters will be circuit
46
+ constants, otherwise they will be circuit variables.
47
+ ace_dir: Directory containing Ace. If not provided then the directory this module is in is used.
48
+ jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
49
+ print_output: if true, the output from Ace is printed.
50
+ m_bytes: requested megabytes for the Java Virtual Machine (using the java "-Xmx" argument).
51
+ check_is_bayesian_network: if true, then the PGM will be checked to confirm it is a Bayesian network.
52
+
53
+ Returns:
54
+ a PGMCircuit object.
55
+
56
+ Raises:
57
+ RuntimeError: if Ace files are not found, including a helpful message.
58
+ ValueError: if `check_is_bayesian_network` is true and the PGM is not a Bayesian network.
59
+ CalledProcessError: if executing Ace failed.
60
+ """
61
+ if check_is_bayesian_network and not pgm.check_is_bayesian_network():
62
+ raise ValueError('the given PGM is not a Bayesian network')
63
+
64
+ # ACE cannot deal with the empty PGM even though it is a valid Bayesian network
65
+ if pgm.number_of_factors == 0:
66
+ circuit = Circuit()
67
+ circuit.new_vars(pgm.number_of_indicators)
68
+ parameter_values = np.array([], dtype=np.float64)
69
+ slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
70
+ return PGMCircuit(
71
+ rvs=pgm.rvs,
72
+ conditions=(),
73
+ circuit_top=circuit.const(1),
74
+ number_of_indicators=pgm.number_of_indicators,
75
+ number_of_parameters=0,
76
+ slot_map=slot_map,
77
+ parameter_values=parameter_values,
78
+ )
79
+
80
+ java: str
81
+ classpath_separator: str
82
+ java, classpath_separator = _find_java()
83
+ files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
84
+ net_file_name = 'to_compile.net'
85
+ main_class = 'edu.ucla.belief.ace.AceCompile'
86
+ class_path: str = classpath_separator.join(
87
+ str(f) for f in [files.ace_jar, files.inflib_jar, files.jdom_jar]
88
+ )
89
+ ace_cmd: List[str] = [
90
+ java,
91
+ '-cp',
92
+ class_path,
93
+ f'-DACEC2D={files.c2d}',
94
+ f'-Xmx{int(m_bytes)}m',
95
+ main_class,
96
+ net_file_name,
97
+ ]
98
+
99
+ with tmp_dir():
100
+ # Render the PGM to a .net file to be read by Ace
101
+ with open(net_file_name, 'w') as file:
102
+ node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
103
+
104
+ # Run Ace
105
+ ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=(not print_output), text=True)
106
+ if ace_result.returncode != 0:
107
+ raise subprocess.CalledProcessError(
108
+ returncode=ace_result.returncode,
109
+ cmd=' '.join(ace_cmd),
110
+ output=None if print_output else ace_result.stdout,
111
+ stderr=None if print_output else ace_result.stderr,
112
+ )
113
+
114
+ # Parse the literal map output from Ace
115
+ with open(f'{net_file_name}.lmap', 'r') as file:
116
+ literal_map: LiteralMap = read_lmap(file, node_names=node_names)
117
+
118
+ # Parse the arithmetic circuit output from Ace
119
+ with open(f'{net_file_name}.ac', 'r') as file:
120
+ circuit_top: CircuitNode
121
+ slot_map: SlotMap
122
+ parameter_values: NDArrayFloat64
123
+ circuit_top, slot_map, parameter_values = read_nnf_with_literal_map(
124
+ file,
125
+ indicators=pgm.indicators,
126
+ literal_map=literal_map,
127
+ const_parameters=const_parameters,
128
+ )
129
+
130
+ # Consistency checking
131
+ number_of_indicators: int = pgm.number_of_indicators
132
+ number_of_parameters: int = parameter_values.shape[0]
133
+ assert circuit_top.circuit.number_of_vars == number_of_indicators + number_of_parameters, 'consistency check'
134
+
135
+ return PGMCircuit(
136
+ rvs=pgm.rvs,
137
+ conditions=(),
138
+ circuit_top=circuit_top,
139
+ number_of_indicators=number_of_indicators,
140
+ number_of_parameters=number_of_parameters,
141
+ slot_map=slot_map,
142
+ parameter_values=parameter_values,
143
+ )
144
+
145
+
146
+ def ace_available(
147
+ ace_dir: Optional[Path | str] = None,
148
+ jar_dir: Optional[Path | str] = None,
149
+ ) -> bool:
150
+ """
151
+ Returns:
152
+ True if it looks like ACE is available, False otherwise.
153
+ ACE is available if ACE files are in the default location and Java is available.
154
+ """
155
+ try:
156
+ java: str
157
+ java, _ = _find_java()
158
+ _: _AceFiles = _find_ace_files(ace_dir, jar_dir)
159
+
160
+ java_cmd: List[str] = [java, '--version',]
161
+ java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
162
+
163
+ return java_result.returncode == 0
164
+
165
+ except RuntimeError:
166
+ return False
167
+
168
+
169
+ def copy_ace_to_default_location(
170
+ ace_dir: Path | str,
171
+ jar_dir: Optional[Path | str] = None,
172
+ ) -> None:
173
+ """
174
+ Copy Ace files from the given directories into the default directory.
175
+
176
+ Args:
177
+ ace_dir: Directory containing Ace.
178
+ jar_dir: Directory containing Ace jar files. If not provided, then `ace_dir` is used.
179
+
180
+ Raises:
181
+ RuntimeError: if Ace files are not found, including a helpful message .
182
+ IOError: if the copy fails.
183
+
184
+ Assumes:
185
+ ace_dir exists and is not the same as the installation directory.
186
+ """
187
+ install_location: Path = default_ace_location()
188
+
189
+ if ace_dir is None or ace_dir == install_location:
190
+ raise RuntimeError(f'Ace directory cannot be the default directory')
191
+
192
+ files: _AceFiles = _find_ace_files(ace_dir, jar_dir)
193
+
194
+ to_copy = [files.ace_jar, files.inflib_jar, files.jdom_jar] + files.c2d_options
195
+
196
+ for file in to_copy:
197
+ shutil.copyfile(file, install_location / file.name)
198
+
199
+
200
+ def default_ace_location() -> Path:
201
+ """
202
+ Get the default location for Ace files.
203
+
204
+ This function checks the local config for the variable
205
+ CK_ACE_LOCATION. If that is not available, then the
206
+ directory that this Python module is in will be used.
207
+ """
208
+ return Path(config.get('CK_ACE_LOCATION', Path(__file__).parent))
209
+
210
+
211
+ @dataclass
212
+ class _AceFiles:
213
+ ace_jar: Path
214
+ inflib_jar: Path
215
+ jdom_jar: Path
216
+ c2d: Path
217
+ c2d_options: List[Path]
218
+
219
+
220
+ def _find_java() -> Tuple[str, str]:
221
+ """
222
+ What to call the Java executable and classpath separator.
223
+
224
+ Returns:
225
+ (java, classpath_separator)
226
+
227
+ Raises:
228
+ RuntimeError: if not found, including a helpful message.
229
+ """
230
+ if sys.platform == 'win32':
231
+ return 'java.exe', ';'
232
+ elif sys.platform == 'darwin':
233
+ return 'java', ':'
234
+ elif sys.platform.startswith('linux'):
235
+ return 'java', ':'
236
+ else:
237
+ raise RuntimeError(f'cannot infer java for platform {sys.platform!r}')
238
+
239
+
240
+ def _find_ace_files(
241
+ ace_dir: Optional[Path | str],
242
+ jar_dir: Optional[Path | str],
243
+ ) -> _AceFiles:
244
+ """
245
+ Look for the needed Ace files.
246
+
247
+ Raises:
248
+ RuntimeError: if not found, including a helpful message.
249
+ """
250
+ ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
251
+ jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
252
+
253
+ if not ace_dir.is_dir():
254
+ raise RuntimeError(f'Ace directory does not exist: {ace_dir}')
255
+ if not jar_dir.is_dir():
256
+ raise RuntimeError(f'Ace jar directory does not exist: {jar_dir}')
257
+
258
+ ace_jar = jar_dir / 'ace.jar'
259
+ inflib_jar = jar_dir / 'inflib.jar'
260
+ jdom_jar = jar_dir / 'jdom.jar'
261
+
262
+ missing: List[str] = [
263
+ jar.name
264
+ for jar in [ace_jar, inflib_jar, jdom_jar]
265
+ if not jar.is_file()
266
+ ]
267
+ if len(missing) > 0:
268
+ raise RuntimeError(f'Ace jars missing (ensure Ace is properly installed): {", ".join(missing)}')
269
+
270
+ c2d_options: List[Path] = [
271
+ file
272
+ for file in ace_dir.iterdir()
273
+ if file.is_file() and file.name.startswith('c2d')
274
+ ]
275
+ c2d: Path
276
+ if len(c2d_options) == 0:
277
+ raise RuntimeError(f'cannot find c2d in the Ace directory: {ace_dir}')
278
+ if len(c2d_options) == 1:
279
+ c2d = next(iter(c2d_options))
280
+ else:
281
+ if sys.platform == 'win32':
282
+ c2d = ace_dir / 'c2d_windows.exe'
283
+ elif sys.platform == 'darwin':
284
+ c2d = ace_dir / 'c2d_osx'
285
+ elif sys.platform.startswith('linux'):
286
+ c2d = ace_dir / 'c2d_linux'
287
+ else:
288
+ raise RuntimeError(f'cannot infer c2d executable name for platform {sys.platform!r}')
289
+
290
+ if not c2d.is_file():
291
+ raise RuntimeError(f'cannot find c2d: {c2d}')
292
+
293
+ return _AceFiles(
294
+ c2d=c2d,
295
+ c2d_options=c2d_options,
296
+ ace_jar=ace_jar,
297
+ inflib_jar=inflib_jar,
298
+ jdom_jar=jdom_jar,
299
+ )