compiled-knowledge 4.0.0a20__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp313-win_amd64.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win_amd64.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win_amd64.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win_amd64.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,268 @@
1
+ import datetime
2
+ import importlib
3
+ import importlib.machinery
4
+ import math
5
+ import types
6
+ from pathlib import Path
7
+ from typing import Optional, List, Any, Sequence
8
+
9
+ from ck.pgm import PGM, DensePotentialFunction, SparsePotentialFunction, State, CompactPotentialFunction, \
10
+ ClausePotentialFunction, CPTPotentialFunction, default_pgm_name
11
+
12
+
13
+ def write_python(
14
+ pgm: PGM,
15
+ pgm_name: str = 'pgm',
16
+ import_module: str = 'ck.pgm',
17
+ package_name: Optional[str] = None,
18
+ use_variable_names: bool = False,
19
+ include_potential_functions: bool = True,
20
+ include_comment: bool = False,
21
+ note: Optional[str] = None,
22
+ author: Optional[str] = None,
23
+ file=None
24
+ ) -> None:
25
+ """
26
+ Print a Python script that would build the PGM.
27
+
28
+ Args:
29
+ pgm: The PGM to write.
30
+ pgm_name: A Python variable name to use for the PGM object.
31
+ import_module: if not None, then an 'import' command will be included to import the named module.
32
+ package_name: if None, then 'PGM()' is used to create the PGM object,
33
+ otherwise '{package_name}.PGM()' is used.
34
+ use_variable_names: If false, then Python variable names for RandomVariable objects will
35
+ be system generated, otherwise the random variable name will be used.
36
+ include_potential_functions: whether to dump the potential functions or not.
37
+ include_comment: include a Python comment or not.
38
+ note: An explicit comment to include.
39
+ author: an optional author name to go in the comment.
40
+ file: optional file argument to the print function.
41
+ """
42
+
43
+ def _print(*args, **kwargs) -> None:
44
+ print(*args, file=file, **kwargs)
45
+
46
+ constructor_args = '' if pgm.name == default_pgm_name(pgm) else repr(pgm.name)
47
+ class_name = PGM.__name__
48
+
49
+ if use_variable_names:
50
+ def rv_name(rv):
51
+ return rv.name
52
+ else:
53
+ def rv_name(rv):
54
+ return f'{pgm_name}_rv{rv.idx}'
55
+
56
+ has_a_comment = include_comment or note is not None
57
+ if has_a_comment:
58
+ _print('"""')
59
+ _print(f'PGM name: {pgm.name}')
60
+ _print(f'{datetime.datetime.now()}')
61
+ _print()
62
+ if note is not None:
63
+ _print(str(note))
64
+ _print()
65
+ if include_comment:
66
+ num_states: int = pgm.number_of_states
67
+ number_of_parameters = sum(factor.function.number_of_parameters for factor in pgm.factors)
68
+ number_of_nz_parameters = sum(function.number_of_parameters for function in pgm.non_zero_functions)
69
+ precision = 3
70
+ _print(f'number of random variables: {pgm.number_of_rvs:,}')
71
+ _print(f'number of indicators: {pgm.number_of_indicators:,}')
72
+ _print(f'number of states: {num_states}')
73
+ _print(f'log 2 of states: {math.log2(num_states):,.{precision}f}')
74
+ _print(f'number of factors: {pgm.number_of_factors:,}')
75
+ _print(f'number of functions: {pgm.number_of_functions:,}')
76
+ _print(f'number of parameters: {number_of_parameters:,}')
77
+ _print(f'number of functions (excluding ZeroPotentialFunction): {pgm.number_of_non_zero_functions:,}')
78
+ _print(f'number of parameters (excluding ZeroPotentialFunction): {number_of_nz_parameters:,}')
79
+ _print(f'Bayesian structure: {pgm.is_structure_bayesian}')
80
+ _print(f'CPT factors: {pgm.factors_are_cpts()}')
81
+ _print()
82
+ _print('Usage:')
83
+ rv_list = ', '.join([rv_name(rv) for rv in pgm.rvs])
84
+ sep = '' if pgm.number_of_rvs == 0 else ', '
85
+ _print(f'from {pgm.name} import {pgm_name}{sep}{rv_list}')
86
+ _print()
87
+ if has_a_comment:
88
+ _print('"""')
89
+ if author is not None:
90
+ _print(f'__author__ = {author!r}')
91
+ if has_a_comment or author is not None:
92
+ _print()
93
+
94
+ if import_module is not None:
95
+ if package_name is None:
96
+ _print(f'from {import_module} import {class_name}')
97
+ else:
98
+ _print(f'import {import_module} as {package_name}')
99
+
100
+ if package_name is None:
101
+ _print(f'{pgm_name} = {class_name}({constructor_args})')
102
+ else:
103
+ _print(f'{pgm_name} = {package_name}.{class_name}({constructor_args})')
104
+
105
+ # Print random variables
106
+ for rv in pgm.rvs:
107
+ if rv.is_default_states():
108
+ states = len(rv.states)
109
+ else:
110
+ states = _repr_states(rv.states)
111
+ _print(f'{rv_name(rv)} = {pgm_name}.new_rv({rv.name!r}, {states})')
112
+
113
+ # Print factors
114
+ for factor in pgm.factors:
115
+ rvs = ', '.join([rv_name(rv) for rv in factor.rvs])
116
+ factor_name = f'{pgm_name}_factor{factor.idx}'
117
+ _print(f'{factor_name} = {pgm_name}.new_factor({rvs})')
118
+
119
+ # Print potential functions
120
+ if include_potential_functions:
121
+ seen_functions = {}
122
+ for factor in pgm.factors:
123
+ if factor.is_zero:
124
+ continue
125
+
126
+ _print()
127
+ factor_name = f'{pgm_name}_factor{factor.idx}'
128
+ function = factor.function
129
+
130
+ function_name = seen_functions.get(function)
131
+ if function_name is not None:
132
+ _print(f'{factor_name}.function = {function_name}')
133
+ continue
134
+
135
+ function_name = f'{pgm_name}_function{len(seen_functions)}'
136
+ seen_functions[function] = function_name
137
+
138
+ if isinstance(function, DensePotentialFunction):
139
+ _print(f'{function_name} = {factor_name}.set_dense()')
140
+ _write_python_dense_function(function_name, function, _print)
141
+
142
+ elif isinstance(function, SparsePotentialFunction):
143
+ _print(f'{function_name} = {factor_name}.set_sparse()')
144
+ for key, idx, value in function.keys_with_param:
145
+ _print(f'{function_name}[{key}] = {value}')
146
+
147
+ elif isinstance(function, CompactPotentialFunction):
148
+ _print(f'{function_name} = {factor_name}.set_compact()')
149
+ for key, value in function.items():
150
+ if value != 0:
151
+ _print(f'{function_name}[{key}] = {value}')
152
+
153
+ elif isinstance(function, ClausePotentialFunction):
154
+ states = ', '.join(repr(v) for v in function.clause)
155
+ _print(f'{function_name} = {factor_name}.set_clause({states})')
156
+
157
+ elif isinstance(function, CPTPotentialFunction):
158
+ _print(f'{function_name} = {factor_name}.set_cpt()')
159
+ for parent_states, cpd in function.cpds():
160
+ cpd = ', '.join(repr(v) for v in cpd)
161
+ _print(f'{function_name}.set_cpd({parent_states}, ({cpd}))')
162
+
163
+ else:
164
+ raise RuntimeError(f'unimplemented writing of function type {function.__class__.__name__}')
165
+
166
+
167
+ def read_python(
168
+ source: str | Path,
169
+ var_name: Optional[str] = None,
170
+ module_name: Optional[str] = None,
171
+ ) -> PGM:
172
+ """
173
+ Load a PGM previously written using `write_python`.
174
+
175
+ Args:
176
+ source: The source file name or file path
177
+ var_name: The name of the PGM variable to load, if None, then a name will be found.
178
+ module_name: The name of the module that file will be loaded as, default is the name of the source.
179
+
180
+ Returns:
181
+ the loaded PGM.
182
+
183
+ Raises:
184
+ RuntimeError: if a unique PGM object is not found (with the given var_name).
185
+ """
186
+ if module_name is None:
187
+ if isinstance(source, str):
188
+ module_name = Path(source).name
189
+ elif isinstance(source, Path):
190
+ module_name = source.name
191
+
192
+ loader = importlib.machinery.SourceFileLoader(module_name, source)
193
+ module = types.ModuleType(loader.name)
194
+ loader.exec_module(module)
195
+
196
+ if var_name is not None:
197
+ pgm = getattr(module, var_name)
198
+ if not isinstance(pgm, PGM):
199
+ raise RuntimeError(f'object {var_name} is not a PGM')
200
+ return pgm
201
+ else:
202
+ potentials: List[PGM] = [
203
+ value
204
+ for var, value in vars(module).items()
205
+ if not var.startswith('_') and isinstance(value, PGM)
206
+ ]
207
+ if len(potentials) != 1:
208
+ raise RuntimeError(f'unique PGM object not found')
209
+ return potentials[0]
210
+
211
+
212
+ def _write_python_dense_function(function_name: str, function: DensePotentialFunction, _print) -> None:
213
+ """
214
+ Support method for `write_python`.
215
+ """
216
+ num_params = function.number_of_parameters
217
+ if num_params > 0:
218
+ indent = ' '
219
+ wrap_count = 5
220
+ _print(f'{function_name}.set_flat(', end='')
221
+
222
+ if num_params >= wrap_count:
223
+ _print('', indent, sep='\n', end='')
224
+ for i in range(num_params):
225
+ _print(repr(function.param_value(i)), end='')
226
+ next_i = i + 1
227
+ if next_i != num_params:
228
+ _print(', ', end='')
229
+ if next_i % wrap_count == 0:
230
+ _print()
231
+ _print(indent, end='')
232
+ if num_params >= wrap_count:
233
+ _print()
234
+
235
+ _print(')')
236
+
237
+
238
+ def _isnan(value: Any) -> bool:
239
+ """
240
+ Returns:
241
+ True only if the given value is a float and is NaN.
242
+ """
243
+ return isinstance(value, float) and math.isnan(value)
244
+
245
+
246
+ def _repr_states(states: Sequence[State]) -> str:
247
+ """
248
+ If states contain float('nan') then write_python needs to avoid the issue
249
+ that repr(float('nan')) is not parsable by Python.
250
+
251
+ See https://bugs.python.org/issue1732212
252
+ """
253
+ return '(' + ', '.join(_repr_state(state) for state in states) + ')'
254
+
255
+
256
+ def _repr_state(state: State) -> str:
257
+ """
258
+ Render a state as a string.
259
+
260
+ If states is float('nan') then write_python needs to avoid the issue
261
+ that repr(float('nan')) is not parsable by Python.
262
+
263
+ See https://bugs.python.org/issue1732212
264
+ """
265
+ if _isnan(state):
266
+ return "float('nan')"
267
+ else:
268
+ return repr(state)
@@ -0,0 +1,111 @@
1
+ import pathlib as _pathlib
2
+ import sys as _sys
3
+ from typing import Sequence
4
+
5
+ from ck.pgm import RandomVariable, Factor, PotentialFunction, PGM
6
+
7
+
8
+ def render_bayesian_network(
9
+ pgm: PGM,
10
+ out=None,
11
+ *,
12
+ check_structure_bayesian: bool = True,
13
+ ) -> None:
14
+ """
15
+ Render a Bayesian network PGM as a BUGS model file in ODC format.
16
+
17
+ Args:
18
+ pgm: is a PGM object.
19
+ out: is an output file or None for stdout.
20
+ check_structure_bayesian: If True, then raise an exception if not pgm.is_structure_bayesian.
21
+ """
22
+ if check_structure_bayesian:
23
+ if not pgm.is_structure_bayesian:
24
+ raise RuntimeError('attempting to render a PGM with non-Bayesian structure')
25
+
26
+ if out is None:
27
+ _render_bayesian_network(pgm, _sys.stdout)
28
+ elif isinstance(out, (str, _pathlib.Path)):
29
+ with open(out, 'w') as file:
30
+ _render_bayesian_network(pgm, file)
31
+ else:
32
+ _render_bayesian_network(pgm, out)
33
+
34
+
35
+ # ============================================================
36
+ # Private support
37
+ # ============================================================
38
+
39
+
40
+ def _render_bayesian_network(pgm: PGM, out):
41
+ """
42
+ See render_bayesian_network.
43
+ """
44
+
45
+ def write(*args, sep=' ', end='\n'):
46
+ out.write(sep.join(str(arg) for arg in args))
47
+ out.write(end)
48
+
49
+ write('model {')
50
+ seen_child_rvs = set()
51
+ for factor in pgm.factors:
52
+ child = factor.rvs[0]
53
+ parents = factor.rvs[1:]
54
+
55
+ if child in seen_child_rvs:
56
+ raise RuntimeError(f'duplicated child random variable in factors: {child}')
57
+ seen_child_rvs.add(child)
58
+
59
+ _render_rv(child, parents, write)
60
+ write('}')
61
+
62
+ write('list(')
63
+ for factor in pgm.factors[:-1]:
64
+ _render_factor(factor, ',', write)
65
+ _render_factor(pgm.factors[-1], '', write)
66
+ write(')')
67
+
68
+
69
+ def _render_rv(child: RandomVariable, parents: Sequence[RandomVariable], write):
70
+ name = child.name
71
+ number_of_states = len(child)
72
+ write(f' {name} ~ dcat(p.{name}[', end='')
73
+ for parent in parents:
74
+ write(f'{parent.name},', end='')
75
+ write(f'1:{number_of_states}])')
76
+
77
+
78
+ def _render_factor(factor: Factor, delim, write):
79
+ child = factor.rvs[0]
80
+ name = f'p.{child.name}'
81
+ write(f' {name} = ', end='')
82
+
83
+ if len(factor.rvs) == 1:
84
+ _write_param_values(factor.function, write)
85
+ else:
86
+ parents = factor.rvs[1:]
87
+ dims = [len(rv) for rv in parents] + [len(child)]
88
+ dims_str = ','.join(str(d) for d in dims)
89
+ write('structure(.Data = ', end='')
90
+ _write_param_values(factor.function, write)
91
+ write(f', .Dim = c({dims_str}))', end='')
92
+
93
+ write(delim)
94
+
95
+
96
+ def _write_param_values(function: PotentialFunction, write):
97
+ write('c(', end='')
98
+
99
+ num_child_states = function.shape[0]
100
+ last_parent_state = function.number_of_parent_states - 1
101
+ last_child_state = num_child_states - 1
102
+
103
+ for i, parent_key in enumerate(function.parent_instances(flip=True)):
104
+ for j in range(num_child_states):
105
+ key = (j,) + parent_key
106
+ value = function[key]
107
+ write(value, end='')
108
+
109
+ if i != last_parent_state or j != last_child_state:
110
+ write(',', end='')
111
+ write(')', end='')
@@ -0,0 +1,177 @@
1
+ import pathlib as _pathlib
2
+ import re as _re
3
+ import sys as _sys
4
+ from typing import Iterable, Set, List
5
+
6
+ from ck.in_out.parser_utils import escape_string
7
+ from ck.pgm import PGM, RandomVariable
8
+
9
+
10
+ def render_bayesian_network(
11
+ pgm: PGM,
12
+ out=None,
13
+ *,
14
+ check_structure_bayesian: bool = True,
15
+ ) -> List[str]:
16
+ """
17
+ Render a PGM as a Hugin 'net' file.
18
+
19
+ Args:
20
+ pgm: is a PGM object.
21
+ out: is an output file or None for stdout.
22
+ check_structure_bayesian: If True, then raise an exception if not pgm.is_structure_bayesian.
23
+
24
+ Returns:
25
+ a list of node names used in the Hugin 'net' file, co-indexed with pgm.rvs.
26
+
27
+ Raises:
28
+ ValueError: if `check_structure_bayesian` is true and `pgm.is_structure_bayesian` is false.
29
+ """
30
+ if check_structure_bayesian and not pgm.is_structure_bayesian:
31
+ raise ValueError('attempting to render a PGM with non-Bayesian structure')
32
+
33
+ node_names: List[str] = _make_node_names(pgm.rvs)
34
+
35
+ if out is None:
36
+ _render_bayesian_network(pgm, node_names, _sys.stdout)
37
+ elif isinstance(out, (str, _pathlib.Path)):
38
+ with open(out, 'w') as file:
39
+ _render_bayesian_network(pgm, node_names, file)
40
+ else:
41
+ _render_bayesian_network(pgm, node_names, out)
42
+
43
+ return node_names
44
+
45
+
46
+ # ============================================================
47
+ # Private support
48
+ # ============================================================
49
+
50
+ def _make_node_names(rvs: Iterable[RandomVariable]) -> List[str]:
51
+ """
52
+ Make a dictionary from `RandomVariable.idx` to a node label that works in a Hugin 'net' file.
53
+ """
54
+ node_names: List[str] = []
55
+ made_names: Set[str] = set()
56
+ for rv in rvs:
57
+ name = _rv_name(rv)
58
+ if name in made_names:
59
+ prefix = name + '_'
60
+ i = 2
61
+ name = prefix + str(i)
62
+ while name in made_names:
63
+ i += 1
64
+ name = prefix + str(i)
65
+ made_names.add(name)
66
+ node_names.append(name)
67
+ return node_names
68
+
69
+
70
+ def _render_bayesian_network(pgm: PGM, node_names: List[str], out):
71
+ out.write('net{}\n')
72
+
73
+ for rv in pgm.rvs:
74
+ _write_node_block(rv, node_names, out)
75
+
76
+ out.write('\n')
77
+
78
+ for factor in pgm.factors:
79
+ _write_potential_block(factor, node_names, out)
80
+
81
+
82
+ def _write_node_block(rv, node_names: List[str], out):
83
+ out.write('node ' + node_names[rv.idx] + '\n')
84
+ out.write('{\n')
85
+ _write_node_block_label(rv, out)
86
+ _write_node_block_states(rv, out)
87
+ out.write('}\n')
88
+
89
+
90
+ def _write_node_block_label(rv, out):
91
+ out.write(' label = "')
92
+ out.write(_rv_label(rv))
93
+ out.write('";\n')
94
+
95
+
96
+ def _write_node_block_states(rv, out):
97
+ out.write(' states = (')
98
+ for state in rv.states:
99
+ out.write(' "' + _state_label(state) + '"')
100
+ out.write(' );\n')
101
+
102
+
103
+ def _write_potential_block_link(factor, node_names: List[str], out):
104
+ out.write('potential (')
105
+
106
+ for rv_count, rv in enumerate(factor.rvs):
107
+ if rv_count == 1:
108
+ out.write(' |')
109
+ out.write(' ' + node_names[rv.idx])
110
+
111
+ out.write(' )\n')
112
+
113
+
114
+ def _recursively_write_ordered_data(shape, address_order_map, address_current, current_depth, max_depth, function, out):
115
+ out.write('( ')
116
+
117
+ mapped_current_depth = address_order_map[current_depth]
118
+
119
+ for i in range(shape[mapped_current_depth]):
120
+ address_current[mapped_current_depth] = i
121
+
122
+ if current_depth == max_depth:
123
+ out.write(str(function[address_current]) + ' ')
124
+ else:
125
+ _recursively_write_ordered_data(
126
+ shape, address_order_map, address_current, current_depth + 1, max_depth, function, out
127
+ )
128
+
129
+ out.write(") ")
130
+
131
+
132
+ def _write_potential_block_data(factor, out):
133
+ out.write('{\n')
134
+ out.write(' data = ')
135
+
136
+ function = factor.function
137
+ shape = factor.shape
138
+
139
+ address_current = [0] * len(shape)
140
+ max_depth = len(shape) - 1
141
+
142
+ # The ordering of data in a 'net' file is different to the natural order.
143
+ # Consequently, address_order_map will keep track of the required ordering.
144
+ address_order_map = [0] * len(shape)
145
+ for index in range(max_depth):
146
+ address_order_map[index] = index + 1
147
+
148
+ _recursively_write_ordered_data(shape, address_order_map, address_current, 0, max_depth, function, out)
149
+
150
+ out.write(';\n')
151
+ out.write('}\n')
152
+
153
+
154
+ def _write_potential_block(factor, node_names: List[str], out):
155
+ _write_potential_block_link(factor, node_names, out)
156
+ _write_potential_block_data(factor, out)
157
+
158
+
159
+ def _rv_label(rv):
160
+ """
161
+ make a label for a random variable
162
+ """
163
+ return escape_string(rv.name, double_quotes=True)
164
+
165
+
166
+ def _state_label(state) -> str:
167
+ """
168
+ make a label for a random variable state
169
+ """
170
+ return escape_string(str(state), double_quotes=True)
171
+
172
+
173
+ def _rv_name(rv: RandomVariable) -> str:
174
+ """
175
+ make a name for a random variable
176
+ """
177
+ return _re.sub(r'[^0-9a-zA-Z]+', '_', rv.name)