compiled-knowledge 4.1.0a2__cp313-cp313-win32.whl → 4.2.0a1__cp313-cp313-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/_circuit_cy.c +1 -1
- ck/circuit/_circuit_cy.cp313-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win32.pyd +0 -0
- ck/circuit_compiler/llvm_compiler.py +4 -4
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win32.pyd +0 -0
- ck/circuit_compiler/support/input_vars.py +4 -4
- ck/dataset/cross_table.py +143 -79
- ck/dataset/dataset.py +95 -7
- ck/dataset/dataset_builder.py +11 -4
- ck/dataset/dataset_from_crosstable.py +21 -2
- ck/learning/coalesce_cross_tables.py +403 -0
- ck/learning/model_from_cross_tables.py +296 -0
- ck/learning/parameters.py +117 -0
- ck/learning/train_generative_bn.py +198 -0
- ck/pgm.py +10 -8
- ck/pgm_circuit/marginals_program.py +5 -0
- ck/pgm_circuit/wmc_program.py +5 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win32.pyd +0 -0
- ck/probability/divergence.py +226 -0
- ck/probability/probability_space.py +43 -19
- ck/utils/map_dict.py +89 -0
- ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
- ck_demos/learning/__init__.py +0 -0
- ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
- ck_demos/learning/demo_simple_learning.py +55 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/METADATA +2 -1
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/RECORD +35 -26
- ck/learning/train_generative.py +0 -149
- /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
General functions for setting the parameter values of a PGM.
|
|
3
|
+
"""
|
|
4
|
+
from typing import List, Tuple, TypeAlias
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ck.dataset.cross_table import CrossTable
|
|
9
|
+
from ck.pgm import PGM, CPTPotentialFunction, Instance, SparsePotentialFunction, DensePotentialFunction, Factor
|
|
10
|
+
from ck.utils.map_list import MapList
|
|
11
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
ParameterValues: TypeAlias = List[CrossTable]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_factors(pgm: PGM, parameter_values: List[CrossTable]) -> None:
|
|
18
|
+
for factor in parameter_values:
|
|
19
|
+
pgm.new_factor(*factor.rvs)
|
|
20
|
+
set_potential_functions(pgm, parameter_values)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_potential_functions(pgm: PGM, parameter_values: List[CrossTable]) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Set the potential function of each PGM factor to one heuristically chosen,
|
|
26
|
+
using the given parameter values. Then set the parameter values of the potential
|
|
27
|
+
function to those given by `parameter_values`.
|
|
28
|
+
|
|
29
|
+
This function modifies `pgm` in-place, iteratively calling `set_potential_function`.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
pgm (PGM): the PGM to have its potential functions set.
|
|
33
|
+
parameter_values: the parameter values,
|
|
34
|
+
"""
|
|
35
|
+
for factor, factor_parameter_values in zip(pgm.factors, parameter_values):
|
|
36
|
+
set_potential_function(factor, factor_parameter_values)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_potential_function(factor: Factor, parameter_values: CrossTable) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Set the potential function of the given factor to one heuristically chosen,
|
|
42
|
+
using the given parameter values. Then set the parameter values of the potential
|
|
43
|
+
function to those given by `parameter_values`.
|
|
44
|
+
|
|
45
|
+
The potential function will be either a ZeroPotentialFunction, DensePotentialFunction,
|
|
46
|
+
or SparsePotentialFunction.
|
|
47
|
+
|
|
48
|
+
This function modifies `factor` in-place.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
factor: The factor to update.
|
|
52
|
+
parameter_values: the parameter values,
|
|
53
|
+
"""
|
|
54
|
+
number_of_parameters: int = len(parameter_values)
|
|
55
|
+
if number_of_parameters == 0:
|
|
56
|
+
factor.set_zero()
|
|
57
|
+
else:
|
|
58
|
+
if number_of_parameters < 100 or number_of_parameters > factor.number_of_states * 0.9:
|
|
59
|
+
pot_function: DensePotentialFunction = factor.set_dense()
|
|
60
|
+
else:
|
|
61
|
+
pot_function: SparsePotentialFunction = factor.set_sparse()
|
|
62
|
+
for instance, weight in parameter_values.items():
|
|
63
|
+
pot_function[instance] = weight
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def set_zero(pgm: PGM) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Set the potential function of each PGM factor to zero.
|
|
69
|
+
"""
|
|
70
|
+
for factor in pgm.factors:
|
|
71
|
+
factor.set_zero()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def set_dense(pgm: PGM, parameter_values: List[CrossTable]) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Set the potential function of each PGM factor to a DensePotentialFunction,
|
|
77
|
+
using the given parameter values.
|
|
78
|
+
"""
|
|
79
|
+
for factor, cpt in zip(pgm.factors, parameter_values):
|
|
80
|
+
pot_function: DensePotentialFunction = factor.set_dense()
|
|
81
|
+
for instance, weight in cpt.items():
|
|
82
|
+
pot_function[instance] = weight
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def set_sparse(pgm: PGM, parameter_values: List[CrossTable]) -> None:
|
|
86
|
+
"""
|
|
87
|
+
Set the potential function of each PGM factor to a SparsePotentialFunction,
|
|
88
|
+
using the given parameter values.
|
|
89
|
+
"""
|
|
90
|
+
for factor, cpt in zip(pgm.factors, parameter_values):
|
|
91
|
+
pot_function: SparsePotentialFunction = factor.set_sparse()
|
|
92
|
+
for instance, weight in cpt.items():
|
|
93
|
+
pot_function[instance] = weight
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def set_cpt(pgm: PGM, parameter_values: List[CrossTable], normalise_cpds: bool = True) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Set the potential function of each PGM factor to a CPTPotentialFunction,
|
|
99
|
+
using the given parameter values.
|
|
100
|
+
"""
|
|
101
|
+
for factor, cpt in zip(pgm.factors, parameter_values):
|
|
102
|
+
pot_function: CPTPotentialFunction = factor.set_cpt()
|
|
103
|
+
|
|
104
|
+
# Group cpt values by parent instance
|
|
105
|
+
cpds: MapList[Instance, Tuple[int, float]] = MapList()
|
|
106
|
+
for instance, weight in cpt.items():
|
|
107
|
+
cpds.append(instance[1:], (instance[0], weight))
|
|
108
|
+
|
|
109
|
+
# Set the CPDs
|
|
110
|
+
cpd_size = len(cpt.rvs[0]) # size of the child random variable
|
|
111
|
+
for parent_instance, cpd in cpds.items():
|
|
112
|
+
cpd_array: NDArrayFloat64 = np.zeros(cpd_size, dtype=np.float64)
|
|
113
|
+
for child_state_index, weight in cpd:
|
|
114
|
+
cpd_array[child_state_index] = weight
|
|
115
|
+
if normalise_cpds:
|
|
116
|
+
cpd_array /= cpd_array.sum()
|
|
117
|
+
pot_function.set_cpd(parent_instance, cpd_array)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Mapping, Tuple
|
|
4
|
+
|
|
5
|
+
from ck.dataset import SoftDataset, HardDataset
|
|
6
|
+
from ck.dataset.cross_table import CrossTable, cross_table_from_dataset
|
|
7
|
+
from ck.learning.parameters import set_potential_functions, ParameterValues
|
|
8
|
+
from ck.pgm import PGM
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def train_generative_bn(
|
|
12
|
+
pgm: PGM,
|
|
13
|
+
dataset: HardDataset | SoftDataset,
|
|
14
|
+
*,
|
|
15
|
+
dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
|
|
16
|
+
check_bayesian_network: bool = True,
|
|
17
|
+
) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Maximum-likelihood, generative training for a Bayesian network.
|
|
20
|
+
|
|
21
|
+
The potential function of the given PGM will be set to new potential functions
|
|
22
|
+
with the learned parameter values.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pgm: the probabilistic graphical model defining the model structure.
|
|
26
|
+
Potential function values are ignored and need not be set.
|
|
27
|
+
dataset: a dataset of random variable states.
|
|
28
|
+
dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
|
|
29
|
+
This can be represented in multiple ways:
|
|
30
|
+
(a) as a uniform prior that is the same for all factors, represented as a float value,
|
|
31
|
+
(b) as a mapping from a factor index to a uniform prior, i.e., a float value,
|
|
32
|
+
(c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
|
|
33
|
+
If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
|
|
34
|
+
If a cross-table is provided as a prior, then it must have the same random variables as
|
|
35
|
+
the factor it pertains to.
|
|
36
|
+
The default value for `dirichlet_prior` is 0.
|
|
37
|
+
See `CrossTable` for more explanation.
|
|
38
|
+
check_bayesian_network: if true and not `pgm.is_structure_bayesian` an exception will be raised.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: if the given PGM does not have a Bayesian network structure, and check_bayesian_network is True.
|
|
42
|
+
"""
|
|
43
|
+
if check_bayesian_network and not pgm.is_structure_bayesian:
|
|
44
|
+
raise ValueError('the given PGM is not a Bayesian network')
|
|
45
|
+
cpts: List[CrossTable] = get_cpts(
|
|
46
|
+
pgm=pgm,
|
|
47
|
+
dataset=dataset,
|
|
48
|
+
dirichlet_prior=dirichlet_prior,
|
|
49
|
+
)
|
|
50
|
+
set_potential_functions(pgm, cpts)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_cpts(
|
|
54
|
+
pgm: PGM,
|
|
55
|
+
dataset: HardDataset | SoftDataset,
|
|
56
|
+
*,
|
|
57
|
+
dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
|
|
58
|
+
) -> ParameterValues:
|
|
59
|
+
"""
|
|
60
|
+
This function applies `cpt_from_crosstab` to each cross-table from `get_factor_cross_tables`.
|
|
61
|
+
The resulting parameter values are CPTs that can be used directly to update the parameters
|
|
62
|
+
of the given PGM, so long as it has a Bayesian network structure.
|
|
63
|
+
|
|
64
|
+
To update the given PGM from the resulting `cpts` use `set_potential_functions(pgm, cpts)`.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
pgm: the probabilistic graphical model defining the model structure.
|
|
68
|
+
Potential function values are ignored and need not be set.
|
|
69
|
+
dataset: a dataset of random variable states.
|
|
70
|
+
dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
|
|
71
|
+
This can be represented in multiple ways:
|
|
72
|
+
(a) as a uniform prior that is the same for all factors, represented as a float value,
|
|
73
|
+
(b) as a mapping from a factor index to a uniform prior, i.e., a float value,
|
|
74
|
+
(c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
|
|
75
|
+
If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
|
|
76
|
+
If a cross-table is provided as a prior, then it must have the same random variables as
|
|
77
|
+
the factor it pertains to.
|
|
78
|
+
The default value for `dirichlet_prior` is 0.
|
|
79
|
+
See `CrossTable` for more explanation.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
ParameterValues object, a CPT for each factor in the given PGM, as a list of cross-tables, co-indexed
|
|
83
|
+
with the PGM factors.
|
|
84
|
+
"""
|
|
85
|
+
cross_tables: List[CrossTable] = get_factor_cross_tables(
|
|
86
|
+
pgm=pgm,
|
|
87
|
+
dataset=dataset,
|
|
88
|
+
dirichlet_prior=dirichlet_prior,
|
|
89
|
+
)
|
|
90
|
+
cpts: List[CrossTable] = list(map(cpt_from_crosstab, cross_tables))
|
|
91
|
+
return cpts
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_factor_cross_tables(
|
|
95
|
+
pgm: PGM,
|
|
96
|
+
dataset: HardDataset | SoftDataset,
|
|
97
|
+
*,
|
|
98
|
+
dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
|
|
99
|
+
) -> ParameterValues:
|
|
100
|
+
"""
|
|
101
|
+
Compute a cross-table for each factor of the given PGM, using the data from
|
|
102
|
+
the given dataset.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
pgm: the probabilistic graphical model defining the model structure.
|
|
106
|
+
Potential function values are ignored and need not be set.
|
|
107
|
+
dataset: a dataset of random variable states.
|
|
108
|
+
dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
|
|
109
|
+
This can be represented in multiple ways:
|
|
110
|
+
(a) as a uniform prior that is the same for all factors, represented as a float value,
|
|
111
|
+
(b) as a mapping from a factor index to a uniform prior, i.e., a float value,
|
|
112
|
+
(c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
|
|
113
|
+
If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
|
|
114
|
+
If a cross-table is provided as a prior, then it must have the same random variables as
|
|
115
|
+
the factor it pertains to.
|
|
116
|
+
The default value for `dirichlet_prior` is 0.
|
|
117
|
+
See `CrossTable` for more explanation.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
ParameterValues object, a crosstable for each factor in the given PGM, as
|
|
121
|
+
per `cross_table_from_dataset`.
|
|
122
|
+
|
|
123
|
+
Assumes:
|
|
124
|
+
every random variable of the PGM is in the dataset.
|
|
125
|
+
"""
|
|
126
|
+
factor_dict: Mapping[int, float | CrossTable]
|
|
127
|
+
default_prior: float
|
|
128
|
+
if isinstance(dirichlet_prior, (float, int)):
|
|
129
|
+
factor_dict = {}
|
|
130
|
+
default_prior = dirichlet_prior
|
|
131
|
+
else:
|
|
132
|
+
factor_dict = dirichlet_prior
|
|
133
|
+
default_prior = 0
|
|
134
|
+
|
|
135
|
+
cross_tables: List[CrossTable] = [
|
|
136
|
+
cross_table_from_dataset(
|
|
137
|
+
dataset,
|
|
138
|
+
factor.rvs,
|
|
139
|
+
dirichlet_prior=factor_dict.get(factor.idx, default_prior),
|
|
140
|
+
)
|
|
141
|
+
for factor in pgm.factors
|
|
142
|
+
]
|
|
143
|
+
return cross_tables
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def cpt_from_crosstab(crosstab: CrossTable) -> CrossTable:
|
|
147
|
+
"""
|
|
148
|
+
Convert the given cross-table to a conditional probability table (CPT),
|
|
149
|
+
assuming the first random variable of the cross-table is the child
|
|
150
|
+
and remaining random variables are the parents.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
crosstab: a CrossTable representing the weight of unique instances.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
A cross-table that is a conditional probability table.
|
|
157
|
+
|
|
158
|
+
Assumes:
|
|
159
|
+
the first random variable in `crosstab.rvs` is the child random variable.
|
|
160
|
+
"""
|
|
161
|
+
return cpt_and_parent_sums_from_crosstab(crosstab)[0]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def cpt_and_parent_sums_from_crosstab(crosstab: CrossTable) -> Tuple[CrossTable, CrossTable]:
|
|
165
|
+
"""
|
|
166
|
+
Convert the given cross-table to a conditional probability table (CPT),
|
|
167
|
+
assuming the first random variable of the cross-table is the child
|
|
168
|
+
and remaining random variables are the parents.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
crosstab: a CrossTable representing the weight of unique instances.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A cross-table that is a conditional probability table.
|
|
175
|
+
A cross-table of the parent sums that were divided out of `crosstab`
|
|
176
|
+
|
|
177
|
+
Assumes:
|
|
178
|
+
the first random variable in `crosstab.rvs` is the child random variable.
|
|
179
|
+
"""
|
|
180
|
+
# Get the sum of weights for parent states
|
|
181
|
+
parent_sums: CrossTable = CrossTable(
|
|
182
|
+
rvs=crosstab.rvs[1:],
|
|
183
|
+
update=(
|
|
184
|
+
(instance[1:], weight)
|
|
185
|
+
for instance, weight in crosstab.items()
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Construct the normalised cross-tables, i.e., the CPTs.
|
|
190
|
+
cpt = CrossTable(
|
|
191
|
+
rvs=crosstab.rvs,
|
|
192
|
+
update=(
|
|
193
|
+
(instance, weight / parent_sums[instance[1:]])
|
|
194
|
+
for instance, weight in crosstab.items()
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return cpt, parent_sums
|
ck/pgm.py
CHANGED
|
@@ -596,9 +596,11 @@ class PGM:
|
|
|
596
596
|
|
|
597
597
|
# Factors form a DAG
|
|
598
598
|
states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
599
|
+
if any(
|
|
600
|
+
self._has_cycle(factor, child_to_factor, states)
|
|
601
|
+
for factor in self._factors
|
|
602
|
+
):
|
|
603
|
+
return False
|
|
602
604
|
|
|
603
605
|
# All tests passed
|
|
604
606
|
return True
|
|
@@ -778,7 +780,7 @@ class PGM:
|
|
|
778
780
|
next_prefix: str = prefix + indent
|
|
779
781
|
next_next_prefix: str = next_prefix + indent
|
|
780
782
|
|
|
781
|
-
print(f'{prefix}PGM id={id(self)}
|
|
783
|
+
print(f'{prefix}PGM id={id(self)}')
|
|
782
784
|
self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
|
|
783
785
|
|
|
784
786
|
print(f'{prefix}random variables ({self.number_of_rvs})')
|
|
@@ -792,16 +794,16 @@ class PGM:
|
|
|
792
794
|
|
|
793
795
|
print(f'{prefix}factors ({self.number_of_factors})')
|
|
794
796
|
for factor in self.factors:
|
|
795
|
-
|
|
797
|
+
factor_rvs = ', '.join(repr(rv.name) for rv in factor.rvs)
|
|
796
798
|
if factor.is_zero:
|
|
797
|
-
function_ref = '<
|
|
799
|
+
function_ref = '<ZeroPotentialFunction>'
|
|
798
800
|
else:
|
|
799
801
|
function = factor.function
|
|
800
802
|
function_ref = f'{id(function)}: {function.__class__.__name__}'
|
|
801
803
|
|
|
802
|
-
print(f'{next_prefix}{factor.idx:>3} rvs={
|
|
804
|
+
print(f'{next_prefix}{factor.idx:>3} rvs=({factor_rvs}) function={function_ref}')
|
|
803
805
|
|
|
804
|
-
print(f'{prefix}functions ({self.
|
|
806
|
+
print(f'{prefix}functions, excluding ZeroPotentialFunction ({sum(1 for _ in self.non_zero_functions)})')
|
|
805
807
|
for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
|
|
806
808
|
print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
|
|
807
809
|
function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
|
|
@@ -308,6 +308,11 @@ class MarginalsProgram(ProgramWithSlotmap, ProbabilitySpace):
|
|
|
308
308
|
The sampler will yield state lists, where the state
|
|
309
309
|
values are co-indexed with rvs, or self.rvs if rvs is None.
|
|
310
310
|
|
|
311
|
+
For more information about this sampler, see the publication:
|
|
312
|
+
Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
|
|
313
|
+
Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
|
|
314
|
+
Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
|
|
315
|
+
|
|
311
316
|
Args:
|
|
312
317
|
rvs: the list of random variables to sample; the
|
|
313
318
|
yielded state vectors are co-indexed with rvs; if None,
|
ck/pgm_circuit/wmc_program.py
CHANGED
|
@@ -132,6 +132,11 @@ class WMCProgram(ProgramWithSlotmap, ProbabilitySpace):
|
|
|
132
132
|
* calls rand.random() once and rand.randrange(...) n times,
|
|
133
133
|
* calls self.program().compute_result() at least once and <= 1 + m.
|
|
134
134
|
|
|
135
|
+
For more information about this sampler, see the publication:
|
|
136
|
+
Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
|
|
137
|
+
Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
|
|
138
|
+
Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
|
|
139
|
+
|
|
135
140
|
Args:
|
|
136
141
|
rvs: the list of random variables to sample; the
|
|
137
142
|
yielded state vectors are co-indexed with rvs; if None,
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
"/O2"
|
|
14
14
|
],
|
|
15
15
|
"include_dirs": [
|
|
16
|
-
"C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-
|
|
16
|
+
"C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-75r5_lk9\\Lib\\site-packages\\numpy\\_core\\include"
|
|
17
17
|
],
|
|
18
18
|
"name": "ck.pgm_compiler.support.circuit_table._circuit_table_cy",
|
|
19
19
|
"sources": [
|
|
Binary file
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module implements several divergences which measure the difference
|
|
3
|
+
between two distributions.
|
|
4
|
+
"""
|
|
5
|
+
import math
|
|
6
|
+
from typing import Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from ck.pgm import RandomVariable, rv_instances_as_indicators, PGM
|
|
11
|
+
from ck.probability.probability_space import ProbabilitySpace
|
|
12
|
+
|
|
13
|
+
_NAN: float = np.nan # Not-a-number (i.e., the result of an invalid calculation).
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def kl(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
|
|
17
|
+
"""
|
|
18
|
+
Compute the Kullback-Leibler divergence between p & q,
|
|
19
|
+
where p is the true distribution.
|
|
20
|
+
|
|
21
|
+
This implementation uses logarithms, base 2.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
p: a probability space to compare to.
|
|
25
|
+
q: the other probability space.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
the Kullback–Leibler (KL) divergence of p & q, where p is
|
|
29
|
+
the true distribution.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: if `p` and `q` do not have compatible random variables.specifically:
|
|
33
|
+
* `len(self.rvs) == len(other.rvs)`
|
|
34
|
+
* `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
|
|
35
|
+
* `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
|
|
36
|
+
|
|
37
|
+
Warning:
|
|
38
|
+
this method will enumerate the whole probability space.
|
|
39
|
+
"""
|
|
40
|
+
if not _compatible_rvs(p.rvs, q.rvs):
|
|
41
|
+
raise ValueError('incompatible random variables')
|
|
42
|
+
|
|
43
|
+
total = 0.0
|
|
44
|
+
for x in rv_instances_as_indicators(*p.rvs):
|
|
45
|
+
p_x = p.probability(*x)
|
|
46
|
+
q_x = q.probability(*x)
|
|
47
|
+
if p_x <= 0 or q_x <= 0:
|
|
48
|
+
return _NAN
|
|
49
|
+
total += p_x * math.log2(p_x / q_x)
|
|
50
|
+
return total
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def pseudo_kl(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
|
|
54
|
+
"""
|
|
55
|
+
A kind of KL divergence, factored by the structure of `p`.
|
|
56
|
+
This is an experimental measure.
|
|
57
|
+
|
|
58
|
+
This implementation uses logarithms, base 2.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
p: a probability space to compare to.
|
|
62
|
+
q: the other probability space.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
the factored histogram intersection between the two probability spaces.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: if `p` and `q` do not have compatible random variables.specifically:
|
|
69
|
+
* `len(self.rvs) == len(other.rvs)`
|
|
70
|
+
* `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
|
|
71
|
+
* `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
|
|
72
|
+
ValueError: if not all random variable of `p` are from a single PGM, which must
|
|
73
|
+
have a Bayesian network structure.
|
|
74
|
+
"""
|
|
75
|
+
p_rvs: Sequence[RandomVariable] = p.rvs
|
|
76
|
+
q_rvs: Sequence[RandomVariable] = q.rvs
|
|
77
|
+
|
|
78
|
+
if not _compatible_rvs(p_rvs, q_rvs):
|
|
79
|
+
raise ValueError('incompatible random variables')
|
|
80
|
+
|
|
81
|
+
if len(p_rvs) == 0:
|
|
82
|
+
return _NAN
|
|
83
|
+
|
|
84
|
+
pgm: PGM = p_rvs[0].pgm
|
|
85
|
+
if any(rv.pgm is not pgm for rv in p_rvs):
|
|
86
|
+
raise ValueError('p random variables are not from a single PGM.')
|
|
87
|
+
if not pgm.is_structure_bayesian:
|
|
88
|
+
raise ValueError('p does not have Bayesian network structure.')
|
|
89
|
+
|
|
90
|
+
# Across the two spaces, corresponding random variables are equivalent;
|
|
91
|
+
# i.e., same number of states and same `idx` values. Therefore,
|
|
92
|
+
# indicators from either one space can be used in both spaces.
|
|
93
|
+
|
|
94
|
+
total: float = 0
|
|
95
|
+
for factor in pgm.factors:
|
|
96
|
+
for x in rv_instances_as_indicators(*factor.rvs): # every possible state of factor rvs
|
|
97
|
+
p_x = p.probability(*x)
|
|
98
|
+
q_x = q.probability(*x)
|
|
99
|
+
if p_x <= 0 or q_x <= 0:
|
|
100
|
+
return _NAN
|
|
101
|
+
total += p_x * math.log2(p_x / q_x)
|
|
102
|
+
return total
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def hi(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
|
|
106
|
+
"""
|
|
107
|
+
Compute the histogram intersection between this probability spaces and the given other.
|
|
108
|
+
|
|
109
|
+
The histogram intersection between two probability spaces P and Q,
|
|
110
|
+
with state spaces X, is defined as:
|
|
111
|
+
HI(P, Q) = sum(min(P(x), Q(x)) for x in X)
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
p: a probability space to compare to.
|
|
115
|
+
q: the other probability space.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
the histogram intersection between the two probability spaces.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: if `p` and `q` do not have compatible random variables.specifically:
|
|
122
|
+
* `len(self.rvs) == len(other.rvs)`
|
|
123
|
+
* `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
|
|
124
|
+
* `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
|
|
125
|
+
|
|
126
|
+
Warning:
|
|
127
|
+
this method will enumerate the whole probability space.
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
p_rvs: Sequence[RandomVariable] = p.rvs
|
|
131
|
+
q_rvs: Sequence[RandomVariable] = q.rvs
|
|
132
|
+
|
|
133
|
+
if not _compatible_rvs(p_rvs, q_rvs):
|
|
134
|
+
raise ValueError('incompatible random variables')
|
|
135
|
+
|
|
136
|
+
# Across the two spaces, corresponding random variables are equivalent;
|
|
137
|
+
# i.e., same number of states and same `idx` values. Therefore,
|
|
138
|
+
# indicators from either one space can be used in both spaces.
|
|
139
|
+
|
|
140
|
+
return sum(
|
|
141
|
+
min(p.probability(*x), q.probability(*x))
|
|
142
|
+
for x in rv_instances_as_indicators(*p_rvs)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def fhi(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
|
|
147
|
+
"""
|
|
148
|
+
Compute the factored histogram intersection between this probability spaces and the given other.
|
|
149
|
+
|
|
150
|
+
The factored histogram intersection between two probability spaces P and Q,
|
|
151
|
+
with state spaces X and factorisation F, is defined as:
|
|
152
|
+
FHI(P, Q) = 1/n sum(P(Y=y) CHI(P, Q, X | Y=y)
|
|
153
|
+
where:
|
|
154
|
+
CHI(P, Q, X | Y=y) = HI(P(X | Y=y), Q(X | Y=y))
|
|
155
|
+
HI(P, Q) = sum(min(P(X=x), Q(X=x)) for x in f)
|
|
156
|
+
|
|
157
|
+
The value of _n_ is the sum ofP(Y=y) over all CPT rows. However,
|
|
158
|
+
this always equals the number of CPTs, i.e., the number of random
|
|
159
|
+
variables.
|
|
160
|
+
|
|
161
|
+
The factorisation F is taken from the `p`.
|
|
162
|
+
|
|
163
|
+
For more information about factored histogram intersection, see the publication:
|
|
164
|
+
Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
|
|
165
|
+
Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
|
|
166
|
+
Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
p: a probability space to compare to.
|
|
170
|
+
q: the other probability space.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
the factored histogram intersection between the two probability spaces.
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
ValueError: if `p` and `q` do not have compatible random variables.specifically:
|
|
177
|
+
* `len(self.rvs) == len(other.rvs)`
|
|
178
|
+
* `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
|
|
179
|
+
* `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
|
|
180
|
+
ValueError: if not all random variable of `p` are from a single PGM, which must
|
|
181
|
+
have a Bayesian network structure.
|
|
182
|
+
"""
|
|
183
|
+
p_rvs: Sequence[RandomVariable] = p.rvs
|
|
184
|
+
q_rvs: Sequence[RandomVariable] = q.rvs
|
|
185
|
+
|
|
186
|
+
if not _compatible_rvs(p_rvs, q_rvs):
|
|
187
|
+
raise ValueError('incompatible random variables')
|
|
188
|
+
|
|
189
|
+
if len(p_rvs) == 0:
|
|
190
|
+
return 0
|
|
191
|
+
|
|
192
|
+
pgm: PGM = p_rvs[0].pgm
|
|
193
|
+
if any(rv.pgm is not pgm for rv in p_rvs):
|
|
194
|
+
raise ValueError('p random variables are not from a single PGM.')
|
|
195
|
+
if not pgm.is_structure_bayesian:
|
|
196
|
+
raise ValueError('p does not have Bayesian network structure.')
|
|
197
|
+
|
|
198
|
+
# Across the two spaces, corresponding random variables are equivalent;
|
|
199
|
+
# i.e., same number of states and same `idx` values. Therefore,
|
|
200
|
+
# indicators from either one space can be used in both spaces.
|
|
201
|
+
|
|
202
|
+
# Loop over all CPTs, accumulating the total
|
|
203
|
+
total: float = 0
|
|
204
|
+
for factor in pgm.factors:
|
|
205
|
+
child: RandomVariable = factor.rvs[0]
|
|
206
|
+
parents: Sequence[RandomVariable] = factor.rvs[1:]
|
|
207
|
+
# Loop over all rows of the CPT
|
|
208
|
+
for parent_indicators in rv_instances_as_indicators(*parents):
|
|
209
|
+
p_marginal = p.marginal_distribution(child, condition=parent_indicators)
|
|
210
|
+
q_marginal = q.marginal_distribution(child, condition=parent_indicators)
|
|
211
|
+
row_hi = np.minimum(p_marginal, q_marginal).sum().item()
|
|
212
|
+
pr_row = p.probability(*parent_indicators)
|
|
213
|
+
total += pr_row * row_hi
|
|
214
|
+
|
|
215
|
+
return total / len(p_rvs)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _compatible_rvs(rvs1: Sequence[RandomVariable], rvs2: Sequence[RandomVariable]) -> bool:
|
|
219
|
+
"""
|
|
220
|
+
The rvs are compatible if they have the same number of random variables
|
|
221
|
+
and the corresponding indicators are equal.
|
|
222
|
+
"""
|
|
223
|
+
return (
|
|
224
|
+
len(rvs1) == len(rvs2)
|
|
225
|
+
and all(len(rv1) == len(rv2) and rv1.idx == rv2.idx for rv1, rv2 in zip(rvs1, rvs2))
|
|
226
|
+
)
|