compiled-knowledge 4.0.0a20__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp313-win32.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win32.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win32.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win32.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,113 @@
1
+ from typing import Collection, Iterator, Dict, Sequence
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import Instance
6
+ from ck.probability.probability_space import dtype_for_state_indexes
7
+ from ck.program.program_buffer import ProgramBuffer
8
+ from ck.program.raw_program import RawProgram
9
+ from ck.sampling.sampler import Sampler
10
+ from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
11
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
12
+ from ck.utils.random_extras import Random
13
+
14
+
15
+ class MarginalsDirectSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ sampler_info: SamplerInfo,
20
+ raw_program: RawProgram,
21
+ rand: Random,
22
+ rv_idx_to_result_offset: Dict[int, int],
23
+ ):
24
+ super().__init__(sampler_info.rvs, sampler_info.condition)
25
+ self._yield_f: YieldF = sampler_info.yield_f
26
+ self._rand: Random = rand
27
+ self._program_buffer = ProgramBuffer(raw_program)
28
+ self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
29
+ self._chain_rvs: Sequence[SampleRV] = tuple(
30
+ sample_rv for sample_rv in sampler_info.sample_rvs if sample_rv.copy_index is not None)
31
+ self._state_dtype = dtype_for_state_indexes(self.rvs)
32
+ self._max_number_of_states: int = max((len(rv) for rv in self.rvs), default=0)
33
+ self._slots_1: Collection[int] = sampler_info.slots_1
34
+
35
+ self._marginals: Sequence[NDArrayNumeric] = tuple(
36
+ self._program_buffer.results[
37
+ rv_idx_to_result_offset[sample_rv.rv.idx]
38
+ :
39
+ rv_idx_to_result_offset[sample_rv.rv.idx] + len(sample_rv.rv)
40
+ ]
41
+ for sample_rv in sampler_info.sample_rvs
42
+ )
43
+ # Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
44
+ slots: NDArray = self._program_buffer.vars
45
+ for slot in sampler_info.slots_0:
46
+ slots[slot] = 0
47
+ for slot in sampler_info.slots_1:
48
+ slots[slot] = 1
49
+
50
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
51
+ yield_f = self._yield_f
52
+ rand = self._rand
53
+ sample_rvs = self._sample_rvs
54
+ chain_rvs = self._chain_rvs
55
+ program_buffer = self._program_buffer
56
+ slots: NDArray = program_buffer.vars
57
+ marginals = self._marginals
58
+ slots_1 = self._slots_1
59
+
60
+ # Set up working memory buffer
61
+ states = np.zeros(len(sample_rvs), dtype=self._state_dtype)
62
+
63
+ def compute() -> float:
64
+ # Compute the program results based on the current input slot values.
65
+ # Return the WMC.
66
+ return program_buffer.compute().item(-1)
67
+
68
+ while True:
69
+ wmc: float = compute()
70
+ rnd: float = rand.random() * wmc
71
+
72
+ for sample_rv in sample_rvs:
73
+ index: int = sample_rv.index
74
+ if index > 0:
75
+ # No need to execute the program on the first time through
76
+ # as it was done just before entering the loop.
77
+ wmc = compute()
78
+
79
+ rv_dist: NDArray = marginals[sample_rv.index]
80
+
81
+ rv_dist_sum: float = rv_dist.sum()
82
+ if rv_dist_sum <= 0:
83
+ raise RuntimeError('zero probability')
84
+ rv_dist *= wmc / rv_dist_sum
85
+
86
+ state_index: int = -1
87
+ for i in range(len(sample_rv.rv)):
88
+ w = rv_dist.item(i)
89
+ if rnd < w:
90
+ state_index = i
91
+ break
92
+ rnd -= w
93
+ assert state_index >= 0
94
+
95
+ for slot in sample_rv.slots:
96
+ slots[slot] = 0
97
+ slots[sample_rv.slots[state_index]] = 1
98
+ states[index] = state_index
99
+
100
+ yield yield_f(states)
101
+
102
+ # Reset the one slots for the next iteration.
103
+ for slot in slots_1:
104
+ slots[slot] = 1
105
+
106
+ # Copy chain pairs for next iteration.
107
+ # (This writes over any initial chain conditions from slots_1.)
108
+ for sample_rv in chain_rvs:
109
+ rv_slots = sample_rv.slots
110
+ prev_state_idx: int = states.item(sample_rv.copy_index)
111
+ for slot in rv_slots:
112
+ slots[slot] = 0
113
+ slots[rv_slots[prev_state_idx]] = 1
ck/sampling/sampler.py ADDED
@@ -0,0 +1,62 @@
1
+ from abc import ABC, abstractmethod
2
+ from itertools import islice
3
+ from typing import Sequence, Iterator
4
+
5
+ from ck.pgm import RandomVariable, Instance, Indicator
6
+
7
+
8
+ class Sampler(ABC):
9
+ """
10
+ A Sampler provides an unlimited series of samples for one or more random variables.
11
+ The random variables being sampled are provided as a tuple via the `rvs` property.
12
+
13
+ A Sampler will either iterate over Instance objects, where each instance is co-indexed
14
+ with `self.rvs`, or may iterate over single state indexes. Whether a Sampler iterates
15
+ over Instance objects or single state indexes is determined by the implementation.
16
+ If iterating over single state indexes, then `len(self.rvs) == 1`.
17
+ """
18
+ __slots__ = ('_rvs', '_condition')
19
+
20
+ def __init__(self, rvs: Sequence[RandomVariable], condition: Sequence[Indicator]):
21
+ """
22
+ Args:
23
+ rvs: a collection of the random variables being
24
+ sampled, co-indexed with each sample provided by `iter(self)`.
25
+ condition: condition on `rvs` that are compiled into the sampler.
26
+ """
27
+ self._rvs: Sequence[RandomVariable] = tuple(rvs)
28
+ self._condition: Sequence[Indicator] = tuple(condition)
29
+
30
+ @property
31
+ def rvs(self) -> Sequence[RandomVariable]:
32
+ """
33
+ What random variables are being sampled.
34
+
35
+ Returns:
36
+ the random variables being sampled, co-indexed with each sample from `iter(self)`.
37
+ """
38
+ return self._rvs
39
+
40
+ @property
41
+ def condition(self) -> Sequence[Indicator]:
42
+ """
43
+ Condition on `self.rvs` that are compiled into the sampler.
44
+ """
45
+ return self._condition
46
+
47
+ @abstractmethod
48
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
49
+ """
50
+ An unlimited series of samples from a random process.
51
+ Each sample is co-indexed with the random variables provided by `self.rvs`.
52
+ """
53
+ ...
54
+
55
+ def take(self, number_of_samples: int) -> Iterator[Instance] | Iterator[int]:
56
+ """
57
+ Take a limited number of samples from `iter(self)`.
58
+
59
+ Args:
60
+ number_of_samples: a limit on the number of samples to provide.
61
+ """
62
+ return islice(self, number_of_samples)
@@ -0,0 +1,232 @@
1
+ from dataclasses import dataclass
2
+ from itertools import count
3
+ from typing import Callable, Sequence, Optional, Set, Tuple, Dict, Collection
4
+
5
+ from ck.pgm import Instance, RandomVariable, Indicator
6
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
7
+ from ck.pgm_circuit.slot_map import SlotMap
8
+ from ck.probability.probability_space import Condition, check_condition
9
+ from ck.utils.map_set import MapSet
10
+ from ck.utils.np_extras import NDArrayStates, NDArrayNumeric
11
+ from ck.utils.random_extras import Random
12
+
13
+ # Type of a yield function. Support for a sampler.
14
+ # A yield function may be used to implement a sampler's iterator, thus
15
+ # it provides an Instance or single state index.
16
+ YieldF = Callable[[NDArrayStates], int] | Callable[[NDArrayStates], Instance]
17
+
18
+
19
+ @dataclass
20
+ class SampleRV:
21
+ """
22
+ Support for a sampler.
23
+ A SampleRV structure keeps track of information for one sampled random variable.
24
+ """
25
+ index: int # index into the sequence of sample rvs.
26
+ rv: RandomVariable # the random variable being sampled.
27
+ slots: Sequence[int] # program input slots for indicators of the random variable (co-indexed with rv.states).
28
+ copy_index: Optional[int] # for Markov chains, which previous sample rv should be copied?
29
+
30
+
31
+ @dataclass
32
+ class SamplerInfo:
33
+ """
34
+ Support for a sampler.
35
+ A SamplerInfo structure keeps track of standard information when a sampler uses a Program.
36
+ """
37
+ sample_rvs: Sequence[SampleRV]
38
+ condition: Sequence[Indicator]
39
+ yield_f: YieldF
40
+ slots_0: Set[int]
41
+ slots_1: Set[int]
42
+
43
+ @property
44
+ def rvs(self) -> Tuple[RandomVariable, ...]:
45
+ """
46
+ Extract the RandomVariable objects from `self.sample_rvs`.
47
+ """
48
+ return tuple(sample_rv.rv for sample_rv in self.sample_rvs)
49
+
50
+
51
+ def get_sampler_info(
52
+ program_with_slotmap: ProgramWithSlotmap,
53
+ rvs: Optional[RandomVariable | Sequence[RandomVariable]],
54
+ condition: Condition,
55
+ chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
56
+ initial_chain_condition: Condition = (),
57
+ ) -> SamplerInfo:
58
+ """
59
+ Helper for samplers.
60
+
61
+ Determines:
62
+ (1) the slots for sampling rvs,
63
+ (2) Markov chaining rvs,
64
+ (3) the function to use for yielding an Instance or state index.
65
+
66
+ If parameter `rvs` is a RandomVariable, then the yield function will
67
+ provide a state index. If parameter `rvs` is a Sequence, then the
68
+ yield function will provide an Instance.
69
+
70
+ Args:
71
+ program_with_slotmap: the program and slotmap being referenced.
72
+ rvs: the random variables to sample. It may be either a sequence of
73
+ random variables, or a single random variable.
74
+ condition: is a collection of zero or more conditioning indicators.
75
+ chain_pairs: is a collection of pairs of random variables, each random variable
76
+ must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
77
+ as a condition for to_rv prior to generating a sample.
78
+ initial_chain_condition: are condition indicators (just like condition)
79
+ for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
80
+
81
+ Raises:
82
+ ValueError: if preconditions of `program_with_slotmap` are incompatible with the given condition.
83
+
84
+ Returns:
85
+ a SamplerInfo structure.
86
+ """
87
+ if rvs is None:
88
+ rvs = program_with_slotmap.rvs
89
+ if isinstance(rvs, RandomVariable):
90
+ # a single rv
91
+ rvs = (rvs,)
92
+ yield_f = lambda x: x.item()
93
+ else:
94
+ # a sequence of rvs
95
+ rvs = tuple(rvs)
96
+ yield_f = lambda x: x.tolist()
97
+
98
+ # Group condition indicators by `rv_idx`.
99
+ conditioned_rvs: MapSet[int, Indicator] = MapSet()
100
+ for ind in check_condition(condition):
101
+ conditioned_rvs.add(ind.rv_idx, ind)
102
+ del condition
103
+
104
+ # Group precondition indicators by `rv_idx`.
105
+ preconditioned_rvs: MapSet[int, Indicator] = MapSet()
106
+ for ind in program_with_slotmap.precondition:
107
+ preconditioned_rvs.add(ind.rv_idx, ind)
108
+
109
+ # Rationalise conditioned_rvs with preconditioned_rvs
110
+ rv_idx: int
111
+ precondition_set: Set[Indicator]
112
+ for rv_idx, precondition_set in preconditioned_rvs.items():
113
+ condition_set = conditioned_rvs.get(rv_idx)
114
+ if condition_set is None:
115
+ # A preconditioned rv was not mentioned in the explicit conditions
116
+ conditioned_rvs.add_all(rv_idx, precondition_set)
117
+ else:
118
+ # A preconditioned rv was also mentioned in the explicit conditions
119
+ condition_set.intersection_update(precondition_set)
120
+ if len(condition_set) == 0:
121
+ rv_index: Dict[int, RandomVariable] = {rv.idx: rv for rv in rvs}
122
+ rv: RandomVariable = rv_index[rv_idx]
123
+ raise ValueError(f'conditions on rv {rv} are disjoint from preconditions')
124
+ del preconditioned_rvs
125
+
126
+ # Group initial chain indicators by `rv_idx`.
127
+ initial_chain_condition: Sequence[Indicator] = check_condition(initial_chain_condition)
128
+ initial_chain_conditioned_rvs: MapSet[int, Indicator] = MapSet()
129
+ for ind in initial_chain_condition:
130
+ initial_chain_conditioned_rvs.add(ind.rv_idx, ind)
131
+
132
+ # Check sample rvs are valid and without duplicates.
133
+ rvs_set: Set[RandomVariable] = set(rvs)
134
+ if not rvs_set.issubset(program_with_slotmap.rvs):
135
+ raise ValueError('sample random variables not available')
136
+ if len(rvs) != len(rvs_set):
137
+ raise ValueError('duplicate sample random variables requested')
138
+
139
+ # Check chain_pairs rvs are being sampled
140
+ if not rvs_set.issuperset(pair[0] for pair in chain_pairs):
141
+ raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
142
+ if not rvs_set.issuperset(pair[1] for pair in chain_pairs):
143
+ raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
144
+
145
+ # Check chain_pairs source and destination rvs are disjoint
146
+ if not {pair[0] for pair in chain_pairs}.isdisjoint(pair[1] for pair in chain_pairs):
147
+ raise ValueError('chain_pairs sources and destinations are not disjoint')
148
+
149
+ # Check no chain_pairs destination rv is a conditioned rv
150
+ if any(pair[1].idx in conditioned_rvs.keys() for pair in chain_pairs):
151
+ raise ValueError('a chain_pairs destination is conditioned')
152
+
153
+ # Check chain initial conditions relate to chain_pairs destination rvs
154
+ chain_dest_rv_idxs: Set[int] = {pair[1].idx for pair in chain_pairs}
155
+ if not all(rv_idx in chain_dest_rv_idxs for rv_idx in initial_chain_conditioned_rvs.keys()):
156
+ raise ValueError('a chain initial condition is not a chain destination rv')
157
+
158
+ # Convert chain_pairs for registering with `sample_rvs`.
159
+ # rv_idx maps RandomVariable id to a position it exists in rvs (doesn't matter if rv is duplicated in rvs)
160
+ # copy_idx RandomVariable id to a position in rvs that it can be copied from for Markov chaining.
161
+ rv_idx: Dict[int, int] = {id(rv): i for i, rv in enumerate(rvs)}
162
+ copy_idx: Dict[int, int] = {id(rv): rv_idx[id(prev_rv)] for prev_rv, rv in chain_pairs}
163
+
164
+ # Get rv state slots, rvs_slots is co-indexed with rvs
165
+ slot_map: SlotMap = program_with_slotmap.slot_map
166
+ rvs_slots = tuple(tuple(slot_map[ind] for ind in rv) for rv in rvs)
167
+
168
+ sample_rvs: Sequence[SampleRV] = tuple(
169
+ SampleRV(idx, rv, rv_slots, copy_idx.get(id(rv)))
170
+ for idx, rv, rv_slots in zip(count(), rvs, rvs_slots)
171
+ )
172
+
173
+ # Process the condition to get zero and one slots
174
+ slots_0: Set[int] = set()
175
+ slots_1: Set[int] = set()
176
+ for rv in program_with_slotmap.rvs:
177
+ conditioning: Optional[Set[Indicator]] = conditioned_rvs.get(rv.idx)
178
+ if conditioning is not None:
179
+ slots_1.update(slot_map[ind] for ind in conditioning)
180
+ slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
181
+ continue
182
+
183
+ conditioning: Optional[Set[Indicator]] = initial_chain_conditioned_rvs.get(rv.idx)
184
+ if conditioning is not None:
185
+ slots_1.update(slot_map[ind] for ind in conditioning)
186
+ slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
187
+ continue
188
+
189
+ # default
190
+ slots_1.update(slot_map[ind] for ind in rv)
191
+
192
+ return SamplerInfo(
193
+ sample_rvs=sample_rvs,
194
+ condition=tuple(ind for condition_set in conditioned_rvs.values() for ind in condition_set),
195
+ yield_f=yield_f,
196
+ slots_0=slots_0,
197
+ slots_1=slots_1,
198
+ )
199
+
200
+
201
+ def uniform_random_sample(
202
+ sample_rvs: Sequence[SampleRV],
203
+ slots_0: Collection[int],
204
+ slots_1: Collection[int],
205
+ slots: NDArrayNumeric,
206
+ state: NDArrayStates,
207
+ rand: Random,
208
+ ):
209
+ """
210
+ Helper for samplers.
211
+
212
+ Sets the states to a random instance and configures slots to match.
213
+ States are drawn from a uniform distribution, drawn using random.randrange.
214
+ """
215
+
216
+ # Set up the input slots to respect conditioning
217
+ for slot in slots_0:
218
+ slots[slot] = 0
219
+ for slot in slots_1:
220
+ slots[slot] = 1
221
+
222
+ for sample_rv in sample_rvs:
223
+ candidates = []
224
+ for slot_state, slot in enumerate(sample_rv.slots):
225
+ if slots[slot] == 1:
226
+ slots[slot] = 0
227
+ candidates.append((slot_state, slot))
228
+
229
+ # Pick a random state for sample_rv
230
+ slot_state, slot = candidates[rand.randrange(0, len(candidates))]
231
+ state[sample_rv.index] = slot_state
232
+ slots[slot] = 1
@@ -0,0 +1,72 @@
1
+ import random
2
+ from typing import Set, List, Iterator, Optional, Sequence
3
+
4
+ import numpy as np
5
+
6
+ from ck.pgm import Instance, RandomVariable, Indicator
7
+ from ck.probability.probability_space import dtype_for_state_indexes, Condition, check_condition
8
+ from ck.utils.map_set import MapSet
9
+ from ck.utils.np_extras import DType
10
+ from ck.utils.random_extras import Random
11
+ from .sampler import Sampler
12
+ from .sampler_support import YieldF
13
+
14
+
15
+ class UniformSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ rvs: RandomVariable | Sequence[RandomVariable],
20
+ condition: Condition = (),
21
+ rand: Random = random,
22
+ ):
23
+ condition: Sequence[Indicator] = check_condition(condition)
24
+
25
+ self._yield_f: YieldF
26
+ if isinstance(rvs, RandomVariable):
27
+ # a single rv
28
+ rvs = (rvs,)
29
+ self._yield_f = lambda x: x.item()
30
+ else:
31
+ # a sequence of rvs
32
+ self._yield_f = lambda x: x.tolist()
33
+
34
+ super().__init__(rvs, condition)
35
+
36
+ # Group condition indicators by `rv_idx`.
37
+ conditioned_rvs: MapSet[int, int] = MapSet()
38
+ for ind in condition:
39
+ conditioned_rvs.add(ind.rv_idx, ind.state_idx)
40
+
41
+ def get_possible_states(_rv: RandomVariable) -> List[int]:
42
+ """
43
+ Get the allowable states for a given random variable, given
44
+ conditions in `conditioned_rvs`.
45
+ """
46
+ condition_states: Optional[Set[int]] = conditioned_rvs.get(_rv.idx)
47
+ if condition_states is None:
48
+ return list(range(len(_rv)))
49
+ else:
50
+ return list(condition_states)
51
+
52
+ possible_states: List[List[int]] = [
53
+ get_possible_states(rv)
54
+ for rv in self.rvs
55
+ ]
56
+
57
+ self._possible_states: List[List[int]] = possible_states
58
+ self._rand: Random = rand
59
+ self._state_dtype: DType = dtype_for_state_indexes(self.rvs)
60
+
61
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
62
+ possible_states = self._possible_states
63
+ yield_f = self._yield_f
64
+ rand = self._rand
65
+ state = np.zeros(len(possible_states), dtype=self._state_dtype)
66
+ while True:
67
+ for i, l in enumerate(possible_states):
68
+ state_idx = rand.randrange(0, len(l))
69
+ state[i] = l[state_idx]
70
+ # We know the yield function will always provide either ints or Instances
71
+ # noinspection PyTypeChecker
72
+ yield yield_f(state)
@@ -0,0 +1,171 @@
1
+ from typing import Collection, Iterator, Sequence
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import Instance
6
+ from ck.probability.probability_space import dtype_for_state_indexes
7
+ from ck.program.program_buffer import ProgramBuffer
8
+ from ck.program.raw_program import RawProgram
9
+ from ck.sampling.sampler import Sampler
10
+ from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
11
+ from ck.utils.np_extras import NDArrayNumeric, NDArrayStates
12
+ from ck.utils.random_extras import Random
13
+
14
+
15
+ class WMCDirectSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ sampler_info: SamplerInfo,
20
+ raw_program: RawProgram,
21
+ rand: Random,
22
+ ):
23
+ super().__init__(sampler_info.rvs, sampler_info.condition)
24
+ self._yield_f: YieldF = sampler_info.yield_f
25
+ self._rand: Random = rand
26
+ self._program_buffer = ProgramBuffer(raw_program)
27
+ self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
28
+ self._chain_rvs: Sequence[SampleRV] = tuple(
29
+ sample_rv for sample_rv in sampler_info.sample_rvs if sample_rv.copy_index is not None)
30
+ self._state_dtype = dtype_for_state_indexes(self.rvs)
31
+ self._max_number_of_states: int = max((len(rv) for rv in self.rvs), default=0)
32
+ self._slots_1: Collection[int] = sampler_info.slots_1
33
+
34
+ # Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
35
+ slots: NDArrayNumeric = self._program_buffer.vars
36
+ for slot in sampler_info.slots_0:
37
+ slots[slot] = 0
38
+ for slot in sampler_info.slots_1:
39
+ slots[slot] = 1
40
+
41
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
42
+ yield_f = self._yield_f
43
+ rand = self._rand
44
+ sample_rvs = self._sample_rvs
45
+ chain_rvs = self._chain_rvs
46
+ slots_1 = self._slots_1
47
+ program_buffer = self._program_buffer
48
+ slots: NDArrayNumeric = program_buffer.vars
49
+
50
+ # Calling wmc() will give the weighted model count for the state of the current input slots.
51
+ def wmc() -> float:
52
+ return program_buffer.compute().item()
53
+
54
+ # Set up working memory buffers
55
+ states: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
56
+ buff_slots = np.zeros(self._max_number_of_states, dtype=np.uintp)
57
+ buff_states = np.zeros(self._max_number_of_states, dtype=self._state_dtype)
58
+
59
+ while True:
60
+ # Consider all possible instantiations given the conditions, c, where the instantiations are ordered.
61
+ # Let awmc(i|c) be the accumulated WMC of the ith instantiation.
62
+ # We want to find the smallest instantiation i such that
63
+ # rnd <= awmc(i|c)
64
+ # where rnd is in [0, 1) * wmc().
65
+
66
+ rnd: float = rand.random() * wmc()
67
+
68
+ for sample_rv in sample_rvs:
69
+ # Prepare to loop over random variable states.
70
+ # Keep track of the non-zero slots in buff_slots and buff_states.
71
+ num_possible_states: int = 0
72
+ for j, slot in enumerate(sample_rv.slots):
73
+ if slots[slot] != 0:
74
+ buff_slots[num_possible_states] = slot
75
+ buff_states[num_possible_states] = j
76
+ num_possible_states += 1
77
+
78
+ if num_possible_states == 0:
79
+ raise RuntimeError('zero probability')
80
+
81
+ # Try each possible state of the current random variable.
82
+ # Once a state is selected, then the following is true:
83
+ # states[rv_position] = state
84
+ # m_prev_states[rv_position] = state
85
+ # slots set up to include condition rv = state.
86
+ # rnd is reduced to account for the states skipped.
87
+ #
88
+ # We can do this either by sequentially checking each state or by doing
89
+ # a binary search. Here we start with binary search then finish sequentially
90
+ # once the candidates size falls below 'THRESHOLD'.
91
+
92
+ # Binary search
93
+ THRESHOLD = 2
94
+ lo: int = 0
95
+ hi: int = num_possible_states
96
+ w_0_mark: int = 0
97
+ w: float = 0
98
+ while lo + THRESHOLD < hi:
99
+ mid: int = (lo + hi) // 2
100
+
101
+ for i in range(mid, hi):
102
+ slots[buff_slots[i]] = 0
103
+
104
+ w = wmc()
105
+ w_0_mark = mid
106
+ if w < rnd:
107
+ # wmc() is too low, the desired state is >= buff_states[mid]
108
+ for i in range(mid, hi):
109
+ slots[buff_slots[i]] = 1
110
+ lo = mid
111
+ else:
112
+ # wmc() is too high, the desired state is < buff_states[mid]
113
+ hi = mid
114
+
115
+ # Now the state we want is between lo (inclusive) and hi (exclusive).
116
+ # Slots at least up to lo will be set to 1.
117
+
118
+ # clear top slots, lo and up.
119
+ for k in range(lo, num_possible_states):
120
+ slots[buff_slots[k]] = 0
121
+
122
+ # Adjust rnd to account for lo > 0.
123
+ if lo == 0:
124
+ # The chances of this case may be low, but if so, then
125
+ # slots[m_buff_slots[lo]] = 0 which implies wmc() == 0,
126
+ # so we can save a call to wmc().
127
+ pass
128
+ elif w_0_mark == lo:
129
+ # We can use the last wmc() call, stored in w.
130
+ # This saves a call to wmc().
131
+ rnd -= w
132
+ else:
133
+ rnd -= wmc()
134
+
135
+ # Clear remaining slots
136
+ for k in range(0, lo):
137
+ slots[buff_slots[k]] = 0
138
+
139
+ # Sequential search
140
+ k = lo
141
+ while k < hi:
142
+ slot = buff_slots[k]
143
+ slots[slot] = 1
144
+ w = wmc()
145
+ if rnd < w:
146
+ break
147
+ slots[slot] = 0
148
+ rnd -= w
149
+ k += 1
150
+
151
+ slot = buff_slots[k]
152
+ state = buff_states[k]
153
+ slots[slot] = 1
154
+ states[sample_rv.index] = state
155
+
156
+ # We know the yield function will always provide either ints or Instances
157
+ # noinspection PyTypeChecker
158
+ yield yield_f(states)
159
+
160
+ # Reset the one slots for the next iteration.
161
+ for slot in slots_1:
162
+ slots[slot] = 1
163
+
164
+ # Copy chain pairs for next iteration.
165
+ # (This writes over any initial chain conditions from slots_1.)
166
+ for sample_rv in chain_rvs:
167
+ rv_slots = sample_rv.slots
168
+ prev_state_idx: int = states.item(sample_rv.copy_index)
169
+ for slot in rv_slots:
170
+ slots[slot] = 0
171
+ slots[rv_slots[prev_state_idx]] = 1