compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence, List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayStates, NDArrayFloat64
|
|
12
|
+
from ck.utils.random_extras import Random, random_permute
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCGibbsSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
skip: int,
|
|
23
|
+
burn_in: int,
|
|
24
|
+
pr_restart: float,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
27
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
28
|
+
self._rand: Random = rand
|
|
29
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
30
|
+
self._sample_rvs: List[SampleRV] = list(sampler_info.sample_rvs)
|
|
31
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
32
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
33
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
34
|
+
self._skip: int = skip
|
|
35
|
+
self._burn_in: int = burn_in
|
|
36
|
+
self._pr_restart: float = pr_restart
|
|
37
|
+
|
|
38
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
39
|
+
sample_rvs: List[SampleRV] = self._sample_rvs
|
|
40
|
+
rand: Random = self._rand
|
|
41
|
+
yield_f: YieldF = self._yield_f
|
|
42
|
+
slots_0: Collection[int] = self._slots_0
|
|
43
|
+
slots_1: Collection[int] = self._slots_1
|
|
44
|
+
program_buffer: ProgramBuffer = self._program_buffer
|
|
45
|
+
skip: int = self._skip
|
|
46
|
+
burn_in: int = self._burn_in
|
|
47
|
+
pr_restart: float = self._pr_restart
|
|
48
|
+
|
|
49
|
+
# Allocate working memory
|
|
50
|
+
state = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
51
|
+
prs: Sequence[NDArrayFloat64] = tuple(
|
|
52
|
+
np.zeros(len(sample_rv.slots), dtype=np.float64)
|
|
53
|
+
for sample_rv in sample_rvs
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Set an initial system state
|
|
57
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
|
|
58
|
+
|
|
59
|
+
# Run a burn in
|
|
60
|
+
for i in range(burn_in):
|
|
61
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
62
|
+
|
|
63
|
+
if pr_restart <= 0:
|
|
64
|
+
# There is no possibility of a restart
|
|
65
|
+
if skip == 0:
|
|
66
|
+
while True:
|
|
67
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
68
|
+
# We know the yield function will always provide either ints or Instances
|
|
69
|
+
# noinspection PyTypeChecker
|
|
70
|
+
yield yield_f(state)
|
|
71
|
+
else:
|
|
72
|
+
while True:
|
|
73
|
+
for _ in range(skip):
|
|
74
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
75
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
76
|
+
# We know the yield function will always provide either ints or Instances
|
|
77
|
+
# noinspection PyTypeChecker
|
|
78
|
+
yield yield_f(state)
|
|
79
|
+
|
|
80
|
+
else:
|
|
81
|
+
# There is the possibility of a restart
|
|
82
|
+
while True:
|
|
83
|
+
for _ in range(skip):
|
|
84
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
85
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
86
|
+
# We know the yield function will always provide either ints or Instances
|
|
87
|
+
# noinspection PyTypeChecker
|
|
88
|
+
yield yield_f(state)
|
|
89
|
+
if rand.random() < pr_restart:
|
|
90
|
+
# Set an initial system state
|
|
91
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
|
|
92
|
+
|
|
93
|
+
# Run a burn in
|
|
94
|
+
for i in range(burn_in):
|
|
95
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _next_sample_gibbs(
|
|
99
|
+
sample_rvs: List[SampleRV],
|
|
100
|
+
slots_1: Collection[int],
|
|
101
|
+
program_buffer: ProgramBuffer,
|
|
102
|
+
prs: Sequence[NDArrayFloat64],
|
|
103
|
+
state: NDArrayStates,
|
|
104
|
+
rand: Random
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Updates the states to a random system and reconfigures program inputs to match.
|
|
108
|
+
"""
|
|
109
|
+
prog_in = program_buffer.vars
|
|
110
|
+
random_permute(sample_rvs, rand=rand)
|
|
111
|
+
for sample_rv in sample_rvs:
|
|
112
|
+
rv_slots = sample_rv.slots
|
|
113
|
+
index = sample_rv.index
|
|
114
|
+
|
|
115
|
+
rv_pr: NDArrayFloat64 = prs[index]
|
|
116
|
+
s: int = state.item(index)
|
|
117
|
+
|
|
118
|
+
candidates = []
|
|
119
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
120
|
+
if slot in slots_1:
|
|
121
|
+
candidates.append((slot_state, slot))
|
|
122
|
+
assert len(candidates) > 0
|
|
123
|
+
|
|
124
|
+
# Compute conditioned marginals for the current rv
|
|
125
|
+
prog_in[rv_slots[s]] = 0
|
|
126
|
+
for slot_state, slot in candidates:
|
|
127
|
+
prog_in[slot] = 1
|
|
128
|
+
rv_pr[slot_state] = program_buffer.compute()
|
|
129
|
+
prog_in[slot] = 0
|
|
130
|
+
|
|
131
|
+
# Pick a new state based on the conditional probabilities
|
|
132
|
+
total = np.sum(rv_pr)
|
|
133
|
+
if total == 0.0:
|
|
134
|
+
# No state of the current rv has a non-zero probability when
|
|
135
|
+
# conditioned on the other random variables states.
|
|
136
|
+
# Pick a random state form a uniform distribution.
|
|
137
|
+
i = rand.randrange(0, len(candidates))
|
|
138
|
+
candidate = candidates[i]
|
|
139
|
+
# update the states array and the wmc input
|
|
140
|
+
state[index] = candidate[0]
|
|
141
|
+
prog_in[candidate[1]] = 1
|
|
142
|
+
else:
|
|
143
|
+
# Pick a state, sampled from the marginal distribution
|
|
144
|
+
r = rand.random() * total
|
|
145
|
+
slot = None
|
|
146
|
+
slot_state = None
|
|
147
|
+
for slot_state, slot in candidates:
|
|
148
|
+
if r <= rv_pr[slot_state]:
|
|
149
|
+
break
|
|
150
|
+
r -= rv_pr[slot_state]
|
|
151
|
+
# update the states array and the wmc input
|
|
152
|
+
state[index] = slot_state
|
|
153
|
+
prog_in[slot] = 1
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayStates, DTypeStates
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCMetropolisSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
skip: int,
|
|
23
|
+
burn_in: int,
|
|
24
|
+
pr_restart: float,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
27
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
28
|
+
self._rand: Random = rand
|
|
29
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
30
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
31
|
+
self._state_dtype: DTypeStates = dtype_for_state_indexes(self.rvs)
|
|
32
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
33
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
34
|
+
self._skip: int = skip
|
|
35
|
+
self._burn_in: int = burn_in
|
|
36
|
+
self._pr_restart: float = pr_restart
|
|
37
|
+
|
|
38
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
39
|
+
sample_rvs = self._sample_rvs
|
|
40
|
+
rand = self._rand
|
|
41
|
+
yield_f = self._yield_f
|
|
42
|
+
slots_0 = self._slots_0
|
|
43
|
+
slots_1 = self._slots_1
|
|
44
|
+
program_buffer = self._program_buffer
|
|
45
|
+
slots = program_buffer.vars
|
|
46
|
+
skip = self._skip
|
|
47
|
+
burn_in = self._burn_in
|
|
48
|
+
pr_restart = self._pr_restart
|
|
49
|
+
|
|
50
|
+
# Allocate working memory
|
|
51
|
+
state: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
52
|
+
|
|
53
|
+
# set up the input slots to respect conditioning
|
|
54
|
+
for slot in slots_0:
|
|
55
|
+
slots[slot] = 0
|
|
56
|
+
for slot in slots_1:
|
|
57
|
+
slots[slot] = 1
|
|
58
|
+
|
|
59
|
+
# Convert sample slots to possibles
|
|
60
|
+
# And map slots to states.
|
|
61
|
+
possibles = []
|
|
62
|
+
for sample_rv in sample_rvs:
|
|
63
|
+
rv_possibles = []
|
|
64
|
+
for slot_state, slot in enumerate(sample_rv.slots):
|
|
65
|
+
if slots[slot] == 1:
|
|
66
|
+
rv_possibles.append((slot_state, slot))
|
|
67
|
+
possibles.append((sample_rv.index, sample_rv.slots, rv_possibles))
|
|
68
|
+
|
|
69
|
+
# Set an initial valid system state
|
|
70
|
+
w: float = self._init_sample_metropolis(state)
|
|
71
|
+
|
|
72
|
+
# Run a burn in
|
|
73
|
+
for i in range(burn_in):
|
|
74
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
75
|
+
|
|
76
|
+
if pr_restart <= 0:
|
|
77
|
+
# There is no possibility of a restart
|
|
78
|
+
if skip == 0:
|
|
79
|
+
while True:
|
|
80
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
81
|
+
# We know the yield function will always provide either ints or Instances
|
|
82
|
+
# noinspection PyTypeChecker
|
|
83
|
+
yield yield_f(state)
|
|
84
|
+
else:
|
|
85
|
+
while True:
|
|
86
|
+
for _ in range(skip):
|
|
87
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
88
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
89
|
+
# We know the yield function will always provide either ints or Instances
|
|
90
|
+
# noinspection PyTypeChecker
|
|
91
|
+
yield yield_f(state)
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
# There is the possibility of a restart
|
|
95
|
+
while True:
|
|
96
|
+
for _ in range(skip):
|
|
97
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
98
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
99
|
+
# We know the yield function will always provide either ints or Instances
|
|
100
|
+
# noinspection PyTypeChecker
|
|
101
|
+
yield yield_f(state)
|
|
102
|
+
|
|
103
|
+
if rand.random() < pr_restart:
|
|
104
|
+
# Set an initial valid system state
|
|
105
|
+
w = self._init_sample_metropolis(state)
|
|
106
|
+
# Run a burn in
|
|
107
|
+
for i in range(burn_in):
|
|
108
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
109
|
+
|
|
110
|
+
def _init_sample_metropolis(self, state: NDArrayStates) -> float:
|
|
111
|
+
"""
|
|
112
|
+
Initialises the states to a valid random system and configures program inputs to match.
|
|
113
|
+
"""
|
|
114
|
+
sample_rvs = self._sample_rvs
|
|
115
|
+
rand = self._rand
|
|
116
|
+
slots_0 = self._slots_0
|
|
117
|
+
slots_1 = self._slots_1
|
|
118
|
+
program_buffer = self._program_buffer
|
|
119
|
+
slots = program_buffer.vars
|
|
120
|
+
|
|
121
|
+
while True:
|
|
122
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
|
|
123
|
+
w: float = program_buffer.compute().item()
|
|
124
|
+
if w >= 0:
|
|
125
|
+
return w
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _next_sample_metropolis(
|
|
129
|
+
possibles,
|
|
130
|
+
program_buffer: ProgramBuffer,
|
|
131
|
+
state,
|
|
132
|
+
cur_w: float,
|
|
133
|
+
rand: Random,
|
|
134
|
+
) -> float:
|
|
135
|
+
"""
|
|
136
|
+
Updates the states to a random system and reconfigures program inputs to match.
|
|
137
|
+
"""
|
|
138
|
+
prog_in = program_buffer.vars
|
|
139
|
+
|
|
140
|
+
# Generate a proposal.
|
|
141
|
+
# randomly choose a random variable
|
|
142
|
+
i = rand.randrange(0, len(possibles))
|
|
143
|
+
idx, rv_slots, rv_possibles = possibles[i]
|
|
144
|
+
# keep track of the current state slot
|
|
145
|
+
cur_s = state[idx]
|
|
146
|
+
cur_s_slot = rv_slots[cur_s]
|
|
147
|
+
# randomly choose a possible state
|
|
148
|
+
i = rand.randrange(0, len(rv_possibles))
|
|
149
|
+
s, s_slot = rv_possibles[i]
|
|
150
|
+
|
|
151
|
+
# set up state and program to compute weight
|
|
152
|
+
prog_in[cur_s_slot] = 0
|
|
153
|
+
prog_in[s_slot] = 1
|
|
154
|
+
|
|
155
|
+
# calculate the weight and test it
|
|
156
|
+
new_w: float = program_buffer.compute().item()
|
|
157
|
+
if rand.random() * cur_w < new_w:
|
|
158
|
+
# accept
|
|
159
|
+
state[idx] = s
|
|
160
|
+
return new_w
|
|
161
|
+
else:
|
|
162
|
+
# reject: set state and program to what it was before
|
|
163
|
+
prog_in[s_slot] = 0
|
|
164
|
+
prog_in[cur_s_slot] = 1
|
|
165
|
+
return cur_w
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayNumeric
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCRejectionSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
z: float,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
25
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
26
|
+
self._rand: Random = rand
|
|
27
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
28
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
29
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
30
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
31
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
32
|
+
|
|
33
|
+
# Initialise fields for tracking max_w
|
|
34
|
+
self._w_max = None # estimated maximum weight for any one world
|
|
35
|
+
self._w_not_seen = z # z - w_seen
|
|
36
|
+
self._w_high = 0.0 # highest instance wight seen so far
|
|
37
|
+
self._samples = set() # what samples have we seen
|
|
38
|
+
|
|
39
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
40
|
+
sample_rvs = self._sample_rvs
|
|
41
|
+
rand = self._rand
|
|
42
|
+
yield_f = self._yield_f
|
|
43
|
+
slots_0 = self._slots_0
|
|
44
|
+
slots_1 = self._slots_1
|
|
45
|
+
program_buffer = self._program_buffer
|
|
46
|
+
slots: NDArrayNumeric = program_buffer.vars
|
|
47
|
+
|
|
48
|
+
# Calling wmc() will give the weighted model count for the state of the current input slots.
|
|
49
|
+
def wmc() -> float:
|
|
50
|
+
return program_buffer.compute().item()
|
|
51
|
+
|
|
52
|
+
# Allocate working memory to store a possible world
|
|
53
|
+
state: NDArrayNumeric = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
54
|
+
|
|
55
|
+
# Initialise w_max to w_max_marginal, if not done yet.
|
|
56
|
+
if self._w_max is None:
|
|
57
|
+
w_max_marginal = self._w_not_seen # initially set to z, so a 'large' weight
|
|
58
|
+
|
|
59
|
+
# Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
|
|
60
|
+
for slot in slots_0:
|
|
61
|
+
slots[slot] = 0
|
|
62
|
+
for slot in slots_1:
|
|
63
|
+
slots[slot] = 1
|
|
64
|
+
|
|
65
|
+
# Loop over the rvs
|
|
66
|
+
for sample_rv in sample_rvs:
|
|
67
|
+
rv_slots = sample_rv.slots
|
|
68
|
+
max_for_rv = 0
|
|
69
|
+
# Set all rv slots to 0
|
|
70
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
71
|
+
slots[slot] = 0
|
|
72
|
+
back_to_one = []
|
|
73
|
+
# Loop over state of the rv.
|
|
74
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
75
|
+
if slot in slots_1:
|
|
76
|
+
slots[slot] = 1
|
|
77
|
+
w: float = wmc()
|
|
78
|
+
max_for_rv = max(max_for_rv, w)
|
|
79
|
+
slots[slot] = 0
|
|
80
|
+
back_to_one.append(slot)
|
|
81
|
+
# Set rv slots back to 1 as needed (ready for next rv).
|
|
82
|
+
for slot in back_to_one:
|
|
83
|
+
slots[slot] = 1
|
|
84
|
+
|
|
85
|
+
w_max_marginal = min(w_max_marginal, max_for_rv)
|
|
86
|
+
|
|
87
|
+
self._w_max = w_max_marginal
|
|
88
|
+
|
|
89
|
+
while True:
|
|
90
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
|
|
91
|
+
w: float = wmc()
|
|
92
|
+
|
|
93
|
+
if rand.random() * self._w_max < w:
|
|
94
|
+
# We know the yield function will always provide either ints or Instances
|
|
95
|
+
# noinspection PyTypeChecker
|
|
96
|
+
yield yield_f(state)
|
|
97
|
+
|
|
98
|
+
# Update w_not_seen and w_high to adapt w_max.
|
|
99
|
+
# We don't bother tracking seen samples once w_not_seen and w_high
|
|
100
|
+
# are close enough, or we have tracked too many samples.
|
|
101
|
+
if self._samples is not None:
|
|
102
|
+
s = tuple(state)
|
|
103
|
+
if s not in self._samples:
|
|
104
|
+
self._samples.add(s)
|
|
105
|
+
self._w_not_seen -= w
|
|
106
|
+
self._w_high = max(self._w_high, w)
|
|
107
|
+
w_max_tracked = max(self._w_high, self._w_not_seen)
|
|
108
|
+
self._w_max = min(w_max_tracked, self._w_max)
|
|
109
|
+
|
|
110
|
+
# See if we should stop tracking samples.
|
|
111
|
+
if (
|
|
112
|
+
self._w_not_seen - self._w_high < 0.001 # w_not_seen and w_high are close enough
|
|
113
|
+
or len(self._samples) > 1000000 # tracked too many samples
|
|
114
|
+
):
|
|
115
|
+
self._samples = None
|
ck/utils/__init__.py
ADDED
|
File without changes
|
ck/utils/iter_extras.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A module with extra iteration functions.
|
|
3
|
+
"""
|
|
4
|
+
from functools import reduce as _reduce
|
|
5
|
+
from itertools import combinations, chain, islice
|
|
6
|
+
from operator import mul as _mul
|
|
7
|
+
from typing import Iterable, Tuple, Sequence, TypeVar
|
|
8
|
+
|
|
9
|
+
_T = TypeVar('_T')
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def flatten(iterables: Iterable[Iterable[_T]]) -> Iterable[_T]:
|
|
13
|
+
"""
|
|
14
|
+
Iterate over the elements of an iterable of iterables.
|
|
15
|
+
"""
|
|
16
|
+
return (elem for iterable in iterables for elem in iterable)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def deep_flatten(iterables: Iterable) -> Iterable:
|
|
20
|
+
"""
|
|
21
|
+
Iterate over the flattening of nested iterables.
|
|
22
|
+
"""
|
|
23
|
+
for el in iterables:
|
|
24
|
+
if isinstance(el, Iterable) and not isinstance(el, str):
|
|
25
|
+
for sub in deep_flatten(el):
|
|
26
|
+
yield sub
|
|
27
|
+
else:
|
|
28
|
+
yield el
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def combos(list_of_lists: Sequence[Sequence[_T]], flip=False) -> Iterable[Tuple[_T, ...]]:
|
|
32
|
+
"""
|
|
33
|
+
Iterate over all combinations of taking one element from each of the lists.
|
|
34
|
+
|
|
35
|
+
The order of results has the first element changing most rapidly.
|
|
36
|
+
For example, given [[1,2,3],[4,5],[6,7]], combos yields the following:
|
|
37
|
+
(1,4,6), (2,4,6), (3,4,6), (1,5,6), (2,5,6), (3,5,6),
|
|
38
|
+
(1,4,7), (2,4,7), (3,4,7), (1,5,7), (2,5,7), (3,5,7).
|
|
39
|
+
|
|
40
|
+
If flip, then the last changes most rapidly.
|
|
41
|
+
"""
|
|
42
|
+
num = len(list_of_lists)
|
|
43
|
+
if num == 0:
|
|
44
|
+
yield ()
|
|
45
|
+
return
|
|
46
|
+
rng = range(num)
|
|
47
|
+
indexes = [0] * num
|
|
48
|
+
if flip:
|
|
49
|
+
start = num - 1
|
|
50
|
+
inc = -1
|
|
51
|
+
end = -1
|
|
52
|
+
else:
|
|
53
|
+
start = 0
|
|
54
|
+
inc = 1
|
|
55
|
+
end = num
|
|
56
|
+
while True:
|
|
57
|
+
yield tuple(list_of_lists[i][indexes[i]] for i in rng)
|
|
58
|
+
i = start
|
|
59
|
+
while True:
|
|
60
|
+
indexes[i] += 1
|
|
61
|
+
if indexes[i] < len(list_of_lists[i]):
|
|
62
|
+
break
|
|
63
|
+
indexes[i] = 0
|
|
64
|
+
i += inc
|
|
65
|
+
if i == end:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def combos_ranges(list_of_lens: Sequence[int], flip=False) -> Iterable[Tuple[int, ...]]:
|
|
70
|
+
"""
|
|
71
|
+
Equivalent to combos([range(l) for l in list_of_lens], flip).
|
|
72
|
+
|
|
73
|
+
The order of results has the first element changing most rapidly.
|
|
74
|
+
If flip, then the last changes most rapidly.
|
|
75
|
+
"""
|
|
76
|
+
num = len(list_of_lens)
|
|
77
|
+
if num == 0:
|
|
78
|
+
yield ()
|
|
79
|
+
return
|
|
80
|
+
indexes = [0] * num
|
|
81
|
+
if flip:
|
|
82
|
+
start = num - 1
|
|
83
|
+
inc = -1
|
|
84
|
+
end = -1
|
|
85
|
+
else:
|
|
86
|
+
start = 0
|
|
87
|
+
inc = 1
|
|
88
|
+
end = num
|
|
89
|
+
while True:
|
|
90
|
+
yield tuple(indexes)
|
|
91
|
+
i = start
|
|
92
|
+
while True:
|
|
93
|
+
indexes[i] += 1
|
|
94
|
+
if indexes[i] < list_of_lens[i]:
|
|
95
|
+
break
|
|
96
|
+
indexes[i] = 0
|
|
97
|
+
i += inc
|
|
98
|
+
if i == end:
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def pairs(elements: Iterable[_T]) -> Iterable[Tuple[_T, _T]]:
|
|
103
|
+
"""
|
|
104
|
+
Iterate over all possible pairs in the given list of elements.
|
|
105
|
+
"""
|
|
106
|
+
return combinations(elements, 2)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def sequential_pairs(elements: Sequence[_T]) -> Iterable[Tuple[_T, _T]]:
|
|
110
|
+
"""
|
|
111
|
+
Iterate over sequential pairs in the given list of elements.
|
|
112
|
+
"""
|
|
113
|
+
for i in range(len(elements) - 1):
|
|
114
|
+
yield elements[i], elements[i + 1]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def powerset(iterable: Iterable[_T], min_size: int = 0, max_size: int = None) -> Iterable[Tuple[_T, ...]]:
|
|
118
|
+
"""
|
|
119
|
+
powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
|
|
120
|
+
"""
|
|
121
|
+
if not isinstance(iterable, (list, tuple)):
|
|
122
|
+
iterable = list(iterable)
|
|
123
|
+
if min_size is None:
|
|
124
|
+
min_size = 0
|
|
125
|
+
if max_size is None:
|
|
126
|
+
max_size = len(iterable)
|
|
127
|
+
return chain.from_iterable(
|
|
128
|
+
combinations(iterable, size)
|
|
129
|
+
for size in range(min_size, max_size + 1)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def unzip(xs: Iterable[Tuple[_T]]) -> Tuple[Iterable[_T]]:
|
|
134
|
+
"""
|
|
135
|
+
Inverse function of zip.
|
|
136
|
+
|
|
137
|
+
E.g., a, b, c = unzip(zip(a, b, c))
|
|
138
|
+
|
|
139
|
+
Note that the Python type of `a`, `b`, and `c` may not be preserved, only
|
|
140
|
+
the contents, order and length are guaranteed.
|
|
141
|
+
"""
|
|
142
|
+
return zip(*xs)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def multiply(items: Iterable[_T], initial: _T = 1) -> _T:
|
|
146
|
+
"""
|
|
147
|
+
Return the product of the given items.
|
|
148
|
+
"""
|
|
149
|
+
return _reduce(_mul, items, initial)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def first(items: Iterable[_T]) -> _T:
|
|
153
|
+
"""
|
|
154
|
+
Return the first element of the iterable.
|
|
155
|
+
"""
|
|
156
|
+
return next(iter(items))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def take(iterable: Iterable[_T], n: int) -> Iterable[_T]:
|
|
160
|
+
"""
|
|
161
|
+
Take the first n elements of the iterable.
|
|
162
|
+
"""
|
|
163
|
+
return islice(iterable, n)
|