compiled-knowledge 4.0.0a20__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,153 @@
1
+ from typing import Collection, Iterator, Sequence, List
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import Instance
6
+ from ck.probability.probability_space import dtype_for_state_indexes
7
+ from ck.program.program_buffer import ProgramBuffer
8
+ from ck.program.raw_program import RawProgram
9
+ from ck.sampling.sampler import Sampler
10
+ from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
11
+ from ck.utils.np_extras import NDArrayStates, NDArrayFloat64
12
+ from ck.utils.random_extras import Random, random_permute
13
+
14
+
15
+ class WMCGibbsSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ sampler_info: SamplerInfo,
20
+ raw_program: RawProgram,
21
+ rand: Random,
22
+ skip: int,
23
+ burn_in: int,
24
+ pr_restart: float,
25
+ ):
26
+ super().__init__(sampler_info.rvs, sampler_info.condition)
27
+ self._yield_f: YieldF = sampler_info.yield_f
28
+ self._rand: Random = rand
29
+ self._program_buffer = ProgramBuffer(raw_program)
30
+ self._sample_rvs: List[SampleRV] = list(sampler_info.sample_rvs)
31
+ self._state_dtype = dtype_for_state_indexes(self.rvs)
32
+ self._slots_0: Collection[int] = sampler_info.slots_0
33
+ self._slots_1: Collection[int] = sampler_info.slots_1
34
+ self._skip: int = skip
35
+ self._burn_in: int = burn_in
36
+ self._pr_restart: float = pr_restart
37
+
38
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
39
+ sample_rvs: List[SampleRV] = self._sample_rvs
40
+ rand: Random = self._rand
41
+ yield_f: YieldF = self._yield_f
42
+ slots_0: Collection[int] = self._slots_0
43
+ slots_1: Collection[int] = self._slots_1
44
+ program_buffer: ProgramBuffer = self._program_buffer
45
+ skip: int = self._skip
46
+ burn_in: int = self._burn_in
47
+ pr_restart: float = self._pr_restart
48
+
49
+ # Allocate working memory
50
+ state = np.zeros(len(sample_rvs), dtype=self._state_dtype)
51
+ prs: Sequence[NDArrayFloat64] = tuple(
52
+ np.zeros(len(sample_rv.slots), dtype=np.float64)
53
+ for sample_rv in sample_rvs
54
+ )
55
+
56
+ # Set an initial system state
57
+ uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
58
+
59
+ # Run a burn in
60
+ for i in range(burn_in):
61
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
62
+
63
+ if pr_restart <= 0:
64
+ # There is no possibility of a restart
65
+ if skip == 0:
66
+ while True:
67
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
68
+ # We know the yield function will always provide either ints or Instances
69
+ # noinspection PyTypeChecker
70
+ yield yield_f(state)
71
+ else:
72
+ while True:
73
+ for _ in range(skip):
74
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
75
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
76
+ # We know the yield function will always provide either ints or Instances
77
+ # noinspection PyTypeChecker
78
+ yield yield_f(state)
79
+
80
+ else:
81
+ # There is the possibility of a restart
82
+ while True:
83
+ for _ in range(skip):
84
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
85
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
86
+ # We know the yield function will always provide either ints or Instances
87
+ # noinspection PyTypeChecker
88
+ yield yield_f(state)
89
+ if rand.random() < pr_restart:
90
+ # Set an initial system state
91
+ uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
92
+
93
+ # Run a burn in
94
+ for i in range(burn_in):
95
+ self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
96
+
97
+ @staticmethod
98
+ def _next_sample_gibbs(
99
+ sample_rvs: List[SampleRV],
100
+ slots_1: Collection[int],
101
+ program_buffer: ProgramBuffer,
102
+ prs: Sequence[NDArrayFloat64],
103
+ state: NDArrayStates,
104
+ rand: Random
105
+ ) -> None:
106
+ """
107
+ Updates the states to a random system and reconfigures program inputs to match.
108
+ """
109
+ prog_in = program_buffer.vars
110
+ random_permute(sample_rvs, rand=rand)
111
+ for sample_rv in sample_rvs:
112
+ rv_slots = sample_rv.slots
113
+ index = sample_rv.index
114
+
115
+ rv_pr: NDArrayFloat64 = prs[index]
116
+ s: int = state.item(index)
117
+
118
+ candidates = []
119
+ for slot_state, slot in enumerate(rv_slots):
120
+ if slot in slots_1:
121
+ candidates.append((slot_state, slot))
122
+ assert len(candidates) > 0
123
+
124
+ # Compute conditioned marginals for the current rv
125
+ prog_in[rv_slots[s]] = 0
126
+ for slot_state, slot in candidates:
127
+ prog_in[slot] = 1
128
+ rv_pr[slot_state] = program_buffer.compute()
129
+ prog_in[slot] = 0
130
+
131
+ # Pick a new state based on the conditional probabilities
132
+ total = np.sum(rv_pr)
133
+ if total == 0.0:
134
+ # No state of the current rv has a non-zero probability when
135
+ # conditioned on the other random variables states.
136
+ # Pick a random state form a uniform distribution.
137
+ i = rand.randrange(0, len(candidates))
138
+ candidate = candidates[i]
139
+ # update the states array and the wmc input
140
+ state[index] = candidate[0]
141
+ prog_in[candidate[1]] = 1
142
+ else:
143
+ # Pick a state, sampled from the marginal distribution
144
+ r = rand.random() * total
145
+ slot = None
146
+ slot_state = None
147
+ for slot_state, slot in candidates:
148
+ if r <= rv_pr[slot_state]:
149
+ break
150
+ r -= rv_pr[slot_state]
151
+ # update the states array and the wmc input
152
+ state[index] = slot_state
153
+ prog_in[slot] = 1
@@ -0,0 +1,165 @@
1
+ from typing import Collection, Iterator, Sequence
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import Instance
6
+ from ck.probability.probability_space import dtype_for_state_indexes
7
+ from ck.program.program_buffer import ProgramBuffer
8
+ from ck.program.raw_program import RawProgram
9
+ from ck.sampling.sampler import Sampler
10
+ from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
11
+ from ck.utils.np_extras import NDArrayStates, DTypeStates
12
+ from ck.utils.random_extras import Random
13
+
14
+
15
+ class WMCMetropolisSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ sampler_info: SamplerInfo,
20
+ raw_program: RawProgram,
21
+ rand: Random,
22
+ skip: int,
23
+ burn_in: int,
24
+ pr_restart: float,
25
+ ):
26
+ super().__init__(sampler_info.rvs, sampler_info.condition)
27
+ self._yield_f: YieldF = sampler_info.yield_f
28
+ self._rand: Random = rand
29
+ self._program_buffer = ProgramBuffer(raw_program)
30
+ self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
31
+ self._state_dtype: DTypeStates = dtype_for_state_indexes(self.rvs)
32
+ self._slots_0: Collection[int] = sampler_info.slots_0
33
+ self._slots_1: Collection[int] = sampler_info.slots_1
34
+ self._skip: int = skip
35
+ self._burn_in: int = burn_in
36
+ self._pr_restart: float = pr_restart
37
+
38
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
39
+ sample_rvs = self._sample_rvs
40
+ rand = self._rand
41
+ yield_f = self._yield_f
42
+ slots_0 = self._slots_0
43
+ slots_1 = self._slots_1
44
+ program_buffer = self._program_buffer
45
+ slots = program_buffer.vars
46
+ skip = self._skip
47
+ burn_in = self._burn_in
48
+ pr_restart = self._pr_restart
49
+
50
+ # Allocate working memory
51
+ state: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
52
+
53
+ # set up the input slots to respect conditioning
54
+ for slot in slots_0:
55
+ slots[slot] = 0
56
+ for slot in slots_1:
57
+ slots[slot] = 1
58
+
59
+ # Convert sample slots to possibles
60
+ # And map slots to states.
61
+ possibles = []
62
+ for sample_rv in sample_rvs:
63
+ rv_possibles = []
64
+ for slot_state, slot in enumerate(sample_rv.slots):
65
+ if slots[slot] == 1:
66
+ rv_possibles.append((slot_state, slot))
67
+ possibles.append((sample_rv.index, sample_rv.slots, rv_possibles))
68
+
69
+ # Set an initial valid system state
70
+ w: float = self._init_sample_metropolis(state)
71
+
72
+ # Run a burn in
73
+ for i in range(burn_in):
74
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
75
+
76
+ if pr_restart <= 0:
77
+ # There is no possibility of a restart
78
+ if skip == 0:
79
+ while True:
80
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
81
+ # We know the yield function will always provide either ints or Instances
82
+ # noinspection PyTypeChecker
83
+ yield yield_f(state)
84
+ else:
85
+ while True:
86
+ for _ in range(skip):
87
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
88
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
89
+ # We know the yield function will always provide either ints or Instances
90
+ # noinspection PyTypeChecker
91
+ yield yield_f(state)
92
+
93
+ else:
94
+ # There is the possibility of a restart
95
+ while True:
96
+ for _ in range(skip):
97
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
98
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
99
+ # We know the yield function will always provide either ints or Instances
100
+ # noinspection PyTypeChecker
101
+ yield yield_f(state)
102
+
103
+ if rand.random() < pr_restart:
104
+ # Set an initial valid system state
105
+ w = self._init_sample_metropolis(state)
106
+ # Run a burn in
107
+ for i in range(burn_in):
108
+ w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
109
+
110
+ def _init_sample_metropolis(self, state: NDArrayStates) -> float:
111
+ """
112
+ Initialises the states to a valid random system and configures program inputs to match.
113
+ """
114
+ sample_rvs = self._sample_rvs
115
+ rand = self._rand
116
+ slots_0 = self._slots_0
117
+ slots_1 = self._slots_1
118
+ program_buffer = self._program_buffer
119
+ slots = program_buffer.vars
120
+
121
+ while True:
122
+ uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
123
+ w: float = program_buffer.compute().item()
124
+ if w >= 0:
125
+ return w
126
+
127
+ @staticmethod
128
+ def _next_sample_metropolis(
129
+ possibles,
130
+ program_buffer: ProgramBuffer,
131
+ state,
132
+ cur_w: float,
133
+ rand: Random,
134
+ ) -> float:
135
+ """
136
+ Updates the states to a random system and reconfigures program inputs to match.
137
+ """
138
+ prog_in = program_buffer.vars
139
+
140
+ # Generate a proposal.
141
+ # randomly choose a random variable
142
+ i = rand.randrange(0, len(possibles))
143
+ idx, rv_slots, rv_possibles = possibles[i]
144
+ # keep track of the current state slot
145
+ cur_s = state[idx]
146
+ cur_s_slot = rv_slots[cur_s]
147
+ # randomly choose a possible state
148
+ i = rand.randrange(0, len(rv_possibles))
149
+ s, s_slot = rv_possibles[i]
150
+
151
+ # set up state and program to compute weight
152
+ prog_in[cur_s_slot] = 0
153
+ prog_in[s_slot] = 1
154
+
155
+ # calculate the weight and test it
156
+ new_w: float = program_buffer.compute().item()
157
+ if rand.random() * cur_w < new_w:
158
+ # accept
159
+ state[idx] = s
160
+ return new_w
161
+ else:
162
+ # reject: set state and program to what it was before
163
+ prog_in[s_slot] = 0
164
+ prog_in[cur_s_slot] = 1
165
+ return cur_w
@@ -0,0 +1,115 @@
1
+ from typing import Collection, Iterator, Sequence
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import Instance
6
+ from ck.probability.probability_space import dtype_for_state_indexes
7
+ from ck.program.program_buffer import ProgramBuffer
8
+ from ck.program.raw_program import RawProgram
9
+ from ck.sampling.sampler import Sampler
10
+ from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
11
+ from ck.utils.np_extras import NDArrayNumeric
12
+ from ck.utils.random_extras import Random
13
+
14
+
15
+ class WMCRejectionSampler(Sampler):
16
+
17
+ def __init__(
18
+ self,
19
+ sampler_info: SamplerInfo,
20
+ raw_program: RawProgram,
21
+ rand: Random,
22
+ z: float,
23
+ ):
24
+ super().__init__(sampler_info.rvs, sampler_info.condition)
25
+ self._yield_f: YieldF = sampler_info.yield_f
26
+ self._rand: Random = rand
27
+ self._program_buffer = ProgramBuffer(raw_program)
28
+ self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
29
+ self._state_dtype = dtype_for_state_indexes(self.rvs)
30
+ self._slots_0: Collection[int] = sampler_info.slots_0
31
+ self._slots_1: Collection[int] = sampler_info.slots_1
32
+
33
+ # Initialise fields for tracking max_w
34
+ self._w_max = None # estimated maximum weight for any one world
35
+ self._w_not_seen = z # z - w_seen
36
+ self._w_high = 0.0 # highest instance wight seen so far
37
+ self._samples = set() # what samples have we seen
38
+
39
+ def __iter__(self) -> Iterator[Instance] | Iterator[int]:
40
+ sample_rvs = self._sample_rvs
41
+ rand = self._rand
42
+ yield_f = self._yield_f
43
+ slots_0 = self._slots_0
44
+ slots_1 = self._slots_1
45
+ program_buffer = self._program_buffer
46
+ slots: NDArrayNumeric = program_buffer.vars
47
+
48
+ # Calling wmc() will give the weighted model count for the state of the current input slots.
49
+ def wmc() -> float:
50
+ return program_buffer.compute().item()
51
+
52
+ # Allocate working memory to store a possible world
53
+ state: NDArrayNumeric = np.zeros(len(sample_rvs), dtype=self._state_dtype)
54
+
55
+ # Initialise w_max to w_max_marginal, if not done yet.
56
+ if self._w_max is None:
57
+ w_max_marginal = self._w_not_seen # initially set to z, so a 'large' weight
58
+
59
+ # Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
60
+ for slot in slots_0:
61
+ slots[slot] = 0
62
+ for slot in slots_1:
63
+ slots[slot] = 1
64
+
65
+ # Loop over the rvs
66
+ for sample_rv in sample_rvs:
67
+ rv_slots = sample_rv.slots
68
+ max_for_rv = 0
69
+ # Set all rv slots to 0
70
+ for slot_state, slot in enumerate(rv_slots):
71
+ slots[slot] = 0
72
+ back_to_one = []
73
+ # Loop over state of the rv.
74
+ for slot_state, slot in enumerate(rv_slots):
75
+ if slot in slots_1:
76
+ slots[slot] = 1
77
+ w: float = wmc()
78
+ max_for_rv = max(max_for_rv, w)
79
+ slots[slot] = 0
80
+ back_to_one.append(slot)
81
+ # Set rv slots back to 1 as needed (ready for next rv).
82
+ for slot in back_to_one:
83
+ slots[slot] = 1
84
+
85
+ w_max_marginal = min(w_max_marginal, max_for_rv)
86
+
87
+ self._w_max = w_max_marginal
88
+
89
+ while True:
90
+ uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
91
+ w: float = wmc()
92
+
93
+ if rand.random() * self._w_max < w:
94
+ # We know the yield function will always provide either ints or Instances
95
+ # noinspection PyTypeChecker
96
+ yield yield_f(state)
97
+
98
+ # Update w_not_seen and w_high to adapt w_max.
99
+ # We don't bother tracking seen samples once w_not_seen and w_high
100
+ # are close enough, or we have tracked too many samples.
101
+ if self._samples is not None:
102
+ s = tuple(state)
103
+ if s not in self._samples:
104
+ self._samples.add(s)
105
+ self._w_not_seen -= w
106
+ self._w_high = max(self._w_high, w)
107
+ w_max_tracked = max(self._w_high, self._w_not_seen)
108
+ self._w_max = min(w_max_tracked, self._w_max)
109
+
110
+ # See if we should stop tracking samples.
111
+ if (
112
+ self._w_not_seen - self._w_high < 0.001 # w_not_seen and w_high are close enough
113
+ or len(self._samples) > 1000000 # tracked too many samples
114
+ ):
115
+ self._samples = None
ck/utils/__init__.py ADDED
File without changes
@@ -0,0 +1,163 @@
1
+ """
2
+ A module with extra iteration functions.
3
+ """
4
+ from functools import reduce as _reduce
5
+ from itertools import combinations, chain, islice
6
+ from operator import mul as _mul
7
+ from typing import Iterable, Tuple, Sequence, TypeVar
8
+
9
+ _T = TypeVar('_T')
10
+
11
+
12
+ def flatten(iterables: Iterable[Iterable[_T]]) -> Iterable[_T]:
13
+ """
14
+ Iterate over the elements of an iterable of iterables.
15
+ """
16
+ return (elem for iterable in iterables for elem in iterable)
17
+
18
+
19
+ def deep_flatten(iterables: Iterable) -> Iterable:
20
+ """
21
+ Iterate over the flattening of nested iterables.
22
+ """
23
+ for el in iterables:
24
+ if isinstance(el, Iterable) and not isinstance(el, str):
25
+ for sub in deep_flatten(el):
26
+ yield sub
27
+ else:
28
+ yield el
29
+
30
+
31
+ def combos(list_of_lists: Sequence[Sequence[_T]], flip=False) -> Iterable[Tuple[_T, ...]]:
32
+ """
33
+ Iterate over all combinations of taking one element from each of the lists.
34
+
35
+ The order of results has the first element changing most rapidly.
36
+ For example, given [[1,2,3],[4,5],[6,7]], combos yields the following:
37
+ (1,4,6), (2,4,6), (3,4,6), (1,5,6), (2,5,6), (3,5,6),
38
+ (1,4,7), (2,4,7), (3,4,7), (1,5,7), (2,5,7), (3,5,7).
39
+
40
+ If flip, then the last changes most rapidly.
41
+ """
42
+ num = len(list_of_lists)
43
+ if num == 0:
44
+ yield ()
45
+ return
46
+ rng = range(num)
47
+ indexes = [0] * num
48
+ if flip:
49
+ start = num - 1
50
+ inc = -1
51
+ end = -1
52
+ else:
53
+ start = 0
54
+ inc = 1
55
+ end = num
56
+ while True:
57
+ yield tuple(list_of_lists[i][indexes[i]] for i in rng)
58
+ i = start
59
+ while True:
60
+ indexes[i] += 1
61
+ if indexes[i] < len(list_of_lists[i]):
62
+ break
63
+ indexes[i] = 0
64
+ i += inc
65
+ if i == end:
66
+ return
67
+
68
+
69
+ def combos_ranges(list_of_lens: Sequence[int], flip=False) -> Iterable[Tuple[int, ...]]:
70
+ """
71
+ Equivalent to combos([range(l) for l in list_of_lens], flip).
72
+
73
+ The order of results has the first element changing most rapidly.
74
+ If flip, then the last changes most rapidly.
75
+ """
76
+ num = len(list_of_lens)
77
+ if num == 0:
78
+ yield ()
79
+ return
80
+ indexes = [0] * num
81
+ if flip:
82
+ start = num - 1
83
+ inc = -1
84
+ end = -1
85
+ else:
86
+ start = 0
87
+ inc = 1
88
+ end = num
89
+ while True:
90
+ yield tuple(indexes)
91
+ i = start
92
+ while True:
93
+ indexes[i] += 1
94
+ if indexes[i] < list_of_lens[i]:
95
+ break
96
+ indexes[i] = 0
97
+ i += inc
98
+ if i == end:
99
+ return
100
+
101
+
102
+ def pairs(elements: Iterable[_T]) -> Iterable[Tuple[_T, _T]]:
103
+ """
104
+ Iterate over all possible pairs in the given list of elements.
105
+ """
106
+ return combinations(elements, 2)
107
+
108
+
109
+ def sequential_pairs(elements: Sequence[_T]) -> Iterable[Tuple[_T, _T]]:
110
+ """
111
+ Iterate over sequential pairs in the given list of elements.
112
+ """
113
+ for i in range(len(elements) - 1):
114
+ yield elements[i], elements[i + 1]
115
+
116
+
117
+ def powerset(iterable: Iterable[_T], min_size: int = 0, max_size: int = None) -> Iterable[Tuple[_T, ...]]:
118
+ """
119
+ powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
120
+ """
121
+ if not isinstance(iterable, (list, tuple)):
122
+ iterable = list(iterable)
123
+ if min_size is None:
124
+ min_size = 0
125
+ if max_size is None:
126
+ max_size = len(iterable)
127
+ return chain.from_iterable(
128
+ combinations(iterable, size)
129
+ for size in range(min_size, max_size + 1)
130
+ )
131
+
132
+
133
+ def unzip(xs: Iterable[Tuple[_T]]) -> Tuple[Iterable[_T]]:
134
+ """
135
+ Inverse function of zip.
136
+
137
+ E.g., a, b, c = unzip(zip(a, b, c))
138
+
139
+ Note that the Python type of `a`, `b`, and `c` may not be preserved, only
140
+ the contents, order and length are guaranteed.
141
+ """
142
+ return zip(*xs)
143
+
144
+
145
+ def multiply(items: Iterable[_T], initial: _T = 1) -> _T:
146
+ """
147
+ Return the product of the given items.
148
+ """
149
+ return _reduce(_mul, items, initial)
150
+
151
+
152
+ def first(items: Iterable[_T]) -> _T:
153
+ """
154
+ Return the first element of the iterable.
155
+ """
156
+ return next(iter(items))
157
+
158
+
159
+ def take(iterable: Iterable[_T], n: int) -> Iterable[_T]:
160
+ """
161
+ Take the first n elements of the iterable.
162
+ """
163
+ return islice(iterable, n)