compiled-knowledge 4.0.0a20__cp313-cp313-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1 @@
1
+ from .pgm_circuit import PGMCircuit
@@ -0,0 +1,352 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Sequence, Optional, Tuple, List, Iterable, Dict
5
+
6
+ import numpy as np
7
+
8
+ from ck.circuit import CircuitNode, Circuit
9
+ from ck.pgm import RandomVariable, number_of_states, rv_instances_as_indicators
10
+ from ck.pgm_circuit import PGMCircuit
11
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
12
+ from ck.pgm_circuit.slot_map import SlotMap
13
+ from ck.pgm_circuit.support.compile_circuit import compile_results
14
+ from ck.probability.probability_space import ProbabilitySpace, check_condition, Condition
15
+ from ck.program.program_buffer import ProgramBuffer
16
+ from ck.program.raw_program import RawProgram
17
+ from ck.sampling.marginals_direct_sampler import MarginalsDirectSampler
18
+ from ck.sampling.sampler import Sampler
19
+ from ck.sampling.sampler_support import SamplerInfo, get_sampler_info
20
+ from ck.sampling.uniform_sampler import UniformSampler
21
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
22
+ from ck.utils.random_extras import Random
23
+
24
+
25
+ class MarginalsProgram(ProgramWithSlotmap, ProbabilitySpace):
26
+ """
27
+ A class for computing marginal probability distributions over states of selected output random variables.
28
+ This class provides, for each indicator, the product of indicator value with the derivative
29
+ of the network function with respect to the indicator.
30
+
31
+ Compile the circuit for computing marginal probability distributions using the
32
+ so-called 'differential' approach.
33
+
34
+ Reference: Darwiche, A. (2003). A differential approach to inference in Bayesian
35
+ networks. Journal of the ACM (JACM), 50(3), 280-305.
36
+
37
+ A note about samplers
38
+ ---------------------
39
+
40
+ When creating a sampler, a client may request that samples are conditioned
41
+ on provided condition indicators. Also, the WMCProgram may have been
42
+ produced with compile-in conditions, e.g., using const_conditions with
43
+ a call to PGM_cct.wmc(...).
44
+
45
+ The conditions respected by a sampler are the conjunction of the compiled
46
+ conditions and the sampler conditions. For example, with compiled condition
47
+ (A[0], A[1], A[2]) and sampler condition (A[1], A[2], A[3]) the effective
48
+ condition is (A[1], A[2]), i.e., a sample of A may be 1 or 2.
49
+
50
+ Warning:
51
+ if the sampled random variables include conditions, those conditions
52
+ must be provided to the sampler. If a sampled random variable is conditioned
53
+ at compile time, but not passed to the sampler, then the sample will not
54
+ be aware of the conditions, and unexpected sample values may be produced.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ pgm_circuit: PGMCircuit,
60
+ output_rvs: Optional[Sequence[RandomVariable]] = None,
61
+ const_parameters: bool = True,
62
+ ):
63
+ """
64
+ Construct a MarginalsProgram object.
65
+
66
+ The given program should produce marginal outputs in the order
67
+ of output_rvs indicators, followed by the wmc output.
68
+
69
+ Args:
70
+ pgm_circuit: The circuit representing a PGM.
71
+ output_rvs: if None, the output rvs are all rvs, otherwise the given rvs.
72
+ const_parameters: if True then any circuit variable representing a parameter value will
73
+ be made 'const' in the resulting program.
74
+ """
75
+ top_node: CircuitNode = pgm_circuit.circuit_top
76
+ circuit: Circuit = top_node.circuit
77
+ slot_map: SlotMap = pgm_circuit.slot_map
78
+ input_rvs: Sequence[RandomVariable] = pgm_circuit.rvs
79
+
80
+ output_rvs: Sequence[RandomVariable] = tuple(output_rvs) if output_rvs is not None else input_rvs
81
+
82
+ output_rvs_slots = [[slot_map[ind] for ind in rv] for rv in output_rvs]
83
+ flat_out_rv_vars = [circuit.vars[slot] for slots in output_rvs_slots for slot in slots]
84
+ derivatives = circuit.partial_derivatives(top_node, flat_out_rv_vars, self_multiply=True)
85
+
86
+ raw_program: RawProgram = compile_results(
87
+ pgm_circuit=pgm_circuit,
88
+ results=derivatives + [top_node],
89
+ const_parameters=const_parameters,
90
+ )
91
+
92
+ program_buffer = ProgramBuffer(raw_program)
93
+ ProgramWithSlotmap.__init__(self, program_buffer, slot_map, input_rvs, pgm_circuit.conditions)
94
+
95
+ # cache the input slots for the output rvs
96
+ output_rvs_slots = [[slot_map[ind] for ind in rv] for rv in output_rvs]
97
+
98
+ # cache the output offsets for the derivatives.
99
+ # A map from `RandomVariable.idx` to offset into the result buffer
100
+ self._rv_idx_to_result_offset: Dict[int, int] = {}
101
+ prev_offset: int = 0
102
+ for rv in output_rvs:
103
+ self._rv_idx_to_result_offset[rv.idx] = prev_offset
104
+ prev_offset += len(rv)
105
+
106
+ # cached a map from output rv to its position in the marginals result
107
+ self._rv_idx_to_output_index: Dict[int, int] = {rv.idx: i for i, rv in enumerate(output_rvs)}
108
+
109
+ self._marginals: List[NDArrayNumeric] = []
110
+ start = 0
111
+ for rv_slots in output_rvs_slots:
112
+ end = start + len(rv_slots)
113
+ result_part = program_buffer.results[start:end] # gets a view onto the same data.
114
+ self._marginals.append(result_part)
115
+ start = end
116
+
117
+ # additional fields
118
+ self._raw_program: RawProgram = raw_program
119
+ self._program_buffer: ProgramBuffer = program_buffer
120
+ self._number_of_indicators: int = pgm_circuit.number_of_indicators
121
+ self._output_rvs = output_rvs
122
+ self._output_rvs_slots = output_rvs_slots
123
+ self._z_cache: Optional[float] = None
124
+
125
+ if not const_parameters:
126
+ # set the parameter slots
127
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
128
+
129
+ @property
130
+ def output_rvs(self):
131
+ """
132
+ What random variables are included in the marginal probabilities calculations.
133
+ """
134
+ return self._output_rvs
135
+
136
+ def wmc(self, *condition: Condition) -> float:
137
+ """
138
+ What is the weight of the world with the given indicators.
139
+ If multiple indicators from the same random variable ar mentioned, then it is treated as a disjunction.
140
+ If a random variable is not mentioned in the indicators, that random variable is marginalised out.
141
+ """
142
+ self.set_condition(*condition)
143
+ self._program_buffer.compute()
144
+ return self.result_wmc
145
+
146
+ @property
147
+ def z(self):
148
+ if self._z_cache is None:
149
+ number_of_indicators: int = self._number_of_indicators
150
+ slots: NDArray = self.vars
151
+ old_vals: NDArray = slots[:number_of_indicators].copy()
152
+ slots[:number_of_indicators] = 1
153
+ self._program_buffer.compute()
154
+ self._z_cache = self.result_wmc
155
+ slots[:number_of_indicators] = old_vals
156
+ return self._z_cache
157
+
158
+ def marginal_distribution(self, *rvs: RandomVariable, condition: Condition = ()):
159
+ # Check for easy cases.
160
+ if len(rvs) == 0:
161
+ if self.wmc(*condition) == 0:
162
+ return np.array([np.nan])
163
+ return np.array([1.0])
164
+ if len(rvs) == 1:
165
+ return self.marginal_for_rv(rvs[0], condition=condition)
166
+
167
+ # We try to eliminate searching combinations of probabilities where marginals are zero.
168
+ # If there are no marginal probabilities = 0, then this is equivalent to
169
+ # ProbabilitySpace.marginal_distribution
170
+
171
+ condition = check_condition(condition)
172
+ rvs_marginals = self.marginal_for_rvs(rvs, condition=condition)
173
+ zero_indicators = set(
174
+ ind
175
+ for rv, rv_marginal in zip(rvs, rvs_marginals)
176
+ for ind, marginal in zip(rv, rv_marginal)
177
+ if marginal == 0
178
+ )
179
+ raw_wmc = self._get_wmc_for_marginals(rvs, condition)
180
+
181
+ if len(zero_indicators) == 0:
182
+ wmc = raw_wmc
183
+ else:
184
+ def wmc(indicators):
185
+ for ind in indicators:
186
+ if ind in zero_indicators:
187
+ return 0
188
+ return raw_wmc(indicators)
189
+
190
+ result = np.fromiter(
191
+ (wmc(indicators) for indicators in rv_instances_as_indicators(*rvs)),
192
+ count=number_of_states(*rvs),
193
+ dtype=np.float64
194
+ )
195
+ _normalise_marginal(result)
196
+ return result
197
+
198
+ def marginal_for_rv(self, rv: RandomVariable, condition: Condition = ()) -> NDArrayNumeric:
199
+ """
200
+ Compute and return marginal distribution over the given random variable.
201
+ The random variable is assumed to be in self.rvs.
202
+
203
+ Returns:
204
+ a numpy array representing the marginal distribution over the states of 'rv'.
205
+ """
206
+ self.compute_conditioned(*condition)
207
+ return self.result_for_rv(rv)
208
+
209
+ def marginal_for_rvs(self, rvs: Iterable[RandomVariable], condition: Condition = ()) -> List[NDArrayNumeric]:
210
+ """
211
+ Compute and return marginal distribution over the given random variables.
212
+ Each random variable is assumed to be in self.rvs.
213
+
214
+ Returns:
215
+ a list of numpy arrays representing the marginal distribution over the
216
+ states of each rv in the given random variables, `rvs`.
217
+ """
218
+ self.compute_conditioned(*condition)
219
+ marginals = self._marginals
220
+ rv_idx_to_output_index = self._rv_idx_to_output_index
221
+ return list(marginals[rv_idx_to_output_index[rv.idx]] for rv in rvs)
222
+
223
+ def compute(self) -> NDArrayNumeric:
224
+ self._program_buffer.compute()
225
+ for part in self._marginals:
226
+ _normalise_marginal(part)
227
+ return self._program_buffer.results
228
+
229
+ @property
230
+ def result_wmc(self) -> float:
231
+ """
232
+ Assuming the result has been computed,
233
+ return the WMC value.
234
+ """
235
+ return self._program_buffer.results.item(-1)
236
+
237
+ @property
238
+ def result_marginals(self) -> List[NDArrayNumeric]:
239
+ """
240
+ Assuming the result has been computed,
241
+ return the marginal distributions of each random variable, co-indexed with the
242
+ output random variables, `self.output_rvs`.
243
+
244
+ Returns:
245
+ a list of numpy arrays, the list co-indexed with `self.output_rvs`, each numpy array
246
+ representing the marginal distribution over the states of the co-indexed random variable.
247
+ """
248
+ return self._marginals
249
+
250
+ def result_for_rv(self, rv: RandomVariable) -> NDArrayNumeric:
251
+ """
252
+ Assuming the result has been computed,
253
+ return marginal distribution over the given random variable.
254
+ The random variable is assumed to be in self.output_rvs.
255
+
256
+ Returns:
257
+ a numpy array representing the marginal distribution over the states of 'rv'.
258
+ """
259
+ return self._marginals[self._rv_idx_to_output_index[rv.idx]]
260
+
261
+ def sample_uniform(
262
+ self,
263
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
264
+ *,
265
+ condition: Condition = (),
266
+ rand: Random = random,
267
+ ) -> Sampler:
268
+ """
269
+ Create a sampler that performs uniform sampling of
270
+ the state space of the given random variables, rvs.
271
+
272
+ The sampler will yield state lists, where the state
273
+ values are co-indexed with rvs, or self.rvs if rvs is None.
274
+
275
+ This sampler is not affected by and does not affect
276
+ the state of input slots.
277
+
278
+ Args:
279
+ rvs: the list of random variables to sample; the
280
+ yielded state vectors are co-indexed with rvs; if None,
281
+ then the self.rvs are used; if rvs is a single
282
+ random variable, then single samples are yielded.
283
+ condition: is a collection of zero or more conditioning indicators.
284
+ rand: provides the stream of random numbers.
285
+
286
+ Returns:
287
+ a Sampler object (UniformSampler).
288
+ """
289
+ return UniformSampler(
290
+ rvs=(self.rvs if rvs is None else rvs),
291
+ condition=condition,
292
+ rand=rand,
293
+ )
294
+
295
+ def sample_direct(
296
+ self,
297
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
298
+ *,
299
+ condition: Condition = (),
300
+ rand: Random = random,
301
+ chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
302
+ initial_chain_condition: Condition = (),
303
+ ) -> Sampler:
304
+ """
305
+ Create an inverse-transform sampler, which uses the fact that marginal
306
+ probabilities are exactly computable with a single execution of the program.
307
+
308
+ The sampler will yield state lists, where the state
309
+ values are co-indexed with rvs, or self.rvs if rvs is None.
310
+
311
+ Args:
312
+ rvs: the list of random variables to sample; the
313
+ yielded state vectors are co-indexed with rvs; if None,
314
+ then the WMC rvs are used; if rvs is a single
315
+ random variable, then single samples are yielded.
316
+ condition: is a collection of zero or more conditioning indicators.
317
+ rand: provides the stream of random numbers.
318
+ chain_pairs: is a collection of pairs of random variables, each random variable
319
+ must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
320
+ as a condition for to_rv prior to generating a sample.
321
+ initial_chain_condition: are condition indicators (just like condition)
322
+ for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
323
+
324
+ Returns:
325
+ a Sampler object (MarginalsDirectSampler).
326
+ """
327
+ sampler_info: SamplerInfo = get_sampler_info(
328
+ program_with_slotmap=self,
329
+ rvs=rvs,
330
+ condition=condition,
331
+ chain_pairs=chain_pairs,
332
+ initial_chain_condition=initial_chain_condition,
333
+ )
334
+
335
+ return MarginalsDirectSampler(
336
+ sampler_info=sampler_info,
337
+ raw_program=self._raw_program,
338
+ rand=rand,
339
+ rv_idx_to_result_offset=self._rv_idx_to_result_offset,
340
+ )
341
+
342
+
343
+ def _normalise_marginal(distribution: NDArrayNumeric) -> None:
344
+ """
345
+ Update the values in the given distribution to
346
+ properly represent a marginal distribution.
347
+ """
348
+ total = np.sum(distribution)
349
+ if total <= 0:
350
+ distribution[:] = np.nan
351
+ elif total != 1:
352
+ distribution /= total
@@ -0,0 +1,237 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import Sequence, Optional, Tuple, List, Dict, Set
6
+
7
+ from ck.circuit import CircuitNode, Circuit, VarNode, OpNode, ADD, MUL
8
+ from ck.circuit_compiler import llvm_vm_compiler, CircuitCompiler
9
+ from ck.pgm import RandomVariable, Instance
10
+ from ck.pgm_circuit import PGMCircuit
11
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
12
+ from ck.pgm_circuit.slot_map import SlotMap
13
+ from ck.pgm_circuit.support.compile_circuit import compile_results
14
+ from ck.probability.probability_space import check_condition
15
+ from ck.program.program_buffer import ProgramBuffer
16
+ from ck.program.raw_program import RawProgram
17
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
18
+
19
+ _NO_TRACE = (-1, -1) # used as a sentinel value
20
+
21
+ _CCT_COMPILER = llvm_vm_compiler # Python module used for compiling an MPE circuit
22
+
23
+
24
+ class MPEProgram(ProgramWithSlotmap):
25
+ """
26
+ A class for computing Most Probable Explanation (MPE). This is equivalent to
27
+ Maximum A Posterior (MAP) inference when there are no latent random variables.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ pgm_circuit: PGMCircuit,
33
+ trace_rvs: Optional[Sequence[RandomVariable]] = None,
34
+ const_parameters: bool = True,
35
+ log_parameters: bool = False,
36
+ ):
37
+ """
38
+ Construct a MPEProgram object.
39
+
40
+ Compile the circuit for computing Most Probable Explanation (MPE). This is equivalent to
41
+ Maximum A Posterior (MAP) inference when there are no latent variables.
42
+
43
+ This will compile a clone of the given circuit with
44
+ 'add' nodes replaced with 'max' nodes.
45
+
46
+ This will augment the given circuit and compile it to make a program for computing MPE states.
47
+ 'trace_vars' is a list random variables, where each random variable is a list of circuit var nodes, each
48
+ var node representing an indicator (i.e., a state) of a random variable.
49
+ Assumes that all operator nodes to compute top are either an add or mul node.
50
+
51
+ Args:
52
+ pgm_circuit: The circuit representing a PGM.
53
+ trace_rvs: the random variables to compute MPE for, default is all random variables of the PGM.
54
+ const_parameters: if True then any circuit variable representing a parameter value will
55
+ be made 'const' in the resulting program.
56
+ log_parameters: if true, then parameters are taken to be logs, i.e., uses addition instead
57
+ of multiplication.
58
+ """
59
+ trace_rvs: Tuple[RandomVariable, ...] = pgm_circuit.rvs if trace_rvs is None else tuple(trace_rvs)
60
+ if len(trace_rvs) != len(set(trace_rvs)):
61
+ raise ValueError('duplicated trace random variable detected')
62
+
63
+ top: CircuitNode = pgm_circuit.circuit_top
64
+ circuit: Circuit = top.circuit
65
+ slot_map: SlotMap = pgm_circuit.slot_map
66
+
67
+ cct_compiler: CircuitCompiler
68
+ if log_parameters:
69
+ cct_compiler = partial(_CCT_COMPILER.compile_circuit, data_type=_CCT_COMPILER.DataType.MAX_SUM)
70
+ else:
71
+ cct_compiler = partial(_CCT_COMPILER.compile_circuit, data_type=_CCT_COMPILER.DataType.MAX_MUL)
72
+
73
+ # make inv_trace_blocks
74
+ #
75
+ # inv_trace_blocks[slot] = (rv_trace_idx, state_idx)
76
+ # where
77
+ # rv_trace_idx is an index into trace_vars,
78
+ # state_idx is an index into trace_vars[rv_trace_idx] indicators,
79
+ #
80
+ # slot = slot_map[ind], where ind = trace_vars[rv_trace_idx][state_idx].
81
+ #
82
+ inv_trace_blocks: List[Tuple[int, int]] = [_NO_TRACE] * circuit.number_of_vars
83
+ rv_trace_idx: int
84
+ trace_rv: RandomVariable
85
+ for rv_trace_idx, trace_rv in enumerate(trace_rvs):
86
+ for state_idx in trace_rv.state_range():
87
+ slot: int = slot_map[trace_rv[state_idx]]
88
+ if inv_trace_blocks[slot] is not _NO_TRACE:
89
+ raise ValueError('unexpected reused circuit slot')
90
+ inv_trace_blocks[slot] = (rv_trace_idx, state_idx)
91
+
92
+ used_nodes: List[CircuitNode] = list(circuit.reachable_op_nodes(top))
93
+
94
+ mpe_idx: Dict[int, int] = {
95
+ id(used_node): used_node_idx
96
+ for used_node_idx, used_node in enumerate(used_nodes)
97
+ }
98
+
99
+ # create a dummy MPE result until compute is called
100
+ dummy_result = MPEResult(float('nan'), tuple(0 for _ in trace_rvs))
101
+
102
+ self._trace_rvs: Tuple[RandomVariable, ...] = trace_rvs
103
+ self._inv_trace_blocks = inv_trace_blocks
104
+ self._top: CircuitNode = top
105
+ self._mpe_result: MPEResult = dummy_result
106
+
107
+ self._top_idx: Optional[int] = mpe_idx.get(id(top)) # it may be possible that top is not an op node.
108
+ self._used_nodes: List[CircuitNode] = used_nodes
109
+ self._mpe_idx: Dict[int, int] = mpe_idx
110
+
111
+ raw_program: RawProgram = compile_results(
112
+ pgm_circuit=pgm_circuit,
113
+ results=used_nodes,
114
+ const_parameters=const_parameters,
115
+ compiler=cct_compiler,
116
+ )
117
+ ProgramWithSlotmap.__init__(self, ProgramBuffer(raw_program), slot_map, pgm_circuit.rvs, pgm_circuit.conditions)
118
+
119
+ if not const_parameters:
120
+ # set the parameter slots
121
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
122
+
123
+ def mpe(self, *condition) -> MPEResult:
124
+ """
125
+ What is the MPE, given any conditioning indicators.
126
+
127
+ The mpe array may contain None in an element corresponding to a traced random variable where
128
+ all states of that random variable lead to the same wmc value. I.e., the solution is indifferent
129
+ to the state of that random variable. In this case, a caller is at liberty to use any state for that
130
+ random variable as an MPE solution. For example, all 'None' values could be replaced with zero
131
+ and the solution is still a valid MPE solution.
132
+
133
+ Returns:
134
+ an MPEResult with field `wmc` and `mpe`.
135
+ wmc: is the value of the weighted model count.
136
+ mpe: is an Instance, co-indexed with trace vars, where mpe[rv_idx] = state_idx.
137
+ """
138
+ condition = check_condition(condition)
139
+ self.compute_conditioned(*condition)
140
+ return self._mpe_result
141
+
142
+ @property
143
+ def trace_rvs(self) -> Sequence[RandomVariable]:
144
+ """
145
+ What are the random variables used in an MPE trace.
146
+ """
147
+ return self._trace_rvs
148
+
149
+ def compute(self) -> NDArrayNumeric:
150
+ """
151
+ Execute the program to compute and return the result. As per `ProgramBuffer.compute`.
152
+
153
+ Warning:
154
+ when returning an array, the array is backed by the program buffer memory, not a copy.
155
+ """
156
+ program_result: NDArray = self._program_buffer.compute()
157
+ self._trace()
158
+ return program_result
159
+
160
+ @property
161
+ def mpe_result(self) -> MPEResult:
162
+ """
163
+ Get the MPEResult of the last program computation.
164
+
165
+ Returns:
166
+ an MPEResult object.
167
+ """
168
+ return self._mpe_result
169
+
170
+ def _trace(self) -> None:
171
+ """
172
+ Trace the last program computation to determine the wmc and the mpe states.
173
+ """
174
+ if self._top_idx is not None:
175
+ wmc: float = self.results.item(self._top_idx)
176
+ states: List[Optional[int]] = [None for _ in self._trace_rvs]
177
+ seen: Set[int] = set()
178
+ self._trace_r(self._top, wmc, states, seen)
179
+ mpe = tuple(
180
+ 0 if state_idx is None else state_idx
181
+ for state_idx in states
182
+ )
183
+ self._mpe_result = MPEResult(wmc, mpe)
184
+
185
+ def _trace_r(self, node: CircuitNode, node_value: float, states: List[Optional[int]], seen: Set[int]) -> None:
186
+
187
+ # A circuit is a DAG, not necessarily a tree.
188
+ # No need to revisit nodes.
189
+ if id(node) in seen:
190
+ return
191
+ seen.add(id(node))
192
+
193
+ if isinstance(node, VarNode):
194
+ self._trace_var(node, states)
195
+ elif isinstance(node, OpNode):
196
+ if node.symbol == ADD:
197
+ # Find which child node led to the max result, then recurse though it only.
198
+ for child in node.args:
199
+ if isinstance(child, OpNode):
200
+ child_value: float = self.results.item(self._mpe_idx[id(child)])
201
+ if child_value == node_value:
202
+ self._trace_r(child, child_value, states, seen)
203
+ return
204
+ elif isinstance(child, VarNode):
205
+ child_value: float = self.vars.item(child.idx)
206
+ if child_value == node_value:
207
+ self._trace_var(child, states)
208
+ return
209
+ # No child value equaled the value for node! We should never get here
210
+ assert False, 'not reached'
211
+ elif node.symbol == MUL:
212
+ # Recurse though each child node
213
+ for child in node.args:
214
+ if isinstance(child, OpNode):
215
+ child_value: float = self.results.item(self._mpe_idx[id(child)])
216
+ self._trace_r(child, child_value, states, seen)
217
+ elif isinstance(child, VarNode):
218
+ self._trace_var(child, states)
219
+
220
+ def _trace_var(self, node: VarNode, states: List[Optional[int]]) -> None:
221
+ trace = self._inv_trace_blocks[node.idx]
222
+ if trace is not _NO_TRACE:
223
+ rv_trace_idx, state_idx = trace
224
+ states[rv_trace_idx] = state_idx
225
+
226
+
227
+ @dataclass
228
+ class MPEResult:
229
+ """
230
+ An MPE result is the result of MPE inference.
231
+
232
+ Fields:
233
+ wmc: the weighted model count value of the MPE solution.
234
+ mpe: The MPE solution instance. If there are ties then this will just be once instance.
235
+ """
236
+ wmc: float
237
+ mpe: Instance
@@ -0,0 +1,79 @@
1
+ from dataclasses import dataclass
2
+ from typing import Sequence, List, Dict
3
+
4
+ from ck.circuit import CircuitNode, Circuit
5
+ from ck.pgm import RandomVariable, Indicator
6
+ from ck.pgm_circuit.slot_map import SlotMap, SlotKey
7
+ from ck.utils.np_extras import NDArray
8
+
9
+
10
+ @dataclass
11
+ class PGMCircuit:
12
+ """
13
+ A data structure representing the results of compiling a PGM to a circuit.
14
+
15
+ If the circuit contains variables to represent parameter values, then `parameter_values`
16
+ holds the values of the parameters. Specifically, given parameter id `param_id`, then
17
+ `parameter_values[slot_map[param_id] - number_of_indicators]` is the value of the
18
+ identified parameter as it was in the PGM.
19
+
20
+ Fields:
21
+ rvs: holds the random variables from the PGM as it was compiled, in order.
22
+
23
+ conditions: any conditions on `rvs` that were compiled into the circuit.
24
+
25
+ number_of_indicators: is the number of indicators in `rvs` which is
26
+ `sum(len(rv) for rv in rvs`. Specifically, `circuit.vars[i]` is the circuit variable
27
+ corresponding to the ith indicator, where `circuit` is `circuit_top.circuit` and
28
+ indicators are ordered as per `rvs`.
29
+
30
+ number_of_parameters: is the number of parameters from the PGM that are
31
+ represented as circuit variables. This may be zero if parameters from the PGM
32
+ were compiled as constants.
33
+
34
+ slot_map[x]: gives the index of the circuit variable corresponding to x,
35
+ where x is either a random variable indicator (Indicator) or a parameter id (ParamId).
36
+
37
+ """
38
+
39
+ rvs: Sequence[RandomVariable]
40
+ conditions: Sequence[Indicator]
41
+ circuit_top: CircuitNode
42
+ number_of_indicators: int
43
+ number_of_parameters: int
44
+ slot_map: SlotMap
45
+ parameter_values: NDArray
46
+
47
+ def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
48
+ """
49
+ Print a dump of the circuit.
50
+ This is intended for debugging and demonstration purposes.
51
+
52
+ Args:
53
+ prefix: optional prefix for indenting all lines.
54
+ indent: additional prefix to use for extra indentation.
55
+ """
56
+
57
+ # We infer names for the circuit variables, either as an indicator or as a parameter.
58
+ # The `var_names` will be passed to `circuit.dump`.
59
+
60
+ circuit: Circuit = self.circuit_top.circuit
61
+ var_names: List[str] = [''] * circuit.number_of_vars
62
+
63
+ # Name the circuit variables that are indicators
64
+ rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
65
+ slot_key: SlotKey
66
+ slot: int
67
+ for slot_key, slot in self.slot_map.items():
68
+ if isinstance(slot_key, Indicator):
69
+ rv = rvs_by_idx[slot_key.rv_idx]
70
+ state_idx = slot_key.state_idx
71
+ var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
72
+
73
+ # Name the circuit variables that are parameters
74
+ for i, param_value in enumerate(self.parameter_values):
75
+ slot = i + self.number_of_indicators
76
+ var_names[slot] = f'param[{i}] = {param_value}'
77
+
78
+ # Dump the circuit
79
+ circuit.dump(prefix=prefix, indent=indent, var_names=var_names)