compiled-knowledge 4.0.0a22__cp312-cp312-macosx_11_0_arm64.whl → 4.0.0a24__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/_circuit_cy.c +50 -60
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit/_circuit_cy.pyx +1 -1
- ck/circuit/_circuit_py.py +1 -1
- ck/circuit_compiler/circuit_compiler.py +3 -2
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +179 -170
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +3 -3
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +14 -7
- ck/circuit_compiler/interpret_compiler.py +35 -4
- ck/circuit_compiler/llvm_vm_compiler.py +9 -3
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/llvm_ir_function.py +18 -1
- ck/pgm.py +100 -102
- ck/pgm_compiler/pgm_compiler.py +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/pgm_compiler/support/join_tree.py +3 -3
- ck/probability/empirical_probability_space.py +4 -3
- ck/probability/pgm_probability_space.py +7 -3
- ck/probability/probability_space.py +21 -15
- ck/program/raw_program.py +40 -7
- ck/sampling/sampler_support.py +8 -5
- ck_demos/ace/simple_ace_demo.py +18 -0
- ck_demos/getting_started/__init__.py +0 -0
- ck_demos/getting_started/simple_demo.py +18 -0
- ck_demos/programs/demo_raw_program_dump.py +17 -0
- {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/RECORD +33 -29
- {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
@@ -38,7 +38,7 @@ DTYPE_TO_CVM_TYPE: Dict[DTypeNumeric, str] = {
|
|
|
38
38
|
}
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawProgramFunction, int]:
|
|
41
|
+
def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawProgramFunction, int, int]:
|
|
42
42
|
"""
|
|
43
43
|
Make a RawProgram function that interprets the circuit.
|
|
44
44
|
|
|
@@ -47,7 +47,7 @@ def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawPr
|
|
|
47
47
|
dtype: a numpy data type that must be a key in the dictionary, DTYPE_TO_CVM_TYPE.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
|
-
(function, number_of_tmps)
|
|
50
|
+
(function, number_of_tmps, number_of_instructions)
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
53
|
cdef Instructions instructions
|
|
@@ -128,7 +128,7 @@ def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawPr
|
|
|
128
128
|
else:
|
|
129
129
|
raise ValueError(f'cvm_type_name unexpected: {cvm_type_name!r}')
|
|
130
130
|
|
|
131
|
-
return function, len(analysis.op_to_tmp)
|
|
131
|
+
return function, len(analysis.op_to_tmp), len(analysis.op_nodes)
|
|
132
132
|
|
|
133
133
|
# VM instructions
|
|
134
134
|
cdef int ADD = circuit.ADD
|
|
@@ -48,15 +48,16 @@ class CythonRawProgram(RawProgram):
|
|
|
48
48
|
result: Sequence[CircuitNode],
|
|
49
49
|
dtype: DTypeNumeric,
|
|
50
50
|
):
|
|
51
|
-
|
|
52
|
-
self.result = result
|
|
53
|
-
|
|
54
|
-
function, number_of_tmps = _make_function(
|
|
51
|
+
function, number_of_tmps, number_of_instructions = _make_function(
|
|
55
52
|
var_nodes=in_vars,
|
|
56
53
|
result_nodes=result,
|
|
57
54
|
dtype=dtype,
|
|
58
55
|
)
|
|
59
56
|
|
|
57
|
+
self.in_vars = in_vars
|
|
58
|
+
self.result = result
|
|
59
|
+
self.number_of_instructions = number_of_instructions
|
|
60
|
+
|
|
60
61
|
super().__init__(
|
|
61
62
|
function=function,
|
|
62
63
|
dtype=dtype,
|
|
@@ -66,6 +67,10 @@ class CythonRawProgram(RawProgram):
|
|
|
66
67
|
var_indices=tuple(var.idx for var in in_vars),
|
|
67
68
|
)
|
|
68
69
|
|
|
70
|
+
def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
71
|
+
super().dump(prefix=prefix, indent=indent)
|
|
72
|
+
print(f'{prefix}number of instructions = {self.number_of_instructions}')
|
|
73
|
+
|
|
69
74
|
def __getstate__(self):
|
|
70
75
|
"""
|
|
71
76
|
Support for pickle.
|
|
@@ -94,7 +99,7 @@ class CythonRawProgram(RawProgram):
|
|
|
94
99
|
self.in_vars = state['in_vars']
|
|
95
100
|
self.result = state['result']
|
|
96
101
|
|
|
97
|
-
self.function, _ = _make_function(
|
|
102
|
+
self.function, _, self.number_of_instructions = _make_function(
|
|
98
103
|
var_nodes=self.in_vars,
|
|
99
104
|
result_nodes=self.result,
|
|
100
105
|
dtype=self.dtype,
|
|
@@ -105,7 +110,7 @@ def _make_function(
|
|
|
105
110
|
var_nodes: Sequence[VarNode],
|
|
106
111
|
result_nodes: Sequence[CircuitNode],
|
|
107
112
|
dtype: DTypeNumeric,
|
|
108
|
-
) -> Tuple[RawProgramFunction, int]:
|
|
113
|
+
) -> Tuple[RawProgramFunction, int, int]:
|
|
109
114
|
"""
|
|
110
115
|
Make a RawProgram function that interprets the circuit.
|
|
111
116
|
|
|
@@ -115,7 +120,9 @@ def _make_function(
|
|
|
115
120
|
dtype: a numpy data type that must be a key in the dictionary, DTYPE_TO_CVM_TYPE.
|
|
116
121
|
|
|
117
122
|
Returns:
|
|
118
|
-
(function, number_of_tmps)
|
|
123
|
+
(function, number_of_tmps, number_of_instructions)
|
|
119
124
|
"""
|
|
120
125
|
analysis: CircuitAnalysis = analyze_circuit(var_nodes, result_nodes)
|
|
126
|
+
DEBUG = _compiler.make_function(analysis, dtype)
|
|
127
|
+
|
|
121
128
|
return _compiler.make_function(analysis, dtype)
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import ctypes as ct
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Sequence, Optional, Dict, List, Tuple, Callable
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
|
-
import ctypes as ct
|
|
8
8
|
|
|
9
|
+
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
10
|
+
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
9
11
|
from ..circuit import Circuit, CircuitNode, VarNode, OpNode, ADD, MUL
|
|
10
12
|
from ..program.raw_program import RawProgram, RawProgramFunction
|
|
11
13
|
from ..utils.iter_extras import multiply, first
|
|
12
14
|
from ..utils.np_extras import NDArrayNumeric, DTypeNumeric
|
|
13
|
-
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
14
|
-
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
15
15
|
|
|
16
16
|
# index to a value array
|
|
17
17
|
_VARS = 0
|
|
@@ -85,6 +85,15 @@ class InterpreterRawProgram(RawProgram):
|
|
|
85
85
|
var_indices=tuple(var.idx for var in in_vars),
|
|
86
86
|
)
|
|
87
87
|
|
|
88
|
+
def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
|
|
89
|
+
super().dump(prefix=prefix, indent=indent)
|
|
90
|
+
print(f'{prefix}number of instructions = {len(self.instructions)}')
|
|
91
|
+
if show_instructions:
|
|
92
|
+
print(f'{prefix}instructions:')
|
|
93
|
+
next_prefix: str = prefix + indent
|
|
94
|
+
for instruction in self.instructions:
|
|
95
|
+
print(f'{next_prefix}{instruction.to_str(self.var_indices, self.np_consts)}')
|
|
96
|
+
|
|
88
97
|
def __getstate__(self):
|
|
89
98
|
"""
|
|
90
99
|
Support for pickle.
|
|
@@ -123,7 +132,6 @@ def _make_instructions(
|
|
|
123
132
|
analysis: CircuitAnalysis,
|
|
124
133
|
dtype: DTypeNumeric,
|
|
125
134
|
) -> Tuple[Sequence[_Instruction], NDArrayNumeric]:
|
|
126
|
-
|
|
127
135
|
# Store const values in a numpy array
|
|
128
136
|
node_to_const_idx: Dict[int, int] = {
|
|
129
137
|
id(node): i
|
|
@@ -216,9 +224,32 @@ class _ElementID:
|
|
|
216
224
|
array: int # VARS, TMPS, CONSTS, RESULT
|
|
217
225
|
index: int # index into the array
|
|
218
226
|
|
|
227
|
+
def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
|
|
228
|
+
if self.array == _VARS:
|
|
229
|
+
return f'var[{var_indices[self.index]}]'
|
|
230
|
+
elif self.array == _TMPS:
|
|
231
|
+
return f'tmp[{self.index}]'
|
|
232
|
+
elif self.array == _CONSTS:
|
|
233
|
+
return str(consts.item(self.index))
|
|
234
|
+
elif self.array == _RESULT:
|
|
235
|
+
return f'result[{self.index}]'
|
|
236
|
+
else:
|
|
237
|
+
return f'?[{self.index}]'
|
|
238
|
+
|
|
219
239
|
|
|
220
240
|
@dataclass
|
|
221
241
|
class _Instruction:
|
|
222
242
|
operation: Callable
|
|
223
243
|
args: Sequence[_ElementID]
|
|
224
244
|
dest: _ElementID
|
|
245
|
+
|
|
246
|
+
def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
|
|
247
|
+
symbol: str
|
|
248
|
+
if self.operation is multiply:
|
|
249
|
+
symbol = 'mul'
|
|
250
|
+
elif self.operation == sum:
|
|
251
|
+
symbol = 'sum'
|
|
252
|
+
else:
|
|
253
|
+
symbol = '<?>'
|
|
254
|
+
args: str = ' '.join(elem.to_str(var_indices, consts) for elem in self.args)
|
|
255
|
+
return f'{self.dest.to_str(var_indices, consts)} = {symbol} {args}'
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import ctypes as ct
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Sequence, Optional, Tuple, List, Dict
|
|
5
6
|
|
|
6
7
|
import llvmlite.binding as llvm
|
|
7
8
|
import llvmlite.ir as ir
|
|
8
9
|
import numpy as np
|
|
9
|
-
import ctypes as ct
|
|
10
10
|
|
|
11
11
|
from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
12
12
|
from .support.input_vars import InputVars, InferVars, infer_input_vars
|
|
@@ -126,6 +126,12 @@ class LLVMRawProgramWithArrays(LLVMRawProgram):
|
|
|
126
126
|
instructions: np.ndarray
|
|
127
127
|
consts: np.ndarray
|
|
128
128
|
|
|
129
|
+
def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
|
|
130
|
+
super().dump(prefix=prefix, indent=indent)
|
|
131
|
+
print(f'{prefix}LLVM byte code size = {len(self.instructions)}')
|
|
132
|
+
if show_instructions:
|
|
133
|
+
self.dump_llvm_program(prefix=prefix, indent=indent)
|
|
134
|
+
|
|
129
135
|
def __post_init__(self):
|
|
130
136
|
self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
|
|
131
137
|
self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
|
|
@@ -200,7 +206,7 @@ def _make_llvm_program(
|
|
|
200
206
|
consts_global.global_constant = True
|
|
201
207
|
consts_global.initializer = ir.Constant(consts_array_type, const_values)
|
|
202
208
|
data_idx_0 = ir.Constant(data_idx_type, 0)
|
|
203
|
-
consts: ir.Value = builder.gep(consts_global, [data_idx_0,
|
|
209
|
+
consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
|
|
204
210
|
|
|
205
211
|
# Put bytecode into the LLVM module
|
|
206
212
|
instructions_array_type = ir.ArrayType(byte_type, len(byte_code))
|
|
@@ -218,7 +224,7 @@ def _make_llvm_program(
|
|
|
218
224
|
|
|
219
225
|
instructions_ptr_type = byte_type.as_pointer()
|
|
220
226
|
instructions_global = ir.GlobalVariable(module, instructions_ptr_type, name='instructions')
|
|
221
|
-
instructions_global.initializer =ir.Constant(instructions_ptr_type, None)
|
|
227
|
+
instructions_global.initializer = ir.Constant(instructions_ptr_type, None)
|
|
222
228
|
instructions: ir.Value = builder.load(instructions_global)
|
|
223
229
|
|
|
224
230
|
interp = _InterpBuilder(builder, type_info, inst_idx_type, data_idx_bytes, num_args_bytes, consts, instructions)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"-O3"
|
|
16
16
|
],
|
|
17
17
|
"include_dirs": [
|
|
18
|
-
"/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-
|
|
18
|
+
"/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-gdy6g4em/lib/python3.12/site-packages/numpy/_core/include"
|
|
19
19
|
],
|
|
20
20
|
"name": "ck.circuit_compiler.support.circuit_analyser._circuit_analyser_cy",
|
|
21
21
|
"sources": [
|
|
Binary file
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import ctypes as ct
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import Enum
|
|
4
|
-
from typing import Callable, Tuple, Optional
|
|
4
|
+
from typing import Callable, Tuple, Optional, List
|
|
5
5
|
|
|
6
6
|
import llvmlite.binding as llvm
|
|
7
7
|
import llvmlite.ir as ir
|
|
@@ -172,6 +172,23 @@ class LLVMRawProgram(RawProgram):
|
|
|
172
172
|
'opt': self.opt,
|
|
173
173
|
}
|
|
174
174
|
|
|
175
|
+
def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
|
|
176
|
+
super().dump(prefix=prefix, indent=indent)
|
|
177
|
+
print(f'{prefix}optimisation level = {self.opt}')
|
|
178
|
+
if show_instructions:
|
|
179
|
+
self.dump_llvm_program(prefix=prefix, indent=indent)
|
|
180
|
+
|
|
181
|
+
def dump_llvm_program(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
182
|
+
if self.llvm_program is None:
|
|
183
|
+
print(f'{prefix}LLVM program: unavailable')
|
|
184
|
+
else:
|
|
185
|
+
llvm_program: List[str] = self.llvm_program.split('\n')
|
|
186
|
+
print(f'{prefix}LLVM program size = {len(llvm_program)}')
|
|
187
|
+
print(f'{prefix}LLVM program:')
|
|
188
|
+
next_prefix: str = prefix + indent
|
|
189
|
+
for line in llvm_program:
|
|
190
|
+
print(f'{next_prefix}{line}')
|
|
191
|
+
|
|
175
192
|
def __setstate__(self, state):
|
|
176
193
|
"""
|
|
177
194
|
Support for pickle.
|
ck/pgm.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
|
|
3
|
-
"""
|
|
4
1
|
from __future__ import annotations
|
|
5
2
|
|
|
6
3
|
import math
|
|
@@ -8,7 +5,7 @@ from abc import ABC, abstractmethod
|
|
|
8
5
|
from dataclasses import dataclass
|
|
9
6
|
from itertools import repeat as _repeat
|
|
10
7
|
from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
|
|
11
|
-
Collection, Any, Iterator
|
|
8
|
+
Collection, Any, Iterator, TypeAlias
|
|
12
9
|
|
|
13
10
|
import numpy as np
|
|
14
11
|
|
|
@@ -17,21 +14,34 @@ from ck.utils.iter_extras import (
|
|
|
17
14
|
)
|
|
18
15
|
from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
|
|
19
16
|
|
|
20
|
-
|
|
21
|
-
State =
|
|
17
|
+
State: TypeAlias = Union[int, str, bool, float, None]
|
|
18
|
+
State.__doc__ = \
|
|
19
|
+
"""
|
|
20
|
+
The type for a possible state of a random variable.
|
|
21
|
+
"""
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
Instance: TypeAlias = Sequence[int]
|
|
24
|
+
Instance.__doc__ = \
|
|
25
|
+
"""
|
|
26
|
+
An instance (of a sequence of random variables) is a sequence of integers
|
|
27
|
+
that are state indexes, co-indexed with a known sequence of random variables.
|
|
28
|
+
"""
|
|
26
29
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
Key: TypeAlias = Union[Instance, int]
|
|
31
|
+
Key.__doc__ = \
|
|
32
|
+
"""
|
|
33
|
+
A key identifies an instance, either as an instance itself or a
|
|
34
|
+
single integer, representing an instance with one dimension.
|
|
35
|
+
"""
|
|
30
36
|
|
|
31
|
-
|
|
32
|
-
|
|
37
|
+
Shape: TypeAlias = Sequence[int]
|
|
38
|
+
Key.__doc__ = \
|
|
39
|
+
"""
|
|
40
|
+
The type for the "shape" of a sequence of random variables.
|
|
41
|
+
That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
|
|
42
|
+
"""
|
|
33
43
|
|
|
34
|
-
|
|
44
|
+
DEFAULT_CPT_TOLERANCE: float = 0.000001 # A tolerance when checking CPT distributions sum to one (or zero).
|
|
35
45
|
|
|
36
46
|
|
|
37
47
|
class PGM:
|
|
@@ -39,11 +49,9 @@ class PGM:
|
|
|
39
49
|
A probabilistic graphical model (PGM) represents a joint probability distribution over
|
|
40
50
|
a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
|
|
41
51
|
|
|
42
|
-
Add a random variable to a PGM, pgm
|
|
43
|
-
|
|
44
|
-
Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
|
|
52
|
+
Add a random variable to a PGM, `pgm`, using `rv = pgm.new_rv(...)`.
|
|
45
53
|
|
|
46
|
-
|
|
54
|
+
Add a factor to the PGM, `pgm`, using `factor = pgm.new_factor(...)`.
|
|
47
55
|
"""
|
|
48
56
|
|
|
49
57
|
def __init__(self, name: Optional[str] = None):
|
|
@@ -587,7 +595,7 @@ class PGM:
|
|
|
587
595
|
# All tests passed
|
|
588
596
|
return True
|
|
589
597
|
|
|
590
|
-
def factors_are_cpts(self, tolerance: float =
|
|
598
|
+
def factors_are_cpts(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
|
|
591
599
|
"""
|
|
592
600
|
Are all factor potential functions set with parameters values
|
|
593
601
|
conforming to Conditional Probability Tables.
|
|
@@ -603,7 +611,7 @@ class PGM:
|
|
|
603
611
|
"""
|
|
604
612
|
return all(function.is_cpt(tolerance) for function in self.functions)
|
|
605
613
|
|
|
606
|
-
def check_is_bayesian_network(self, tolerance: float =
|
|
614
|
+
def check_is_bayesian_network(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
|
|
607
615
|
"""
|
|
608
616
|
Is this PGM a Bayesian network.
|
|
609
617
|
|
|
@@ -648,7 +656,7 @@ class PGM:
|
|
|
648
656
|
This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
|
|
649
657
|
|
|
650
658
|
Warning:
|
|
651
|
-
this is potentially computationally expensive as it
|
|
659
|
+
this is potentially computationally expensive as it marginalises random
|
|
652
660
|
variables not mentioned in the given indicators.
|
|
653
661
|
|
|
654
662
|
Args:
|
|
@@ -658,29 +666,11 @@ class PGM:
|
|
|
658
666
|
the product of factors, conditioned on the given instance. This is the
|
|
659
667
|
computed value of the PGM, conditioned on the given instance.
|
|
660
668
|
"""
|
|
661
|
-
#
|
|
662
|
-
#
|
|
663
|
-
#
|
|
664
|
-
#
|
|
665
|
-
#
|
|
666
|
-
# # Collect rvs not mentioned - to marginalise
|
|
667
|
-
# for rv, rv_filter in zip(self.rvs, inst_filter):
|
|
668
|
-
# if len(rv_filter) == 0:
|
|
669
|
-
# rv_filter.update(rv.state_range())
|
|
670
|
-
#
|
|
671
|
-
# def _sum_inst(_instance: Instance) -> bool:
|
|
672
|
-
# return all(
|
|
673
|
-
# (_state in _rv_filter)
|
|
674
|
-
# for _state, _rv_filter in zip(_instance, inst_filter)
|
|
675
|
-
# )
|
|
676
|
-
#
|
|
677
|
-
# # Accumulate the result
|
|
678
|
-
# sum_value = 0
|
|
679
|
-
# for instance in self.instances():
|
|
680
|
-
# if _sum_inst(instance):
|
|
681
|
-
# sum_value += self.value_product(instance)
|
|
682
|
-
#
|
|
683
|
-
# return sum_value
|
|
669
|
+
# Rather than naively checking all possible states of the PGM random
|
|
670
|
+
# variables, this method works to define the state space that should
|
|
671
|
+
# be summed over, based on the given indicators. Thus, if the given
|
|
672
|
+
# indicators constrain the state space to a small number of possibilities,
|
|
673
|
+
# then the sum is only performed over those possibilities.
|
|
684
674
|
|
|
685
675
|
# Work out the space to sum over
|
|
686
676
|
sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
|
|
@@ -719,11 +709,10 @@ class PGM:
|
|
|
719
709
|
precision: a limit on the render precision of floating point numbers.
|
|
720
710
|
max_state_digits: a limit on the number of digits when showing number of states as an integer.
|
|
721
711
|
"""
|
|
722
|
-
# limit
|
|
712
|
+
# Determine a limit to precision when displaying number of states
|
|
723
713
|
num_states: int = self.number_of_states
|
|
724
714
|
number_of_parameters = sum(function.number_of_parameters for function in self.functions)
|
|
725
715
|
number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
|
|
726
|
-
|
|
727
716
|
if math.log10(num_states) > max_state_digits:
|
|
728
717
|
log_states = math.log10(num_states)
|
|
729
718
|
exp = int(log_states)
|
|
@@ -731,7 +720,6 @@ class PGM:
|
|
|
731
720
|
num_states_str = f'{man:,.{precision}f}e+{exp}'
|
|
732
721
|
else:
|
|
733
722
|
num_states_str = f'{num_states:,}'
|
|
734
|
-
|
|
735
723
|
log_2_num_states = math.log2(num_states)
|
|
736
724
|
if (
|
|
737
725
|
log_2_num_states == 0
|
|
@@ -820,9 +808,9 @@ class PGM:
|
|
|
820
808
|
|
|
821
809
|
For a factor `f` the value of states[f.idx] is the search state.
|
|
822
810
|
Specifically:
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
811
|
+
state 0 => the factor has not been seen yet,
|
|
812
|
+
state 1 => the factor is seen but not fully processed,
|
|
813
|
+
state 2 => the factor is fully processed.
|
|
826
814
|
|
|
827
815
|
Args:
|
|
828
816
|
factor: the current Factor being checked.
|
|
@@ -1040,7 +1028,7 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1040
1028
|
|
|
1041
1029
|
def state_range(self) -> Iterable[int]:
|
|
1042
1030
|
"""
|
|
1043
|
-
Iterate over the state indexes of this random variable, in order.
|
|
1031
|
+
Iterate over the state indexes of this random variable, in ascending order.
|
|
1044
1032
|
|
|
1045
1033
|
Returns:
|
|
1046
1034
|
range(len(self))
|
|
@@ -1122,18 +1110,19 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1122
1110
|
|
|
1123
1111
|
def __eq__(self, other) -> bool:
|
|
1124
1112
|
"""
|
|
1125
|
-
Two random
|
|
1113
|
+
Two random variables are equal if they are the same object.
|
|
1126
1114
|
"""
|
|
1127
1115
|
return self is other
|
|
1128
1116
|
|
|
1129
1117
|
def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
|
|
1130
1118
|
"""
|
|
1131
|
-
Two random variable are equivalent if their indicators are equal.
|
|
1132
|
-
random variable indexes and state indexes are checked.
|
|
1133
|
-
|
|
1119
|
+
Two random variable are equivalent if their indicators are equal.
|
|
1120
|
+
Only random variable indexes and state indexes are checked.
|
|
1134
1121
|
This ignores the names of the random variable and the names of their states.
|
|
1135
|
-
|
|
1136
|
-
|
|
1122
|
+
|
|
1123
|
+
Slot maps operate across `equivalent` random variables.
|
|
1124
|
+
This means indicators of equivalent random variables will work
|
|
1125
|
+
correctly in slot maps, even if from different PGMs.
|
|
1137
1126
|
|
|
1138
1127
|
Args:
|
|
1139
1128
|
other: either a random variable or a sequence of Indicators.
|
|
@@ -1181,7 +1170,8 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1181
1170
|
"""
|
|
1182
1171
|
Returns the first index of `value`.
|
|
1183
1172
|
Raises ValueError if the value is not present.
|
|
1184
|
-
|
|
1173
|
+
|
|
1174
|
+
This method is contracted by `Sequence[Indicator]`.
|
|
1185
1175
|
|
|
1186
1176
|
Warning:
|
|
1187
1177
|
This method is different to `self.idx`.
|
|
@@ -1198,7 +1188,10 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1198
1188
|
def count(self, value: Any) -> int:
|
|
1199
1189
|
"""
|
|
1200
1190
|
Returns the number of occurrences of `value`.
|
|
1201
|
-
|
|
1191
|
+
That is, if `value` is an indicator of this random variable
|
|
1192
|
+
then 1 is returned, otherwise 0 is returned.
|
|
1193
|
+
|
|
1194
|
+
This method is contracted by `Sequence[Indicator]`.
|
|
1202
1195
|
"""
|
|
1203
1196
|
if isinstance(value, Indicator):
|
|
1204
1197
|
if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
|
|
@@ -1210,15 +1203,16 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1210
1203
|
"""
|
|
1211
1204
|
Wrap a PGM to provide convenient access to PGM random variables.
|
|
1212
1205
|
|
|
1213
|
-
An RVMap of a PGM behaves
|
|
1214
|
-
|
|
1206
|
+
An RVMap of a PGM behaves like the PGM `rvs` property (sequence of
|
|
1207
|
+
RandomVariable objects), with additional access methods for the PGM's
|
|
1208
|
+
random variables.
|
|
1215
1209
|
|
|
1216
1210
|
If the underlying PGM is updated, then the RVMap will automatically update.
|
|
1217
1211
|
|
|
1218
|
-
|
|
1219
|
-
of each random variable.
|
|
1212
|
+
In addition to accessing a random variable by its index, an RVMap enables
|
|
1213
|
+
access to the PGM random variable via the name of each random variable.
|
|
1220
1214
|
|
|
1221
|
-
|
|
1215
|
+
For example, if `pgm.rvs[1]` is a random variable named `xray`, then:
|
|
1222
1216
|
```
|
|
1223
1217
|
rvs = RVMap(pgm)
|
|
1224
1218
|
|
|
@@ -1228,7 +1222,7 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1228
1222
|
xray = rvs.xray
|
|
1229
1223
|
```
|
|
1230
1224
|
|
|
1231
|
-
To use an RVMap on a PGM, the variable names must be unique across the PGM.
|
|
1225
|
+
To use an RVMap on a PGM, the random variable names must be unique across the PGM.
|
|
1232
1226
|
"""
|
|
1233
1227
|
|
|
1234
1228
|
def __init__(self, pgm: PGM, ignore_case: bool = False):
|
|
@@ -1248,28 +1242,6 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1248
1242
|
# This may raise an exception.
|
|
1249
1243
|
_ = self._rv_map
|
|
1250
1244
|
|
|
1251
|
-
def _clean_name(self, name: str) -> str:
|
|
1252
|
-
"""
|
|
1253
|
-
Adjust the case of the given name as needed.
|
|
1254
|
-
"""
|
|
1255
|
-
return name.lower() if self._ignore_case else name
|
|
1256
|
-
|
|
1257
|
-
@property
|
|
1258
|
-
def _rv_map(self) -> Dict[str, RandomVariable]:
|
|
1259
|
-
"""
|
|
1260
|
-
Get the cached rv map, updating as needed if the PGM changed.
|
|
1261
|
-
Returns:
|
|
1262
|
-
a mapping from random variable name to random variable
|
|
1263
|
-
"""
|
|
1264
|
-
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1265
|
-
# There is a difference between the map and the PGM - create a new map.
|
|
1266
|
-
self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
|
|
1267
|
-
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1268
|
-
raise RuntimeError(f'random variable names are not unique')
|
|
1269
|
-
if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
|
|
1270
|
-
raise RuntimeError(f'random variable names clash with reserved names.')
|
|
1271
|
-
return self.__rv_map
|
|
1272
|
-
|
|
1273
1245
|
def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
|
|
1274
1246
|
"""
|
|
1275
1247
|
As per `PGM.new_rv`.
|
|
@@ -1304,6 +1276,29 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1304
1276
|
def __getattr__(self, rv_name: str) -> RandomVariable:
|
|
1305
1277
|
return self(rv_name)
|
|
1306
1278
|
|
|
1279
|
+
@property
|
|
1280
|
+
def _rv_map(self) -> Dict[str, RandomVariable]:
|
|
1281
|
+
"""
|
|
1282
|
+
Get the cached random variable map, updating as needed if the PGM changed.
|
|
1283
|
+
|
|
1284
|
+
Returns:
|
|
1285
|
+
a mapping from random variable name to random variable
|
|
1286
|
+
"""
|
|
1287
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1288
|
+
# There is a difference between the map and the PGM - create a new map.
|
|
1289
|
+
self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
|
|
1290
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1291
|
+
raise RuntimeError(f'random variable names are not unique')
|
|
1292
|
+
if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
|
|
1293
|
+
raise RuntimeError(f'random variable names clash with reserved names.')
|
|
1294
|
+
return self.__rv_map
|
|
1295
|
+
|
|
1296
|
+
def _clean_name(self, name: str) -> str:
|
|
1297
|
+
"""
|
|
1298
|
+
Adjust the case of the given name as needed.
|
|
1299
|
+
"""
|
|
1300
|
+
return name.lower() if self._ignore_case else name
|
|
1301
|
+
|
|
1307
1302
|
|
|
1308
1303
|
class Factor:
|
|
1309
1304
|
"""
|
|
@@ -1544,7 +1539,7 @@ class Factor:
|
|
|
1544
1539
|
self._potential_function = ClausePotentialFunction(self, key)
|
|
1545
1540
|
return self._potential_function
|
|
1546
1541
|
|
|
1547
|
-
def set_cpt(self, tolerance: float =
|
|
1542
|
+
def set_cpt(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> CPTPotentialFunction:
|
|
1548
1543
|
"""
|
|
1549
1544
|
Set to the potential function to a new `CPTPotentialFunction` object.
|
|
1550
1545
|
|
|
@@ -1820,7 +1815,7 @@ class PotentialFunction(ABC):
|
|
|
1820
1815
|
"""
|
|
1821
1816
|
...
|
|
1822
1817
|
|
|
1823
|
-
def is_cpt(self, tolerance=
|
|
1818
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
1824
1819
|
"""
|
|
1825
1820
|
Is the potential function set with parameters values conforming to a
|
|
1826
1821
|
Conditional Probability Table.
|
|
@@ -2028,7 +2023,7 @@ class ZeroPotentialFunction(PotentialFunction):
|
|
|
2028
2023
|
def param_idx(self, key: Key) -> int:
|
|
2029
2024
|
return _natural_key_idx(self._shape, key)
|
|
2030
2025
|
|
|
2031
|
-
def is_cpt(self, tolerance=
|
|
2026
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2032
2027
|
return True
|
|
2033
2028
|
|
|
2034
2029
|
|
|
@@ -2836,7 +2831,7 @@ class ClausePotentialFunction(PotentialFunction):
|
|
|
2836
2831
|
else:
|
|
2837
2832
|
return None
|
|
2838
2833
|
|
|
2839
|
-
def is_cpt(self, tolerance=
|
|
2834
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2840
2835
|
"""
|
|
2841
2836
|
A ClausePotentialFunction can only be a CTP when all entries are zero.
|
|
2842
2837
|
"""
|
|
@@ -2930,7 +2925,7 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
2930
2925
|
def number_of_parameters(self) -> int:
|
|
2931
2926
|
return len(self._values)
|
|
2932
2927
|
|
|
2933
|
-
def is_cpt(self, tolerance=
|
|
2928
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2934
2929
|
if tolerance >= self._tolerance:
|
|
2935
2930
|
return True
|
|
2936
2931
|
else:
|
|
@@ -3015,12 +3010,11 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3015
3010
|
|
|
3016
3011
|
def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
|
|
3017
3012
|
"""
|
|
3018
|
-
Iterate over (parent_states, cpd) tuples.
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
|
|
3022
|
-
|
|
3023
|
-
|
|
3013
|
+
Iterate over (parent_states, cpd) tuples. This will exclude zero CPDs.
|
|
3014
|
+
|
|
3015
|
+
Warning:
|
|
3016
|
+
Do not change CPDs to (or from) zero while iterating over them.
|
|
3017
|
+
|
|
3024
3018
|
Returns:
|
|
3025
3019
|
an iterator over pairs (instance, cpd) where,
|
|
3026
3020
|
instance: is indicates the state of the parent random variables.
|
|
@@ -3288,7 +3282,7 @@ def check_key(shape: Shape, key: Key) -> Instance:
|
|
|
3288
3282
|
A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
|
|
3289
3283
|
|
|
3290
3284
|
Raises:
|
|
3291
|
-
KeyError if the key is not valid.
|
|
3285
|
+
KeyError if the key is not valid for the given shape.
|
|
3292
3286
|
"""
|
|
3293
3287
|
_key: Instance = _key_to_instance(key)
|
|
3294
3288
|
if len(_key) != len(shape):
|
|
@@ -3336,8 +3330,8 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]
|
|
|
3336
3330
|
flip: if true, then first random variable changes most quickly.
|
|
3337
3331
|
|
|
3338
3332
|
Returns:
|
|
3339
|
-
an iteration over
|
|
3340
|
-
co-indexed with the given random variables.
|
|
3333
|
+
an iteration over instances, each instance is a tuple of state
|
|
3334
|
+
indexes, co-indexed with the given random variables.
|
|
3341
3335
|
"""
|
|
3342
3336
|
shape = [len(rv) for rv in rvs]
|
|
3343
3337
|
return _combos_ranges(shape, flip=not flip)
|
|
@@ -3384,6 +3378,10 @@ def _natural_key_idx(shape: Shape, key: Key) -> int:
|
|
|
3384
3378
|
"""
|
|
3385
3379
|
What is the natural index of the given key, assuming the given shape.
|
|
3386
3380
|
|
|
3381
|
+
The natural index of an instance is defined as the index of the
|
|
3382
|
+
instance if all instances for the shape are enumerated as per
|
|
3383
|
+
`rv_instances`.
|
|
3384
|
+
|
|
3387
3385
|
Args:
|
|
3388
3386
|
shape: the shape defining the state space.
|
|
3389
3387
|
key: a key into the state space.
|
ck/pgm_compiler/pgm_compiler.py
CHANGED
|
@@ -7,7 +7,7 @@ from ck.pgm_circuit import PGMCircuit
|
|
|
7
7
|
class PGMCompiler(Protocol):
|
|
8
8
|
def __call__(self, pgm: PGM, *, const_parameters: bool = True) -> PGMCircuit:
|
|
9
9
|
"""
|
|
10
|
-
A PGM compiler
|
|
10
|
+
A PGM compiler compiles a PGM to an arithmetic circuit.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
pgm: The PGM to compile.
|