compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +13 -0
- ck/circuit/circuit.c +38749 -0
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit_py.py +807 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser.py +81 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53674 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +288 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3494 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +75 -0
- ck/pgm_circuit/program_with_slotmap.py +234 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +252 -0
- ck/pgm_compiler/factor_elimination.py +383 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +226 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
- ck/pgm_compiler/support/clusters.py +556 -0
- ck/pgm_compiler/support/factor_tables.py +398 -0
- ck/pgm_compiler/support/join_tree.py +275 -0
- ck/pgm_compiler/support/named_compiler_maker.py +33 -0
- ck/pgm_compiler/variable_elimination.py +89 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +47 -0
- ck/probability/probability_space.py +568 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +129 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +61 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +66 -0
- ck/sampling/wmc_direct_sampler.py +169 -0
- ck/sampling/wmc_gibbs_sampler.py +147 -0
- ck/sampling/wmc_metropolis_sampler.py +159 -0
- ck/sampling/wmc_rejection_sampler.py +113 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +153 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +44 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +88 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
- compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from typing import Tuple, Iterator, Sequence, Dict, Iterable
|
|
2
|
+
|
|
3
|
+
from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
|
|
4
|
+
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
5
|
+
from ck.probability.probability_space import Condition, check_condition
|
|
6
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
7
|
+
from ck.utils.np_extras import NDArray, NDArrayNumeric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ProgramWithSlotmap:
|
|
11
|
+
"""
|
|
12
|
+
A class for bundling a program buffer with a slot-map, where the slot-map maps keys
|
|
13
|
+
(e.g., random variable indicators) to program input slots.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
program_buffer: ProgramBuffer,
|
|
19
|
+
slot_map: SlotMap,
|
|
20
|
+
rvs: Sequence[RandomVariable],
|
|
21
|
+
precondition: Sequence[Indicator]
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Construct a ProgramWithSlotmap object.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
program_buffer: is a ProgramBuffer object which is a compiled circuit with input and output slots.
|
|
28
|
+
slot_map: a maps from a slot_key to input slot of 'program'.
|
|
29
|
+
rvs: a sequence of rvs used for setting program input slots, each rv
|
|
30
|
+
has a length and rv[i] is a unique 'indicator' across all rvs.
|
|
31
|
+
precondition: conditions on rvs that are compiled into the program.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
self._program_buffer: ProgramBuffer = program_buffer
|
|
35
|
+
self._slot_map: SlotMap = slot_map
|
|
36
|
+
self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
|
|
37
|
+
self._precondition: Sequence[Indicator] = precondition
|
|
38
|
+
|
|
39
|
+
if len(rvs) != len(set(rv.idx for rv in rvs)):
|
|
40
|
+
raise ValueError('duplicate random variables provided')
|
|
41
|
+
|
|
42
|
+
# Given rv = rvs[i], then _rvs_slots[i][state_idx] gives the slot for rv[state_idx].
|
|
43
|
+
self._rvs_slots: Tuple[Tuple[int, ...], ...] = tuple(tuple(self._slot_map[ind] for ind in rv) for rv in rvs)
|
|
44
|
+
|
|
45
|
+
# Given rv = rvs[i], then _indicator_map maps[rv[j]] = (i, slot), where slot is for indicator rv[j].
|
|
46
|
+
self._indicator_map: Dict[Indicator, Tuple[int, int]] = {
|
|
47
|
+
ind: (i, slot_map[ind])
|
|
48
|
+
for i, rv in enumerate(rvs)
|
|
49
|
+
for ind in rv
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
54
|
+
"""
|
|
55
|
+
What are the random variables considered as 'inputs'.
|
|
56
|
+
"""
|
|
57
|
+
return self._rvs
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def precondition(self) -> Sequence[Indicator]:
|
|
61
|
+
"""
|
|
62
|
+
Condition on `self.rvs` that is compiled into the program.
|
|
63
|
+
"""
|
|
64
|
+
return self._precondition
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def slot_map(self) -> SlotMap:
|
|
68
|
+
return self._slot_map
|
|
69
|
+
|
|
70
|
+
def instances(self, flip: bool = False) -> Iterator[Instance]:
|
|
71
|
+
"""
|
|
72
|
+
Enumerate instances of the random variables.
|
|
73
|
+
|
|
74
|
+
Each instance is a tuples of state indexes, co-indexed with the given random variables.
|
|
75
|
+
|
|
76
|
+
The order is the natural index order (i.e., last random variable changing most quickly).
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
flip: if true, then first random variable changes most quickly.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
an iteration over tuples, each tuple holds state indexes
|
|
83
|
+
co-indexed with the given random variables.
|
|
84
|
+
"""
|
|
85
|
+
return rv_instances(*self._rvs, flip=flip)
|
|
86
|
+
|
|
87
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
|
|
88
|
+
"""
|
|
89
|
+
Enumerate instances of the random variables.
|
|
90
|
+
|
|
91
|
+
Each instance is a tuples of indicators, co-indexed with the given random variables.
|
|
92
|
+
|
|
93
|
+
The order is the natural index order (i.e., last random variable changing most quickly).
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
flip: if true, then first random variable changes most quickly.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
an iteration over tuples, each tuples holds random variable indicators
|
|
100
|
+
co-indexed with the given random variables.
|
|
101
|
+
"""
|
|
102
|
+
return rv_instances_as_indicators(*self._rvs, flip=flip)
|
|
103
|
+
|
|
104
|
+
def compute(self) -> NDArrayNumeric:
|
|
105
|
+
"""
|
|
106
|
+
Execute the program to compute and return the result. As per `ProgramBuffer.compute`.
|
|
107
|
+
|
|
108
|
+
Warning:
|
|
109
|
+
when returning an array, the array is backed by the program buffer memory, not a copy.
|
|
110
|
+
"""
|
|
111
|
+
return self._program_buffer.compute()
|
|
112
|
+
|
|
113
|
+
def compute_conditioned(self, *condition: Condition) -> NDArrayNumeric:
|
|
114
|
+
"""
|
|
115
|
+
Equivalent to:
|
|
116
|
+
self.set_condition(*condition)
|
|
117
|
+
return self.compute()
|
|
118
|
+
"""
|
|
119
|
+
self.set_condition(*condition)
|
|
120
|
+
return self.compute()
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def results(self) -> NDArrayNumeric:
|
|
124
|
+
"""
|
|
125
|
+
Get the results of the last computation.
|
|
126
|
+
As per `ProgramBuffer.results`.
|
|
127
|
+
|
|
128
|
+
Warning:
|
|
129
|
+
the array is backed by the program buffer memory, not a copy.
|
|
130
|
+
"""
|
|
131
|
+
return self._program_buffer.results
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def vars(self) -> NDArrayNumeric:
|
|
135
|
+
"""
|
|
136
|
+
Return the input variables as a numpy array.
|
|
137
|
+
As per `ProgramBuffer.vars`.
|
|
138
|
+
|
|
139
|
+
Warning:
|
|
140
|
+
writing to the returned array will write to the input slots of the program buffer.
|
|
141
|
+
"""
|
|
142
|
+
return self._program_buffer.vars
|
|
143
|
+
|
|
144
|
+
def __setitem__(self, item: int | slice | SlotKey | Iterable[SlotKey], value: float) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Set one or more input slot values, identified by slot keys.
|
|
147
|
+
"""
|
|
148
|
+
if isinstance(item, (int, slice)):
|
|
149
|
+
self._program_buffer[item] = value
|
|
150
|
+
elif isinstance(item, (Indicator, ParamId)):
|
|
151
|
+
self._program_buffer[self._slot_map[item]] = value
|
|
152
|
+
else:
|
|
153
|
+
# Assume its iterable
|
|
154
|
+
for i in item:
|
|
155
|
+
self[i] = value
|
|
156
|
+
|
|
157
|
+
def __getitem__(self, item: int | slice | SlotKey) -> NDArrayNumeric:
|
|
158
|
+
"""
|
|
159
|
+
Get an input slot value, identified by a slot key.
|
|
160
|
+
"""
|
|
161
|
+
if isinstance(item, (int, slice)):
|
|
162
|
+
return self._program_buffer[item]
|
|
163
|
+
elif isinstance(item, (Indicator, ParamId)):
|
|
164
|
+
return self._program_buffer[self._slot_map[item]]
|
|
165
|
+
else:
|
|
166
|
+
raise IndexError('unknown index type')
|
|
167
|
+
|
|
168
|
+
def set_condition(self, *condition: Condition) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Set the input slots of random variables to 1, except where implied to
|
|
171
|
+
0 according to the given conditions.
|
|
172
|
+
|
|
173
|
+
Specifically:
|
|
174
|
+
each slot corresponding to an indicator given condition will be set to 1;
|
|
175
|
+
|
|
176
|
+
if a random variable is mentioned in the given indicators, then all
|
|
177
|
+
slots for indicators for that random variable, except for slots corresponding
|
|
178
|
+
to an indicator given condition;
|
|
179
|
+
|
|
180
|
+
if a random variable is not mentioned in the given condition, that random variable
|
|
181
|
+
will have all its slots set to 1.
|
|
182
|
+
"""
|
|
183
|
+
condition: Sequence[Indicator] = check_condition(condition)
|
|
184
|
+
|
|
185
|
+
ind_slot_groups = [[] for _ in self._rvs_slots]
|
|
186
|
+
for ind in condition:
|
|
187
|
+
rv_idx, slot = self._indicator_map[ind]
|
|
188
|
+
ind_slot_groups[rv_idx].append(slot)
|
|
189
|
+
|
|
190
|
+
slots: NDArray = self._program_buffer.vars
|
|
191
|
+
for rv_slots, ind_slots in zip(self._rvs_slots, ind_slot_groups):
|
|
192
|
+
if len(ind_slots) == 0:
|
|
193
|
+
# this rv _is not_ mentioned in the indicators - marginalise it
|
|
194
|
+
for slot in rv_slots:
|
|
195
|
+
slots[slot] = 1
|
|
196
|
+
else:
|
|
197
|
+
# this rv _is_ mentioned in the indicators - we set the mentioned slots to 1 and others to 0.
|
|
198
|
+
for slot in rv_slots:
|
|
199
|
+
slots[slot] = 0
|
|
200
|
+
for slot in ind_slots:
|
|
201
|
+
slots[slot] = 1
|
|
202
|
+
|
|
203
|
+
def set_rv(self, rv: RandomVariable, *values: float | int) -> None:
|
|
204
|
+
"""
|
|
205
|
+
Set the input values of a random variable.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
rv: a random variable whose indicators are in the slot map.
|
|
209
|
+
values: list of values, assumes len(values) == len(rv).
|
|
210
|
+
"""
|
|
211
|
+
for i in range(len(rv)):
|
|
212
|
+
self[rv[i]] = values[i]
|
|
213
|
+
|
|
214
|
+
def set_rvs_uniform(self, *rvs: RandomVariable) -> None:
|
|
215
|
+
"""
|
|
216
|
+
Set the input values for each rv in rvs to 1 / len(rv).
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
rvs: a collection of random variable whose indicators are in the slot map.
|
|
220
|
+
"""
|
|
221
|
+
for rv in rvs:
|
|
222
|
+
value = 1.0 / len(rv)
|
|
223
|
+
for ind in rv:
|
|
224
|
+
self[ind] = value
|
|
225
|
+
|
|
226
|
+
def set_all_rvs_uniform(self) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Set the input values for each rv in rvs to 1 / len(rv).
|
|
229
|
+
"""
|
|
230
|
+
slots: NDArray = self._program_buffer.vars
|
|
231
|
+
for rv_slots in self._rvs_slots:
|
|
232
|
+
value = 1.0 / len(rv_slots)
|
|
233
|
+
for slot in rv_slots:
|
|
234
|
+
slots[slot] = value
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Protocol, Optional, overload, Iterable, Tuple
|
|
2
|
+
|
|
3
|
+
from ck.pgm import Indicator, ParamId
|
|
4
|
+
|
|
5
|
+
# Type of a slot map key.
|
|
6
|
+
SlotKey = Indicator | ParamId
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SlotMap(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
A slotmap is a protocol for mapping keys (indicators and
|
|
12
|
+
parameter ids) to slots in a ProgramBuffer.
|
|
13
|
+
|
|
14
|
+
A Python dict[SlotKey, int] implements the protocol.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __len__(self) -> int:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
@overload
|
|
21
|
+
def get(self, slot_key: SlotKey, default: None) -> Optional[int]:
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
@overload
|
|
25
|
+
def get(self, slot_key: SlotKey, default: int) -> int:
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
def get(self, slot_key: SlotKey, default: Optional[int]) -> Optional[int]:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, slot_key: SlotKey) -> int:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def items(self) -> Iterable[Tuple[SlotKey, int]]:
|
|
35
|
+
...
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Optional, Sequence
|
|
2
|
+
|
|
3
|
+
from ck.circuit import CircuitNode, TmpConst, Circuit
|
|
4
|
+
from ck.circuit_compiler import CircuitCompiler
|
|
5
|
+
from ck.circuit_compiler.llvm_compiler import DataType, DEFAULT_TYPE_INFO, compile_circuit
|
|
6
|
+
from ck.circuit_compiler import DEFAULT_CIRCUIT_COMPILER
|
|
7
|
+
from ck.pgm_circuit import PGMCircuit
|
|
8
|
+
from ck.program import RawProgram
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compile_results(
|
|
12
|
+
pgm_circuit: PGMCircuit,
|
|
13
|
+
results: Sequence[CircuitNode],
|
|
14
|
+
const_parameters: bool,
|
|
15
|
+
compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
|
|
16
|
+
) -> RawProgram:
|
|
17
|
+
"""
|
|
18
|
+
Compile a circuit to a raw program that calculates the given result.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: if not all nodes are from the same circuit.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
pgm_circuit: The circuit (and PGM) that will be compiled to a program.
|
|
25
|
+
results: the result circuit nodes for the returned program.
|
|
26
|
+
const_parameters: if True then any circuit variable representing a parameter value will
|
|
27
|
+
be made 'const' in the resulting program.
|
|
28
|
+
compiler: function from circuit nodes to raw program.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
a compiled RawProgram.
|
|
32
|
+
"""
|
|
33
|
+
circuit: Circuit = pgm_circuit.circuit_top.circuit
|
|
34
|
+
if const_parameters:
|
|
35
|
+
parameter_values = pgm_circuit.parameter_values
|
|
36
|
+
number_of_indicators = pgm_circuit.number_of_indicators
|
|
37
|
+
with TmpConst(circuit) as tmp:
|
|
38
|
+
for slot, value in enumerate(parameter_values, start=number_of_indicators):
|
|
39
|
+
tmp.set_const(slot, value)
|
|
40
|
+
raw_program: RawProgram = compiler(*results, circuit=circuit)
|
|
41
|
+
else:
|
|
42
|
+
raw_program: RawProgram = compiler(*results, circuit=circuit)
|
|
43
|
+
|
|
44
|
+
return raw_program
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def compile_param_derivatives(
|
|
48
|
+
pgm_circuit: PGMCircuit,
|
|
49
|
+
self_multiply: bool = False,
|
|
50
|
+
params_value: Optional[float | int] = 1,
|
|
51
|
+
data_type: DataType = DEFAULT_TYPE_INFO,
|
|
52
|
+
) -> RawProgram:
|
|
53
|
+
"""
|
|
54
|
+
Compile the circuit to a program for computing the partial derivatives of the parameters.
|
|
55
|
+
partial derivatives are co-indexed with pgm_circuit.parameter_values.
|
|
56
|
+
|
|
57
|
+
Typically, this will grow the circuit by the addition of circuit nodes to compute the derivatives.
|
|
58
|
+
|
|
59
|
+
This uses the LLVM circuit compiler.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
pgm_circuit: The circuit (and PGM) that will be compiled to a program.
|
|
63
|
+
self_multiply: if true then each partial derivative df/dx will be multiplied by x.
|
|
64
|
+
params_value: if not None, then circuit vars representing parameters will be temporarily
|
|
65
|
+
set to this value for compiling the program. Default is 1.
|
|
66
|
+
data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
|
|
67
|
+
"""
|
|
68
|
+
top: CircuitNode = pgm_circuit.circuit_top
|
|
69
|
+
circuit: Circuit = top.circuit
|
|
70
|
+
|
|
71
|
+
start_idx = pgm_circuit.number_of_indicators
|
|
72
|
+
end_idx = start_idx + pgm_circuit.number_of_parameters
|
|
73
|
+
param_vars = circuit.vars[start_idx:end_idx]
|
|
74
|
+
derivatives = circuit.partial_derivatives(top, param_vars, self_multiply=self_multiply)
|
|
75
|
+
|
|
76
|
+
if params_value is not None:
|
|
77
|
+
with TmpConst(circuit) as tmp:
|
|
78
|
+
tmp.set_const(param_vars, params_value)
|
|
79
|
+
raw_program: RawProgram = compile_circuit(*derivatives, circuit=circuit, data_type=data_type)
|
|
80
|
+
else:
|
|
81
|
+
raw_program: RawProgram = compile_circuit(*derivatives, circuit=circuit, data_type=data_type)
|
|
82
|
+
|
|
83
|
+
return raw_program
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, List
|
|
4
|
+
|
|
5
|
+
from ck.circuit import CircuitNode, Circuit, TmpConst
|
|
6
|
+
from ck.pgm import RandomVariable
|
|
7
|
+
from ck.pgm_circuit import PGMCircuit
|
|
8
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
9
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
10
|
+
from ck.pgm_circuit.support.compile_circuit import compile_results
|
|
11
|
+
from ck.probability.probability_space import check_condition, Condition
|
|
12
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
13
|
+
from ck.program.raw_program import RawProgram
|
|
14
|
+
from ck.utils.np_extras import NDArray
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TargetMarginalsProgram(ProgramWithSlotmap):
|
|
18
|
+
"""
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
pgm_circuit: PGMCircuit,
|
|
24
|
+
target_rv: RandomVariable,
|
|
25
|
+
const_parameters: bool = True,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Construct a TargetMarginalsProgram object.
|
|
29
|
+
|
|
30
|
+
Compile the given circuit for computing marginal probabilities over the states of 'target_var'.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
pgm_circuit: The circuit representing a PGM.
|
|
34
|
+
target_rv: the random variable to compute marginals for.
|
|
35
|
+
const_parameters: if True then any circuit variable representing a parameter value will
|
|
36
|
+
be made 'const' in the resulting program.
|
|
37
|
+
"""
|
|
38
|
+
top_node: CircuitNode = pgm_circuit.circuit_top
|
|
39
|
+
circuit: Circuit = top_node.circuit
|
|
40
|
+
slot_map: SlotMap = pgm_circuit.slot_map
|
|
41
|
+
input_rvs: List[RandomVariable] = list(pgm_circuit.rvs)
|
|
42
|
+
|
|
43
|
+
target_vars = [circuit.vars[slot_map[ind]] for ind in target_rv]
|
|
44
|
+
cct_outputs = circuit.partial_derivatives(top_node, target_vars)
|
|
45
|
+
|
|
46
|
+
# Remove the target rv from the input rvs.
|
|
47
|
+
target_index = input_rvs.index(target_rv) # will throw if not found
|
|
48
|
+
del input_rvs[target_index]
|
|
49
|
+
|
|
50
|
+
with TmpConst(circuit) as tmp:
|
|
51
|
+
tmp.set_const(target_vars, 1)
|
|
52
|
+
raw_program: RawProgram = compile_results(
|
|
53
|
+
pgm_circuit=pgm_circuit,
|
|
54
|
+
results=cct_outputs,
|
|
55
|
+
const_parameters=const_parameters,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
ProgramWithSlotmap.__init__(self, ProgramBuffer(raw_program), slot_map, input_rvs, pgm_circuit.conditions)
|
|
59
|
+
|
|
60
|
+
# additional fields
|
|
61
|
+
self._x_slots: List[List[int]] = [[slot_map[ind] for ind in rv] for rv in input_rvs]
|
|
62
|
+
self._y_size: int = raw_program.number_of_results
|
|
63
|
+
self._target_rv: RandomVariable = target_rv
|
|
64
|
+
self._number_of_indicators: int = pgm_circuit.number_of_indicators
|
|
65
|
+
self._z_cache: Optional[float] = None
|
|
66
|
+
|
|
67
|
+
# consistency check
|
|
68
|
+
assert (self._y_size == len(self._target_rv))
|
|
69
|
+
|
|
70
|
+
if not const_parameters:
|
|
71
|
+
# set the parameter slots
|
|
72
|
+
self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def target_rv(self) -> RandomVariable:
|
|
76
|
+
return self._target_rv
|
|
77
|
+
|
|
78
|
+
def map(self, condition: Condition = ()) -> Tuple[float, int]:
|
|
79
|
+
"""
|
|
80
|
+
Return the maximum a posterior (MAP) state of the target variable.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
condition: any conditioning indicators.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
(pr, state_idx) where
|
|
87
|
+
pr is the MAP probability
|
|
88
|
+
state_idx: is the MAP state index of `self.target_rv`.
|
|
89
|
+
"""
|
|
90
|
+
self.set_condition(*check_condition(condition))
|
|
91
|
+
self.compute()
|
|
92
|
+
results: NDArray = self.results
|
|
93
|
+
z: float = results.sum()
|
|
94
|
+
|
|
95
|
+
max_p = -1
|
|
96
|
+
max_i = -1
|
|
97
|
+
for i in range(self._y_size):
|
|
98
|
+
p = results[i]
|
|
99
|
+
if p > max_p:
|
|
100
|
+
max_p = p
|
|
101
|
+
max_i = i
|
|
102
|
+
|
|
103
|
+
return max_p / z, max_i
|