compiled-knowledge 4.0.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (182) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37513 -0
  4. ck/circuit/_circuit_cy.cp313-win_amd64.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +75 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +27 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19833 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win_amd64.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +128 -0
  16. ck/circuit_compiler/interpret_compiler.py +255 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +552 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win_amd64.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +251 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +70 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +56 -0
  58. ck/example/truss.py +51 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +482 -0
  63. ck/in_out/parser_utils.py +189 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3482 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +236 -0
  73. ck/pgm_circuit/pgm_circuit.py +88 -0
  74. ck/pgm_circuit/program_with_slotmap.py +217 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +78 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +60 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win_amd64.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +572 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +52 -0
  100. ck/probability/pgm_probability_space.py +36 -0
  101. ck/probability/probability_space.py +627 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +106 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +234 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +164 -0
  118. ck/utils/local_config.py +278 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/ace/simple_ace_demo.py +18 -0
  129. ck_demos/all_demos.py +88 -0
  130. ck_demos/circuit/__init__.py +0 -0
  131. ck_demos/circuit/demo_circuit_dump.py +22 -0
  132. ck_demos/circuit/demo_derivatives.py +43 -0
  133. ck_demos/circuit_compiler/__init__.py +0 -0
  134. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  135. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  136. ck_demos/getting_started/__init__.py +0 -0
  137. ck_demos/getting_started/simple_demo.py +18 -0
  138. ck_demos/pgm/__init__.py +0 -0
  139. ck_demos/pgm/demo_pgm_dump.py +18 -0
  140. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  141. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  142. ck_demos/pgm/show_examples.py +25 -0
  143. ck_demos/pgm_compiler/__init__.py +0 -0
  144. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  145. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  146. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  147. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  148. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  149. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  150. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  151. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  152. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  153. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  154. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  155. ck_demos/pgm_inference/__init__.py +0 -0
  156. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  157. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  158. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  159. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  160. ck_demos/programs/__init__.py +0 -0
  161. ck_demos/programs/demo_program_buffer.py +24 -0
  162. ck_demos/programs/demo_program_multi.py +24 -0
  163. ck_demos/programs/demo_program_none.py +19 -0
  164. ck_demos/programs/demo_program_single.py +23 -0
  165. ck_demos/programs/demo_raw_program_dump.py +17 -0
  166. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  167. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  168. ck_demos/sampling/__init__.py +0 -0
  169. ck_demos/sampling/check_sampler.py +71 -0
  170. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  171. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  172. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  173. ck_demos/utils/__init__.py +0 -0
  174. ck_demos/utils/compare.py +120 -0
  175. ck_demos/utils/convert_network.py +45 -0
  176. ck_demos/utils/sample_model.py +216 -0
  177. ck_demos/utils/stop_watch.py +384 -0
  178. compiled_knowledge-4.0.0.dist-info/METADATA +50 -0
  179. compiled_knowledge-4.0.0.dist-info/RECORD +182 -0
  180. compiled_knowledge-4.0.0.dist-info/WHEEL +5 -0
  181. compiled_knowledge-4.0.0.dist-info/licenses/LICENSE.txt +21 -0
  182. compiled_knowledge-4.0.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,217 @@
1
+ from typing import Tuple, Sequence, Dict
2
+
3
+ import numpy as np
4
+
5
+ from ck.pgm import RandomVariable, Indicator, ParamId
6
+ from ck.pgm_circuit.slot_map import SlotMap, SlotKey
7
+ from ck.probability.probability_space import Condition, check_condition
8
+ from ck.program.program_buffer import ProgramBuffer
9
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
10
+
11
+
12
+ class ProgramWithSlotmap:
13
+ """
14
+ A class for bundling a program buffer with a slot-map, where the slot-map maps keys
15
+ (e.g., random variable indicators) to program input slots.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ program_buffer: ProgramBuffer,
21
+ slot_map: SlotMap,
22
+ rvs: Sequence[RandomVariable],
23
+ precondition: Sequence[Indicator]
24
+ ):
25
+ """
26
+ Construct a ProgramWithSlotmap object.
27
+
28
+ Args:
29
+ program_buffer: is a ProgramBuffer object which is a compiled circuit with input and output slots.
30
+ slot_map: a maps from a slot_key to input slot of 'program'.
31
+ rvs: a sequence of rvs used for setting program input slots, each rv
32
+ has a length and rv[i] is a unique 'indicator' across all rvs.
33
+ precondition: conditions on rvs that are compiled into the program.
34
+
35
+ Raises:
36
+ ValueError: if rvs contains duplicates.
37
+ """
38
+ self._program_buffer: ProgramBuffer = program_buffer
39
+ self._slot_map: SlotMap = slot_map
40
+ self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
41
+ self._precondition: Tuple[Indicator, ...] = tuple(precondition)
42
+
43
+ if len(rvs) != len(set(rv.idx for rv in rvs)):
44
+ raise ValueError('duplicate random variables provided')
45
+
46
+ # Given rv = rvs[i], then _rvs_slots[i][state_idx] gives the slot for rv[state_idx].
47
+ self._rvs_slots: Tuple[Tuple[int, ...], ...] = tuple(tuple(self._slot_map[ind] for ind in rv) for rv in rvs)
48
+
49
+ # Given rv = rvs[i], then _indicator_map maps[rv[j]] = (i, slot), where slot is for indicator rv[j].
50
+ self._indicator_map: Dict[Indicator, Tuple[int, int]] = {
51
+ ind: (i, slot_map[ind])
52
+ for i, rv in enumerate(rvs)
53
+ for ind in rv
54
+ }
55
+
56
+ @property
57
+ def rvs(self) -> Sequence[RandomVariable]:
58
+ """
59
+ What are the random variables considered as 'inputs'.
60
+ """
61
+ return self._rvs
62
+
63
+ @property
64
+ def precondition(self) -> Sequence[Indicator]:
65
+ """
66
+ Condition on `self.rvs` that is compiled into the program.
67
+ """
68
+ return self._precondition
69
+
70
+ @property
71
+ def slot_map(self) -> SlotMap:
72
+ return self._slot_map
73
+
74
+ def compute(self) -> NDArrayNumeric:
75
+ """
76
+ Execute the program to compute and return the result. As per `ProgramBuffer.compute`.
77
+
78
+ Warning:
79
+ when returning an array, the array is backed by the program buffer memory, not a copy.
80
+ """
81
+ return self._program_buffer.compute()
82
+
83
+ def compute_conditioned(self, *condition: Condition) -> NDArrayNumeric:
84
+ """
85
+ Compute the program value, after setting the given condition.
86
+
87
+ Equivalent to::
88
+
89
+ self.set_condition(*condition)
90
+ return self.compute()
91
+ """
92
+ self.set_condition(*condition)
93
+ return self.compute()
94
+
95
+ @property
96
+ def results(self) -> NDArrayNumeric:
97
+ """
98
+ Get the results of the last computation.
99
+ As per `ProgramBuffer.results`.
100
+
101
+ Warning:
102
+ the array is backed by the program buffer memory, not a copy.
103
+ """
104
+ return self._program_buffer.results
105
+
106
+ @property
107
+ def vars(self) -> NDArrayNumeric:
108
+ """
109
+ Return the input variables as a numpy array.
110
+ As per `ProgramBuffer.vars`.
111
+
112
+ Warning:
113
+ writing to the returned array will write to the input slots of the program buffer.
114
+ """
115
+ return self._program_buffer.vars
116
+
117
+ def __setitem__(self, item: int | slice | SlotKey | RandomVariable, value: float) -> None:
118
+ """
119
+ Set input slot value/s.
120
+ """
121
+ if isinstance(item, (int, slice)):
122
+ self._program_buffer[item] = value
123
+ elif isinstance(item, (Indicator, ParamId)):
124
+ self._program_buffer[self._slot_map[item]] = value
125
+ elif isinstance(item, RandomVariable):
126
+ for ind in item:
127
+ self._program_buffer[self._slot_map[ind]] = value
128
+ else:
129
+ raise IndexError(f'unknown index type: {type(item)}')
130
+
131
+ def __getitem__(self, item: int | slice | SlotKey | RandomVariable) -> NDArrayNumeric:
132
+ """
133
+ Get input slot value/s.
134
+ """
135
+ if isinstance(item, (int, slice)):
136
+ return self._program_buffer[item]
137
+ elif isinstance(item, (Indicator, ParamId)):
138
+ return self._program_buffer[self._slot_map[item]]
139
+ elif isinstance(item, RandomVariable):
140
+ return np.fromiter(
141
+ (self._program_buffer[self._slot_map[ind]] for ind in item),
142
+ dtype=self._program_buffer.dtype,
143
+ count=len(item)
144
+ )
145
+ else:
146
+ raise IndexError(f'unknown index type: {type(item)}')
147
+
148
+ def set_condition(self, *condition: Condition) -> None:
149
+ """
150
+ Set the input slots of random variables to 1, except where implied to
151
+ 0 according to the given conditions.
152
+
153
+ Specifically:
154
+ each slot corresponding to an indicator given condition will be set to 1;
155
+
156
+ if a random variable is mentioned in the given indicators, then all
157
+ slots for indicators for that random variable, except for slots corresponding
158
+ to an indicator given condition;
159
+
160
+ if a random variable is not mentioned in the given condition, that random variable
161
+ will have all its slots set to 1.
162
+ """
163
+ condition: Sequence[Indicator] = check_condition(condition)
164
+
165
+ ind_slot_groups = [[] for _ in self._rvs_slots]
166
+ for ind in condition:
167
+ rv_idx, slot = self._indicator_map[ind]
168
+ ind_slot_groups[rv_idx].append(slot)
169
+
170
+ slots: NDArray = self._program_buffer.vars
171
+ for rv_slots, ind_slots in zip(self._rvs_slots, ind_slot_groups):
172
+ if len(ind_slots) == 0:
173
+ # this rv _is not_ mentioned in the indicators - marginalise it
174
+ for slot in rv_slots:
175
+ slots[slot] = 1
176
+ else:
177
+ # this rv _is_ mentioned in the indicators - we set the mentioned slots to 1 and others to 0.
178
+ for slot in rv_slots:
179
+ slots[slot] = 0
180
+ for slot in ind_slots:
181
+ slots[slot] = 1
182
+
183
+ def set_rv(self, rv: RandomVariable, *values: float | int) -> None:
184
+ """
185
+ Set the input values of a random variable.
186
+
187
+ Args:
188
+ rv: a random variable whose indicators are in the slot map.
189
+ values: list of values
190
+
191
+ Assumes:
192
+ len(values) == len(rv).
193
+ """
194
+ for i in range(len(rv)):
195
+ self[rv[i]] = values[i]
196
+
197
+ def set_rvs_uniform(self, *rvs: RandomVariable) -> None:
198
+ """
199
+ Set the input values for each rv in rvs to 1 / len(rv).
200
+
201
+ Args:
202
+ rvs: a collection of random variable whose indicators are in the slot map.
203
+ """
204
+ for rv in rvs:
205
+ value = 1.0 / len(rv)
206
+ for ind in rv:
207
+ self[ind] = value
208
+
209
+ def set_all_rvs_uniform(self) -> None:
210
+ """
211
+ Set the input values for each rv in rvs to 1 / len(rv).
212
+ """
213
+ slots: NDArray = self._program_buffer.vars
214
+ for rv_slots in self._rvs_slots:
215
+ value = 1.0 / len(rv_slots)
216
+ for slot in rv_slots:
217
+ slots[slot] = value
@@ -0,0 +1,35 @@
1
+ from typing import Protocol, Optional, overload, Iterable, Tuple
2
+
3
+ from ck.pgm import Indicator, ParamId
4
+
5
+ # Type of a slot map key.
6
+ SlotKey = Indicator | ParamId
7
+
8
+
9
+ class SlotMap(Protocol):
10
+ """
11
+ A slotmap is a protocol for mapping keys (indicators and
12
+ parameter ids) to slots in a ProgramBuffer.
13
+
14
+ A Python dict[SlotKey, int] implements the protocol.
15
+ """
16
+
17
+ def __len__(self) -> int:
18
+ ...
19
+
20
+ @overload
21
+ def get(self, slot_key: SlotKey, default: None) -> Optional[int]:
22
+ ...
23
+
24
+ @overload
25
+ def get(self, slot_key: SlotKey, default: int) -> int:
26
+ ...
27
+
28
+ def get(self, slot_key: SlotKey, default: Optional[int]) -> Optional[int]:
29
+ ...
30
+
31
+ def __getitem__(self, slot_key: SlotKey) -> int:
32
+ ...
33
+
34
+ def items(self) -> Iterable[Tuple[SlotKey, int]]:
35
+ ...
File without changes
@@ -0,0 +1,78 @@
1
+ from typing import Optional, Sequence
2
+
3
+ from ck.circuit import CircuitNode, TmpConst, Circuit
4
+ from ck.circuit_compiler import CircuitCompiler
5
+ from ck.circuit_compiler import DEFAULT_CIRCUIT_COMPILER
6
+ from ck.pgm_circuit import PGMCircuit
7
+ from ck.program import RawProgram
8
+
9
+
10
+ def compile_results(
11
+ pgm_circuit: PGMCircuit,
12
+ results: Sequence[CircuitNode],
13
+ const_parameters: bool,
14
+ compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
15
+ ) -> RawProgram:
16
+ """
17
+ Compile a circuit to a raw program that calculates the given result.
18
+
19
+ Raises:
20
+ ValueError: if not all nodes are from the same circuit.
21
+
22
+ Args:
23
+ pgm_circuit: The circuit (and PGM) that will be compiled to a program.
24
+ results: the result circuit nodes for the returned program.
25
+ const_parameters: if True then any circuit variable representing a parameter value will
26
+ be made 'const' in the resulting program.
27
+ compiler: function from circuit nodes to raw program.
28
+
29
+ Returns:
30
+ a compiled RawProgram.
31
+ """
32
+ circuit: Circuit = pgm_circuit.circuit_top.circuit
33
+ if const_parameters and len(pgm_circuit.parameter_values) > 0:
34
+ with TmpConst(circuit) as tmp:
35
+ for slot, value in enumerate(pgm_circuit.parameter_values, start=pgm_circuit.number_of_indicators):
36
+ tmp.set_const(slot, value)
37
+ raw_program: RawProgram = compiler(*results, circuit=circuit)
38
+ else:
39
+ raw_program: RawProgram = compiler(*results, circuit=circuit)
40
+
41
+ return raw_program
42
+
43
+
44
+ def compile_param_derivatives(
45
+ pgm_circuit: PGMCircuit,
46
+ self_multiply: bool = False,
47
+ params_value: Optional[float | int] = 1,
48
+ compiler: CircuitCompiler = DEFAULT_CIRCUIT_COMPILER,
49
+ ) -> RawProgram:
50
+ """
51
+ Compile the circuit to a program for computing the partial derivatives of the parameters.
52
+ partial derivatives are co-indexed with pgm_circuit.parameter_values.
53
+
54
+ Typically, this will grow the circuit by the addition of circuit nodes to compute the derivatives.
55
+
56
+ Args:
57
+ pgm_circuit: The circuit (and PGM) that will be compiled to a program.
58
+ self_multiply: if true then each partial derivative df/dx will be multiplied by x.
59
+ params_value: if not None, then circuit vars representing parameters will be temporarily
60
+ set to this value for compiling the program. Default is 1.
61
+ compiler: function from circuit nodes to raw program.
62
+ """
63
+ top: CircuitNode = pgm_circuit.circuit_top
64
+ circuit: Circuit = top.circuit
65
+
66
+ start_idx = pgm_circuit.number_of_indicators
67
+ end_idx = start_idx + pgm_circuit.number_of_parameters
68
+ param_vars = circuit.vars[start_idx:end_idx]
69
+ derivatives = circuit.partial_derivatives(top, param_vars, self_multiply=self_multiply)
70
+
71
+ if params_value is not None:
72
+ with TmpConst(circuit) as tmp:
73
+ tmp.set_const(param_vars, params_value)
74
+ raw_program: RawProgram = compiler(*derivatives, circuit=circuit)
75
+ else:
76
+ raw_program: RawProgram = compiler(*derivatives, circuit=circuit)
77
+
78
+ return raw_program
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple, List
4
+
5
+ from ck.circuit import CircuitNode, Circuit, TmpConst
6
+ from ck.pgm import RandomVariable
7
+ from ck.pgm_circuit import PGMCircuit
8
+ from ck.pgm_circuit.program_with_slotmap import ProgramWithSlotmap
9
+ from ck.pgm_circuit.slot_map import SlotMap
10
+ from ck.pgm_circuit.support.compile_circuit import compile_results
11
+ from ck.probability.probability_space import check_condition, Condition
12
+ from ck.program.program_buffer import ProgramBuffer
13
+ from ck.program.raw_program import RawProgram
14
+ from ck.utils.np_extras import NDArray
15
+
16
+
17
+ class TargetMarginalsProgram(ProgramWithSlotmap):
18
+ """
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ pgm_circuit: PGMCircuit,
24
+ target_rv: RandomVariable,
25
+ const_parameters: bool = True,
26
+ ):
27
+ """
28
+ Construct a TargetMarginalsProgram object.
29
+
30
+ Compile the given circuit for computing marginal probabilities over the states of 'target_var'.
31
+
32
+ Args:
33
+ pgm_circuit: The circuit representing a PGM.
34
+ target_rv: the random variable to compute marginals for.
35
+ const_parameters: if True then any circuit variable representing a parameter value will
36
+ be made 'const' in the resulting program.
37
+ """
38
+ top_node: CircuitNode = pgm_circuit.circuit_top
39
+ circuit: Circuit = top_node.circuit
40
+ slot_map: SlotMap = pgm_circuit.slot_map
41
+ input_rvs: List[RandomVariable] = list(pgm_circuit.rvs)
42
+
43
+ target_vars = [circuit.vars[slot_map[ind]] for ind in target_rv]
44
+ cct_outputs = circuit.partial_derivatives(top_node, target_vars)
45
+
46
+ # Remove the target rv from the input rvs.
47
+ target_index = input_rvs.index(target_rv) # will throw if not found
48
+ del input_rvs[target_index]
49
+
50
+ with TmpConst(circuit) as tmp:
51
+ tmp.set_const(target_vars, 1)
52
+ raw_program: RawProgram = compile_results(
53
+ pgm_circuit=pgm_circuit,
54
+ results=cct_outputs,
55
+ const_parameters=const_parameters,
56
+ )
57
+
58
+ ProgramWithSlotmap.__init__(self, ProgramBuffer(raw_program), slot_map, input_rvs, pgm_circuit.conditions)
59
+
60
+ # additional fields
61
+ self._x_slots: List[List[int]] = [[slot_map[ind] for ind in rv] for rv in input_rvs]
62
+ self._y_size: int = raw_program.number_of_results
63
+ self._target_rv: RandomVariable = target_rv
64
+ self._number_of_indicators: int = pgm_circuit.number_of_indicators
65
+ self._z_cache: Optional[float] = None
66
+
67
+ # consistency check
68
+ assert (self._y_size == len(self._target_rv))
69
+
70
+ if not const_parameters:
71
+ # set the parameter slots
72
+ self.vars[pgm_circuit.number_of_indicators:] = pgm_circuit.parameter_values
73
+
74
+ @property
75
+ def target_rv(self) -> RandomVariable:
76
+ return self._target_rv
77
+
78
+ def map(self, condition: Condition = ()) -> Tuple[float, int]:
79
+ """
80
+ Return the maximum a posterior (MAP) state of the target variable.
81
+
82
+ Args:
83
+ condition: any conditioning indicators.
84
+
85
+ Returns:
86
+ (pr, state_idx) where
87
+ pr is the MAP probability
88
+ state_idx: is the MAP state index of `self.target_rv`.
89
+ """
90
+ self.set_condition(*check_condition(condition))
91
+ self.compute()
92
+ results: NDArray = self.results
93
+ z: float = results.sum()
94
+
95
+ max_p = -1
96
+ max_i = -1
97
+ for i in range(self._y_size):
98
+ p = results[i]
99
+ if p > max_p:
100
+ max_p = p
101
+ max_i = i
102
+
103
+ return max_p / z, max_i