compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit.pyx +20 -8
- ck/circuit/circuit_py.py +40 -19
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/pgm.py +111 -130
- ck/pgm_circuit/pgm_circuit.py +13 -9
- ck/pgm_circuit/program_with_slotmap.py +6 -4
- ck/pgm_compiler/ace/ace.py +48 -4
- ck/pgm_compiler/factor_elimination.py +6 -4
- ck/pgm_compiler/recursive_conditioning.py +8 -3
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/clusters.py +1 -1
- ck/pgm_compiler/variable_elimination.py +3 -3
- ck/probability/empirical_probability_space.py +3 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +66 -12
- ck/program/program.py +9 -1
- ck/program/raw_program.py +9 -3
- ck/sampling/sampler_support.py +1 -1
- ck/sampling/uniform_sampler.py +10 -4
- ck/sampling/wmc_direct_sampler.py +4 -2
- ck/sampling/wmc_gibbs_sampler.py +6 -0
- ck/sampling/wmc_metropolis_sampler.py +7 -1
- ck/sampling/wmc_rejection_sampler.py +2 -0
- ck/utils/iter_extras.py +9 -6
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/METADATA +16 -12
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/RECORD +30 -29
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/top_level.txt +0 -0
ck/pgm_circuit/pgm_circuit.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Sequence, List, Dict
|
|
3
3
|
|
|
4
|
-
from ck.circuit import CircuitNode
|
|
4
|
+
from ck.circuit import CircuitNode, Circuit
|
|
5
5
|
from ck.pgm import RandomVariable, Indicator
|
|
6
6
|
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
7
7
|
from ck.utils.np_extras import NDArray
|
|
@@ -46,19 +46,21 @@ class PGMCircuit:
|
|
|
46
46
|
|
|
47
47
|
def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
48
48
|
"""
|
|
49
|
-
Print a dump of the
|
|
49
|
+
Print a dump of the circuit.
|
|
50
50
|
This is intended for debugging and demonstration purposes.
|
|
51
51
|
|
|
52
|
-
Each cluster is printed as: {separator rvs} | {non-separator rvs}.
|
|
53
|
-
|
|
54
52
|
Args:
|
|
55
53
|
prefix: optional prefix for indenting all lines.
|
|
56
54
|
indent: additional prefix to use for extra indentation.
|
|
57
55
|
"""
|
|
58
|
-
|
|
59
|
-
circuit
|
|
56
|
+
|
|
57
|
+
# We infer names for the circuit variables, either as an indicator or as a parameter.
|
|
58
|
+
# The `var_names` will be passed to `circuit.dump`.
|
|
59
|
+
|
|
60
|
+
circuit: Circuit = self.circuit_top.circuit
|
|
60
61
|
var_names: List[str] = [''] * circuit.number_of_vars
|
|
61
62
|
|
|
63
|
+
# Name the circuit variables that are indicators
|
|
62
64
|
rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
|
|
63
65
|
slot_key: SlotKey
|
|
64
66
|
slot: int
|
|
@@ -68,8 +70,10 @@ class PGMCircuit:
|
|
|
68
70
|
state_idx = slot_key.state_idx
|
|
69
71
|
var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
|
|
70
72
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
+
# Name the circuit variables that are parameters
|
|
74
|
+
for i, param_value in enumerate(self.parameter_values):
|
|
75
|
+
slot = i + self.number_of_indicators
|
|
76
|
+
var_names[slot] = f'param[{i}] = {param_value}'
|
|
73
77
|
|
|
74
|
-
#
|
|
78
|
+
# Dump the circuit
|
|
75
79
|
circuit.dump(prefix=prefix, indent=indent, var_names=var_names)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Tuple,
|
|
1
|
+
from typing import Tuple, Sequence, Dict, Iterable
|
|
2
2
|
|
|
3
3
|
from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
|
|
4
4
|
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
@@ -30,11 +30,13 @@ class ProgramWithSlotmap:
|
|
|
30
30
|
has a length and rv[i] is a unique 'indicator' across all rvs.
|
|
31
31
|
precondition: conditions on rvs that are compiled into the program.
|
|
32
32
|
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: if rvs contains duplicates.
|
|
33
35
|
"""
|
|
34
36
|
self._program_buffer: ProgramBuffer = program_buffer
|
|
35
37
|
self._slot_map: SlotMap = slot_map
|
|
36
38
|
self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
|
|
37
|
-
self._precondition:
|
|
39
|
+
self._precondition: Tuple[Indicator, ...] = tuple(precondition)
|
|
38
40
|
|
|
39
41
|
if len(rvs) != len(set(rv.idx for rv in rvs)):
|
|
40
42
|
raise ValueError('duplicate random variables provided')
|
|
@@ -67,7 +69,7 @@ class ProgramWithSlotmap:
|
|
|
67
69
|
def slot_map(self) -> SlotMap:
|
|
68
70
|
return self._slot_map
|
|
69
71
|
|
|
70
|
-
def instances(self, flip: bool = False) ->
|
|
72
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
71
73
|
"""
|
|
72
74
|
Enumerate instances of the random variables.
|
|
73
75
|
|
|
@@ -84,7 +86,7 @@ class ProgramWithSlotmap:
|
|
|
84
86
|
"""
|
|
85
87
|
return rv_instances(*self._rvs, flip=flip)
|
|
86
88
|
|
|
87
|
-
def instances_as_indicators(self, flip: bool = False) ->
|
|
89
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
88
90
|
"""
|
|
89
91
|
Enumerate instances of the random variables.
|
|
90
92
|
|
ck/pgm_compiler/ace/ace.py
CHANGED
|
@@ -5,7 +5,9 @@ from dataclasses import dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Optional, List, Tuple
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from ck.circuit import CircuitNode, Circuit
|
|
9
11
|
from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
|
|
10
12
|
from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
|
|
11
13
|
from ck.in_out.render_net import render_bayesian_network
|
|
@@ -58,6 +60,22 @@ def compile_pgm(
|
|
|
58
60
|
if check_is_bayesian_network and not pgm.check_is_bayesian_network():
|
|
59
61
|
raise ValueError('the given PGM is not a Bayesian network')
|
|
60
62
|
|
|
63
|
+
# ACE cannot deal with the empty PGM even though it is a valid Bayesian network
|
|
64
|
+
if pgm.number_of_factors == 0:
|
|
65
|
+
circuit = Circuit()
|
|
66
|
+
circuit.new_vars(pgm.number_of_indicators)
|
|
67
|
+
parameter_values = np.array([], dtype=np.float64)
|
|
68
|
+
slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
|
|
69
|
+
return PGMCircuit(
|
|
70
|
+
rvs=pgm.rvs,
|
|
71
|
+
conditions=(),
|
|
72
|
+
circuit_top=circuit.const(1),
|
|
73
|
+
number_of_indicators=pgm.number_of_indicators,
|
|
74
|
+
number_of_parameters=0,
|
|
75
|
+
slot_map=slot_map,
|
|
76
|
+
parameter_values=parameter_values,
|
|
77
|
+
)
|
|
78
|
+
|
|
61
79
|
java: str
|
|
62
80
|
classpath_separator: str
|
|
63
81
|
java, classpath_separator = _find_java()
|
|
@@ -69,7 +87,7 @@ def compile_pgm(
|
|
|
69
87
|
)
|
|
70
88
|
ace_cmd: List[str] = [
|
|
71
89
|
java,
|
|
72
|
-
|
|
90
|
+
'-cp',
|
|
73
91
|
class_path,
|
|
74
92
|
f'-DACEC2D={files.c2d}',
|
|
75
93
|
f'-Xmx{int(m_bytes)}m',
|
|
@@ -83,7 +101,7 @@ def compile_pgm(
|
|
|
83
101
|
node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
|
|
84
102
|
|
|
85
103
|
# Run Ace
|
|
86
|
-
ace_result = subprocess.run(ace_cmd, capture_output=True, text=True)
|
|
104
|
+
ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=True, text=True)
|
|
87
105
|
if print_output:
|
|
88
106
|
print(ace_result.stdout)
|
|
89
107
|
print(ace_result.stderr)
|
|
@@ -126,6 +144,29 @@ def compile_pgm(
|
|
|
126
144
|
)
|
|
127
145
|
|
|
128
146
|
|
|
147
|
+
def ace_available(
|
|
148
|
+
ace_dir: Optional[Path | str] = None,
|
|
149
|
+
jar_dir: Optional[Path | str] = None,
|
|
150
|
+
) -> bool:
|
|
151
|
+
"""
|
|
152
|
+
Returns:
|
|
153
|
+
True if it looks like ACE is available, False otherwise.
|
|
154
|
+
ACE is available if ACE files are in the default location and Java is available.
|
|
155
|
+
"""
|
|
156
|
+
try:
|
|
157
|
+
java: str
|
|
158
|
+
java, _ = _find_java()
|
|
159
|
+
_: _AceFiles = _find_ace_files(ace_dir, jar_dir)
|
|
160
|
+
|
|
161
|
+
java_cmd: List[str] = [java, '--version',]
|
|
162
|
+
java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
|
|
163
|
+
|
|
164
|
+
return java_result.returncode == 0
|
|
165
|
+
|
|
166
|
+
except RuntimeError:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
|
|
129
170
|
def copy_ace_to_default_location(
|
|
130
171
|
ace_dir: Path | str,
|
|
131
172
|
jar_dir: Optional[Path | str] = None,
|
|
@@ -179,6 +220,9 @@ def _find_java() -> Tuple[str, str]:
|
|
|
179
220
|
|
|
180
221
|
Returns:
|
|
181
222
|
(java, classpath_separator)
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
RuntimeError: if not found, including a helpful message.
|
|
182
226
|
"""
|
|
183
227
|
if sys.platform == 'win32':
|
|
184
228
|
return 'java.exe', ';'
|
|
@@ -198,7 +242,7 @@ def _find_ace_files(
|
|
|
198
242
|
Look for the needed Ace files.
|
|
199
243
|
|
|
200
244
|
Raises:
|
|
201
|
-
RuntimeError: if not found, including a helpful message
|
|
245
|
+
RuntimeError: if not found, including a helpful message.
|
|
202
246
|
"""
|
|
203
247
|
ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
|
|
204
248
|
jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
|
|
@@ -12,19 +12,21 @@ from ck.pgm_compiler.support.join_tree import *
|
|
|
12
12
|
|
|
13
13
|
_NEG_INF = float('-inf')
|
|
14
14
|
|
|
15
|
+
DEFAULT_PRODUCT_SEARCH_LIMIT: int = 1000
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
def compile_pgm(
|
|
17
19
|
pgm: PGM,
|
|
18
20
|
const_parameters: bool = True,
|
|
19
21
|
*,
|
|
20
22
|
algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
|
|
21
|
-
limit_product_tree_search: int =
|
|
23
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
22
24
|
pre_prune_factor_tables: bool = True,
|
|
23
25
|
) -> PGMCircuit:
|
|
24
26
|
"""
|
|
25
27
|
Compile the PGM to an arithmetic circuit, using factor elimination.
|
|
26
28
|
|
|
27
|
-
When forming the product of factors
|
|
29
|
+
When forming the product of factors within join tree nodes,
|
|
28
30
|
this method searches all practical binary trees for forming products,
|
|
29
31
|
up to the given limit, `limit_product_tree_search`. The minimum is 1.
|
|
30
32
|
|
|
@@ -57,7 +59,7 @@ def compile_pgm_best_jointree(
|
|
|
57
59
|
pgm: PGM,
|
|
58
60
|
const_parameters: bool = True,
|
|
59
61
|
*,
|
|
60
|
-
limit_product_tree_search: int =
|
|
62
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
61
63
|
pre_prune_factor_tables: bool = True,
|
|
62
64
|
) -> PGMCircuit:
|
|
63
65
|
"""
|
|
@@ -111,7 +113,7 @@ def compile_pgm_best_jointree(
|
|
|
111
113
|
def join_tree_to_circuit(
|
|
112
114
|
join_tree: JoinTree,
|
|
113
115
|
const_parameters: bool = True,
|
|
114
|
-
limit_product_tree_search: int =
|
|
116
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
115
117
|
pre_prune_factor_tables: bool = True,
|
|
116
118
|
) -> PGMCircuit:
|
|
117
119
|
"""
|
|
@@ -52,10 +52,15 @@ def compile_pgm(
|
|
|
52
52
|
multiply_indicators=True,
|
|
53
53
|
pre_prune_factor_tables=pre_prune_factor_tables,
|
|
54
54
|
)
|
|
55
|
-
dtree: _DTree = _make_dtree(elimination_order, factor_tables)
|
|
56
55
|
|
|
57
|
-
|
|
58
|
-
|
|
56
|
+
if pgm.number_of_factors == 0:
|
|
57
|
+
# Deal with special case: no factors
|
|
58
|
+
top: CircuitNode = factor_tables.circuit.const(1)
|
|
59
|
+
else:
|
|
60
|
+
dtree: _DTree = _make_dtree(elimination_order, factor_tables)
|
|
61
|
+
states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
|
|
62
|
+
top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
|
|
63
|
+
|
|
59
64
|
top.circuit.remove_unreachable_op_nodes(top)
|
|
60
65
|
|
|
61
66
|
return PGMCircuit(
|
|
Binary file
|
|
@@ -9,7 +9,7 @@ from ck.pgm import PGM
|
|
|
9
9
|
|
|
10
10
|
# A VEObjective is a variable elimination objective function.
|
|
11
11
|
# An objective function is a function from a random variable index (int)
|
|
12
|
-
# to an
|
|
12
|
+
# to an objective value (float or int). This is used to select
|
|
13
13
|
# a random variable to eliminate in `ve_greedy_min`.
|
|
14
14
|
VEObjective = Callable[[int], int | float]
|
|
15
15
|
|
|
@@ -69,13 +69,13 @@ def compile_pgm(
|
|
|
69
69
|
tables_with_rv.append(product(x, y))
|
|
70
70
|
next_tables.append(sum_out(tables_with_rv[0], (rv_idx,)))
|
|
71
71
|
cur_tables = next_tables
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
# All rvs are now eliminated - all tables should have a single top.
|
|
73
74
|
tops: List[CircuitNode] = [
|
|
74
75
|
table.top()
|
|
75
76
|
for table in cur_tables
|
|
76
|
-
if len(table) > 0
|
|
77
77
|
]
|
|
78
|
-
top = factor_tables.circuit.
|
|
78
|
+
top: CircuitNode = factor_tables.circuit.optimised_mul(tops)
|
|
79
79
|
top.circuit.remove_unreachable_op_nodes(top)
|
|
80
80
|
|
|
81
81
|
return PGMCircuit(
|
|
@@ -7,6 +7,9 @@ from ck.probability.probability_space import ProbabilitySpace, Condition, check_
|
|
|
7
7
|
class EmpiricalProbabilitySpace(ProbabilitySpace):
|
|
8
8
|
def __init__(self, rvs: Sequence[RandomVariable], samples: Iterable[Instance]):
|
|
9
9
|
"""
|
|
10
|
+
Enable probabilistic queries over a sample from a sample space.
|
|
11
|
+
Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
|
|
12
|
+
|
|
10
13
|
Assumes:
|
|
11
14
|
len(sample) == len(rvs), for each sample in samples.
|
|
12
15
|
0 <= sample[i] < len(rvs[i]), for each sample in samples, for i in range(len(rvs)).
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import Sequence, Iterable, Tuple, Dict, List
|
|
2
|
+
|
|
3
|
+
from ck.pgm import RandomVariable, Indicator, Instance, PGM
|
|
4
|
+
from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PGMProbabilitySpace(ProbabilitySpace):
|
|
8
|
+
def __init__(self, pgm: PGM):
|
|
9
|
+
"""
|
|
10
|
+
Enable probabilistic queries directly on a PGM.
|
|
11
|
+
Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
pgm: The PGM to query.
|
|
15
|
+
"""
|
|
16
|
+
self._pgm = pgm
|
|
17
|
+
self._z = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
21
|
+
return self._pgm.rvs
|
|
22
|
+
|
|
23
|
+
def wmc(self, *condition: Condition) -> float:
|
|
24
|
+
condition: Tuple[Indicator, ...] = check_condition(condition)
|
|
25
|
+
return self._pgm.value_product_indicators(*condition)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def z(self) -> float:
|
|
29
|
+
if self._z is None:
|
|
30
|
+
self._z = self._pgm.value_product_indicators()
|
|
31
|
+
return self._z
|
|
32
|
+
|
|
@@ -3,6 +3,7 @@ An abstract class for object providing probabilities.
|
|
|
3
3
|
"""
|
|
4
4
|
import math
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from itertools import chain
|
|
6
7
|
from typing import Sequence, Tuple, Iterable, Callable
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
@@ -75,15 +76,38 @@ class ProbabilitySpace(ABC):
|
|
|
75
76
|
the probability of the given indicators, conditioned on the given conditions.
|
|
76
77
|
"""
|
|
77
78
|
condition: Tuple[Indicator, ...] = check_condition(condition)
|
|
79
|
+
|
|
78
80
|
if len(condition) == 0:
|
|
79
81
|
z = self.z
|
|
82
|
+
if z <= 0:
|
|
83
|
+
return np.nan
|
|
80
84
|
else:
|
|
81
85
|
z = self.wmc(*condition)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
86
|
+
if z <= 0:
|
|
87
|
+
return np.nan
|
|
88
|
+
|
|
89
|
+
# Combine the indicators with the condition
|
|
90
|
+
# If a variable is mentioned in both the indicators and condition, then
|
|
91
|
+
# we need to take the intersection, and check for contradictions.
|
|
92
|
+
# If a variable is mentioned in the condition but not indicators, then
|
|
93
|
+
# the rv condition needs to be added to the indicators.
|
|
94
|
+
indicator_groups: MapSet[int, Indicator] = _group_indicators(indicators)
|
|
95
|
+
condition_groups: MapSet[int, Indicator] = _group_indicators(condition)
|
|
96
|
+
|
|
97
|
+
for rv_idx, indicators in condition_groups.items():
|
|
98
|
+
indicator_group = indicator_groups.get(rv_idx)
|
|
99
|
+
if indicator_group is None:
|
|
100
|
+
indicator_groups.add_all(rv_idx, indicators)
|
|
101
|
+
else:
|
|
102
|
+
indicator_group.intersection_update(indicators)
|
|
103
|
+
if len(indicator_group) == 0:
|
|
104
|
+
# A contradiction between the indicators and conditions
|
|
105
|
+
return 0.0
|
|
106
|
+
|
|
107
|
+
# Collect all the indicators from the updated indicator_groups
|
|
108
|
+
indicators = chain(*indicator_groups.values())
|
|
109
|
+
|
|
110
|
+
return self.wmc(*indicators) / z
|
|
87
111
|
|
|
88
112
|
def marginal_distribution(self, *rvs: RandomVariable, condition: Condition = ()) -> NDArrayNumeric:
|
|
89
113
|
"""
|
|
@@ -160,9 +184,7 @@ class ProbabilitySpace(ABC):
|
|
|
160
184
|
assert len(rv_indexes) == len(rvs), 'duplicated random variables not allowed'
|
|
161
185
|
|
|
162
186
|
# Group conditioning indicators by random variable.
|
|
163
|
-
conditions_by_rvs =
|
|
164
|
-
for ind in condition:
|
|
165
|
-
conditions_by_rvs.get_set(ind.rv_idx).add(ind.state_idx)
|
|
187
|
+
conditions_by_rvs = _group_states(condition)
|
|
166
188
|
|
|
167
189
|
# See if any MAP random variable is also conditioned.
|
|
168
190
|
# Reduce the state space of any conditioned MAP rv.
|
|
@@ -195,12 +217,12 @@ class ProbabilitySpace(ABC):
|
|
|
195
217
|
# Loop over the state space of the 'loop' rvs
|
|
196
218
|
best_probability = float('-inf')
|
|
197
219
|
best_states = None
|
|
198
|
-
|
|
199
|
-
for
|
|
200
|
-
probability = self.wmc(*(
|
|
220
|
+
indicators: Tuple[Indicator, ...]
|
|
221
|
+
for indicators in _combos(loop_rvs):
|
|
222
|
+
probability = self.wmc(*(indicators + new_conditions))
|
|
201
223
|
if probability > best_probability:
|
|
202
224
|
best_probability = probability
|
|
203
|
-
best_states = tuple(ind.state_idx for ind in
|
|
225
|
+
best_states = tuple(ind.state_idx for ind in indicators)
|
|
204
226
|
condition_probability = self.wmc(*condition)
|
|
205
227
|
return best_probability / condition_probability, best_states
|
|
206
228
|
|
|
@@ -551,6 +573,38 @@ def dtype_for_state_indexes(rvs: Iterable[RandomVariable]) -> DTypeStates:
|
|
|
551
573
|
return dtype_for_number_of_states(max((len(rv) for rv in rvs), default=0))
|
|
552
574
|
|
|
553
575
|
|
|
576
|
+
def _group_indicators(indicators: Iterable[Indicator]) -> MapSet[int, Indicator]:
|
|
577
|
+
"""
|
|
578
|
+
Group the given indicators by rv_idx.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
indicators: the indicators to group.
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
A mapping from rv_idx to set of indicators.
|
|
585
|
+
"""
|
|
586
|
+
groups: MapSet[int, Indicator] = MapSet()
|
|
587
|
+
for indicator in indicators:
|
|
588
|
+
groups.add(indicator.rv_idx, indicator)
|
|
589
|
+
return groups
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def _group_states(indicators: Iterable[Indicator]) -> MapSet[int, int]:
|
|
593
|
+
"""
|
|
594
|
+
Group the given indicator states by rv_idx.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
indicators: the indicators to group.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
A mapping from rv_idx to set of state indexes.
|
|
601
|
+
"""
|
|
602
|
+
groups: MapSet[int, int] = MapSet()
|
|
603
|
+
for indicator in indicators:
|
|
604
|
+
groups.add(indicator.rv_idx, indicator.state_idx)
|
|
605
|
+
return groups
|
|
606
|
+
|
|
607
|
+
|
|
554
608
|
def _normalise_marginal(distribution: NDArrayFloat64) -> None:
|
|
555
609
|
"""
|
|
556
610
|
Update the values in the given distribution to
|
ck/program/program.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
"""
|
|
2
|
+
For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
|
|
3
|
+
"""
|
|
1
4
|
from typing import Callable, Sequence
|
|
2
5
|
|
|
3
6
|
import numpy as np
|
|
@@ -8,7 +11,12 @@ from ck.utils.np_extras import DTypeNumeric, NDArrayNumeric
|
|
|
8
11
|
|
|
9
12
|
class Program:
|
|
10
13
|
"""
|
|
11
|
-
A
|
|
14
|
+
A program represents an arithmetic a function from input values to output values.
|
|
15
|
+
|
|
16
|
+
Internally a `Program` wraps a `RawProgram` which is the object returned by a circuit compiler.
|
|
17
|
+
|
|
18
|
+
Every `Program` has a numpy `dtype` which defines the numeric data type for input and output values.
|
|
19
|
+
Typically, the `dtype` of a program is a C style double.
|
|
12
20
|
"""
|
|
13
21
|
|
|
14
22
|
def __init__(self, raw_program: RawProgram):
|
ck/program/raw_program.py
CHANGED
|
@@ -9,8 +9,8 @@ from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
|
|
|
9
9
|
|
|
10
10
|
# RawProgramFunction is a function of three ctypes arrays, returning nothing.
|
|
11
11
|
# Args:
|
|
12
|
-
# [0]: input
|
|
13
|
-
# [1]: working memory,
|
|
12
|
+
# [0]: input values,
|
|
13
|
+
# [1]: temporary working memory,
|
|
14
14
|
# [2]: output values.
|
|
15
15
|
RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
|
|
16
16
|
|
|
@@ -18,10 +18,16 @@ RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
|
|
|
18
18
|
@dataclass
|
|
19
19
|
class RawProgram:
|
|
20
20
|
"""
|
|
21
|
+
A raw program is returned by a circuit compiler to provide execution of
|
|
22
|
+
the function defined by a compiled circuit.
|
|
23
|
+
|
|
24
|
+
A `RawProgram` is a `Callable` with the signature:
|
|
25
|
+
|
|
26
|
+
|
|
21
27
|
Fields:
|
|
22
28
|
function: is a function of three ctypes arrays, returning nothing.
|
|
23
29
|
dtype: the numpy data type of the array values.
|
|
24
|
-
number_of_vars: the number of input
|
|
30
|
+
number_of_vars: the number of input values (first function argument).
|
|
25
31
|
number_of_tmps: the number of working memory values (second function argument).
|
|
26
32
|
number_of_results: the number of result values (third function argument).
|
|
27
33
|
var_indices: maps the index of inputs (from 0 to self.number_of_vars - 1) to the index
|
ck/sampling/sampler_support.py
CHANGED
|
@@ -13,7 +13,7 @@ from ck.utils.random_extras import Random
|
|
|
13
13
|
# Type of a yield function. Support for a sampler.
|
|
14
14
|
# A yield function may be used to implement a sampler's iterator, thus
|
|
15
15
|
# it provides an Instance or single state index.
|
|
16
|
-
YieldF = Callable[[NDArrayStates],
|
|
16
|
+
YieldF = Callable[[NDArrayStates], int] | Callable[[NDArrayStates], Instance]
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@dataclass
|
ck/sampling/uniform_sampler.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
from typing import Set, List, Iterator, Optional, Sequence
|
|
2
1
|
import random
|
|
2
|
+
from typing import Set, List, Iterator, Optional, Sequence
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
from ck.pgm import Instance, RandomVariable, Indicator
|
|
7
7
|
from ck.probability.probability_space import dtype_for_state_indexes, Condition, check_condition
|
|
8
|
-
from .sampler import Sampler
|
|
9
|
-
from .sampler_support import YieldF
|
|
10
8
|
from ck.utils.map_set import MapSet
|
|
11
9
|
from ck.utils.np_extras import DType
|
|
12
10
|
from ck.utils.random_extras import Random
|
|
11
|
+
from .sampler import Sampler
|
|
12
|
+
from .sampler_support import YieldF
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class UniformSampler(Sampler):
|
|
@@ -39,11 +39,15 @@ class UniformSampler(Sampler):
|
|
|
39
39
|
conditioned_rvs.add(ind.rv_idx, ind.state_idx)
|
|
40
40
|
|
|
41
41
|
def get_possible_states(_rv: RandomVariable) -> List[int]:
|
|
42
|
+
"""
|
|
43
|
+
Get the allowable states for a given random variable, given
|
|
44
|
+
conditions in `conditioned_rvs`.
|
|
45
|
+
"""
|
|
42
46
|
condition_states: Optional[Set[int]] = conditioned_rvs.get(_rv.idx)
|
|
43
47
|
if condition_states is None:
|
|
44
48
|
return list(range(len(_rv)))
|
|
45
49
|
else:
|
|
46
|
-
return
|
|
50
|
+
return list(condition_states)
|
|
47
51
|
|
|
48
52
|
possible_states: List[List[int]] = [
|
|
49
53
|
get_possible_states(rv)
|
|
@@ -63,4 +67,6 @@ class UniformSampler(Sampler):
|
|
|
63
67
|
for i, l in enumerate(possible_states):
|
|
64
68
|
state_idx = rand.randrange(0, len(l))
|
|
65
69
|
state[i] = l[state_idx]
|
|
70
|
+
# We know the yield function will always provide either ints or Instances
|
|
71
|
+
# noinspection PyTypeChecker
|
|
66
72
|
yield yield_f(state)
|
|
@@ -8,7 +8,7 @@ from ck.program.program_buffer import ProgramBuffer
|
|
|
8
8
|
from ck.program.raw_program import RawProgram
|
|
9
9
|
from ck.sampling.sampler import Sampler
|
|
10
10
|
from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
|
|
11
|
-
from ck.utils.np_extras import NDArrayNumeric
|
|
11
|
+
from ck.utils.np_extras import NDArrayNumeric, NDArrayStates
|
|
12
12
|
from ck.utils.random_extras import Random
|
|
13
13
|
|
|
14
14
|
|
|
@@ -52,7 +52,7 @@ class WMCDirectSampler(Sampler):
|
|
|
52
52
|
return program_buffer.compute().item()
|
|
53
53
|
|
|
54
54
|
# Set up working memory buffers
|
|
55
|
-
states = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
55
|
+
states: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
56
56
|
buff_slots = np.zeros(self._max_number_of_states, dtype=np.uintp)
|
|
57
57
|
buff_states = np.zeros(self._max_number_of_states, dtype=self._state_dtype)
|
|
58
58
|
|
|
@@ -153,6 +153,8 @@ class WMCDirectSampler(Sampler):
|
|
|
153
153
|
slots[slot] = 1
|
|
154
154
|
states[sample_rv.index] = state
|
|
155
155
|
|
|
156
|
+
# We know the yield function will always provide either ints or Instances
|
|
157
|
+
# noinspection PyTypeChecker
|
|
156
158
|
yield yield_f(states)
|
|
157
159
|
|
|
158
160
|
# Reset the one slots for the next iteration.
|
ck/sampling/wmc_gibbs_sampler.py
CHANGED
|
@@ -65,12 +65,16 @@ class WMCGibbsSampler(Sampler):
|
|
|
65
65
|
if skip == 0:
|
|
66
66
|
while True:
|
|
67
67
|
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
68
|
+
# We know the yield function will always provide either ints or Instances
|
|
69
|
+
# noinspection PyTypeChecker
|
|
68
70
|
yield yield_f(state)
|
|
69
71
|
else:
|
|
70
72
|
while True:
|
|
71
73
|
for _ in range(skip):
|
|
72
74
|
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
73
75
|
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
76
|
+
# We know the yield function will always provide either ints or Instances
|
|
77
|
+
# noinspection PyTypeChecker
|
|
74
78
|
yield yield_f(state)
|
|
75
79
|
|
|
76
80
|
else:
|
|
@@ -79,6 +83,8 @@ class WMCGibbsSampler(Sampler):
|
|
|
79
83
|
for _ in range(skip):
|
|
80
84
|
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
81
85
|
self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
|
|
86
|
+
# We know the yield function will always provide either ints or Instances
|
|
87
|
+
# noinspection PyTypeChecker
|
|
82
88
|
yield yield_f(state)
|
|
83
89
|
if rand.random() < pr_restart:
|
|
84
90
|
# Set an initial system state
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Collection, Iterator,
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
@@ -78,12 +78,16 @@ class WMCMetropolisSampler(Sampler):
|
|
|
78
78
|
if skip == 0:
|
|
79
79
|
while True:
|
|
80
80
|
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
81
|
+
# We know the yield function will always provide either ints or Instances
|
|
82
|
+
# noinspection PyTypeChecker
|
|
81
83
|
yield yield_f(state)
|
|
82
84
|
else:
|
|
83
85
|
while True:
|
|
84
86
|
for _ in range(skip):
|
|
85
87
|
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
86
88
|
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
89
|
+
# We know the yield function will always provide either ints or Instances
|
|
90
|
+
# noinspection PyTypeChecker
|
|
87
91
|
yield yield_f(state)
|
|
88
92
|
|
|
89
93
|
else:
|
|
@@ -92,6 +96,8 @@ class WMCMetropolisSampler(Sampler):
|
|
|
92
96
|
for _ in range(skip):
|
|
93
97
|
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
94
98
|
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
99
|
+
# We know the yield function will always provide either ints or Instances
|
|
100
|
+
# noinspection PyTypeChecker
|
|
95
101
|
yield yield_f(state)
|
|
96
102
|
|
|
97
103
|
if rand.random() < pr_restart:
|
|
@@ -91,6 +91,8 @@ class WMCRejectionSampler(Sampler):
|
|
|
91
91
|
w: float = wmc()
|
|
92
92
|
|
|
93
93
|
if rand.random() * self._w_max < w:
|
|
94
|
+
# We know the yield function will always provide either ints or Instances
|
|
95
|
+
# noinspection PyTypeChecker
|
|
94
96
|
yield yield_f(state)
|
|
95
97
|
|
|
96
98
|
# Update w_not_seen and w_high to adapt w_max.
|