compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,120 @@
1
+ import gc
2
+ from typing import Sequence
3
+
4
+ from ck.circuit_compiler import NamedCircuitCompiler
5
+ from ck.pgm import PGM
6
+ from ck.pgm_circuit import PGMCircuit
7
+ from ck.pgm_circuit.wmc_program import WMCProgram
8
+ from ck.pgm_compiler import NamedPGMCompiler
9
+ from ck_demos.utils.stop_watch import StopWatch
10
+
11
+
12
+ def compare(
13
+ pgms: Sequence[PGM],
14
+ pgm_compilers: Sequence[NamedPGMCompiler],
15
+ cct_compilers: Sequence[NamedCircuitCompiler],
16
+ *,
17
+ cache_circuits: bool = True,
18
+ break_between_pgms: bool = True,
19
+ comma_numbers: bool = True,
20
+ print_header: bool = True,
21
+ sep: str = ' ',
22
+ ) -> None:
23
+ """
24
+ For each combination of the given arguments, construct a PGMCircuit (using a
25
+ PGMCompiler) and then a WMCProgram (using a CircuitCompiler). The resulting
26
+ WMCProgram is executed 1000 times to estimate compute time.
27
+
28
+ For each PGM, PGM compiler, and circuit compiler, a line is printed showing:
29
+ PGM,
30
+ PGM compiler name,
31
+ Circuit compiler name,
32
+ number of circuit operations,
33
+ PGMCircuit compile time,
34
+ WMCProgram compile time,
35
+ WMC compute time.
36
+
37
+ The print output is formatted using fixed column width.
38
+
39
+ Args:
40
+ pgms: a sequence of PGM objects.
41
+ pgm_compilers: a sequence of named PGM compilers.
42
+ cct_compilers: a sequence of named circuit compilers.
43
+ cache_circuits: if true, then circuits are reused across different circuit compilers.
44
+ break_between_pgms: if true, print a blank line between different workload PGMs.
45
+ comma_numbers: if true, commas are used in large numbers.
46
+ print_header: if true, a header line is printed.
47
+ sep: column separator.
48
+ """
49
+ # Work out column widths for names.
50
+ col_pgm_name: int = max(3, max(len(pgm.name) for pgm in pgms))
51
+ col_pgm_compiler_name: int = max(12, max(len(pgm_compiler.name) for pgm_compiler in pgm_compilers))
52
+ col_cct_compiler_name: int = max(12, max(len(cct_compiler.name) for cct_compiler in cct_compilers))
53
+ col_cct_ops: int = 10
54
+ col_pgm_compile_time: int = 16
55
+ col_cct_compile_time: int = 16
56
+ col_execute_time: int = 10
57
+
58
+ # Print formatting
59
+ comma: str = ',' if comma_numbers else ''
60
+
61
+ if print_header:
62
+ print('PGM'.ljust(col_pgm_name), end=sep)
63
+ print('PGM-compiler'.ljust(col_pgm_compiler_name), end=sep)
64
+ print('CCT-compiler'.ljust(col_cct_compiler_name), end=sep)
65
+ print('CCT-ops'.rjust(col_cct_ops), end=sep)
66
+ print('PGM-compile-time'.rjust(col_pgm_compile_time), end=sep)
67
+ print('CCT-compile-time'.rjust(col_cct_compile_time), end=sep)
68
+ print('Run-time'.rjust(col_execute_time))
69
+
70
+ # Variables for when cache_circuits is true
71
+ prev_pgm = None
72
+ prev_pgm_compiler = None
73
+
74
+ for pgm in pgms:
75
+ pgm_name: str = pgm.name.ljust(col_pgm_name)
76
+ for pgm_compiler in pgm_compilers:
77
+ pgm_compiler_name: str = pgm_compiler.name.ljust(col_pgm_compiler_name)
78
+ for cct_compiler in cct_compilers:
79
+ cct_compiler_name: str = cct_compiler.name.ljust(col_cct_compiler_name)
80
+
81
+ print(pgm_name, end=sep)
82
+ print(pgm_compiler_name, end=sep)
83
+ print(cct_compiler_name, end=sep)
84
+
85
+ try:
86
+ time = StopWatch()
87
+
88
+ if cache_circuits and pgm is prev_pgm and pgm_compiler is prev_pgm_compiler:
89
+ print(f'{"":{col_cct_ops}}', end=sep)
90
+ print(f'{"":{col_pgm_compile_time}}', end=sep)
91
+ else:
92
+ gc.collect()
93
+ time.start()
94
+ pgm_cct: PGMCircuit = pgm_compiler(pgm)
95
+ time.stop()
96
+ num_ops: int = pgm_cct.circuit_top.circuit.number_of_operations
97
+ print(f'{num_ops:{col_cct_ops}{comma}}', end=sep)
98
+ print(f'{time.seconds():{col_pgm_compile_time}{comma}.3f}', end=sep)
99
+ prev_pgm = pgm
100
+ prev_pgm_compiler = pgm_compiler
101
+
102
+ gc.collect()
103
+ time.start()
104
+ # `pgm_cct` will always be set but the IDE can't work that out.
105
+ # noinspection PyUnboundLocalVariable
106
+ wmc = WMCProgram(pgm_cct, compiler=cct_compiler.compiler)
107
+ time.stop()
108
+ print(f'{time.seconds():{col_cct_compile_time}{comma}.3f}', end=sep)
109
+
110
+ gc.collect()
111
+ time.start()
112
+ for _ in range(1000):
113
+ wmc.compute()
114
+ time.stop()
115
+ print(f'{time.seconds() * 1000:{col_execute_time}{comma}.3f}', end='')
116
+ except Exception as err:
117
+ print(repr(err), end='')
118
+ print()
119
+ if break_between_pgms:
120
+ print()
@@ -0,0 +1,45 @@
1
+ from pathlib import Path
2
+
3
+ from ck.in_out.parse_net import read_network
4
+ from ck.in_out.pgm_python import write_python
5
+ from ck.pgm import PGM
6
+
7
+
8
+ def convert_network(network_path: Path, file=None) -> None:
9
+ """
10
+ Convert a Hugin 'net' format to our PGM format.
11
+
12
+ Args:
13
+ network_path: path to a Hugin 'net' file.
14
+ file: destination, as per the `print` function.
15
+ """
16
+ # Read the Hugin 'net' file.
17
+ with open(network_path) as in_file:
18
+ pgm: PGM = read_network(in_file)
19
+
20
+ # Replace functions that may be better being sparse
21
+ for factor in pgm.factors:
22
+ function = factor.function
23
+ total_params: int = function.number_of_parameters
24
+ zero_params: int = sum(1 for _, value in function.params if value == 0)
25
+ if zero_params > 10 and zero_params / total_params > 0.1:
26
+ new_function = factor.set_sparse()
27
+ for key, _, value in function.keys_with_param:
28
+ new_function[key] = value
29
+
30
+ # Write the PGM Python code.
31
+ write_python(pgm, file=file)
32
+
33
+
34
+ def main() -> None:
35
+ """
36
+ Demo of `convert_network`.
37
+ """
38
+ network_directory = r'E:\Dropbox\Research\data\BN\networks'
39
+ network_name = 'pathfinder'
40
+
41
+ convert_network(Path(network_directory) / f'{network_name}.net')
42
+
43
+
44
+ if __name__ == '__main__':
45
+ main()
@@ -0,0 +1,216 @@
1
+ import random
2
+ from typing import Optional, Dict, Callable, List
3
+
4
+ import numpy as np
5
+
6
+ from ck.pgm import rv_instances, PGM, RandomVariable, Indicator
7
+ from ck.pgm_compiler import factor_elimination
8
+ from ck.pgm_circuit.marginals_program import MarginalsProgram
9
+ from ck.pgm_circuit import PGMCircuit
10
+ from ck.pgm_circuit.wmc_program import WMCProgram
11
+ from ck.sampling.forward_sampler import ForwardSampler
12
+ from ck.sampling.sampler import Sampler
13
+ from ck.utils.random_extras import random_permute
14
+ from ck_demos.utils.stop_watch import StopWatch
15
+
16
+ SamplerFactory = Callable[[PGM, WMCProgram, MarginalsProgram, List[RandomVariable], List[Indicator]], Sampler]
17
+
18
+ BURN_IN: int = 1000 # Burn in for standard samplers, where needed. Not all samplers use burn in.
19
+
20
+ # Standard Samplers (by name)
21
+ STANDARD_SAMPLERS: Dict[str, SamplerFactory] = {
22
+ 'Direct-wmc': (
23
+ lambda pgm, wmc, mar, sample_rvs, condition:
24
+ wmc.sample_direct(rvs=sample_rvs, condition=condition)
25
+ ),
26
+ 'Direct-mar': (
27
+ lambda pgm, wmc, mar, sample_rvs, condition:
28
+ mar.sample_direct(rvs=sample_rvs, condition=condition)
29
+ ),
30
+ 'Rejection': (
31
+ lambda pgm, wmc, mar, sample_rvs, condition:
32
+ wmc.sample_rejection(rvs=sample_rvs, condition=condition)
33
+ ),
34
+ 'Gibbs': (
35
+ lambda pgm, wmc, mar, sample_rvs, condition:
36
+ wmc.sample_gibbs(burn_in=BURN_IN, rvs=sample_rvs, condition=condition)
37
+ ),
38
+ 'Metropolis': (
39
+ lambda pgm, wmc, mar, sample_rvs, condition:
40
+ wmc.sample_metropolis(burn_in=BURN_IN, rvs=sample_rvs, condition=condition)
41
+ ),
42
+ 'Forward': (
43
+ lambda pgm, wmc, mar, sample_rvs, condition:
44
+ ForwardSampler(pgm, sample_rvs, condition, check_is_bayesian_network=True)
45
+ ),
46
+ 'Uniform': (
47
+ lambda pgm, wmc, mar, sample_rvs, condition:
48
+ wmc.sample_uniform(rvs=sample_rvs, condition=condition)
49
+ ),
50
+ }
51
+
52
+
53
+ def sample_model(
54
+ pgm: PGM,
55
+ samplers: Dict[str, SamplerFactory],
56
+ num_of_trials: int,
57
+ num_of_samples: int,
58
+ limit_conditioning: Optional[int] = None,
59
+ show_each_analysis: bool = True,
60
+ line: str = '-' * 80,
61
+ ):
62
+ """
63
+ Evaluate the given samplers on the given PGM.
64
+
65
+ Results are printed to standard out.
66
+
67
+ Args:
68
+ pgm: is the model to sample.
69
+ samplers: is a dict from sampler name to factory method. The
70
+ factor method type is (pgm, wmc, mar, sample_rvs, condition) -> Sampler.
71
+ num_of_trials: how many trials to perform.
72
+ num_of_samples: how many num_of_samples to draw from each sampler, for each trial.
73
+ limit_conditioning: maximum number of indicators to use when determining
74
+ conditioning for a trial, or None then pgm.number_of_random_variables is used.
75
+ show_each_analysis: if True, then extra details is printed.
76
+ line: is the 'line' string to use to delimit trials.
77
+ """
78
+ print(f'Model: {pgm.name}')
79
+ print(f'Number of random variables: {pgm.number_of_rvs}')
80
+ print(f'Number of indicators: {pgm.number_of_indicators}')
81
+ print(f'States space: {pgm.number_of_states:,}')
82
+
83
+ # compile
84
+ pgm_cct: PGMCircuit = factor_elimination.compile_pgm(pgm)
85
+ wmc = WMCProgram(pgm_cct)
86
+ mar = MarginalsProgram(pgm_cct)
87
+
88
+ rvs = pgm.rvs
89
+ num_of_rvs = len(rvs)
90
+ sampler_names = list(samplers.keys())
91
+ overall_max_difference = {name: 0 for name in sampler_names}
92
+ overall_sum_difference = {name: 0 for name in sampler_names}
93
+ overall_time = {name: 0 for name in sampler_names}
94
+ errors = {name: [] for name in sampler_names}
95
+
96
+ name_pad = max(
97
+ max(len(name) for name in sampler_names) + 1,
98
+ max(len(rv.name) for rv in rvs) + 1
99
+ )
100
+
101
+ for trial in range(1, 1 + num_of_trials):
102
+ print(line)
103
+
104
+ # what random variables to sample
105
+ num_rvs_to_sample = random.randint(1, num_of_rvs)
106
+ sample_rvs = list(rvs)
107
+ random_permute(sample_rvs)
108
+ del sample_rvs[num_rvs_to_sample:]
109
+ sample_rvs.sort(key=(lambda rv: rv.idx))
110
+ rvs_str = ', '.join([str(rv) for rv in sample_rvs])
111
+
112
+ # what conditions
113
+ if limit_conditioning is None:
114
+ limit_conditioning = pgm.number_of_rvs
115
+ if limit_conditioning == 0:
116
+ condition = ()
117
+ condition_str = ''
118
+ else:
119
+ while True:
120
+ num_indicators_to_condition = random.randint(0, limit_conditioning)
121
+ rand_rvs = list(rvs)
122
+ random_permute(rand_rvs)
123
+ condition = []
124
+ while len(condition) < num_indicators_to_condition and len(rand_rvs) > 0:
125
+ rv = rand_rvs.pop()
126
+ max_rv_indicators_to_condition = min(len(rv) - 1, num_indicators_to_condition - len(condition))
127
+ assert max_rv_indicators_to_condition >= 1, 'assumption check'
128
+ num_rv_indicators_to_condition = random.randint(1, max_rv_indicators_to_condition)
129
+ indicators = list(rv)
130
+ random_permute(indicators)
131
+ condition += sorted(indicators[:num_rv_indicators_to_condition])
132
+
133
+ if len(condition) == 0:
134
+ condition_str = ''
135
+ break
136
+
137
+ condition_str = ' | ' + pgm.condition_str(*condition)
138
+
139
+ # only accept the condition if the Pr(condition) > 0
140
+ if wmc.probability(*condition) > 0:
141
+ break
142
+ print(f'Note: discarded impossible condition{condition_str}')
143
+
144
+ # show the trial parameters
145
+ print(f'trial {trial} of {num_of_trials}: {rvs_str}{condition_str}')
146
+
147
+ # create state indexes for printing
148
+ state_to_index = {}
149
+ all_states = []
150
+ for i, state in enumerate(rv_instances(*sample_rvs)):
151
+ state = tuple(state)
152
+ all_states.append(state)
153
+ state_to_index[state] = i
154
+
155
+ # print detailed results - header
156
+ for i, rv in enumerate(sample_rvs):
157
+ print(str(rv).ljust(name_pad), end='')
158
+ print(' '.join([f'{str(state[i]).ljust(7)}' for state in all_states]))
159
+
160
+ # pgm_stats
161
+ print('PGM'.ljust(name_pad), end='')
162
+ pgm_stats = np.array(wmc.marginal_distribution(*sample_rvs, condition=condition))
163
+ print(' '.join([f'{p:.5f}' for p in pgm_stats]))
164
+
165
+ for sampler_name in sampler_names:
166
+ print(sampler_name.ljust(name_pad), end='')
167
+
168
+ # sample_stats
169
+ try:
170
+ sample_stats = np.zeros(len(all_states))
171
+ sampler = samplers[sampler_name](pgm, wmc, mar, sample_rvs, condition)
172
+ stop_watch = StopWatch()
173
+ for state in sampler.take(num_of_samples):
174
+ i = state_to_index[tuple(state)]
175
+ sample_stats[i] += 1
176
+ stop_watch.stop()
177
+ sample_stats /= np.sum(sample_stats)
178
+ except (ValueError, RuntimeError, AssertionError) as err:
179
+ errors[sampler_name].append(repr(err))
180
+ print(repr(err))
181
+ continue
182
+
183
+ # print detailed results - for this sampler
184
+ print(' '.join([f'{p:.5f}' for p in sample_stats]))
185
+
186
+ # analyse
187
+ max_difference = 0
188
+ sum_difference = 0
189
+ for pgm_stat, sample_stat in zip(pgm_stats, sample_stats):
190
+ diff = abs(pgm_stat - sample_stat)
191
+ max_difference = max(max_difference, diff)
192
+ sum_difference += diff
193
+ overall_max_difference[sampler_name] = max(overall_max_difference[sampler_name], max_difference)
194
+ overall_sum_difference[sampler_name] = max(overall_sum_difference[sampler_name], sum_difference)
195
+ overall_time[sampler_name] += stop_watch.seconds()
196
+
197
+ if show_each_analysis:
198
+ print(
199
+ ' ' * name_pad +
200
+ f'max_difference = {max_difference}, '
201
+ f'sum_difference = {sum_difference}, '
202
+ f'time = {stop_watch.seconds()}'
203
+ )
204
+
205
+ print(line)
206
+ sep: str = ', '
207
+ print(' ' * name_pad + sep.join(['overall_max_difference', 'overall_sum_difference', 'overall_time', 'errors']))
208
+ for sampler_name in sampler_names:
209
+ print(
210
+ f'{sampler_name.ljust(name_pad)}'
211
+ f'{overall_max_difference[sampler_name]}{sep}'
212
+ f'{overall_sum_difference[sampler_name]}{sep}'
213
+ f'{overall_time[sampler_name]}{sep}'
214
+ f'{len(errors[sampler_name])}'
215
+ )
216
+ print()