compiled-knowledge 4.0.0a20__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
from typing import Sequence
|
|
3
|
+
|
|
4
|
+
from ck.circuit_compiler import NamedCircuitCompiler
|
|
5
|
+
from ck.pgm import PGM
|
|
6
|
+
from ck.pgm_circuit import PGMCircuit
|
|
7
|
+
from ck.pgm_circuit.wmc_program import WMCProgram
|
|
8
|
+
from ck.pgm_compiler import NamedPGMCompiler
|
|
9
|
+
from ck_demos.utils.stop_watch import StopWatch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compare(
|
|
13
|
+
pgms: Sequence[PGM],
|
|
14
|
+
pgm_compilers: Sequence[NamedPGMCompiler],
|
|
15
|
+
cct_compilers: Sequence[NamedCircuitCompiler],
|
|
16
|
+
*,
|
|
17
|
+
cache_circuits: bool = True,
|
|
18
|
+
break_between_pgms: bool = True,
|
|
19
|
+
comma_numbers: bool = True,
|
|
20
|
+
print_header: bool = True,
|
|
21
|
+
sep: str = ' ',
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
For each combination of the given arguments, construct a PGMCircuit (using a
|
|
25
|
+
PGMCompiler) and then a WMCProgram (using a CircuitCompiler). The resulting
|
|
26
|
+
WMCProgram is executed 1000 times to estimate compute time.
|
|
27
|
+
|
|
28
|
+
For each PGM, PGM compiler, and circuit compiler, a line is printed showing:
|
|
29
|
+
PGM,
|
|
30
|
+
PGM compiler name,
|
|
31
|
+
Circuit compiler name,
|
|
32
|
+
number of circuit operations,
|
|
33
|
+
PGMCircuit compile time,
|
|
34
|
+
WMCProgram compile time,
|
|
35
|
+
WMC compute time.
|
|
36
|
+
|
|
37
|
+
The print output is formatted using fixed column width.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
pgms: a sequence of PGM objects.
|
|
41
|
+
pgm_compilers: a sequence of named PGM compilers.
|
|
42
|
+
cct_compilers: a sequence of named circuit compilers.
|
|
43
|
+
cache_circuits: if true, then circuits are reused across different circuit compilers.
|
|
44
|
+
break_between_pgms: if true, print a blank line between different workload PGMs.
|
|
45
|
+
comma_numbers: if true, commas are used in large numbers.
|
|
46
|
+
print_header: if true, a header line is printed.
|
|
47
|
+
sep: column separator.
|
|
48
|
+
"""
|
|
49
|
+
# Work out column widths for names.
|
|
50
|
+
col_pgm_name: int = max(3, max(len(pgm.name) for pgm in pgms))
|
|
51
|
+
col_pgm_compiler_name: int = max(12, max(len(pgm_compiler.name) for pgm_compiler in pgm_compilers))
|
|
52
|
+
col_cct_compiler_name: int = max(12, max(len(cct_compiler.name) for cct_compiler in cct_compilers))
|
|
53
|
+
col_cct_ops: int = 10
|
|
54
|
+
col_pgm_compile_time: int = 16
|
|
55
|
+
col_cct_compile_time: int = 16
|
|
56
|
+
col_execute_time: int = 10
|
|
57
|
+
|
|
58
|
+
# Print formatting
|
|
59
|
+
comma: str = ',' if comma_numbers else ''
|
|
60
|
+
|
|
61
|
+
if print_header:
|
|
62
|
+
print('PGM'.ljust(col_pgm_name), end=sep)
|
|
63
|
+
print('PGM-compiler'.ljust(col_pgm_compiler_name), end=sep)
|
|
64
|
+
print('CCT-compiler'.ljust(col_cct_compiler_name), end=sep)
|
|
65
|
+
print('CCT-ops'.rjust(col_cct_ops), end=sep)
|
|
66
|
+
print('PGM-compile-time'.rjust(col_pgm_compile_time), end=sep)
|
|
67
|
+
print('CCT-compile-time'.rjust(col_cct_compile_time), end=sep)
|
|
68
|
+
print('Run-time'.rjust(col_execute_time))
|
|
69
|
+
|
|
70
|
+
# Variables for when cache_circuits is true
|
|
71
|
+
prev_pgm = None
|
|
72
|
+
prev_pgm_compiler = None
|
|
73
|
+
|
|
74
|
+
for pgm in pgms:
|
|
75
|
+
pgm_name: str = pgm.name.ljust(col_pgm_name)
|
|
76
|
+
for pgm_compiler in pgm_compilers:
|
|
77
|
+
pgm_compiler_name: str = pgm_compiler.name.ljust(col_pgm_compiler_name)
|
|
78
|
+
for cct_compiler in cct_compilers:
|
|
79
|
+
cct_compiler_name: str = cct_compiler.name.ljust(col_cct_compiler_name)
|
|
80
|
+
|
|
81
|
+
print(pgm_name, end=sep)
|
|
82
|
+
print(pgm_compiler_name, end=sep)
|
|
83
|
+
print(cct_compiler_name, end=sep)
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
time = StopWatch()
|
|
87
|
+
|
|
88
|
+
if cache_circuits and pgm is prev_pgm and pgm_compiler is prev_pgm_compiler:
|
|
89
|
+
print(f'{"":{col_cct_ops}}', end=sep)
|
|
90
|
+
print(f'{"":{col_pgm_compile_time}}', end=sep)
|
|
91
|
+
else:
|
|
92
|
+
gc.collect()
|
|
93
|
+
time.start()
|
|
94
|
+
pgm_cct: PGMCircuit = pgm_compiler(pgm)
|
|
95
|
+
time.stop()
|
|
96
|
+
num_ops: int = pgm_cct.circuit_top.circuit.number_of_operations
|
|
97
|
+
print(f'{num_ops:{col_cct_ops}{comma}}', end=sep)
|
|
98
|
+
print(f'{time.seconds():{col_pgm_compile_time}{comma}.3f}', end=sep)
|
|
99
|
+
prev_pgm = pgm
|
|
100
|
+
prev_pgm_compiler = pgm_compiler
|
|
101
|
+
|
|
102
|
+
gc.collect()
|
|
103
|
+
time.start()
|
|
104
|
+
# `pgm_cct` will always be set but the IDE can't work that out.
|
|
105
|
+
# noinspection PyUnboundLocalVariable
|
|
106
|
+
wmc = WMCProgram(pgm_cct, compiler=cct_compiler.compiler)
|
|
107
|
+
time.stop()
|
|
108
|
+
print(f'{time.seconds():{col_cct_compile_time}{comma}.3f}', end=sep)
|
|
109
|
+
|
|
110
|
+
gc.collect()
|
|
111
|
+
time.start()
|
|
112
|
+
for _ in range(1000):
|
|
113
|
+
wmc.compute()
|
|
114
|
+
time.stop()
|
|
115
|
+
print(f'{time.seconds() * 1000:{col_execute_time}{comma}.3f}', end='')
|
|
116
|
+
except Exception as err:
|
|
117
|
+
print(repr(err), end='')
|
|
118
|
+
print()
|
|
119
|
+
if break_between_pgms:
|
|
120
|
+
print()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from ck.in_out.parse_net import read_network
|
|
4
|
+
from ck.in_out.pgm_python import write_python
|
|
5
|
+
from ck.pgm import PGM
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_network(network_path: Path, file=None) -> None:
|
|
9
|
+
"""
|
|
10
|
+
Convert a Hugin 'net' format to our PGM format.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
network_path: path to a Hugin 'net' file.
|
|
14
|
+
file: destination, as per the `print` function.
|
|
15
|
+
"""
|
|
16
|
+
# Read the Hugin 'net' file.
|
|
17
|
+
with open(network_path) as in_file:
|
|
18
|
+
pgm: PGM = read_network(in_file)
|
|
19
|
+
|
|
20
|
+
# Replace functions that may be better being sparse
|
|
21
|
+
for factor in pgm.factors:
|
|
22
|
+
function = factor.function
|
|
23
|
+
total_params: int = function.number_of_parameters
|
|
24
|
+
zero_params: int = sum(1 for _, value in function.params if value == 0)
|
|
25
|
+
if zero_params > 10 and zero_params / total_params > 0.1:
|
|
26
|
+
new_function = factor.set_sparse()
|
|
27
|
+
for key, _, value in function.keys_with_param:
|
|
28
|
+
new_function[key] = value
|
|
29
|
+
|
|
30
|
+
# Write the PGM Python code.
|
|
31
|
+
write_python(pgm, file=file)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def main() -> None:
|
|
35
|
+
"""
|
|
36
|
+
Demo of `convert_network`.
|
|
37
|
+
"""
|
|
38
|
+
network_directory = r'E:\Dropbox\Research\data\BN\networks'
|
|
39
|
+
network_name = 'pathfinder'
|
|
40
|
+
|
|
41
|
+
convert_network(Path(network_directory) / f'{network_name}.net')
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == '__main__':
|
|
45
|
+
main()
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Optional, Dict, Callable, List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ck.pgm import rv_instances, PGM, RandomVariable, Indicator
|
|
7
|
+
from ck.pgm_compiler import factor_elimination
|
|
8
|
+
from ck.pgm_circuit.marginals_program import MarginalsProgram
|
|
9
|
+
from ck.pgm_circuit import PGMCircuit
|
|
10
|
+
from ck.pgm_circuit.wmc_program import WMCProgram
|
|
11
|
+
from ck.sampling.forward_sampler import ForwardSampler
|
|
12
|
+
from ck.sampling.sampler import Sampler
|
|
13
|
+
from ck.utils.random_extras import random_permute
|
|
14
|
+
from ck_demos.utils.stop_watch import StopWatch
|
|
15
|
+
|
|
16
|
+
SamplerFactory = Callable[[PGM, WMCProgram, MarginalsProgram, List[RandomVariable], List[Indicator]], Sampler]
|
|
17
|
+
|
|
18
|
+
BURN_IN: int = 1000 # Burn in for standard samplers, where needed. Not all samplers use burn in.
|
|
19
|
+
|
|
20
|
+
# Standard Samplers (by name)
|
|
21
|
+
STANDARD_SAMPLERS: Dict[str, SamplerFactory] = {
|
|
22
|
+
'Direct-wmc': (
|
|
23
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
24
|
+
wmc.sample_direct(rvs=sample_rvs, condition=condition)
|
|
25
|
+
),
|
|
26
|
+
'Direct-mar': (
|
|
27
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
28
|
+
mar.sample_direct(rvs=sample_rvs, condition=condition)
|
|
29
|
+
),
|
|
30
|
+
'Rejection': (
|
|
31
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
32
|
+
wmc.sample_rejection(rvs=sample_rvs, condition=condition)
|
|
33
|
+
),
|
|
34
|
+
'Gibbs': (
|
|
35
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
36
|
+
wmc.sample_gibbs(burn_in=BURN_IN, rvs=sample_rvs, condition=condition)
|
|
37
|
+
),
|
|
38
|
+
'Metropolis': (
|
|
39
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
40
|
+
wmc.sample_metropolis(burn_in=BURN_IN, rvs=sample_rvs, condition=condition)
|
|
41
|
+
),
|
|
42
|
+
'Forward': (
|
|
43
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
44
|
+
ForwardSampler(pgm, sample_rvs, condition, check_is_bayesian_network=True)
|
|
45
|
+
),
|
|
46
|
+
'Uniform': (
|
|
47
|
+
lambda pgm, wmc, mar, sample_rvs, condition:
|
|
48
|
+
wmc.sample_uniform(rvs=sample_rvs, condition=condition)
|
|
49
|
+
),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def sample_model(
|
|
54
|
+
pgm: PGM,
|
|
55
|
+
samplers: Dict[str, SamplerFactory],
|
|
56
|
+
num_of_trials: int,
|
|
57
|
+
num_of_samples: int,
|
|
58
|
+
limit_conditioning: Optional[int] = None,
|
|
59
|
+
show_each_analysis: bool = True,
|
|
60
|
+
line: str = '-' * 80,
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Evaluate the given samplers on the given PGM.
|
|
64
|
+
|
|
65
|
+
Results are printed to standard out.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
pgm: is the model to sample.
|
|
69
|
+
samplers: is a dict from sampler name to factory method. The
|
|
70
|
+
factor method type is (pgm, wmc, mar, sample_rvs, condition) -> Sampler.
|
|
71
|
+
num_of_trials: how many trials to perform.
|
|
72
|
+
num_of_samples: how many num_of_samples to draw from each sampler, for each trial.
|
|
73
|
+
limit_conditioning: maximum number of indicators to use when determining
|
|
74
|
+
conditioning for a trial, or None then pgm.number_of_random_variables is used.
|
|
75
|
+
show_each_analysis: if True, then extra details is printed.
|
|
76
|
+
line: is the 'line' string to use to delimit trials.
|
|
77
|
+
"""
|
|
78
|
+
print(f'Model: {pgm.name}')
|
|
79
|
+
print(f'Number of random variables: {pgm.number_of_rvs}')
|
|
80
|
+
print(f'Number of indicators: {pgm.number_of_indicators}')
|
|
81
|
+
print(f'States space: {pgm.number_of_states:,}')
|
|
82
|
+
|
|
83
|
+
# compile
|
|
84
|
+
pgm_cct: PGMCircuit = factor_elimination.compile_pgm(pgm)
|
|
85
|
+
wmc = WMCProgram(pgm_cct)
|
|
86
|
+
mar = MarginalsProgram(pgm_cct)
|
|
87
|
+
|
|
88
|
+
rvs = pgm.rvs
|
|
89
|
+
num_of_rvs = len(rvs)
|
|
90
|
+
sampler_names = list(samplers.keys())
|
|
91
|
+
overall_max_difference = {name: 0 for name in sampler_names}
|
|
92
|
+
overall_sum_difference = {name: 0 for name in sampler_names}
|
|
93
|
+
overall_time = {name: 0 for name in sampler_names}
|
|
94
|
+
errors = {name: [] for name in sampler_names}
|
|
95
|
+
|
|
96
|
+
name_pad = max(
|
|
97
|
+
max(len(name) for name in sampler_names) + 1,
|
|
98
|
+
max(len(rv.name) for rv in rvs) + 1
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
for trial in range(1, 1 + num_of_trials):
|
|
102
|
+
print(line)
|
|
103
|
+
|
|
104
|
+
# what random variables to sample
|
|
105
|
+
num_rvs_to_sample = random.randint(1, num_of_rvs)
|
|
106
|
+
sample_rvs = list(rvs)
|
|
107
|
+
random_permute(sample_rvs)
|
|
108
|
+
del sample_rvs[num_rvs_to_sample:]
|
|
109
|
+
sample_rvs.sort(key=(lambda rv: rv.idx))
|
|
110
|
+
rvs_str = ', '.join([str(rv) for rv in sample_rvs])
|
|
111
|
+
|
|
112
|
+
# what conditions
|
|
113
|
+
if limit_conditioning is None:
|
|
114
|
+
limit_conditioning = pgm.number_of_rvs
|
|
115
|
+
if limit_conditioning == 0:
|
|
116
|
+
condition = ()
|
|
117
|
+
condition_str = ''
|
|
118
|
+
else:
|
|
119
|
+
while True:
|
|
120
|
+
num_indicators_to_condition = random.randint(0, limit_conditioning)
|
|
121
|
+
rand_rvs = list(rvs)
|
|
122
|
+
random_permute(rand_rvs)
|
|
123
|
+
condition = []
|
|
124
|
+
while len(condition) < num_indicators_to_condition and len(rand_rvs) > 0:
|
|
125
|
+
rv = rand_rvs.pop()
|
|
126
|
+
max_rv_indicators_to_condition = min(len(rv) - 1, num_indicators_to_condition - len(condition))
|
|
127
|
+
assert max_rv_indicators_to_condition >= 1, 'assumption check'
|
|
128
|
+
num_rv_indicators_to_condition = random.randint(1, max_rv_indicators_to_condition)
|
|
129
|
+
indicators = list(rv)
|
|
130
|
+
random_permute(indicators)
|
|
131
|
+
condition += sorted(indicators[:num_rv_indicators_to_condition])
|
|
132
|
+
|
|
133
|
+
if len(condition) == 0:
|
|
134
|
+
condition_str = ''
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
condition_str = ' | ' + pgm.condition_str(*condition)
|
|
138
|
+
|
|
139
|
+
# only accept the condition if the Pr(condition) > 0
|
|
140
|
+
if wmc.probability(*condition) > 0:
|
|
141
|
+
break
|
|
142
|
+
print(f'Note: discarded impossible condition{condition_str}')
|
|
143
|
+
|
|
144
|
+
# show the trial parameters
|
|
145
|
+
print(f'trial {trial} of {num_of_trials}: {rvs_str}{condition_str}')
|
|
146
|
+
|
|
147
|
+
# create state indexes for printing
|
|
148
|
+
state_to_index = {}
|
|
149
|
+
all_states = []
|
|
150
|
+
for i, state in enumerate(rv_instances(*sample_rvs)):
|
|
151
|
+
state = tuple(state)
|
|
152
|
+
all_states.append(state)
|
|
153
|
+
state_to_index[state] = i
|
|
154
|
+
|
|
155
|
+
# print detailed results - header
|
|
156
|
+
for i, rv in enumerate(sample_rvs):
|
|
157
|
+
print(str(rv).ljust(name_pad), end='')
|
|
158
|
+
print(' '.join([f'{str(state[i]).ljust(7)}' for state in all_states]))
|
|
159
|
+
|
|
160
|
+
# pgm_stats
|
|
161
|
+
print('PGM'.ljust(name_pad), end='')
|
|
162
|
+
pgm_stats = np.array(wmc.marginal_distribution(*sample_rvs, condition=condition))
|
|
163
|
+
print(' '.join([f'{p:.5f}' for p in pgm_stats]))
|
|
164
|
+
|
|
165
|
+
for sampler_name in sampler_names:
|
|
166
|
+
print(sampler_name.ljust(name_pad), end='')
|
|
167
|
+
|
|
168
|
+
# sample_stats
|
|
169
|
+
try:
|
|
170
|
+
sample_stats = np.zeros(len(all_states))
|
|
171
|
+
sampler = samplers[sampler_name](pgm, wmc, mar, sample_rvs, condition)
|
|
172
|
+
stop_watch = StopWatch()
|
|
173
|
+
for state in sampler.take(num_of_samples):
|
|
174
|
+
i = state_to_index[tuple(state)]
|
|
175
|
+
sample_stats[i] += 1
|
|
176
|
+
stop_watch.stop()
|
|
177
|
+
sample_stats /= np.sum(sample_stats)
|
|
178
|
+
except (ValueError, RuntimeError, AssertionError) as err:
|
|
179
|
+
errors[sampler_name].append(repr(err))
|
|
180
|
+
print(repr(err))
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
# print detailed results - for this sampler
|
|
184
|
+
print(' '.join([f'{p:.5f}' for p in sample_stats]))
|
|
185
|
+
|
|
186
|
+
# analyse
|
|
187
|
+
max_difference = 0
|
|
188
|
+
sum_difference = 0
|
|
189
|
+
for pgm_stat, sample_stat in zip(pgm_stats, sample_stats):
|
|
190
|
+
diff = abs(pgm_stat - sample_stat)
|
|
191
|
+
max_difference = max(max_difference, diff)
|
|
192
|
+
sum_difference += diff
|
|
193
|
+
overall_max_difference[sampler_name] = max(overall_max_difference[sampler_name], max_difference)
|
|
194
|
+
overall_sum_difference[sampler_name] = max(overall_sum_difference[sampler_name], sum_difference)
|
|
195
|
+
overall_time[sampler_name] += stop_watch.seconds()
|
|
196
|
+
|
|
197
|
+
if show_each_analysis:
|
|
198
|
+
print(
|
|
199
|
+
' ' * name_pad +
|
|
200
|
+
f'max_difference = {max_difference}, '
|
|
201
|
+
f'sum_difference = {sum_difference}, '
|
|
202
|
+
f'time = {stop_watch.seconds()}'
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
print(line)
|
|
206
|
+
sep: str = ', '
|
|
207
|
+
print(' ' * name_pad + sep.join(['overall_max_difference', 'overall_sum_difference', 'overall_time', 'errors']))
|
|
208
|
+
for sampler_name in sampler_names:
|
|
209
|
+
print(
|
|
210
|
+
f'{sampler_name.ljust(name_pad)}'
|
|
211
|
+
f'{overall_max_difference[sampler_name]}{sep}'
|
|
212
|
+
f'{overall_sum_difference[sampler_name]}{sep}'
|
|
213
|
+
f'{overall_time[sampler_name]}{sep}'
|
|
214
|
+
f'{len(errors[sampler_name])}'
|
|
215
|
+
)
|
|
216
|
+
print()
|