compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +13 -0
- ck/circuit/circuit.c +38749 -0
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit_py.py +807 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser.py +81 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53674 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +288 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3494 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +75 -0
- ck/pgm_circuit/program_with_slotmap.py +234 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +252 -0
- ck/pgm_compiler/factor_elimination.py +383 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +226 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
- ck/pgm_compiler/support/clusters.py +556 -0
- ck/pgm_compiler/support/factor_tables.py +398 -0
- ck/pgm_compiler/support/join_tree.py +275 -0
- ck/pgm_compiler/support/named_compiler_maker.py +33 -0
- ck/pgm_compiler/variable_elimination.py +89 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +47 -0
- ck/probability/probability_space.py +568 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +129 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +61 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +66 -0
- ck/sampling/wmc_direct_sampler.py +169 -0
- ck/sampling/wmc_gibbs_sampler.py +147 -0
- ck/sampling/wmc_metropolis_sampler.py +159 -0
- ck/sampling/wmc_rejection_sampler.py +113 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +153 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +44 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +88 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
- compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from itertools import count
|
|
3
|
+
from typing import Callable, Sequence, Optional, Set, Tuple, Dict, Collection
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance, RandomVariable, Indicator
|
|
6
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
7
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
8
|
+
from ck.probability.probability_space import Condition, check_condition
|
|
9
|
+
from ck.utils.map_set import MapSet
|
|
10
|
+
from ck.utils.np_extras import NDArrayStates, NDArrayNumeric
|
|
11
|
+
from ck.utils.random_extras import Random
|
|
12
|
+
|
|
13
|
+
# Type of a yield function. Support for a sampler.
|
|
14
|
+
# A yield function may be used to implement a sampler's iterator, thus
|
|
15
|
+
# it provides an Instance or single state index.
|
|
16
|
+
YieldF = Callable[[NDArrayStates], Instance | int]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SampleRV:
|
|
21
|
+
"""
|
|
22
|
+
Support for a sampler.
|
|
23
|
+
A SampleRV structure keeps track of information for one sampled random variable.
|
|
24
|
+
"""
|
|
25
|
+
index: int # index into the sequence of sample rvs.
|
|
26
|
+
rv: RandomVariable # the random variable being sampled.
|
|
27
|
+
slots: Sequence[int] # program input slots for indicators of the random variable (co-indexed with rv.states).
|
|
28
|
+
copy_index: Optional[int] # for Markov chains, which previous sample rv should be copied?
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class SamplerInfo:
|
|
33
|
+
"""
|
|
34
|
+
Support for a sampler.
|
|
35
|
+
A SamplerInfo structure keeps track of standard information when a sampler uses a Program.
|
|
36
|
+
"""
|
|
37
|
+
sample_rvs: Sequence[SampleRV]
|
|
38
|
+
condition: Sequence[Indicator]
|
|
39
|
+
yield_f: YieldF
|
|
40
|
+
slots_0: Set[int]
|
|
41
|
+
slots_1: Set[int]
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def rvs(self) -> Tuple[RandomVariable, ...]:
|
|
45
|
+
"""
|
|
46
|
+
Extract the RandomVariable objects from `self.sample_rvs`.
|
|
47
|
+
"""
|
|
48
|
+
return tuple(sample_rv.rv for sample_rv in self.sample_rvs)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_sampler_info(
|
|
52
|
+
program_with_slotmap: ProgramWithSlotmap,
|
|
53
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]],
|
|
54
|
+
condition: Condition,
|
|
55
|
+
chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
|
|
56
|
+
initial_chain_condition: Condition = (),
|
|
57
|
+
) -> SamplerInfo:
|
|
58
|
+
"""
|
|
59
|
+
Helper for samplers.
|
|
60
|
+
|
|
61
|
+
Determines:
|
|
62
|
+
(1) the slots for sampling rvs,
|
|
63
|
+
(2) Markov chaining rvs,
|
|
64
|
+
(3) the function to use for yielding an Instance or state index.
|
|
65
|
+
|
|
66
|
+
If parameter `rvs` is a RandomVariable, then the yield function will
|
|
67
|
+
provide a state index. If parameter `rvs` is a Sequence, then the
|
|
68
|
+
yield function will provide an Instance.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
program_with_slotmap: the program and slotmap being referenced.
|
|
72
|
+
rvs: the random variables to sample. It may be either a sequence of
|
|
73
|
+
random variables, or a single random variable.
|
|
74
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
75
|
+
chain_pairs: is a collection of pairs of random variables, each random variable
|
|
76
|
+
must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
|
|
77
|
+
as a condition for to_rv prior to generating a sample.
|
|
78
|
+
initial_chain_condition: are condition indicators (just like condition)
|
|
79
|
+
for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: if preconditions of `program_with_slotmap` are incompatible with the given condition.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a SamplerInfo structure.
|
|
86
|
+
"""
|
|
87
|
+
if rvs is None:
|
|
88
|
+
rvs = program_with_slotmap.rvs
|
|
89
|
+
if isinstance(rvs, RandomVariable):
|
|
90
|
+
# a single rv
|
|
91
|
+
rvs = (rvs,)
|
|
92
|
+
yield_f = lambda x: x.item()
|
|
93
|
+
else:
|
|
94
|
+
# a sequence of rvs
|
|
95
|
+
rvs = tuple(rvs)
|
|
96
|
+
yield_f = lambda x: x.tolist()
|
|
97
|
+
|
|
98
|
+
# Group condition indicators by `rv_idx`.
|
|
99
|
+
conditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
100
|
+
for ind in check_condition(condition):
|
|
101
|
+
conditioned_rvs.add(ind.rv_idx, ind)
|
|
102
|
+
del condition
|
|
103
|
+
|
|
104
|
+
# Group precondition indicators by `rv_idx`.
|
|
105
|
+
preconditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
106
|
+
for ind in program_with_slotmap.precondition:
|
|
107
|
+
preconditioned_rvs.add(ind.rv_idx, ind)
|
|
108
|
+
|
|
109
|
+
# Rationalise conditioned_rvs with preconditioned_rvs
|
|
110
|
+
rv_idx: int
|
|
111
|
+
precondition_set: Set[Indicator]
|
|
112
|
+
for rv_idx, precondition_set in preconditioned_rvs.items():
|
|
113
|
+
condition_set = conditioned_rvs.get(rv_idx)
|
|
114
|
+
if condition_set is None:
|
|
115
|
+
# A preconditioned rv was not mentioned in the explicit conditions
|
|
116
|
+
conditioned_rvs.add_all(rv_idx, precondition_set)
|
|
117
|
+
else:
|
|
118
|
+
# A preconditioned rv was also mentioned in the explicit conditions
|
|
119
|
+
condition_set.intersection_update(precondition_set)
|
|
120
|
+
if len(condition_set) == 0:
|
|
121
|
+
rv_index: Dict[int, RandomVariable] = {rv.idx: rv for rv in rvs}
|
|
122
|
+
rv: RandomVariable = rv_index[rv_idx]
|
|
123
|
+
raise ValueError(f'conditions on rv {rv} are disjoint from preconditions')
|
|
124
|
+
del preconditioned_rvs
|
|
125
|
+
|
|
126
|
+
# Group initial chain indicators by `rv_idx`.
|
|
127
|
+
initial_chain_condition: Sequence[Indicator] = check_condition(initial_chain_condition)
|
|
128
|
+
initial_chain_conditioned_rvs: MapSet[int, Indicator] = MapSet()
|
|
129
|
+
for ind in initial_chain_condition:
|
|
130
|
+
initial_chain_conditioned_rvs.add(ind.rv_idx, ind)
|
|
131
|
+
|
|
132
|
+
# Check sample rvs are valid and without duplicates.
|
|
133
|
+
rvs_set: Set[RandomVariable] = set(rvs)
|
|
134
|
+
if not rvs_set.issubset(program_with_slotmap.rvs):
|
|
135
|
+
raise ValueError('sample random variables not available')
|
|
136
|
+
if len(rvs) != len(rvs_set):
|
|
137
|
+
raise ValueError('duplicate sample random variables requested')
|
|
138
|
+
|
|
139
|
+
# Check chain_pairs rvs are being sampled
|
|
140
|
+
if not rvs_set.issuperset(pair[0] for pair in chain_pairs):
|
|
141
|
+
raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
|
|
142
|
+
if not rvs_set.issuperset(pair[1] for pair in chain_pairs):
|
|
143
|
+
raise ValueError('a random variable appears in chain_pairs but not in sample rvs')
|
|
144
|
+
|
|
145
|
+
# Check chain_pairs source and destination rvs are disjoint
|
|
146
|
+
if not {pair[0] for pair in chain_pairs}.isdisjoint(pair[1] for pair in chain_pairs):
|
|
147
|
+
raise ValueError('chain_pairs sources and destinations are not disjoint')
|
|
148
|
+
|
|
149
|
+
# Check no chain_pairs destination rv is a conditioned rv
|
|
150
|
+
if any(pair[1].idx in conditioned_rvs.keys() for pair in chain_pairs):
|
|
151
|
+
raise ValueError('a chain_pairs destination is conditioned')
|
|
152
|
+
|
|
153
|
+
# Check chain initial conditions relate to chain_pairs destination rvs
|
|
154
|
+
chain_dest_rv_idxs: Set[int] = {pair[1].idx for pair in chain_pairs}
|
|
155
|
+
if not all(rv_idx in chain_dest_rv_idxs for rv_idx in initial_chain_conditioned_rvs.keys()):
|
|
156
|
+
raise ValueError('a chain initial condition is not a chain destination rv')
|
|
157
|
+
|
|
158
|
+
# Convert chain_pairs for registering with `sample_rvs`.
|
|
159
|
+
# rv_idx maps RandomVariable id to a position it exists in rvs (doesn't matter if rv is duplicated in rvs)
|
|
160
|
+
# copy_idx RandomVariable id to a position in rvs that it can be copied from for Markov chaining.
|
|
161
|
+
rv_idx: Dict[int, int] = {id(rv): i for i, rv in enumerate(rvs)}
|
|
162
|
+
copy_idx: Dict[int, int] = {id(rv): rv_idx[id(prev_rv)] for prev_rv, rv in chain_pairs}
|
|
163
|
+
|
|
164
|
+
# Get rv state slots, rvs_slots is co-indexed with rvs
|
|
165
|
+
slot_map: SlotMap = program_with_slotmap.slot_map
|
|
166
|
+
rvs_slots = tuple(tuple(slot_map[ind] for ind in rv) for rv in rvs)
|
|
167
|
+
|
|
168
|
+
sample_rvs: Sequence[SampleRV] = tuple(
|
|
169
|
+
SampleRV(idx, rv, rv_slots, copy_idx.get(id(rv)))
|
|
170
|
+
for idx, rv, rv_slots in zip(count(), rvs, rvs_slots)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Process the condition to get zero and one slots
|
|
174
|
+
slots_0: Set[int] = set()
|
|
175
|
+
slots_1: Set[int] = set()
|
|
176
|
+
for rv in program_with_slotmap.rvs:
|
|
177
|
+
conditioning: Optional[Set[Indicator]] = conditioned_rvs.get(rv.idx)
|
|
178
|
+
if conditioning is not None:
|
|
179
|
+
slots_1.update(slot_map[ind] for ind in conditioning)
|
|
180
|
+
slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
conditioning: Optional[Set[Indicator]] = initial_chain_conditioned_rvs.get(rv.idx)
|
|
184
|
+
if conditioning is not None:
|
|
185
|
+
slots_1.update(slot_map[ind] for ind in conditioning)
|
|
186
|
+
slots_0.update(slot_map[ind] for ind in rv if ind not in conditioning)
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
# default
|
|
190
|
+
slots_1.update(slot_map[ind] for ind in rv)
|
|
191
|
+
|
|
192
|
+
return SamplerInfo(
|
|
193
|
+
sample_rvs=sample_rvs,
|
|
194
|
+
condition=tuple(ind for condition_set in conditioned_rvs.values() for ind in condition_set),
|
|
195
|
+
yield_f=yield_f,
|
|
196
|
+
slots_0=slots_0,
|
|
197
|
+
slots_1=slots_1,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def uniform_random_sample(
|
|
202
|
+
sample_rvs: Sequence[SampleRV],
|
|
203
|
+
slots_0: Collection[int],
|
|
204
|
+
slots_1: Collection[int],
|
|
205
|
+
slots: NDArrayNumeric,
|
|
206
|
+
state: NDArrayStates,
|
|
207
|
+
rand: Random,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Helper for samplers.
|
|
211
|
+
|
|
212
|
+
Sets the states to a random instance and configures slots to match.
|
|
213
|
+
States are drawn from a uniform distribution, drawn using random.randrange.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
# Set up the input slots to respect conditioning
|
|
217
|
+
for slot in slots_0:
|
|
218
|
+
slots[slot] = 0
|
|
219
|
+
for slot in slots_1:
|
|
220
|
+
slots[slot] = 1
|
|
221
|
+
|
|
222
|
+
for sample_rv in sample_rvs:
|
|
223
|
+
candidates = []
|
|
224
|
+
for slot_state, slot in enumerate(sample_rv.slots):
|
|
225
|
+
if slots[slot] == 1:
|
|
226
|
+
slots[slot] = 0
|
|
227
|
+
candidates.append((slot_state, slot))
|
|
228
|
+
|
|
229
|
+
# Pick a random state for sample_rv
|
|
230
|
+
slot_state, slot = candidates[rand.randrange(0, len(candidates))]
|
|
231
|
+
state[sample_rv.index] = slot_state
|
|
232
|
+
slots[slot] = 1
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import Set, List, Iterator, Optional, Sequence
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ck.pgm import Instance, RandomVariable, Indicator
|
|
7
|
+
from ck.probability.probability_space import dtype_for_state_indexes, Condition, check_condition
|
|
8
|
+
from .sampler import Sampler
|
|
9
|
+
from .sampler_support import YieldF
|
|
10
|
+
from ck.utils.map_set import MapSet
|
|
11
|
+
from ck.utils.np_extras import DType
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UniformSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
rvs: RandomVariable | Sequence[RandomVariable],
|
|
20
|
+
condition: Condition = (),
|
|
21
|
+
rand: Random = random,
|
|
22
|
+
):
|
|
23
|
+
condition: Sequence[Indicator] = check_condition(condition)
|
|
24
|
+
|
|
25
|
+
self._yield_f: YieldF
|
|
26
|
+
if isinstance(rvs, RandomVariable):
|
|
27
|
+
# a single rv
|
|
28
|
+
rvs = (rvs,)
|
|
29
|
+
self._yield_f = lambda x: x.item()
|
|
30
|
+
else:
|
|
31
|
+
# a sequence of rvs
|
|
32
|
+
self._yield_f = lambda x: x.tolist()
|
|
33
|
+
|
|
34
|
+
super().__init__(rvs, condition)
|
|
35
|
+
|
|
36
|
+
# Group condition indicators by `rv_idx`.
|
|
37
|
+
conditioned_rvs: MapSet[int, int] = MapSet()
|
|
38
|
+
for ind in condition:
|
|
39
|
+
conditioned_rvs.add(ind.rv_idx, ind.state_idx)
|
|
40
|
+
|
|
41
|
+
def get_possible_states(_rv: RandomVariable) -> List[int]:
|
|
42
|
+
condition_states: Optional[Set[int]] = conditioned_rvs.get(_rv.idx)
|
|
43
|
+
if condition_states is None:
|
|
44
|
+
return list(range(len(_rv)))
|
|
45
|
+
else:
|
|
46
|
+
return [state_idx for state_idx in range(len(_rv)) if state_idx not in condition_states]
|
|
47
|
+
|
|
48
|
+
possible_states: List[List[int]] = [
|
|
49
|
+
get_possible_states(rv)
|
|
50
|
+
for rv in self.rvs
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
self._possible_states: List[List[int]] = possible_states
|
|
54
|
+
self._rand: Random = rand
|
|
55
|
+
self._state_dtype: DType = dtype_for_state_indexes(self.rvs)
|
|
56
|
+
|
|
57
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
58
|
+
possible_states = self._possible_states
|
|
59
|
+
yield_f = self._yield_f
|
|
60
|
+
rand = self._rand
|
|
61
|
+
state = np.zeros(len(possible_states), dtype=self._state_dtype)
|
|
62
|
+
while True:
|
|
63
|
+
for i, l in enumerate(possible_states):
|
|
64
|
+
state_idx = rand.randrange(0, len(l))
|
|
65
|
+
state[i] = l[state_idx]
|
|
66
|
+
yield yield_f(state)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayNumeric
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCDirectSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
24
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
25
|
+
self._rand: Random = rand
|
|
26
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
27
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
28
|
+
self._chain_rvs: Sequence[SampleRV] = tuple(
|
|
29
|
+
sample_rv for sample_rv in sampler_info.sample_rvs if sample_rv.copy_index is not None)
|
|
30
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
31
|
+
self._max_number_of_states: int = max((len(rv) for rv in self.rvs), default=0)
|
|
32
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
33
|
+
|
|
34
|
+
# Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
|
|
35
|
+
slots: NDArrayNumeric = self._program_buffer.vars
|
|
36
|
+
for slot in sampler_info.slots_0:
|
|
37
|
+
slots[slot] = 0
|
|
38
|
+
for slot in sampler_info.slots_1:
|
|
39
|
+
slots[slot] = 1
|
|
40
|
+
|
|
41
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
42
|
+
yield_f = self._yield_f
|
|
43
|
+
rand = self._rand
|
|
44
|
+
sample_rvs = self._sample_rvs
|
|
45
|
+
chain_rvs = self._chain_rvs
|
|
46
|
+
slots_1 = self._slots_1
|
|
47
|
+
program_buffer = self._program_buffer
|
|
48
|
+
slots: NDArrayNumeric = program_buffer.vars
|
|
49
|
+
|
|
50
|
+
# Calling wmc() will give the weighted model count for the state of the current input slots.
|
|
51
|
+
def wmc() -> float:
|
|
52
|
+
return program_buffer.compute().item()
|
|
53
|
+
|
|
54
|
+
# Set up working memory buffers
|
|
55
|
+
states = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
56
|
+
buff_slots = np.zeros(self._max_number_of_states, dtype=np.uintp)
|
|
57
|
+
buff_states = np.zeros(self._max_number_of_states, dtype=self._state_dtype)
|
|
58
|
+
|
|
59
|
+
while True:
|
|
60
|
+
# Consider all possible instantiations given the conditions, c, where the instantiations are ordered.
|
|
61
|
+
# Let awmc(i|c) be the accumulated WMC of the ith instantiation.
|
|
62
|
+
# We want to find the smallest instantiation i such that
|
|
63
|
+
# rnd <= awmc(i|c)
|
|
64
|
+
# where rnd is in [0, 1) * wmc().
|
|
65
|
+
|
|
66
|
+
rnd: float = rand.random() * wmc()
|
|
67
|
+
|
|
68
|
+
for sample_rv in sample_rvs:
|
|
69
|
+
# Prepare to loop over random variable states.
|
|
70
|
+
# Keep track of the non-zero slots in buff_slots and buff_states.
|
|
71
|
+
num_possible_states: int = 0
|
|
72
|
+
for j, slot in enumerate(sample_rv.slots):
|
|
73
|
+
if slots[slot] != 0:
|
|
74
|
+
buff_slots[num_possible_states] = slot
|
|
75
|
+
buff_states[num_possible_states] = j
|
|
76
|
+
num_possible_states += 1
|
|
77
|
+
|
|
78
|
+
if num_possible_states == 0:
|
|
79
|
+
raise RuntimeError('zero probability')
|
|
80
|
+
|
|
81
|
+
# Try each possible state of the current random variable.
|
|
82
|
+
# Once a state is selected, then the following is true:
|
|
83
|
+
# states[rv_position] = state
|
|
84
|
+
# m_prev_states[rv_position] = state
|
|
85
|
+
# slots set up to include condition rv = state.
|
|
86
|
+
# rnd is reduced to account for the states skipped.
|
|
87
|
+
#
|
|
88
|
+
# We can do this either by sequentially checking each state or by doing
|
|
89
|
+
# a binary search. Here we start with binary search then finish sequentially
|
|
90
|
+
# once the candidates size falls below 'THRESHOLD'.
|
|
91
|
+
|
|
92
|
+
# Binary search
|
|
93
|
+
THRESHOLD = 2
|
|
94
|
+
lo: int = 0
|
|
95
|
+
hi: int = num_possible_states
|
|
96
|
+
w_0_mark: int = 0
|
|
97
|
+
w: float = 0
|
|
98
|
+
while lo + THRESHOLD < hi:
|
|
99
|
+
mid: int = (lo + hi) // 2
|
|
100
|
+
|
|
101
|
+
for i in range(mid, hi):
|
|
102
|
+
slots[buff_slots[i]] = 0
|
|
103
|
+
|
|
104
|
+
w = wmc()
|
|
105
|
+
w_0_mark = mid
|
|
106
|
+
if w < rnd:
|
|
107
|
+
# wmc() is too low, the desired state is >= buff_states[mid]
|
|
108
|
+
for i in range(mid, hi):
|
|
109
|
+
slots[buff_slots[i]] = 1
|
|
110
|
+
lo = mid
|
|
111
|
+
else:
|
|
112
|
+
# wmc() is too high, the desired state is < buff_states[mid]
|
|
113
|
+
hi = mid
|
|
114
|
+
|
|
115
|
+
# Now the state we want is between lo (inclusive) and hi (exclusive).
|
|
116
|
+
# Slots at least up to lo will be set to 1.
|
|
117
|
+
|
|
118
|
+
# clear top slots, lo and up.
|
|
119
|
+
for k in range(lo, num_possible_states):
|
|
120
|
+
slots[buff_slots[k]] = 0
|
|
121
|
+
|
|
122
|
+
# Adjust rnd to account for lo > 0.
|
|
123
|
+
if lo == 0:
|
|
124
|
+
# The chances of this case may be low, but if so, then
|
|
125
|
+
# slots[m_buff_slots[lo]] = 0 which implies wmc() == 0,
|
|
126
|
+
# so we can save a call to wmc().
|
|
127
|
+
pass
|
|
128
|
+
elif w_0_mark == lo:
|
|
129
|
+
# We can use the last wmc() call, stored in w.
|
|
130
|
+
# This saves a call to wmc().
|
|
131
|
+
rnd -= w
|
|
132
|
+
else:
|
|
133
|
+
rnd -= wmc()
|
|
134
|
+
|
|
135
|
+
# Clear remaining slots
|
|
136
|
+
for k in range(0, lo):
|
|
137
|
+
slots[buff_slots[k]] = 0
|
|
138
|
+
|
|
139
|
+
# Sequential search
|
|
140
|
+
k = lo
|
|
141
|
+
while k < hi:
|
|
142
|
+
slot = buff_slots[k]
|
|
143
|
+
slots[slot] = 1
|
|
144
|
+
w = wmc()
|
|
145
|
+
if rnd < w:
|
|
146
|
+
break
|
|
147
|
+
slots[slot] = 0
|
|
148
|
+
rnd -= w
|
|
149
|
+
k += 1
|
|
150
|
+
|
|
151
|
+
slot = buff_slots[k]
|
|
152
|
+
state = buff_states[k]
|
|
153
|
+
slots[slot] = 1
|
|
154
|
+
states[sample_rv.index] = state
|
|
155
|
+
|
|
156
|
+
yield yield_f(states)
|
|
157
|
+
|
|
158
|
+
# Reset the one slots for the next iteration.
|
|
159
|
+
for slot in slots_1:
|
|
160
|
+
slots[slot] = 1
|
|
161
|
+
|
|
162
|
+
# Copy chain pairs for next iteration.
|
|
163
|
+
# (This writes over any initial chain conditions from slots_1.)
|
|
164
|
+
for sample_rv in chain_rvs:
|
|
165
|
+
rv_slots = sample_rv.slots
|
|
166
|
+
prev_state_idx: int = states.item(sample_rv.copy_index)
|
|
167
|
+
for slot in rv_slots:
|
|
168
|
+
slots[slot] = 0
|
|
169
|
+
slots[rv_slots[prev_state_idx]] = 1
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence, List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayStates, NDArrayFloat64
|
|
12
|
+
from ck.utils.random_extras import Random, random_permute
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCGibbsSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
skip: int,
|
|
23
|
+
burn_in: int,
|
|
24
|
+
pr_restart: float,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
27
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
28
|
+
self._rand: Random = rand
|
|
29
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
30
|
+
self._sample_rvs: List[SampleRV] = list(sampler_info.sample_rvs)
|
|
31
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
32
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
33
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
34
|
+
self._skip: int = skip
|
|
35
|
+
self._burn_in: int = burn_in
|
|
36
|
+
self._pr_restart: float = pr_restart
|
|
37
|
+
|
|
38
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
39
|
+
sample_rvs: List[SampleRV] = self._sample_rvs
|
|
40
|
+
rand: Random = self._rand
|
|
41
|
+
yield_f: YieldF = self._yield_f
|
|
42
|
+
slots_0: Collection[int] = self._slots_0
|
|
43
|
+
slots_1: Collection[int] = self._slots_1
|
|
44
|
+
program_buffer: ProgramBuffer = self._program_buffer
|
|
45
|
+
skip: int = self._skip
|
|
46
|
+
burn_in: int = self._burn_in
|
|
47
|
+
pr_restart: float = self._pr_restart
|
|
48
|
+
|
|
49
|
+
# Allocate working memory
|
|
50
|
+
state = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
51
|
+
prs: Sequence[NDArrayFloat64] = tuple(
|
|
52
|
+
np.zeros(len(sample_rv.slots), dtype=np.float64)
|
|
53
|
+
for sample_rv in sample_rvs
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Set an initial system state
|
|
57
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
|
|
58
|
+
|
|
59
|
+
# Run a burn in
|
|
60
|
+
for i in range(burn_in):
|
|
61
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
62
|
+
|
|
63
|
+
if pr_restart <= 0:
|
|
64
|
+
# There is no possibility of a restart
|
|
65
|
+
if skip == 0:
|
|
66
|
+
while True:
|
|
67
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
68
|
+
yield yield_f(state)
|
|
69
|
+
else:
|
|
70
|
+
while True:
|
|
71
|
+
for _ in range(skip):
|
|
72
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
73
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
74
|
+
yield yield_f(state)
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
# There is the possibility of a restart
|
|
78
|
+
while True:
|
|
79
|
+
for _ in range(skip):
|
|
80
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
81
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
82
|
+
yield yield_f(state)
|
|
83
|
+
if rand.random() < pr_restart:
|
|
84
|
+
# Set an initial system state
|
|
85
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, program_buffer.vars, state, rand)
|
|
86
|
+
|
|
87
|
+
# Run a burn in
|
|
88
|
+
for i in range(burn_in):
|
|
89
|
+
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _next_sample_gibbs(
|
|
93
|
+
sample_rvs: List[SampleRV],
|
|
94
|
+
slots_1: Collection[int],
|
|
95
|
+
program_buffer: ProgramBuffer,
|
|
96
|
+
prs: Sequence[NDArrayFloat64],
|
|
97
|
+
state: NDArrayStates,
|
|
98
|
+
rand: Random
|
|
99
|
+
) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Updates the states to a random system and reconfigures program inputs to match.
|
|
102
|
+
"""
|
|
103
|
+
prog_in = program_buffer.vars
|
|
104
|
+
random_permute(sample_rvs, rand=rand)
|
|
105
|
+
for sample_rv in sample_rvs:
|
|
106
|
+
rv_slots = sample_rv.slots
|
|
107
|
+
index = sample_rv.index
|
|
108
|
+
|
|
109
|
+
rv_pr: NDArrayFloat64 = prs[index]
|
|
110
|
+
s: int = state.item(index)
|
|
111
|
+
|
|
112
|
+
candidates = []
|
|
113
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
114
|
+
if slot in slots_1:
|
|
115
|
+
candidates.append((slot_state, slot))
|
|
116
|
+
assert len(candidates) > 0
|
|
117
|
+
|
|
118
|
+
# Compute conditioned marginals for the current rv
|
|
119
|
+
prog_in[rv_slots[s]] = 0
|
|
120
|
+
for slot_state, slot in candidates:
|
|
121
|
+
prog_in[slot] = 1
|
|
122
|
+
rv_pr[slot_state] = program_buffer.compute()
|
|
123
|
+
prog_in[slot] = 0
|
|
124
|
+
|
|
125
|
+
# Pick a new state based on the conditional probabilities
|
|
126
|
+
total = np.sum(rv_pr)
|
|
127
|
+
if total == 0.0:
|
|
128
|
+
# No state of the current rv has a non-zero probability when
|
|
129
|
+
# conditioned on the other random variables states.
|
|
130
|
+
# Pick a random state form a uniform distribution.
|
|
131
|
+
i = rand.randrange(0, len(candidates))
|
|
132
|
+
candidate = candidates[i]
|
|
133
|
+
# update the states array and the wmc input
|
|
134
|
+
state[index] = candidate[0]
|
|
135
|
+
prog_in[candidate[1]] = 1
|
|
136
|
+
else:
|
|
137
|
+
# Pick a state, sampled from the marginal distribution
|
|
138
|
+
r = rand.random() * total
|
|
139
|
+
slot = None
|
|
140
|
+
slot_state = None
|
|
141
|
+
for slot_state, slot in candidates:
|
|
142
|
+
if r <= rv_pr[slot_state]:
|
|
143
|
+
break
|
|
144
|
+
r -= rv_pr[slot_state]
|
|
145
|
+
# update the states array and the wmc input
|
|
146
|
+
state[index] = slot_state
|
|
147
|
+
prog_in[slot] = 1
|