compiled-knowledge 4.0.0a20__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .pgm_circuit import PGMCircuit
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Sequence, Optional, Tuple, List, Iterable, Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ck.circuit import CircuitNode, Circuit
|
|
9
|
+
from ck.pgm import RandomVariable, number_of_states, rv_instances_as_indicators
|
|
10
|
+
from ck.pgm_circuit import PGMCircuit
|
|
11
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
12
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
13
|
+
from ck.pgm_circuit.support.compile_circuit import compile_results
|
|
14
|
+
from ck.probability.probability_space import ProbabilitySpace, check_condition, Condition
|
|
15
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
16
|
+
from ck.program.raw_program import RawProgram
|
|
17
|
+
from ck.sampling.marginals_direct_sampler import MarginalsDirectSampler
|
|
18
|
+
from ck.sampling.sampler import Sampler
|
|
19
|
+
from ck.sampling.sampler_support import SamplerInfo, get_sampler_info
|
|
20
|
+
from ck.sampling.uniform_sampler import UniformSampler
|
|
21
|
+
from ck.utils.np_extras import NDArray, NDArrayNumeric
|
|
22
|
+
from ck.utils.random_extras import Random
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MarginalsProgram(ProgramWithSlotmap, ProbabilitySpace):
|
|
26
|
+
"""
|
|
27
|
+
A class for computing marginal probability distributions over states of selected output random variables.
|
|
28
|
+
This class provides, for each indicator, the product of indicator value with the derivative
|
|
29
|
+
of the network function with respect to the indicator.
|
|
30
|
+
|
|
31
|
+
Compile the circuit for computing marginal probability distributions using the
|
|
32
|
+
so-called 'differential' approach.
|
|
33
|
+
|
|
34
|
+
Reference: Darwiche, A. (2003). A differential approach to inference in Bayesian
|
|
35
|
+
networks. Journal of the ACM (JACM), 50(3), 280-305.
|
|
36
|
+
|
|
37
|
+
A note about samplers
|
|
38
|
+
---------------------
|
|
39
|
+
|
|
40
|
+
When creating a sampler, a client may request that samples are conditioned
|
|
41
|
+
on provided condition indicators. Also, the WMCProgram may have been
|
|
42
|
+
produced with compile-in conditions, e.g., using const_conditions with
|
|
43
|
+
a call to PGM_cct.wmc(...).
|
|
44
|
+
|
|
45
|
+
The conditions respected by a sampler are the conjunction of the compiled
|
|
46
|
+
conditions and the sampler conditions. For example, with compiled condition
|
|
47
|
+
(A[0], A[1], A[2]) and sampler condition (A[1], A[2], A[3]) the effective
|
|
48
|
+
condition is (A[1], A[2]), i.e., a sample of A may be 1 or 2.
|
|
49
|
+
|
|
50
|
+
Warning:
|
|
51
|
+
if the sampled random variables include conditions, those conditions
|
|
52
|
+
must be provided to the sampler. If a sampled random variable is conditioned
|
|
53
|
+
at compile time, but not passed to the sampler, then the sample will not
|
|
54
|
+
be aware of the conditions, and unexpected sample values may be produced.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
pgm_circuit: PGMCircuit,
|
|
60
|
+
output_rvs: Optional[Sequence[RandomVariable]] = None,
|
|
61
|
+
const_parameters: bool = True,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Construct a MarginalsProgram object.
|
|
65
|
+
|
|
66
|
+
The given program should produce marginal outputs in the order
|
|
67
|
+
of output_rvs indicators, followed by the wmc output.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
pgm_circuit: The circuit representing a PGM.
|
|
71
|
+
output_rvs: if None, the output rvs are all rvs, otherwise the given rvs.
|
|
72
|
+
const_parameters: if True then any circuit variable representing a parameter value will
|
|
73
|
+
be made 'const' in the resulting program.
|
|
74
|
+
"""
|
|
75
|
+
top_node: CircuitNode = pgm_circuit.circuit_top
|
|
76
|
+
circuit: Circuit = top_node.circuit
|
|
77
|
+
slot_map: SlotMap = pgm_circuit.slot_map
|
|
78
|
+
input_rvs: Sequence[RandomVariable] = pgm_circuit.rvs
|
|
79
|
+
|
|
80
|
+
output_rvs: Sequence[RandomVariable] = tuple(output_rvs) if output_rvs is not None else input_rvs
|
|
81
|
+
|
|
82
|
+
output_rvs_slots = [[slot_map[ind] for ind in rv] for rv in output_rvs]
|
|
83
|
+
flat_out_rv_vars = [circuit.vars[slot] for slots in output_rvs_slots for slot in slots]
|
|
84
|
+
derivatives = circuit.partial_derivatives(top_node, flat_out_rv_vars, self_multiply=True)
|
|
85
|
+
|
|
86
|
+
raw_program: RawProgram = compile_results(
|
|
87
|
+
pgm_circuit=pgm_circuit,
|
|
88
|
+
results=derivatives + [top_node],
|
|
89
|
+
const_parameters=const_parameters,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
program_buffer = ProgramBuffer(raw_program)
|
|
93
|
+
ProgramWithSlotmap.__init__(self, program_buffer, slot_map, input_rvs, pgm_circuit.conditions)
|
|
94
|
+
|
|
95
|
+
# cache the input slots for the output rvs
|
|
96
|
+
output_rvs_slots = [[slot_map[ind] for ind in rv] for rv in output_rvs]
|
|
97
|
+
|
|
98
|
+
# cache the output offsets for the derivatives.
|
|
99
|
+
# A map from `RandomVariable.idx` to offset into the result buffer
|
|
100
|
+
self._rv_idx_to_result_offset: Dict[int, int] = {}
|
|
101
|
+
prev_offset: int = 0
|
|
102
|
+
for rv in output_rvs:
|
|
103
|
+
self._rv_idx_to_result_offset[rv.idx] = prev_offset
|
|
104
|
+
prev_offset += len(rv)
|
|
105
|
+
|
|
106
|
+
# cached a map from output rv to its position in the marginals result
|
|
107
|
+
self._rv_idx_to_output_index: Dict[int, int] = {rv.idx: i for i, rv in enumerate(output_rvs)}
|
|
108
|
+
|
|
109
|
+
self._marginals: List[NDArrayNumeric] = []
|
|
110
|
+
start = 0
|
|
111
|
+
for rv_slots in output_rvs_slots:
|
|
112
|
+
end = start + len(rv_slots)
|
|
113
|
+
result_part = program_buffer.results[start:end] # gets a view onto the same data.
|
|
114
|
+
self._marginals.append(result_part)
|
|
115
|
+
start = end
|
|
116
|
+
|
|
117
|
+
# additional fields
|
|
118
|
+
self._raw_program: RawProgram = raw_program
|
|
119
|
+
self._program_buffer: ProgramBuffer = program_buffer
|
|
120
|
+
self._number_of_indicators: int = pgm_circuit.number_of_indicators
|
|
121
|
+
self._output_rvs = output_rvs
|
|
122
|
+
self._output_rvs_slots = output_rvs_slots
|
|
123
|
+
self._z_cache: Optional[float] = None
|
|
124
|
+
|
|
125
|
+
if not const_parameters:
|
|
126
|
+
# set the parameter slots
|
|
127
|
+
self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def output_rvs(self):
|
|
131
|
+
"""
|
|
132
|
+
What random variables are included in the marginal probabilities calculations.
|
|
133
|
+
"""
|
|
134
|
+
return self._output_rvs
|
|
135
|
+
|
|
136
|
+
def wmc(self, *condition: Condition) -> float:
|
|
137
|
+
"""
|
|
138
|
+
What is the weight of the world with the given indicators.
|
|
139
|
+
If multiple indicators from the same random variable ar mentioned, then it is treated as a disjunction.
|
|
140
|
+
If a random variable is not mentioned in the indicators, that random variable is marginalised out.
|
|
141
|
+
"""
|
|
142
|
+
self.set_condition(*condition)
|
|
143
|
+
self._program_buffer.compute()
|
|
144
|
+
return self.result_wmc
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def z(self):
|
|
148
|
+
if self._z_cache is None:
|
|
149
|
+
number_of_indicators: int = self._number_of_indicators
|
|
150
|
+
slots: NDArray = self.vars
|
|
151
|
+
old_vals: NDArray = slots[:number_of_indicators].copy()
|
|
152
|
+
slots[:number_of_indicators] = 1
|
|
153
|
+
self._program_buffer.compute()
|
|
154
|
+
self._z_cache = self.result_wmc
|
|
155
|
+
slots[:number_of_indicators] = old_vals
|
|
156
|
+
return self._z_cache
|
|
157
|
+
|
|
158
|
+
def marginal_distribution(self, *rvs: RandomVariable, condition: Condition = ()):
|
|
159
|
+
# Check for easy cases.
|
|
160
|
+
if len(rvs) == 0:
|
|
161
|
+
if self.wmc(*condition) == 0:
|
|
162
|
+
return np.array([np.nan])
|
|
163
|
+
return np.array([1.0])
|
|
164
|
+
if len(rvs) == 1:
|
|
165
|
+
return self.marginal_for_rv(rvs[0], condition=condition)
|
|
166
|
+
|
|
167
|
+
# We try to eliminate searching combinations of probabilities where marginals are zero.
|
|
168
|
+
# If there are no marginal probabilities = 0, then this is equivalent to
|
|
169
|
+
# ProbabilitySpace.marginal_distribution
|
|
170
|
+
|
|
171
|
+
condition = check_condition(condition)
|
|
172
|
+
rvs_marginals = self.marginal_for_rvs(rvs, condition=condition)
|
|
173
|
+
zero_indicators = set(
|
|
174
|
+
ind
|
|
175
|
+
for rv, rv_marginal in zip(rvs, rvs_marginals)
|
|
176
|
+
for ind, marginal in zip(rv, rv_marginal)
|
|
177
|
+
if marginal == 0
|
|
178
|
+
)
|
|
179
|
+
raw_wmc = self._get_wmc_for_marginals(rvs, condition)
|
|
180
|
+
|
|
181
|
+
if len(zero_indicators) == 0:
|
|
182
|
+
wmc = raw_wmc
|
|
183
|
+
else:
|
|
184
|
+
def wmc(indicators):
|
|
185
|
+
for ind in indicators:
|
|
186
|
+
if ind in zero_indicators:
|
|
187
|
+
return 0
|
|
188
|
+
return raw_wmc(indicators)
|
|
189
|
+
|
|
190
|
+
result = np.fromiter(
|
|
191
|
+
(wmc(indicators) for indicators in rv_instances_as_indicators(*rvs)),
|
|
192
|
+
count=number_of_states(*rvs),
|
|
193
|
+
dtype=np.float64
|
|
194
|
+
)
|
|
195
|
+
_normalise_marginal(result)
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
def marginal_for_rv(self, rv: RandomVariable, condition: Condition = ()) -> NDArrayNumeric:
|
|
199
|
+
"""
|
|
200
|
+
Compute and return marginal distribution over the given random variable.
|
|
201
|
+
The random variable is assumed to be in self.rvs.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
a numpy array representing the marginal distribution over the states of 'rv'.
|
|
205
|
+
"""
|
|
206
|
+
self.compute_conditioned(*condition)
|
|
207
|
+
return self.result_for_rv(rv)
|
|
208
|
+
|
|
209
|
+
def marginal_for_rvs(self, rvs: Iterable[RandomVariable], condition: Condition = ()) -> List[NDArrayNumeric]:
|
|
210
|
+
"""
|
|
211
|
+
Compute and return marginal distribution over the given random variables.
|
|
212
|
+
Each random variable is assumed to be in self.rvs.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
a list of numpy arrays representing the marginal distribution over the
|
|
216
|
+
states of each rv in the given random variables, `rvs`.
|
|
217
|
+
"""
|
|
218
|
+
self.compute_conditioned(*condition)
|
|
219
|
+
marginals = self._marginals
|
|
220
|
+
rv_idx_to_output_index = self._rv_idx_to_output_index
|
|
221
|
+
return list(marginals[rv_idx_to_output_index[rv.idx]] for rv in rvs)
|
|
222
|
+
|
|
223
|
+
def compute(self) -> NDArrayNumeric:
|
|
224
|
+
self._program_buffer.compute()
|
|
225
|
+
for part in self._marginals:
|
|
226
|
+
_normalise_marginal(part)
|
|
227
|
+
return self._program_buffer.results
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def result_wmc(self) -> float:
|
|
231
|
+
"""
|
|
232
|
+
Assuming the result has been computed,
|
|
233
|
+
return the WMC value.
|
|
234
|
+
"""
|
|
235
|
+
return self._program_buffer.results.item(-1)
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def result_marginals(self) -> List[NDArrayNumeric]:
|
|
239
|
+
"""
|
|
240
|
+
Assuming the result has been computed,
|
|
241
|
+
return the marginal distributions of each random variable, co-indexed with the
|
|
242
|
+
output random variables, `self.output_rvs`.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
a list of numpy arrays, the list co-indexed with `self.output_rvs`, each numpy array
|
|
246
|
+
representing the marginal distribution over the states of the co-indexed random variable.
|
|
247
|
+
"""
|
|
248
|
+
return self._marginals
|
|
249
|
+
|
|
250
|
+
def result_for_rv(self, rv: RandomVariable) -> NDArrayNumeric:
|
|
251
|
+
"""
|
|
252
|
+
Assuming the result has been computed,
|
|
253
|
+
return marginal distribution over the given random variable.
|
|
254
|
+
The random variable is assumed to be in self.output_rvs.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
a numpy array representing the marginal distribution over the states of 'rv'.
|
|
258
|
+
"""
|
|
259
|
+
return self._marginals[self._rv_idx_to_output_index[rv.idx]]
|
|
260
|
+
|
|
261
|
+
def sample_uniform(
|
|
262
|
+
self,
|
|
263
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
264
|
+
*,
|
|
265
|
+
condition: Condition = (),
|
|
266
|
+
rand: Random = random,
|
|
267
|
+
) -> Sampler:
|
|
268
|
+
"""
|
|
269
|
+
Create a sampler that performs uniform sampling of
|
|
270
|
+
the state space of the given random variables, rvs.
|
|
271
|
+
|
|
272
|
+
The sampler will yield state lists, where the state
|
|
273
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
274
|
+
|
|
275
|
+
This sampler is not affected by and does not affect
|
|
276
|
+
the state of input slots.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
rvs: the list of random variables to sample; the
|
|
280
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
281
|
+
then the self.rvs are used; if rvs is a single
|
|
282
|
+
random variable, then single samples are yielded.
|
|
283
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
284
|
+
rand: provides the stream of random numbers.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
a Sampler object (UniformSampler).
|
|
288
|
+
"""
|
|
289
|
+
return UniformSampler(
|
|
290
|
+
rvs=(self.rvs if rvs is None else rvs),
|
|
291
|
+
condition=condition,
|
|
292
|
+
rand=rand,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def sample_direct(
|
|
296
|
+
self,
|
|
297
|
+
rvs: Optional[RandomVariable | Sequence[RandomVariable]] = None,
|
|
298
|
+
*,
|
|
299
|
+
condition: Condition = (),
|
|
300
|
+
rand: Random = random,
|
|
301
|
+
chain_pairs: Sequence[Tuple[RandomVariable, RandomVariable]] = (),
|
|
302
|
+
initial_chain_condition: Condition = (),
|
|
303
|
+
) -> Sampler:
|
|
304
|
+
"""
|
|
305
|
+
Create an inverse-transform sampler, which uses the fact that marginal
|
|
306
|
+
probabilities are exactly computable with a single execution of the program.
|
|
307
|
+
|
|
308
|
+
The sampler will yield state lists, where the state
|
|
309
|
+
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
rvs: the list of random variables to sample; the
|
|
313
|
+
yielded state vectors are co-indexed with rvs; if None,
|
|
314
|
+
then the WMC rvs are used; if rvs is a single
|
|
315
|
+
random variable, then single samples are yielded.
|
|
316
|
+
condition: is a collection of zero or more conditioning indicators.
|
|
317
|
+
rand: provides the stream of random numbers.
|
|
318
|
+
chain_pairs: is a collection of pairs of random variables, each random variable
|
|
319
|
+
must be in the given rvs. Given a pair (from_rv, to_rv) the state of from_rv is used
|
|
320
|
+
as a condition for to_rv prior to generating a sample.
|
|
321
|
+
initial_chain_condition: are condition indicators (just like condition)
|
|
322
|
+
for the initialisation of the 'to_rv' random variables mentioned in chain_pairs.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
a Sampler object (MarginalsDirectSampler).
|
|
326
|
+
"""
|
|
327
|
+
sampler_info: SamplerInfo = get_sampler_info(
|
|
328
|
+
program_with_slotmap=self,
|
|
329
|
+
rvs=rvs,
|
|
330
|
+
condition=condition,
|
|
331
|
+
chain_pairs=chain_pairs,
|
|
332
|
+
initial_chain_condition=initial_chain_condition,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return MarginalsDirectSampler(
|
|
336
|
+
sampler_info=sampler_info,
|
|
337
|
+
raw_program=self._raw_program,
|
|
338
|
+
rand=rand,
|
|
339
|
+
rv_idx_to_result_offset=self._rv_idx_to_result_offset,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _normalise_marginal(distribution: NDArrayNumeric) -> None:
|
|
344
|
+
"""
|
|
345
|
+
Update the values in the given distribution to
|
|
346
|
+
properly represent a marginal distribution.
|
|
347
|
+
"""
|
|
348
|
+
total = np.sum(distribution)
|
|
349
|
+
if total <= 0:
|
|
350
|
+
distribution[:] = np.nan
|
|
351
|
+
elif total != 1:
|
|
352
|
+
distribution /= total
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Sequence, Optional, Tuple, List, Dict, Set
|
|
6
|
+
|
|
7
|
+
from ck.circuit import CircuitNode, Circuit, VarNode, OpNode, ADD, MUL
|
|
8
|
+
from ck.circuit_compiler import llvm_vm_compiler, CircuitCompiler
|
|
9
|
+
from ck.pgm import RandomVariable, Instance
|
|
10
|
+
from ck.pgm_circuit import PGMCircuit
|
|
11
|
+
from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
|
|
12
|
+
from ck.pgm_circuit.slot_map import SlotMap
|
|
13
|
+
from ck.pgm_circuit.support.compile_circuit import compile_results
|
|
14
|
+
from ck.probability.probability_space import check_condition
|
|
15
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
16
|
+
from ck.program.raw_program import RawProgram
|
|
17
|
+
from ck.utils.np_extras import NDArray, NDArrayNumeric
|
|
18
|
+
|
|
19
|
+
_NO_TRACE = (-1, -1) # used as a sentinel value
|
|
20
|
+
|
|
21
|
+
_CCT_COMPILER = llvm_vm_compiler # Python module used for compiling an MPE circuit
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MPEProgram(ProgramWithSlotmap):
|
|
25
|
+
"""
|
|
26
|
+
A class for computing Most Probable Explanation (MPE). This is equivalent to
|
|
27
|
+
Maximum A Posterior (MAP) inference when there are no latent random variables.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
pgm_circuit: PGMCircuit,
|
|
33
|
+
trace_rvs: Optional[Sequence[RandomVariable]] = None,
|
|
34
|
+
const_parameters: bool = True,
|
|
35
|
+
log_parameters: bool = False,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Construct a MPEProgram object.
|
|
39
|
+
|
|
40
|
+
Compile the circuit for computing Most Probable Explanation (MPE). This is equivalent to
|
|
41
|
+
Maximum A Posterior (MAP) inference when there are no latent variables.
|
|
42
|
+
|
|
43
|
+
This will compile a clone of the given circuit with
|
|
44
|
+
'add' nodes replaced with 'max' nodes.
|
|
45
|
+
|
|
46
|
+
This will augment the given circuit and compile it to make a program for computing MPE states.
|
|
47
|
+
'trace_vars' is a list random variables, where each random variable is a list of circuit var nodes, each
|
|
48
|
+
var node representing an indicator (i.e., a state) of a random variable.
|
|
49
|
+
Assumes that all operator nodes to compute top are either an add or mul node.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
pgm_circuit: The circuit representing a PGM.
|
|
53
|
+
trace_rvs: the random variables to compute MPE for, default is all random variables of the PGM.
|
|
54
|
+
const_parameters: if True then any circuit variable representing a parameter value will
|
|
55
|
+
be made 'const' in the resulting program.
|
|
56
|
+
log_parameters: if true, then parameters are taken to be logs, i.e., uses addition instead
|
|
57
|
+
of multiplication.
|
|
58
|
+
"""
|
|
59
|
+
trace_rvs: Tuple[RandomVariable, ...] = pgm_circuit.rvs if trace_rvs is None else tuple(trace_rvs)
|
|
60
|
+
if len(trace_rvs) != len(set(trace_rvs)):
|
|
61
|
+
raise ValueError('duplicated trace random variable detected')
|
|
62
|
+
|
|
63
|
+
top: CircuitNode = pgm_circuit.circuit_top
|
|
64
|
+
circuit: Circuit = top.circuit
|
|
65
|
+
slot_map: SlotMap = pgm_circuit.slot_map
|
|
66
|
+
|
|
67
|
+
cct_compiler: CircuitCompiler
|
|
68
|
+
if log_parameters:
|
|
69
|
+
cct_compiler = partial(_CCT_COMPILER.compile_circuit, data_type=_CCT_COMPILER.DataType.MAX_SUM)
|
|
70
|
+
else:
|
|
71
|
+
cct_compiler = partial(_CCT_COMPILER.compile_circuit, data_type=_CCT_COMPILER.DataType.MAX_MUL)
|
|
72
|
+
|
|
73
|
+
# make inv_trace_blocks
|
|
74
|
+
#
|
|
75
|
+
# inv_trace_blocks[slot] = (rv_trace_idx, state_idx)
|
|
76
|
+
# where
|
|
77
|
+
# rv_trace_idx is an index into trace_vars,
|
|
78
|
+
# state_idx is an index into trace_vars[rv_trace_idx] indicators,
|
|
79
|
+
#
|
|
80
|
+
# slot = slot_map[ind], where ind = trace_vars[rv_trace_idx][state_idx].
|
|
81
|
+
#
|
|
82
|
+
inv_trace_blocks: List[Tuple[int, int]] = [_NO_TRACE] * circuit.number_of_vars
|
|
83
|
+
rv_trace_idx: int
|
|
84
|
+
trace_rv: RandomVariable
|
|
85
|
+
for rv_trace_idx, trace_rv in enumerate(trace_rvs):
|
|
86
|
+
for state_idx in trace_rv.state_range():
|
|
87
|
+
slot: int = slot_map[trace_rv[state_idx]]
|
|
88
|
+
if inv_trace_blocks[slot] is not _NO_TRACE:
|
|
89
|
+
raise ValueError('unexpected reused circuit slot')
|
|
90
|
+
inv_trace_blocks[slot] = (rv_trace_idx, state_idx)
|
|
91
|
+
|
|
92
|
+
used_nodes: List[CircuitNode] = list(circuit.reachable_op_nodes(top))
|
|
93
|
+
|
|
94
|
+
mpe_idx: Dict[int, int] = {
|
|
95
|
+
id(used_node): used_node_idx
|
|
96
|
+
for used_node_idx, used_node in enumerate(used_nodes)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
# create a dummy MPE result until compute is called
|
|
100
|
+
dummy_result = MPEResult(float('nan'), tuple(0 for _ in trace_rvs))
|
|
101
|
+
|
|
102
|
+
self._trace_rvs: Tuple[RandomVariable, ...] = trace_rvs
|
|
103
|
+
self._inv_trace_blocks = inv_trace_blocks
|
|
104
|
+
self._top: CircuitNode = top
|
|
105
|
+
self._mpe_result: MPEResult = dummy_result
|
|
106
|
+
|
|
107
|
+
self._top_idx: Optional[int] = mpe_idx.get(id(top)) # it may be possible that top is not an op node.
|
|
108
|
+
self._used_nodes: List[CircuitNode] = used_nodes
|
|
109
|
+
self._mpe_idx: Dict[int, int] = mpe_idx
|
|
110
|
+
|
|
111
|
+
raw_program: RawProgram = compile_results(
|
|
112
|
+
pgm_circuit=pgm_circuit,
|
|
113
|
+
results=used_nodes,
|
|
114
|
+
const_parameters=const_parameters,
|
|
115
|
+
compiler=cct_compiler,
|
|
116
|
+
)
|
|
117
|
+
ProgramWithSlotmap.__init__(self, ProgramBuffer(raw_program), slot_map, pgm_circuit.rvs, pgm_circuit.conditions)
|
|
118
|
+
|
|
119
|
+
if not const_parameters:
|
|
120
|
+
# set the parameter slots
|
|
121
|
+
self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
|
|
122
|
+
|
|
123
|
+
def mpe(self, *condition) -> MPEResult:
|
|
124
|
+
"""
|
|
125
|
+
What is the MPE, given any conditioning indicators.
|
|
126
|
+
|
|
127
|
+
The mpe array may contain None in an element corresponding to a traced random variable where
|
|
128
|
+
all states of that random variable lead to the same wmc value. I.e., the solution is indifferent
|
|
129
|
+
to the state of that random variable. In this case, a caller is at liberty to use any state for that
|
|
130
|
+
random variable as an MPE solution. For example, all 'None' values could be replaced with zero
|
|
131
|
+
and the solution is still a valid MPE solution.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
an MPEResult with field `wmc` and `mpe`.
|
|
135
|
+
wmc: is the value of the weighted model count.
|
|
136
|
+
mpe: is an Instance, co-indexed with trace vars, where mpe[rv_idx] = state_idx.
|
|
137
|
+
"""
|
|
138
|
+
condition = check_condition(condition)
|
|
139
|
+
self.compute_conditioned(*condition)
|
|
140
|
+
return self._mpe_result
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def trace_rvs(self) -> Sequence[RandomVariable]:
|
|
144
|
+
"""
|
|
145
|
+
What are the random variables used in an MPE trace.
|
|
146
|
+
"""
|
|
147
|
+
return self._trace_rvs
|
|
148
|
+
|
|
149
|
+
def compute(self) -> NDArrayNumeric:
|
|
150
|
+
"""
|
|
151
|
+
Execute the program to compute and return the result. As per `ProgramBuffer.compute`.
|
|
152
|
+
|
|
153
|
+
Warning:
|
|
154
|
+
when returning an array, the array is backed by the program buffer memory, not a copy.
|
|
155
|
+
"""
|
|
156
|
+
program_result: NDArray = self._program_buffer.compute()
|
|
157
|
+
self._trace()
|
|
158
|
+
return program_result
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def mpe_result(self) -> MPEResult:
|
|
162
|
+
"""
|
|
163
|
+
Get the MPEResult of the last program computation.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
an MPEResult object.
|
|
167
|
+
"""
|
|
168
|
+
return self._mpe_result
|
|
169
|
+
|
|
170
|
+
def _trace(self) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Trace the last program computation to determine the wmc and the mpe states.
|
|
173
|
+
"""
|
|
174
|
+
if self._top_idx is not None:
|
|
175
|
+
wmc: float = self.results.item(self._top_idx)
|
|
176
|
+
states: List[Optional[int]] = [None for _ in self._trace_rvs]
|
|
177
|
+
seen: Set[int] = set()
|
|
178
|
+
self._trace_r(self._top, wmc, states, seen)
|
|
179
|
+
mpe = tuple(
|
|
180
|
+
0 if state_idx is None else state_idx
|
|
181
|
+
for state_idx in states
|
|
182
|
+
)
|
|
183
|
+
self._mpe_result = MPEResult(wmc, mpe)
|
|
184
|
+
|
|
185
|
+
def _trace_r(self, node: CircuitNode, node_value: float, states: List[Optional[int]], seen: Set[int]) -> None:
|
|
186
|
+
|
|
187
|
+
# A circuit is a DAG, not necessarily a tree.
|
|
188
|
+
# No need to revisit nodes.
|
|
189
|
+
if id(node) in seen:
|
|
190
|
+
return
|
|
191
|
+
seen.add(id(node))
|
|
192
|
+
|
|
193
|
+
if isinstance(node, VarNode):
|
|
194
|
+
self._trace_var(node, states)
|
|
195
|
+
elif isinstance(node, OpNode):
|
|
196
|
+
if node.symbol == ADD:
|
|
197
|
+
# Find which child node led to the max result, then recurse though it only.
|
|
198
|
+
for child in node.args:
|
|
199
|
+
if isinstance(child, OpNode):
|
|
200
|
+
child_value: float = self.results.item(self._mpe_idx[id(child)])
|
|
201
|
+
if child_value == node_value:
|
|
202
|
+
self._trace_r(child, child_value, states, seen)
|
|
203
|
+
return
|
|
204
|
+
elif isinstance(child, VarNode):
|
|
205
|
+
child_value: float = self.vars.item(child.idx)
|
|
206
|
+
if child_value == node_value:
|
|
207
|
+
self._trace_var(child, states)
|
|
208
|
+
return
|
|
209
|
+
# No child value equaled the value for node! We should never get here
|
|
210
|
+
assert False, 'not reached'
|
|
211
|
+
elif node.symbol == MUL:
|
|
212
|
+
# Recurse though each child node
|
|
213
|
+
for child in node.args:
|
|
214
|
+
if isinstance(child, OpNode):
|
|
215
|
+
child_value: float = self.results.item(self._mpe_idx[id(child)])
|
|
216
|
+
self._trace_r(child, child_value, states, seen)
|
|
217
|
+
elif isinstance(child, VarNode):
|
|
218
|
+
self._trace_var(child, states)
|
|
219
|
+
|
|
220
|
+
def _trace_var(self, node: VarNode, states: List[Optional[int]]) -> None:
|
|
221
|
+
trace = self._inv_trace_blocks[node.idx]
|
|
222
|
+
if trace is not _NO_TRACE:
|
|
223
|
+
rv_trace_idx, state_idx = trace
|
|
224
|
+
states[rv_trace_idx] = state_idx
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@dataclass
|
|
228
|
+
class MPEResult:
|
|
229
|
+
"""
|
|
230
|
+
An MPE result is the result of MPE inference.
|
|
231
|
+
|
|
232
|
+
Fields:
|
|
233
|
+
wmc: the weighted model count value of the MPE solution.
|
|
234
|
+
mpe: The MPE solution instance. If there are ties then this will just be once instance.
|
|
235
|
+
"""
|
|
236
|
+
wmc: float
|
|
237
|
+
mpe: Instance
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Sequence, List, Dict
|
|
3
|
+
|
|
4
|
+
from ck.circuit import CircuitNode, Circuit
|
|
5
|
+
from ck.pgm import RandomVariable, Indicator
|
|
6
|
+
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
7
|
+
from ck.utils.np_extras import NDArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PGMCircuit:
|
|
12
|
+
"""
|
|
13
|
+
A data structure representing the results of compiling a PGM to a circuit.
|
|
14
|
+
|
|
15
|
+
If the circuit contains variables to represent parameter values, then `parameter_values`
|
|
16
|
+
holds the values of the parameters. Specifically, given parameter id `param_id`, then
|
|
17
|
+
`parameter_values[slot_map[param_id] - number_of_indicators]` is the value of the
|
|
18
|
+
identified parameter as it was in the PGM.
|
|
19
|
+
|
|
20
|
+
Fields:
|
|
21
|
+
rvs: holds the random variables from the PGM as it was compiled, in order.
|
|
22
|
+
|
|
23
|
+
conditions: any conditions on `rvs` that were compiled into the circuit.
|
|
24
|
+
|
|
25
|
+
number_of_indicators: is the number of indicators in `rvs` which is
|
|
26
|
+
`sum(len(rv) for rv in rvs`. Specifically, `circuit.vars[i]` is the circuit variable
|
|
27
|
+
corresponding to the ith indicator, where `circuit` is `circuit_top.circuit` and
|
|
28
|
+
indicators are ordered as per `rvs`.
|
|
29
|
+
|
|
30
|
+
number_of_parameters: is the number of parameters from the PGM that are
|
|
31
|
+
represented as circuit variables. This may be zero if parameters from the PGM
|
|
32
|
+
were compiled as constants.
|
|
33
|
+
|
|
34
|
+
slot_map[x]: gives the index of the circuit variable corresponding to x,
|
|
35
|
+
where x is either a random variable indicator (Indicator) or a parameter id (ParamId).
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
rvs: Sequence[RandomVariable]
|
|
40
|
+
conditions: Sequence[Indicator]
|
|
41
|
+
circuit_top: CircuitNode
|
|
42
|
+
number_of_indicators: int
|
|
43
|
+
number_of_parameters: int
|
|
44
|
+
slot_map: SlotMap
|
|
45
|
+
parameter_values: NDArray
|
|
46
|
+
|
|
47
|
+
def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
48
|
+
"""
|
|
49
|
+
Print a dump of the circuit.
|
|
50
|
+
This is intended for debugging and demonstration purposes.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
prefix: optional prefix for indenting all lines.
|
|
54
|
+
indent: additional prefix to use for extra indentation.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# We infer names for the circuit variables, either as an indicator or as a parameter.
|
|
58
|
+
# The `var_names` will be passed to `circuit.dump`.
|
|
59
|
+
|
|
60
|
+
circuit: Circuit = self.circuit_top.circuit
|
|
61
|
+
var_names: List[str] = [''] * circuit.number_of_vars
|
|
62
|
+
|
|
63
|
+
# Name the circuit variables that are indicators
|
|
64
|
+
rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
|
|
65
|
+
slot_key: SlotKey
|
|
66
|
+
slot: int
|
|
67
|
+
for slot_key, slot in self.slot_map.items():
|
|
68
|
+
if isinstance(slot_key, Indicator):
|
|
69
|
+
rv = rvs_by_idx[slot_key.rv_idx]
|
|
70
|
+
state_idx = slot_key.state_idx
|
|
71
|
+
var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
|
|
72
|
+
|
|
73
|
+
# Name the circuit variables that are parameters
|
|
74
|
+
for i, param_value in enumerate(self.parameter_values):
|
|
75
|
+
slot = i + self.number_of_indicators
|
|
76
|
+
var_names[slot] = f'param[{i}] = {param_value}'
|
|
77
|
+
|
|
78
|
+
# Dump the circuit
|
|
79
|
+
circuit.dump(prefix=prefix, indent=indent, var_names=var_names)
|