compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,234 @@
1
+ from typing import Tuple, Iterator, Sequence, Dict, Iterable
2
+
3
+ from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
4
+ from ck.pgm_circuit.slot_map import SlotMap, SlotKey
5
+ from ck.probability.probability_space import Condition, check_condition
6
+ from ck.program.program_buffer import ProgramBuffer
7
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
8
+
9
+
10
+ class ProgramWithSlotmap:
11
+ """
12
+ A class for bundling a program buffer with a slot-map, where the slot-map maps keys
13
+ (e.g., random variable indicators) to program input slots.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ program_buffer: ProgramBuffer,
19
+ slot_map: SlotMap,
20
+ rvs: Sequence[RandomVariable],
21
+ precondition: Sequence[Indicator]
22
+ ):
23
+ """
24
+ Construct a ProgramWithSlotmap object.
25
+
26
+ Args:
27
+ program_buffer: is a ProgramBuffer object which is a compiled circuit with input and output slots.
28
+ slot_map: a maps from a slot_key to input slot of 'program'.
29
+ rvs: a sequence of rvs used for setting program input slots, each rv
30
+ has a length and rv[i] is a unique 'indicator' across all rvs.
31
+ precondition: conditions on rvs that are compiled into the program.
32
+
33
+ """
34
+ self._program_buffer: ProgramBuffer = program_buffer
35
+ self._slot_map: SlotMap = slot_map
36
+ self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
37
+ self._precondition: Sequence[Indicator] = precondition
38
+
39
+ if len(rvs) != len(set(rv.idx for rv in rvs)):
40
+ raise ValueError('duplicate random variables provided')
41
+
42
+ # Given rv = rvs[i], then _rvs_slots[i][state_idx] gives the slot for rv[state_idx].
43
+ self._rvs_slots: Tuple[Tuple[int, ...], ...] = tuple(tuple(self._slot_map[ind] for ind in rv) for rv in rvs)
44
+
45
+ # Given rv = rvs[i], then _indicator_map maps[rv[j]] = (i, slot), where slot is for indicator rv[j].
46
+ self._indicator_map: Dict[Indicator, Tuple[int, int]] = {
47
+ ind: (i, slot_map[ind])
48
+ for i, rv in enumerate(rvs)
49
+ for ind in rv
50
+ }
51
+
52
+ @property
53
+ def rvs(self) -> Sequence[RandomVariable]:
54
+ """
55
+ What are the random variables considered as 'inputs'.
56
+ """
57
+ return self._rvs
58
+
59
+ @property
60
+ def precondition(self) -> Sequence[Indicator]:
61
+ """
62
+ Condition on `self.rvs` that is compiled into the program.
63
+ """
64
+ return self._precondition
65
+
66
+ @property
67
+ def slot_map(self) -> SlotMap:
68
+ return self._slot_map
69
+
70
+ def instances(self, flip: bool = False) -> Iterator[Instance]:
71
+ """
72
+ Enumerate instances of the random variables.
73
+
74
+ Each instance is a tuples of state indexes, co-indexed with the given random variables.
75
+
76
+ The order is the natural index order (i.e., last random variable changing most quickly).
77
+
78
+ Args:
79
+ flip: if true, then first random variable changes most quickly.
80
+
81
+ Returns:
82
+ an iteration over tuples, each tuple holds state indexes
83
+ co-indexed with the given random variables.
84
+ """
85
+ return rv_instances(*self._rvs, flip=flip)
86
+
87
+ def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
88
+ """
89
+ Enumerate instances of the random variables.
90
+
91
+ Each instance is a tuples of indicators, co-indexed with the given random variables.
92
+
93
+ The order is the natural index order (i.e., last random variable changing most quickly).
94
+
95
+ Args:
96
+ flip: if true, then first random variable changes most quickly.
97
+
98
+ Returns:
99
+ an iteration over tuples, each tuples holds random variable indicators
100
+ co-indexed with the given random variables.
101
+ """
102
+ return rv_instances_as_indicators(*self._rvs, flip=flip)
103
+
104
+ def compute(self) -> NDArrayNumeric:
105
+ """
106
+ Execute the program to compute and return the result. As per `ProgramBuffer.compute`.
107
+
108
+ Warning:
109
+ when returning an array, the array is backed by the program buffer memory, not a copy.
110
+ """
111
+ return self._program_buffer.compute()
112
+
113
+ def compute_conditioned(self, *condition: Condition) -> NDArrayNumeric:
114
+ """
115
+ Equivalent to:
116
+ self.set_condition(*condition)
117
+ return self.compute()
118
+ """
119
+ self.set_condition(*condition)
120
+ return self.compute()
121
+
122
+ @property
123
+ def results(self) -> NDArrayNumeric:
124
+ """
125
+ Get the results of the last computation.
126
+ As per `ProgramBuffer.results`.
127
+
128
+ Warning:
129
+ the array is backed by the program buffer memory, not a copy.
130
+ """
131
+ return self._program_buffer.results
132
+
133
+ @property
134
+ def vars(self) -> NDArrayNumeric:
135
+ """
136
+ Return the input variables as a numpy array.
137
+ As per `ProgramBuffer.vars`.
138
+
139
+ Warning:
140
+ writing to the returned array will write to the input slots of the program buffer.
141
+ """
142
+ return self._program_buffer.vars
143
+
144
+ def __setitem__(self, item: int | slice | SlotKey | Iterable[SlotKey], value: float) -> None:
145
+ """
146
+ Set one or more input slot values, identified by slot keys.
147
+ """
148
+ if isinstance(item, (int, slice)):
149
+ self._program_buffer[item] = value
150
+ elif isinstance(item, (Indicator, ParamId)):
151
+ self._program_buffer[self._slot_map[item]] = value
152
+ else:
153
+ # Assume its iterable
154
+ for i in item:
155
+ self[i] = value
156
+
157
+ def __getitem__(self, item: int | slice | SlotKey) -> NDArrayNumeric:
158
+ """
159
+ Get an input slot value, identified by a slot key.
160
+ """
161
+ if isinstance(item, (int, slice)):
162
+ return self._program_buffer[item]
163
+ elif isinstance(item, (Indicator, ParamId)):
164
+ return self._program_buffer[self._slot_map[item]]
165
+ else:
166
+ raise IndexError('unknown index type')
167
+
168
+ def set_condition(self, *condition: Condition) -> None:
169
+ """
170
+ Set the input slots of random variables to 1, except where implied to
171
+ 0 according to the given conditions.
172
+
173
+ Specifically:
174
+ each slot corresponding to an indicator given condition will be set to 1;
175
+
176
+ if a random variable is mentioned in the given indicators, then all
177
+ slots for indicators for that random variable, except for slots corresponding
178
+ to an indicator given condition;
179
+
180
+ if a random variable is not mentioned in the given condition, that random variable
181
+ will have all its slots set to 1.
182
+ """
183
+ condition: Sequence[Indicator] = check_condition(condition)
184
+
185
+ ind_slot_groups = [[] for _ in self._rvs_slots]
186
+ for ind in condition:
187
+ rv_idx, slot = self._indicator_map[ind]
188
+ ind_slot_groups[rv_idx].append(slot)
189
+
190
+ slots: NDArray = self._program_buffer.vars
191
+ for rv_slots, ind_slots in zip(self._rvs_slots, ind_slot_groups):
192
+ if len(ind_slots) == 0:
193
+ # this rv _is not_ mentioned in the indicators - marginalise it
194
+ for slot in rv_slots:
195
+ slots[slot] = 1
196
+ else:
197
+ # this rv _is_ mentioned in the indicators - we set the mentioned slots to 1 and others to 0.
198
+ for slot in rv_slots:
199
+ slots[slot] = 0
200
+ for slot in ind_slots:
201
+ slots[slot] = 1
202
+
203
+ def set_rv(self, rv: RandomVariable, *values: float | int) -> None:
204
+ """
205
+ Set the input values of a random variable.
206
+
207
+ Args:
208
+ rv: a random variable whose indicators are in the slot map.
209
+ values: list of values, assumes len(values) == len(rv).
210
+ """
211
+ for i in range(len(rv)):
212
+ self[rv[i]] = values[i]
213
+
214
+ def set_rvs_uniform(self, *rvs: RandomVariable) -> None:
215
+ """
216
+ Set the input values for each rv in rvs to 1 / len(rv).
217
+
218
+ Args:
219
+ rvs: a collection of random variable whose indicators are in the slot map.
220
+ """
221
+ for rv in rvs:
222
+ value = 1.0 / len(rv)
223
+ for ind in rv:
224
+ self[ind] = value
225
+
226
+ def set_all_rvs_uniform(self) -> None:
227
+ """
228
+ Set the input values for each rv in rvs to 1 / len(rv).
229
+ """
230
+ slots: NDArray = self._program_buffer.vars
231
+ for rv_slots in self._rvs_slots:
232
+ value = 1.0 / len(rv_slots)
233
+ for slot in rv_slots:
234
+ slots[slot] = value
@@ -0,0 +1,35 @@
1
+ from typing import Protocol, Optional, overload, Iterable, Tuple
2
+
3
+ from ck.pgm import Indicator, ParamId
4
+
5
+ # Type of a slot map key.
6
+ SlotKey = Indicator | ParamId
7
+
8
+
9
+ class SlotMap(Protocol):
10
+ """
11
+ A slotmap is a protocol for mapping keys (indicators and
12
+ parameter ids) to slots in a ProgramBuffer.
13
+
14
+ A Python dict[SlotKey, int] implements the protocol.
15
+ """
16
+
17
+ def __len__(self) -> int:
18
+ ...
19
+
20
+ @overload
21
+ def get(self, slot_key: SlotKey, default: None) -> Optional[int]:
22
+ ...
23
+
24
+ @overload
25
+ def get(self, slot_key: SlotKey, default: int) -> int:
26
+ ...
27
+
28
+ def get(self, slot_key: SlotKey, default: Optional[int]) -> Optional[int]:
29
+ ...
30
+
31
+ def __getitem__(self, slot_key: SlotKey) -> int:
32
+ ...
33
+
34
+ def items(self) -> Iterable[Tuple[SlotKey, int]]:
35
+ ...
File without changes
@@ -0,0 +1,83 @@
1
+ from typing import Optional, Sequence
2
+
3
+ from ck.circuit import CircuitNode, TmpConst, Circuit
4
+ from ck.circuit_compiler import CircuitCompiler
5
+ from ck.circuit_compiler.llvm_compiler import DataType, DEFAULT_TYPE_INFO, compile_circuit
6
+ from ck.circuit_compiler import DEFAULT_CIRCUIT_COMPILER
7
+ from ck.pgm_circuit import PGMCircuit
8
+ from ck.program import RawProgram
9
+
10
+
11
+ def compile_results(
12
+ pgm_circuit: PGMCircuit,
13
+ results: Sequence[CircuitNode],
14
+ const_parameters: bool,
15
+ compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
16
+ ) -> RawProgram:
17
+ """
18
+ Compile a circuit to a raw program that calculates the given result.
19
+
20
+ Raises:
21
+ ValueError: if not all nodes are from the same circuit.
22
+
23
+ Args:
24
+ pgm_circuit: The circuit (and PGM) that will be compiled to a program.
25
+ results: the result circuit nodes for the returned program.
26
+ const_parameters: if True then any circuit variable representing a parameter value will
27
+ be made 'const' in the resulting program.
28
+ compiler: function from circuit nodes to raw program.
29
+
30
+ Returns:
31
+ a compiled RawProgram.
32
+ """
33
+ circuit: Circuit = pgm_circuit.circuit_top.circuit
34
+ if const_parameters:
35
+ parameter_values = pgm_circuit.parameter_values
36
+ number_of_indicators = pgm_circuit.number_of_indicators
37
+ with TmpConst(circuit) as tmp:
38
+ for slot, value in enumerate(parameter_values, start=number_of_indicators):
39
+ tmp.set_const(slot, value)
40
+ raw_program: RawProgram = compiler(*results, circuit=circuit)
41
+ else:
42
+ raw_program: RawProgram = compiler(*results, circuit=circuit)
43
+
44
+ return raw_program
45
+
46
+
47
+ def compile_param_derivatives(
48
+ pgm_circuit: PGMCircuit,
49
+ self_multiply: bool = False,
50
+ params_value: Optional[float | int] = 1,
51
+ data_type: DataType = DEFAULT_TYPE_INFO,
52
+ ) -> RawProgram:
53
+ """
54
+ Compile the circuit to a program for computing the partial derivatives of the parameters.
55
+ partial derivatives are co-indexed with pgm_circuit.parameter_values.
56
+
57
+ Typically, this will grow the circuit by the addition of circuit nodes to compute the derivatives.
58
+
59
+ This uses the LLVM circuit compiler.
60
+
61
+ Args:
62
+ pgm_circuit: The circuit (and PGM) that will be compiled to a program.
63
+ self_multiply: if true then each partial derivative df/dx will be multiplied by x.
64
+ params_value: if not None, then circuit vars representing parameters will be temporarily
65
+ set to this value for compiling the program. Default is 1.
66
+ data_type: What data type to use for arithmetic calculations. Either a DataType member or TypeInfo.
67
+ """
68
+ top: CircuitNode = pgm_circuit.circuit_top
69
+ circuit: Circuit = top.circuit
70
+
71
+ start_idx = pgm_circuit.number_of_indicators
72
+ end_idx = start_idx + pgm_circuit.number_of_parameters
73
+ param_vars = circuit.vars[start_idx:end_idx]
74
+ derivatives = circuit.partial_derivatives(top, param_vars, self_multiply=self_multiply)
75
+
76
+ if params_value is not None:
77
+ with TmpConst(circuit) as tmp:
78
+ tmp.set_const(param_vars, params_value)
79
+ raw_program: RawProgram = compile_circuit(*derivatives, circuit=circuit, data_type=data_type)
80
+ else:
81
+ raw_program: RawProgram = compile_circuit(*derivatives, circuit=circuit, data_type=data_type)
82
+
83
+ return raw_program
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple, List
4
+
5
+ from ck.circuit import CircuitNode, Circuit, TmpConst
6
+ from ck.pgm import RandomVariable
7
+ from ck.pgm_circuit import PGMCircuit
8
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
9
+ from ck.pgm_circuit.slot_map import SlotMap
10
+ from ck.pgm_circuit.support.compile_circuit import compile_results
11
+ from ck.probability.probability_space import check_condition, Condition
12
+ from ck.program.program_buffer import ProgramBuffer
13
+ from ck.program.raw_program import RawProgram
14
+ from ck.utils.np_extras import NDArray
15
+
16
+
17
+ class TargetMarginalsProgram(ProgramWithSlotmap):
18
+ """
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ pgm_circuit: PGMCircuit,
24
+ target_rv: RandomVariable,
25
+ const_parameters: bool = True,
26
+ ):
27
+ """
28
+ Construct a TargetMarginalsProgram object.
29
+
30
+ Compile the given circuit for computing marginal probabilities over the states of 'target_var'.
31
+
32
+ Args:
33
+ pgm_circuit: The circuit representing a PGM.
34
+ target_rv: the random variable to compute marginals for.
35
+ const_parameters: if True then any circuit variable representing a parameter value will
36
+ be made 'const' in the resulting program.
37
+ """
38
+ top_node: CircuitNode = pgm_circuit.circuit_top
39
+ circuit: Circuit = top_node.circuit
40
+ slot_map: SlotMap = pgm_circuit.slot_map
41
+ input_rvs: List[RandomVariable] = list(pgm_circuit.rvs)
42
+
43
+ target_vars = [circuit.vars[slot_map[ind]] for ind in target_rv]
44
+ cct_outputs = circuit.partial_derivatives(top_node, target_vars)
45
+
46
+ # Remove the target rv from the input rvs.
47
+ target_index = input_rvs.index(target_rv) # will throw if not found
48
+ del input_rvs[target_index]
49
+
50
+ with TmpConst(circuit) as tmp:
51
+ tmp.set_const(target_vars, 1)
52
+ raw_program: RawProgram = compile_results(
53
+ pgm_circuit=pgm_circuit,
54
+ results=cct_outputs,
55
+ const_parameters=const_parameters,
56
+ )
57
+
58
+ ProgramWithSlotmap.__init__(self, ProgramBuffer(raw_program), slot_map, input_rvs, pgm_circuit.conditions)
59
+
60
+ # additional fields
61
+ self._x_slots: List[List[int]] = [[slot_map[ind] for ind in rv] for rv in input_rvs]
62
+ self._y_size: int = raw_program.number_of_results
63
+ self._target_rv: RandomVariable = target_rv
64
+ self._number_of_indicators: int = pgm_circuit.number_of_indicators
65
+ self._z_cache: Optional[float] = None
66
+
67
+ # consistency check
68
+ assert (self._y_size == len(self._target_rv))
69
+
70
+ if not const_parameters:
71
+ # set the parameter slots
72
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
73
+
74
+ @property
75
+ def target_rv(self) -> RandomVariable:
76
+ return self._target_rv
77
+
78
+ def map(self, condition: Condition = ()) -> Tuple[float, int]:
79
+ """
80
+ Return the maximum a posterior (MAP) state of the target variable.
81
+
82
+ Args:
83
+ condition: any conditioning indicators.
84
+
85
+ Returns:
86
+ (pr, state_idx) where
87
+ pr is the MAP probability
88
+ state_idx: is the MAP state index of `self.target_rv`.
89
+ """
90
+ self.set_condition(*check_condition(condition))
91
+ self.compute()
92
+ results: NDArray = self.results
93
+ z: float = results.sum()
94
+
95
+ max_p = -1
96
+ max_i = -1
97
+ for i in range(self._y_size):
98
+ p = results[i]
99
+ if p > max_p:
100
+ max_p = p
101
+ max_i = i
102
+
103
+ return max_p / z, max_i