compiled-knowledge 4.0.0a20__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Dict, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArray, NDArrayNumeric
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MarginalsDirectSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
rv_idx_to_result_offset: Dict[int, int],
|
|
23
|
+
):
|
|
24
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
25
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
26
|
+
self._rand: Random = rand
|
|
27
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
28
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
29
|
+
self._chain_rvs: Sequence[SampleRV] = tuple(
|
|
30
|
+
sample_rv for sample_rv in sampler_info.sample_rvs if sample_rv.copy_index is not None)
|
|
31
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
32
|
+
self._max_number_of_states: int = max((len(rv) for rv in self.rvs), default=0)
|
|
33
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
34
|
+
|
|
35
|
+
self._marginals: Sequence[NDArrayNumeric] = tuple(
|
|
36
|
+
self._program_buffer.results[
|
|
37
|
+
rv_idx_to_result_offset[sample_rv.rv.idx]
|
|
38
|
+
:
|
|
39
|
+
rv_idx_to_result_offset[sample_rv.rv.idx] + len(sample_rv.rv)
|
|
40
|
+
]
|
|
41
|
+
for sample_rv in sampler_info.sample_rvs
|
|
42
|
+
)
|
|
43
|
+
# Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
|
|
44
|
+
slots: NDArray = self._program_buffer.vars
|
|
45
|
+
for slot in sampler_info.slots_0:
|
|
46
|
+
slots[slot] = 0
|
|
47
|
+
for slot in sampler_info.slots_1:
|
|
48
|
+
slots[slot] = 1
|
|
49
|
+
|
|
50
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
51
|
+
yield_f = self._yield_f
|
|
52
|
+
rand = self._rand
|
|
53
|
+
sample_rvs = self._sample_rvs
|
|
54
|
+
chain_rvs = self._chain_rvs
|
|
55
|
+
program_buffer = self._program_buffer
|
|
56
|
+
slots: NDArray = program_buffer.vars
|
|
57
|
+
marginals = self._marginals
|
|
58
|
+
slots_1 = self._slots_1
|
|
59
|
+
|
|
60
|
+
# Set up working memory buffer
|
|
61
|
+
states = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
62
|
+
|
|
63
|
+
def compute() -> float:
|
|
64
|
+
# Compute the program results based on the current input slot values.
|
|
65
|
+
# Return the WMC.
|
|
66
|
+
return program_buffer.compute().item(-1)
|
|
67
|
+
|
|
68
|
+
while True:
|
|
69
|
+
wmc: float = compute()
|
|
70
|
+
rnd: float = rand.random() * wmc
|
|
71
|
+
|
|
72
|
+
for sample_rv in sample_rvs:
|
|
73
|
+
index: int = sample_rv.index
|
|
74
|
+
if index > 0:
|
|
75
|
+
# No need to execute the program on the first time through
|
|
76
|
+
# as it was done just before entering the loop.
|
|
77
|
+
wmc = compute()
|
|
78
|
+
|
|
79
|
+
rv_dist: NDArray = marginals[sample_rv.index]
|
|
80
|
+
|
|
81
|
+
rv_dist_sum: float = rv_dist.sum()
|
|
82
|
+
if rv_dist_sum <= 0:
|
|
83
|
+
raise RuntimeError('zero probability')
|
|
84
|
+
rv_dist *= wmc / rv_dist_sum
|
|
85
|
+
|
|
86
|
+
state_index: int = -1
|
|
87
|
+
for i in range(len(sample_rv.rv)):
|
|
88
|
+
w = rv_dist.item(i)
|
|
89
|
+
if rnd < w:
|
|
90
|
+
state_index = i
|
|
91
|
+
break
|
|
92
|
+
rnd -= w
|
|
93
|
+
assert state_index >= 0
|
|
94
|
+
|
|
95
|
+
for slot in sample_rv.slots:
|
|
96
|
+
slots[slot] = 0
|
|
97
|
+
slots[sample_rv.slots[state_index]] = 1
|
|
98
|
+
states[index] = state_index
|
|
99
|
+
|
|
100
|
+
yield yield_f(states)
|
|
101
|
+
|
|
102
|
+
# Reset the one slots for the next iteration.
|
|
103
|
+
for slot in slots_1:
|
|
104
|
+
slots[slot] = 1
|
|
105
|
+
|
|
106
|
+
# Copy chain pairs for next iteration.
|
|
107
|
+
# (This writes over any initial chain conditions from slots_1.)
|
|
108
|
+
for sample_rv in chain_rvs:
|
|
109
|
+
rv_slots = sample_rv.slots
|
|
110
|
+
prev_state_idx: int = states.item(sample_rv.copy_index)
|
|
111
|
+
for slot in rv_slots:
|
|
112
|
+
slots[slot] = 0
|
|
113
|
+
slots[rv_slots[prev_state_idx]] = 1
|
ck/sampling/sampler.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from itertools import islice
|
|
3
|
+
from typing import Sequence, Iterator
|
|
4
|
+
|
|
5
|
+
from ck.pgm import RandomVariable, Instance, Indicator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Sampler(ABC):
|
|
9
|
+
"""
|
|
10
|
+
A Sampler provides an unlimited series of samples for one or more random variables.
|
|
11
|
+
The random variables being sampled are provided as a tuple via the `rvs` property.
|
|
12
|
+
|
|
13
|
+
A Sampler will either iterate over Instance objects, where each instance is co-indexed
|
|
14
|
+
with `self.rvs`, or may iterate over single state indexes. Whether a Sampler iterates
|
|
15
|
+
over Instance objects or single state indexes is determined by the implementation.
|
|
16
|
+
If iterating over single state indexes, then `len(self.rvs) == 1`.
|
|
17
|
+
"""
|
|
18
|
+
__slots__ = ('_rvs', '_condition')
|
|
19
|
+
|
|
20
|
+
def __init__(self, rvs: Sequence[RandomVariable], condition: Sequence[Indicator]):
|
|
21
|
+
"""
|
|
22
|
+
Args:
|
|
23
|
+
rvs: a collection of the random variables being
|
|
24
|
+
sampled, co-indexed with each sample provided by `iter(self)`.
|
|
25
|
+
condition: condition on `rvs` that are compiled into the sampler.
|
|
26
|
+
"""
|
|
27
|
+
self._rvs: Sequence[RandomVariable] = tuple(rvs)
|
|
28
|
+
self._condition: Sequence[Indicator] = tuple(condition)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
32
|
+
"""
|
|
33
|
+
What random variables are being sampled.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
the random variables being sampled, co-indexed with each sample from `iter(self)`.
|
|
37
|
+
"""
|
|
38
|
+
return self._rvs
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def condition(self) -> Sequence[Indicator]:
|
|
42
|
+
"""
|
|
43
|
+
Condition on `self.rvs` that are compiled into the sampler.
|
|
44
|
+
"""
|
|
45
|
+
return self._condition
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
49
|
+
"""
|
|
50
|
+
An unlimited series of samples from a random process.
|
|
51
|
+
Each sample is co-indexed with the random variables provided by `self.rvs`.
|
|
52
|
+
"""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
def take(self, number_of_samples: int) -> Iterator[Instance] | Iterator[int]:
|
|
56
|
+
"""
|
|
57
|
+
Take a limited number of samples from `iter(self)`.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
number_of_samples: a limit on the number of samples to provide.
|
|
61
|
+
"""
|
|
62
|
+
return islice(self, number_of_samples)
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from itertools import count
|
|
3
|
+
from typing import Callable, Sequence, Optional, Set, Tuple, Dict, Collection
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance, RandomVariable, Indicator
|
|
6
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
7
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
8
|
+
from ck.probability.probability_space import Condition, check_condition
|
|
9
|
+
from ck.utils.map_set import MapSet
|
|
10
|
+
from ck.utils.np_extras import NDArrayStates, NDArrayNumeric
|
|
11
|
+
from ck.utils.random_extras import Random
|
|
12
|
+
|
|
13
|
+
# Type of a yield function. Support for a sampler.
|
|
14
|
+
# A yield function may be used to implement a sampler's iterator, thus
|
|
15
|
+
# it provides an Instance or single state index.
|
|
16
|
+
YieldF = Callable[[NDArrayStates], int] | Callable[[NDArrayStates], Instance]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SampleRV:
|
|
21
|
+
"""
|
|
22
|
+
Support for a sampler.
|
|
23
|
+
A SampleRV structure keeps track of information for one sampled random variable.
|
|
24
|
+
"""
|
|
25
|
+
index: int # index into the sequence of sample rvs.
|
|
26
|
+
rv: RandomVariable # the random variable being sampled.
|
|
27
|
+
slots: Sequence[int] # program input slots for indicators of the random variable (co-indexed with rv.states).
|
|
28
|
+
copy_index: Optional[int] # for Markov chains, which previous sample rv should be copied?
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class SamplerInfo:
|
|
33
|
+
"""
|
|
34
|
+
Support for a sampler.
|
|
35
|
+
A SamplerInfo structure keeps track of standard information when a sampler uses a Program.
|
|
36
|
+
"""
|
|
37
|
+
sample_rvs: Sequence[SampleRV]
|
|
38
|
+
condition: Sequence[Indicator]
|
|
39
|
+
yield_f: YieldF
|
|
40
|
+
slots_0: Set[int]
|
|
41
|
+
slots_1: Set[int]
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def rvs(self) -> Tuple[RandomVariable, ...]:
|
|
45
|
+
"""
|
|
46
|
+
Extract the RandomVariable objects from `self.sample_rvs`.
|
|
47
|
+
"""
|
|
48
|
+
return tuple(sample_rv.rv for sample_rv in self.sample_rvs)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_sampler_info(
|
|
52
|
+
program_with_slotmap: ProgramWithSlotmap,
|
|
53
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]],
|
|
54
|
+
condition: Condition,
|
|
55
|
+
chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
|
|
56
|
+
initial_chain_condition: Condition = (),
|
|
57
|
+
) -> SamplerInfo:
|
|
58
|
+
"""
|
|
59
|
+
Helper for samplers.
|
|
60
|
+
|
|
61
|
+
Determines:
|
|
62
|
+
(1) the slots for sampling rvs,
|
|
63
|
+
(2) Markov chaining rvs,
|
|
64
|
+
(3) the function to use for yielding an Instance or state index.
|
|
65
|
+
|
|
66
|
+
If parameter `rvs` is a RandomVariable, then the yield function will
|
|
67
|
+
provide a state index. If parameter `rvs` is a Sequence, then the
|
|
68
|
+
yield function will provide an Instance.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
program_with_slotmap: the program and slotmap being referenced.
|
|
72
|
+
rvs: the random variables to sample. It may be either a sequence of
|
|
73
|
+
random variables, or a single random variable.
|
|
74
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
75
|
+
chain_pairs: is a collection of pairs of random variables, each random variable
|
|
76
|
+
must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
|
|
77
|
+
as a condition for to_rv prior to generating a sample.
|
|
78
|
+
initial_chain_condition: are condition indicators (just like condition)
|
|
79
|
+
for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: if preconditions of `program_with_slotmap` are incompatible with the given condition.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a SamplerInfo structure.
|
|
86
|
+
"""
|
|
87
|
+
if rvs is None:
|
|
88
|
+
rvs = program_with_slotmap.rvs
|
|
89
|
+
if isinstance(rvs, RandomVariable):
|
|
90
|
+
# a single rv
|
|
91
|
+
rvs = (rvs,)
|
|
92
|
+
yield_f = lambda x: x.item()
|
|
93
|
+
else:
|
|
94
|
+
# a sequence of rvs
|
|
95
|
+
rvs = tuple(rvs)
|
|
96
|
+
yield_f = lambda x: x.tolist()
|
|
97
|
+
|
|
98
|
+
# Group condition indicators by `rv_idx`.
|
|
99
|
+
conditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
100
|
+
for ind in check_condition(condition):
|
|
101
|
+
conditioned_rvs.add(ind.rv_idx, ind)
|
|
102
|
+
del condition
|
|
103
|
+
|
|
104
|
+
# Group precondition indicators by `rv_idx`.
|
|
105
|
+
preconditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
106
|
+
for ind in program_with_slotmap.precondition:
|
|
107
|
+
preconditioned_rvs.add(ind.rv_idx, ind)
|
|
108
|
+
|
|
109
|
+
# Rationalise conditioned_rvs with preconditioned_rvs
|
|
110
|
+
rv_idx: int
|
|
111
|
+
precondition_set: Set[Indicator]
|
|
112
|
+
for rv_idx, precondition_set in preconditioned_rvs.items():
|
|
113
|
+
condition_set = conditioned_rvs.get(rv_idx)
|
|
114
|
+
if condition_set is None:
|
|
115
|
+
# A preconditioned rv was not mentioned in the explicit conditions
|
|
116
|
+
conditioned_rvs.add_all(rv_idx, precondition_set)
|
|
117
|
+
else:
|
|
118
|
+
# A preconditioned rv was also mentioned in the explicit conditions
|
|
119
|
+
condition_set.intersection_update(precondition_set)
|
|
120
|
+
if len(condition_set) == 0:
|
|
121
|
+
rv_index: Dict[int, RandomVariable] = {rv.idx: rv for rv in rvs}
|
|
122
|
+
rv: RandomVariable = rv_index[rv_idx]
|
|
123
|
+
raise ValueError(f'conditions on rv {rv} are disjoint from preconditions')
|
|
124
|
+
del preconditioned_rvs
|
|
125
|
+
|
|
126
|
+
# Group initial chain indicators by `rv_idx`.
|
|
127
|
+
initial_chain_condition: Sequence[Indicator] = check_condition(initial_chain_condition)
|
|
128
|
+
initial_chain_conditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
129
|
+
for ind in initial_chain_condition:
|
|
130
|
+
initial_chain_conditioned_rvs.add(ind.rv_idx, ind)
|
|
131
|
+
|
|
132
|
+
# Check sample rvs are valid and without duplicates.
|
|
133
|
+
rvs_set: Set[RandomVariable] = set(rvs)
|
|
134
|
+
if not rvs_set.issubset(program_with_slotmap.rvs):
|
|
135
|
+
raise ValueError('sample random variables not available')
|
|
136
|
+
if len(rvs) != len(rvs_set):
|
|
137
|
+
raise ValueError('duplicate sample random variables requested')
|
|
138
|
+
|
|
139
|
+
# Check chain_pairs rvs are being sampled
|
|
140
|
+
if not rvs_set.issuperset(pair[0] for pair in chain_pairs):
|
|
141
|
+
raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
|
|
142
|
+
if not rvs_set.issuperset(pair[1] for pair in chain_pairs):
|
|
143
|
+
raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
|
|
144
|
+
|
|
145
|
+
# Check chain_pairs source and destination rvs are disjoint
|
|
146
|
+
if not {pair[0] for pair in chain_pairs}.isdisjoint(pair[1] for pair in chain_pairs):
|
|
147
|
+
raise ValueError('chain_pairs sources and destinations are not disjoint')
|
|
148
|
+
|
|
149
|
+
# Check no chain_pairs destination rv is a conditioned rv
|
|
150
|
+
if any(pair[1].idx in conditioned_rvs.keys() for pair in chain_pairs):
|
|
151
|
+
raise ValueError('a chain_pairs destination is conditioned')
|
|
152
|
+
|
|
153
|
+
# Check chain initial conditions relate to chain_pairs destination rvs
|
|
154
|
+
chain_dest_rv_idxs: Set[int] = {pair[1].idx for pair in chain_pairs}
|
|
155
|
+
if not all(rv_idx in chain_dest_rv_idxs for rv_idx in initial_chain_conditioned_rvs.keys()):
|
|
156
|
+
raise ValueError('a chain initial condition is not a chain destination rv')
|
|
157
|
+
|
|
158
|
+
# Convert chain_pairs for registering with `sample_rvs`.
|
|
159
|
+
# rv_idx maps RandomVariable id to a position it exists in rvs (doesn't matter if rv is duplicated in rvs)
|
|
160
|
+
# copy_idx RandomVariable id to a position in rvs that it can be copied from for Markov chaining.
|
|
161
|
+
rv_idx: Dict[int, int] = {id(rv): i for i, rv in enumerate(rvs)}
|
|
162
|
+
copy_idx: Dict[int, int] = {id(rv): rv_idx[id(prev_rv)] for prev_rv, rv in chain_pairs}
|
|
163
|
+
|
|
164
|
+
# Get rv state slots, rvs_slots is co-indexed with rvs
|
|
165
|
+
slot_map: SlotMap = program_with_slotmap.slot_map
|
|
166
|
+
rvs_slots = tuple(tuple(slot_map[ind] for ind in rv) for rv in rvs)
|
|
167
|
+
|
|
168
|
+
sample_rvs: Sequence[SampleRV] = tuple(
|
|
169
|
+
SampleRV(idx, rv, rv_slots, copy_idx.get(id(rv)))
|
|
170
|
+
for idx, rv, rv_slots in zip(count(), rvs, rvs_slots)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Process the condition to get zero and one slots
|
|
174
|
+
slots_0: Set[int] = set()
|
|
175
|
+
slots_1: Set[int] = set()
|
|
176
|
+
for rv in program_with_slotmap.rvs:
|
|
177
|
+
conditioning: Optional[Set[Indicator]] = conditioned_rvs.get(rv.idx)
|
|
178
|
+
if conditioning is not None:
|
|
179
|
+
slots_1.update(slot_map[ind] for ind in conditioning)
|
|
180
|
+
slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
conditioning: Optional[Set[Indicator]] = initial_chain_conditioned_rvs.get(rv.idx)
|
|
184
|
+
if conditioning is not None:
|
|
185
|
+
slots_1.update(slot_map[ind] for ind in conditioning)
|
|
186
|
+
slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
# default
|
|
190
|
+
slots_1.update(slot_map[ind] for ind in rv)
|
|
191
|
+
|
|
192
|
+
return SamplerInfo(
|
|
193
|
+
sample_rvs=sample_rvs,
|
|
194
|
+
condition=tuple(ind for condition_set in conditioned_rvs.values() for ind in condition_set),
|
|
195
|
+
yield_f=yield_f,
|
|
196
|
+
slots_0=slots_0,
|
|
197
|
+
slots_1=slots_1,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def uniform_random_sample(
|
|
202
|
+
sample_rvs: Sequence[SampleRV],
|
|
203
|
+
slots_0: Collection[int],
|
|
204
|
+
slots_1: Collection[int],
|
|
205
|
+
slots: NDArrayNumeric,
|
|
206
|
+
state: NDArrayStates,
|
|
207
|
+
rand: Random,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Helper for samplers.
|
|
211
|
+
|
|
212
|
+
Sets the states to a random instance and configures slots to match.
|
|
213
|
+
States are drawn from a uniform distribution, drawn using random.randrange.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
# Set up the input slots to respect conditioning
|
|
217
|
+
for slot in slots_0:
|
|
218
|
+
slots[slot] = 0
|
|
219
|
+
for slot in slots_1:
|
|
220
|
+
slots[slot] = 1
|
|
221
|
+
|
|
222
|
+
for sample_rv in sample_rvs:
|
|
223
|
+
candidates = []
|
|
224
|
+
for slot_state, slot in enumerate(sample_rv.slots):
|
|
225
|
+
if slots[slot] == 1:
|
|
226
|
+
slots[slot] = 0
|
|
227
|
+
candidates.append((slot_state, slot))
|
|
228
|
+
|
|
229
|
+
# Pick a random state for sample_rv
|
|
230
|
+
slot_state, slot = candidates[rand.randrange(0, len(candidates))]
|
|
231
|
+
state[sample_rv.index] = slot_state
|
|
232
|
+
slots[slot] = 1
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Set, List, Iterator, Optional, Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ck.pgm import Instance, RandomVariable, Indicator
|
|
7
|
+
from ck.probability.probability_space import dtype_for_state_indexes, Condition, check_condition
|
|
8
|
+
from ck.utils.map_set import MapSet
|
|
9
|
+
from ck.utils.np_extras import DType
|
|
10
|
+
from ck.utils.random_extras import Random
|
|
11
|
+
from .sampler import Sampler
|
|
12
|
+
from .sampler_support import YieldF
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UniformSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
rvs: RandomVariable | Sequence[RandomVariable],
|
|
20
|
+
condition: Condition = (),
|
|
21
|
+
rand: Random = random,
|
|
22
|
+
):
|
|
23
|
+
condition: Sequence[Indicator] = check_condition(condition)
|
|
24
|
+
|
|
25
|
+
self._yield_f: YieldF
|
|
26
|
+
if isinstance(rvs, RandomVariable):
|
|
27
|
+
# a single rv
|
|
28
|
+
rvs = (rvs,)
|
|
29
|
+
self._yield_f = lambda x: x.item()
|
|
30
|
+
else:
|
|
31
|
+
# a sequence of rvs
|
|
32
|
+
self._yield_f = lambda x: x.tolist()
|
|
33
|
+
|
|
34
|
+
super().__init__(rvs, condition)
|
|
35
|
+
|
|
36
|
+
# Group condition indicators by `rv_idx`.
|
|
37
|
+
conditioned_rvs: MapSet[int, int] = MapSet()
|
|
38
|
+
for ind in condition:
|
|
39
|
+
conditioned_rvs.add(ind.rv_idx, ind.state_idx)
|
|
40
|
+
|
|
41
|
+
def get_possible_states(_rv: RandomVariable) -> List[int]:
|
|
42
|
+
"""
|
|
43
|
+
Get the allowable states for a given random variable, given
|
|
44
|
+
conditions in `conditioned_rvs`.
|
|
45
|
+
"""
|
|
46
|
+
condition_states: Optional[Set[int]] = conditioned_rvs.get(_rv.idx)
|
|
47
|
+
if condition_states is None:
|
|
48
|
+
return list(range(len(_rv)))
|
|
49
|
+
else:
|
|
50
|
+
return list(condition_states)
|
|
51
|
+
|
|
52
|
+
possible_states: List[List[int]] = [
|
|
53
|
+
get_possible_states(rv)
|
|
54
|
+
for rv in self.rvs
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
self._possible_states: List[List[int]] = possible_states
|
|
58
|
+
self._rand: Random = rand
|
|
59
|
+
self._state_dtype: DType = dtype_for_state_indexes(self.rvs)
|
|
60
|
+
|
|
61
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
62
|
+
possible_states = self._possible_states
|
|
63
|
+
yield_f = self._yield_f
|
|
64
|
+
rand = self._rand
|
|
65
|
+
state = np.zeros(len(possible_states), dtype=self._state_dtype)
|
|
66
|
+
while True:
|
|
67
|
+
for i, l in enumerate(possible_states):
|
|
68
|
+
state_idx = rand.randrange(0, len(l))
|
|
69
|
+
state[i] = l[state_idx]
|
|
70
|
+
# We know the yield function will always provide either ints or Instances
|
|
71
|
+
# noinspection PyTypeChecker
|
|
72
|
+
yield yield_f(state)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayNumeric, NDArrayStates
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCDirectSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
24
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
25
|
+
self._rand: Random = rand
|
|
26
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
27
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
28
|
+
self._chain_rvs: Sequence[SampleRV] = tuple(
|
|
29
|
+
sample_rv for sample_rv in sampler_info.sample_rvs if sample_rv.copy_index is not None)
|
|
30
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
31
|
+
self._max_number_of_states: int = max((len(rv) for rv in self.rvs), default=0)
|
|
32
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
33
|
+
|
|
34
|
+
# Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
|
|
35
|
+
slots: NDArrayNumeric = self._program_buffer.vars
|
|
36
|
+
for slot in sampler_info.slots_0:
|
|
37
|
+
slots[slot] = 0
|
|
38
|
+
for slot in sampler_info.slots_1:
|
|
39
|
+
slots[slot] = 1
|
|
40
|
+
|
|
41
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
42
|
+
yield_f = self._yield_f
|
|
43
|
+
rand = self._rand
|
|
44
|
+
sample_rvs = self._sample_rvs
|
|
45
|
+
chain_rvs = self._chain_rvs
|
|
46
|
+
slots_1 = self._slots_1
|
|
47
|
+
program_buffer = self._program_buffer
|
|
48
|
+
slots: NDArrayNumeric = program_buffer.vars
|
|
49
|
+
|
|
50
|
+
# Calling wmc() will give the weighted model count for the state of the current input slots.
|
|
51
|
+
def wmc() -> float:
|
|
52
|
+
return program_buffer.compute().item()
|
|
53
|
+
|
|
54
|
+
# Set up working memory buffers
|
|
55
|
+
states: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
56
|
+
buff_slots = np.zeros(self._max_number_of_states, dtype=np.uintp)
|
|
57
|
+
buff_states = np.zeros(self._max_number_of_states, dtype=self._state_dtype)
|
|
58
|
+
|
|
59
|
+
while True:
|
|
60
|
+
# Consider all possible instantiations given the conditions, c, where the instantiations are ordered.
|
|
61
|
+
# Let awmc(i|c) be the accumulated WMC of the ith instantiation.
|
|
62
|
+
# We want to find the smallest instantiation i such that
|
|
63
|
+
# rnd <= awmc(i|c)
|
|
64
|
+
# where rnd is in [0, 1) * wmc().
|
|
65
|
+
|
|
66
|
+
rnd: float = rand.random() * wmc()
|
|
67
|
+
|
|
68
|
+
for sample_rv in sample_rvs:
|
|
69
|
+
# Prepare to loop over random variable states.
|
|
70
|
+
# Keep track of the non-zero slots in buff_slots and buff_states.
|
|
71
|
+
num_possible_states: int = 0
|
|
72
|
+
for j, slot in enumerate(sample_rv.slots):
|
|
73
|
+
if slots[slot] != 0:
|
|
74
|
+
buff_slots[num_possible_states] = slot
|
|
75
|
+
buff_states[num_possible_states] = j
|
|
76
|
+
num_possible_states += 1
|
|
77
|
+
|
|
78
|
+
if num_possible_states == 0:
|
|
79
|
+
raise RuntimeError('zero probability')
|
|
80
|
+
|
|
81
|
+
# Try each possible state of the current random variable.
|
|
82
|
+
# Once a state is selected, then the following is true:
|
|
83
|
+
# states[rv_position] = state
|
|
84
|
+
# m_prev_states[rv_position] = state
|
|
85
|
+
# slots set up to include condition rv = state.
|
|
86
|
+
# rnd is reduced to account for the states skipped.
|
|
87
|
+
#
|
|
88
|
+
# We can do this either by sequentially checking each state or by doing
|
|
89
|
+
# a binary search. Here we start with binary search then finish sequentially
|
|
90
|
+
# once the candidates size falls below 'THRESHOLD'.
|
|
91
|
+
|
|
92
|
+
# Binary search
|
|
93
|
+
THRESHOLD = 2
|
|
94
|
+
lo: int = 0
|
|
95
|
+
hi: int = num_possible_states
|
|
96
|
+
w_0_mark: int = 0
|
|
97
|
+
w: float = 0
|
|
98
|
+
while lo + THRESHOLD < hi:
|
|
99
|
+
mid: int = (lo + hi) // 2
|
|
100
|
+
|
|
101
|
+
for i in range(mid, hi):
|
|
102
|
+
slots[buff_slots[i]] = 0
|
|
103
|
+
|
|
104
|
+
w = wmc()
|
|
105
|
+
w_0_mark = mid
|
|
106
|
+
if w < rnd:
|
|
107
|
+
# wmc() is too low, the desired state is >= buff_states[mid]
|
|
108
|
+
for i in range(mid, hi):
|
|
109
|
+
slots[buff_slots[i]] = 1
|
|
110
|
+
lo = mid
|
|
111
|
+
else:
|
|
112
|
+
# wmc() is too high, the desired state is < buff_states[mid]
|
|
113
|
+
hi = mid
|
|
114
|
+
|
|
115
|
+
# Now the state we want is between lo (inclusive) and hi (exclusive).
|
|
116
|
+
# Slots at least up to lo will be set to 1.
|
|
117
|
+
|
|
118
|
+
# clear top slots, lo and up.
|
|
119
|
+
for k in range(lo, num_possible_states):
|
|
120
|
+
slots[buff_slots[k]] = 0
|
|
121
|
+
|
|
122
|
+
# Adjust rnd to account for lo > 0.
|
|
123
|
+
if lo == 0:
|
|
124
|
+
# The chances of this case may be low, but if so, then
|
|
125
|
+
# slots[m_buff_slots[lo]] = 0 which implies wmc() == 0,
|
|
126
|
+
# so we can save a call to wmc().
|
|
127
|
+
pass
|
|
128
|
+
elif w_0_mark == lo:
|
|
129
|
+
# We can use the last wmc() call, stored in w.
|
|
130
|
+
# This saves a call to wmc().
|
|
131
|
+
rnd -= w
|
|
132
|
+
else:
|
|
133
|
+
rnd -= wmc()
|
|
134
|
+
|
|
135
|
+
# Clear remaining slots
|
|
136
|
+
for k in range(0, lo):
|
|
137
|
+
slots[buff_slots[k]] = 0
|
|
138
|
+
|
|
139
|
+
# Sequential search
|
|
140
|
+
k = lo
|
|
141
|
+
while k < hi:
|
|
142
|
+
slot = buff_slots[k]
|
|
143
|
+
slots[slot] = 1
|
|
144
|
+
w = wmc()
|
|
145
|
+
if rnd < w:
|
|
146
|
+
break
|
|
147
|
+
slots[slot] = 0
|
|
148
|
+
rnd -= w
|
|
149
|
+
k += 1
|
|
150
|
+
|
|
151
|
+
slot = buff_slots[k]
|
|
152
|
+
state = buff_states[k]
|
|
153
|
+
slots[slot] = 1
|
|
154
|
+
states[sample_rv.index] = state
|
|
155
|
+
|
|
156
|
+
# We know the yield function will always provide either ints or Instances
|
|
157
|
+
# noinspection PyTypeChecker
|
|
158
|
+
yield yield_f(states)
|
|
159
|
+
|
|
160
|
+
# Reset the one slots for the next iteration.
|
|
161
|
+
for slot in slots_1:
|
|
162
|
+
slots[slot] = 1
|
|
163
|
+
|
|
164
|
+
# Copy chain pairs for next iteration.
|
|
165
|
+
# (This writes over any initial chain conditions from slots_1.)
|
|
166
|
+
for sample_rv in chain_rvs:
|
|
167
|
+
rv_slots = sample_rv.slots
|
|
168
|
+
prev_state_idx: int = states.item(sample_rv.copy_index)
|
|
169
|
+
for slot in rv_slots:
|
|
170
|
+
slots[slot] = 0
|
|
171
|
+
slots[rv_slots[prev_state_idx]] = 1
|